Python源码示例:torch.argmin()

示例1
def kmeans(input, n_clusters=16, tol=1e-6):
  """
  TODO: check correctness
  """
  indices = torch.Tensor(np.random.choice(input.size(-1), n_clusters))
  values = input[:, :, indices]

  while True:
    dist = func.pairwise_distance(
      input.unsqueeze(2).expand(-1, -1, values.size(2), input.size(2)).reshape(
        input.size(0), input.size(1), input.size(2) * values.size(2)),
      values.unsqueeze(3).expand(-1, -1, values.size(2), input.size(2)).reshape(
        input.size(0), input.size(1), input.size(2) * values.size(2))
    )
    choice_cluster = torch.argmin(dist, dim=1)
    old_values = values
    values = input[choice_cluster.nonzeros()]
    shift = (old_values - values).norm(dim=1)
    if shift.max() ** 2 < tol:
      break

  return values 
示例2
def forward(self, grammian):
        """Planar case solver, when Vi lies on the same plane

        Args:
          grammian: grammian matrix G[i, j] = [<Vi, Vj>], G is a nxn tensor

        Returns:
          sol: coefficients c = [c1, ... cn] that solves the min-norm problem
        """
        vivj = grammian[self.ii_triu, self.jj_triu]
        vivi = grammian[self.ii_triu, self.ii_triu]
        vjvj = grammian[self.jj_triu, self.jj_triu]

        gamma, cost = self.line_solver_vectorized(vivi, vivj, vjvj)
        offset = torch.argmin(cost)
        i_min, j_min = self.i_triu[offset], self.j_triu[offset]
        sol = torch.zeros(self.n, device=grammian.device)
        sol[i_min], sol[j_min] = gamma[offset], 1. - gamma[offset]
        return sol 
示例3
def forward(self, XS, YS, XQ, YQ):
        '''
            @param XS (support x): support_size x ebd_dim
            @param YS (support y): support_size
            @param XQ (support x): query_size x ebd_dim
            @param YQ (support y): query_size

            @return acc
            @return None (a placeholder for loss)
        '''
        if self.args.nn_distance == 'l2':
            dist = self._compute_l2(XS, XQ)
        elif self.args.nn_distance == 'cos':
            dist = self._compute_cos(XS, XQ)
        else:
            raise ValueError("nn_distance can only be l2 or cos.")

        # 1-NearestNeighbour
        nn_idx = torch.argmin(dist, dim=1)
        pred = YS[nn_idx]

        acc = torch.mean((pred == YQ).float()).item()

        return acc, None 
示例4
def forward(self, coord_volumes_batch, volumes_batch_pred, keypoints_gt, keypoints_binary_validity):
        loss = 0.0
        n_losses = 0

        batch_size = volumes_batch_pred.shape[0]
        for batch_i in range(batch_size):
            coord_volume = coord_volumes_batch[batch_i]
            keypoints_gt_i = keypoints_gt[batch_i]

            coord_volume_unsq = coord_volume.unsqueeze(0)
            keypoints_gt_i_unsq = keypoints_gt_i.unsqueeze(1).unsqueeze(1).unsqueeze(1)

            dists = torch.sqrt(((coord_volume_unsq - keypoints_gt_i_unsq) ** 2).sum(-1))
            dists = dists.view(dists.shape[0], -1)

            min_indexes = torch.argmin(dists, dim=-1).detach().cpu().numpy()
            min_indexes = np.stack(np.unravel_index(min_indexes, volumes_batch_pred.shape[-3:]), axis=1)

            for joint_i, index in enumerate(min_indexes):
                validity = keypoints_binary_validity[batch_i, joint_i]
                loss += validity[0] * (-torch.log(volumes_batch_pred[batch_i, joint_i, index[0], index[1], index[2]] + 1e-6))
                n_losses += 1


        return loss / n_losses 
示例5
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        print('Distances:', distances[:3])
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
示例6
def assign(self, points, distance='euclid', greedy=False):
        centroids = F.dropout(self.centroids, p=0.3)
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        print('Distances:', distances[:3])
        if not greedy:
            resp = - .5 * self.tau * distances - self.reduce_dim / 2 * math.log(2 * math.pi * self.tau) + torch.log(self.prior)
        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
示例7
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
示例8
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data  # the only diff from 16
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
示例9
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
示例10
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data  # the only diff from 16
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
示例11
def calculate_partitions(partitions_count, cluster_partitions, types):
    partition_distribution = torch.ones((partitions_count,
                                         len(torch.unique(types))),
                                        dtype=torch.long)
    partition_assignments = torch.zeros(cluster_partitions.shape[0],
                                        dtype=torch.long)

    for i in torch.unique(cluster_partitions):
        cluster_positions = (cluster_partitions == i).nonzero()
        cluster_types = types[cluster_positions]
        unique_types_in_cluster, type_count = torch.unique(cluster_types, return_counts=True)
        tmp_distribution = partition_distribution.clone()
        tmp_distribution[:, unique_types_in_cluster] += type_count
        relative_distribution = partition_distribution.double() / tmp_distribution.double()
        min_relative_distribution_group = torch.argmin(torch.sum(relative_distribution, dim=1))
        partition_distribution[min_relative_distribution_group,
                               unique_types_in_cluster] += type_count
        partition_assignments[cluster_positions] = min_relative_distribution_group

    write_out("Loaded data into the following partitions")
    write_out("[[  TM  SP+TM  SP Glob]")
    write_out(partition_distribution - torch.ones(partition_distribution.shape,
                                                  dtype=torch.long))
    return partition_assignments 
示例12
def _nnef_argminmax_reduce(input, axes, argmin=False):
    # type:(torch.Tensor, List[int], bool)->torch.Tensor
    if len(axes) == 1:
        return _nnef_generic_reduce(input=input, axes=axes, f=torch.argmin if argmin else torch.argmax)
    else:
        axes = sorted(axes)
        consecutive_axes = list(range(axes[0], axes[0] + len(axes)))
        if axes == consecutive_axes:
            reshaped = nnef_reshape(input,
                                    shape=(list(input.shape)[:axes[0]]
                                           + [-1]
                                           + list(input.shape[axes[0] + len(axes):])))
            reduced = _nnef_generic_reduce(input=reshaped, axes=[axes[0]], f=torch.argmin if argmin else torch.argmax)
            reshaped = nnef_reshape(reduced, shape=list(dim if axis not in axes else 1
                                                        for axis, dim in enumerate(input.shape)))
            return reshaped
        else:
            raise utils.NNEFToolsException(
                "{} is only implemented for consecutive axes.".format("argmin_reduce" if argmin else "argmax_reduce")) 
示例13
def step(self, i):
        """
        There are two standard steps for each iteration: expectation (E) and
        minimization (M). The E-step (assignment) is performed with an exhaustive
        search and the M-step (centroid computation) is performed with
        the exact solution.

        Args:
            - i: step number

        Remarks:
            - The E-step heavily uses PyTorch broadcasting to speed up computations
              and reduce the memory overhead
        """

        # assignments (E-step)
        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features)
        n_empty_clusters = self.resolve_empty_clusters()

        # centroids (M-step)
        for k in range(self.n_centroids):
            W_k = self.W[:, self.assignments == k]  # (in_features x size_of_cluster_k)
            self.centroids[k] = W_k.mean(dim=1)  # (in_features)

        # book-keeping
        obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
        self.objective.append(obj)
        if self.verbose:
            logging.info(
                f"Iteration: {i},\t"
                f"objective: {obj:.6f},\t"
                f"resolved empty clusters: {n_empty_clusters}"
            ) 
示例14
def resolve_empty_clusters(self):
        """
        If one cluster is empty, the most populated cluster is split into
        two clusters by shifting the respective centroids. This is done
        iteratively for a fixed number of tentatives.
        """

        # empty clusters
        counts = Counter(map(lambda x: x.item(), self.assignments))
        empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
        n_empty_clusters = len(empty_clusters)

        tentatives = 0
        while len(empty_clusters) > 0:
            # given an empty cluster, find most populated cluster and split it into two
            k = random.choice(list(empty_clusters))
            m = counts.most_common(1)[0][0]
            e = torch.randn_like(self.centroids[m]) * self.eps
            self.centroids[k] = self.centroids[m].clone()
            self.centroids[k] += e
            self.centroids[m] -= e

            # recompute assignments
            distances = self.compute_distances()  # (n_centroids x out_features)
            self.assignments = torch.argmin(distances, dim=0)  # (out_features)

            # check for empty clusters
            counts = Counter(map(lambda x: x.item(), self.assignments))
            empty_clusters = set(range(self.n_centroids)) - set(counts.keys())

            # increment tentatives
            if tentatives == self.max_tentatives:
                logging.info(
                    f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
                )
                raise EmptyClusterResolveError
            tentatives += 1

        return n_empty_clusters 
示例15
def assign(self):
        """
        Assigns each column of W to its closest centroid, thus essentially
        performing the E-step in train().

        Remarks:
            - The function must be called after train() or after loading
              centroids using self.load(), otherwise it will return empty tensors
        """

        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features) 
示例16
def append(self, new_k, new_v):
        min_idx = torch.argmin(self.weight_buff).item()
        self.keys[min_idx, :] = new_k
        self.values[min_idx, :] = new_v
        self.weight_buff[min_idx] = torch.mean(self.weight_buff) 
示例17
def append(self, new_k, new_v):
        """
        :param new_k: expecting a vector of dimensionality [Num Key Chan]
        :param new_v: expecting a vector of dimensionality [Num Value Chan]
        :return:
        """
        min_idx = torch.argmin(self.weight_buff).item()
        self.keys[min_idx, :] = new_k
        self.values[min_idx, :] = new_v
        self.weight_buff[min_idx] = torch.mean(self.weight_buff) 
示例18
def forward(self, vecs):
        """Computes grammian matrix G_{i,j} = (<v_i, v_j>)_{i,j}.
        """
        if self.n_tasks == 1:
            return vecs[0]
        if self.n_tasks == 2:
            v1v1 = torch.dot(vecs[0], vecs[0])
            v1v2 = torch.dot(vecs[0], vecs[1])
            v2v2 = torch.dot(vecs[1], vecs[1])
            gamma = self.line_solver(v1v1, v1v2, v2v2)
            return gamma * vecs[0] + (1. - gamma) * vecs[1]

        self.sol.fill_(1. / self.n)
        self.new_sol.copy_(self.sol)
        torch.mm(vecs, vecs.t(), out=self.grammian)

        for iter_count in range(self.MAX_ITER):
            gram_dot_sol = torch.mv(self.grammian, self.sol)
            t_iter = torch.argmin(gram_dot_sol)

            v1v1 = torch.dot(self.sol, gram_dot_sol)
            v1v2 = torch.dot(self.sol, self.grammian[:, t_iter])
            v2v2 = self.grammian[t_iter, t_iter]

            gamma = self.line_solver(v1v1, v1v2, v2v2)
            self.new_sol *= gamma
            self.new_sol[t_iter] += 1. - gamma

            change = self.new_sol - self.sol
            if torch.sum(torch.abs(change)) < self.STOP_CRIT:
                return self.new_sol
            self.sol.copy_(self.new_sol)
        return self.sol 
示例19
def _proj_val(x, set):
    """
    Compute the projection from x onto the set given.

    :param x: Input pytorch tensor.
    :param set: Input pytorch vector used to perform the projection.
    """
    x = x.repeat((set.size()[0],)+(1,)*len(x.size()))
    x = x.permute(*(tuple(range(len(x.size())))[1:]  +(0,) ))
    x = torch.abs(x-set)
    x = torch.argmin(x, dim=len(x.size())-1, keepdim=False)
    return set[x] 
示例20
def least_used_cuda_device() -> Generator:
    """Contextmanager for automatically selecting the cuda device
    with the least allocated memory"""
    mem_allocs = get_cuda_max_memory_allocations()
    least_used_device = torch.argmin(mem_allocs).item()
    with torch.cuda.device(least_used_device):
        yield 
示例21
def subgraph_filter(x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, args):
    D = sqdist(x_atom_pos[:,:,:3], x_atom_pos[:,:,:3])
    x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle = \
        x_atom.clone().detach(), x_atom_pos.clone().detach(), x_bond.clone().detach(), x_bond_dist.clone().detach(), x_triplet.clone().detach(), x_triplet_angle.clone().detach()
    bsz = x_atom.shape[0]
    bonds_mask = torch.ones(bsz, x_bond.shape[1], 1).to(x_atom.device)
    for mol_id in range(bsz):
        if np.random.uniform(0,1) > args.cutout:
            continue
        assert not args.use_quad, "Quads are NOT cut out yet"
        atom_dists = D[mol_id]
        atoms = x_atom[mol_id, :, 0]
        n_valid_atoms = (atoms > 0).sum().item()
        if n_valid_atoms < 10:
            continue
        idx_to_drop = np.random.randint(n_valid_atoms-1)
        dist_row = atom_dists[idx_to_drop]
        neighbor_to_drop = torch.argmin((dist_row[dist_row>0])[:n_valid_atoms-1]).item()
        if neighbor_to_drop >= idx_to_drop: 
            neighbor_to_drop += 1
        x_atom[mol_id, idx_to_drop] = 0
        x_atom[mol_id, neighbor_to_drop] = 0
        x_atom_pos[mol_id, idx_to_drop] = 0
        x_atom_pos[mol_id, neighbor_to_drop] = 0
        bond_pos_to_drop = (x_bond[mol_id, :, 3] == idx_to_drop) | (x_bond[mol_id, :, 3] == neighbor_to_drop) \
                         | (x_bond[mol_id, :, 4] == idx_to_drop) | (x_bond[mol_id, :, 4] == neighbor_to_drop)
        trip_pos_to_drop = (x_triplet[mol_id, :, 2] == idx_to_drop) | (x_triplet[mol_id, :, 2] == neighbor_to_drop) \
                         | (x_triplet[mol_id, :, 3] == idx_to_drop) | (x_triplet[mol_id, :, 3] == neighbor_to_drop) \
                         | (x_triplet[mol_id, :, 4] == idx_to_drop) | (x_triplet[mol_id, :, 4] == neighbor_to_drop)
        x_bond[mol_id, bond_pos_to_drop] = 0
        x_bond_dist[mol_id, bond_pos_to_drop] = 0
        bonds_mask[mol_id, bond_pos_to_drop] = 0
        x_triplet[mol_id, trip_pos_to_drop] = 0
        x_triplet_angle[mol_id, trip_pos_to_drop] = 0
    return x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, bonds_mask 
示例22
def step(self, i):
        """
        There are two standard steps for each iteration: expectation (E) and
        minimization (M). The E-step (assignment) is performed with an exhaustive
        search and the M-step (centroid computation) is performed with
        the exact solution.

        Args:
            - i: step number

        Remarks:
            - The E-step heavily uses PyTorch broadcasting to speed up computations
              and reduce the memory overhead
        """

        # assignments (E-step)
        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features)
        n_empty_clusters = self.resolve_empty_clusters()

        # centroids (M-step)
        for k in range(self.n_centroids):
            W_k = self.W[:, self.assignments == k]  # (in_features x size_of_cluster_k)
            self.centroids[k] = W_k.mean(dim=1)  # (in_features)

        # book-keeping
        obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
        self.objective.append(obj)
        if self.verbose:
            logging.info(
                f"Iteration: {i},\t"
                f"objective: {obj:.6f},\t"
                f"resolved empty clusters: {n_empty_clusters}"
            ) 
示例23
def resolve_empty_clusters(self):
        """
        If one cluster is empty, the most populated cluster is split into
        two clusters by shifting the respective centroids. This is done
        iteratively for a fixed number of tentatives.
        """

        # empty clusters
        counts = Counter(map(lambda x: x.item(), self.assignments))
        empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
        n_empty_clusters = len(empty_clusters)

        tentatives = 0
        while len(empty_clusters) > 0:
            # given an empty cluster, find most populated cluster and split it into two
            k = random.choice(list(empty_clusters))
            m = counts.most_common(1)[0][0]
            e = torch.randn_like(self.centroids[m]) * self.eps
            self.centroids[k] = self.centroids[m].clone()
            self.centroids[k] += e
            self.centroids[m] -= e

            # recompute assignments
            distances = self.compute_distances()  # (n_centroids x out_features)
            self.assignments = torch.argmin(distances, dim=0)  # (out_features)

            # check for empty clusters
            counts = Counter(map(lambda x: x.item(), self.assignments))
            empty_clusters = set(range(self.n_centroids)) - set(counts.keys())

            # increment tentatives
            if tentatives == self.max_tentatives:
                logging.info(
                    f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
                )
                raise EmptyClusterResolveError
            tentatives += 1

        return n_empty_clusters 
示例24
def assign(self):
        """
        Assigns each column of W to its closest centroid, thus essentially
        performing the E-step in train().

        Remarks:
            - The function must be called after train() or after loading
              centroids using self.load(), otherwise it will return empty tensors
        """

        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features) 
示例25
def unravel_index(tensor, cols):
        """
        args:
            tensor : 2D tensor, [nb, rows*cols]
            cols : int
        return 2D tensor nb * [rowIndex, colIndex]
        """
        index = torch.argmin(tensor, dim=1).view(-1,1)
        rIndex = index / cols
        cIndex = index % cols
        minRC = torch.cat([rIndex, cIndex], dim=1)
        # print("minRC", minRC.shape, minRC)
        return minRC 
示例26
def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(), torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(), torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out) 
示例27
def nnef_argmax_reduce(input, axes):
    # type:(torch.Tensor, List[int])->torch.Tensor
    return _nnef_argminmax_reduce(input, axes, argmin=False) 
示例28
def nnef_argmin_reduce(input, axes):
    # type:(torch.Tensor, List[int])->torch.Tensor
    return _nnef_argminmax_reduce(input, axes, argmin=True) 
示例29
def forward(self, segmentation, prob, gt_instance, gt_plane_num):
        """
        greedy matching
        match segmentation with ground truth instance 
        :param segmentation: tensor with size (N, K)
        :param prob: tensor with size (N, 1)
        :param gt_instance: tensor with size (21, h, w)
        :param gt_plane_num: int
        :return: a (K, 1) long tensor indicate closest ground truth instance id, start from 0
        """

        n, k = segmentation.size()
        _, h, w = gt_instance.size()
        assert (prob.size(0) == n and h*w  == n)
        
        # ingnore non planar region
        gt_instance = gt_instance[:gt_plane_num, :, :].view(1, -1, h*w)     # (1, gt_plane_num, h*w)

        segmentation = segmentation.t().view(k, 1, h*w)                     # (k, 1, h*w)

        # calculate instance wise cross entropy matrix (K, gt_plane_num)
        gt_instance = gt_instance.type(torch.float32)

        ce_loss = - (gt_instance * torch.log(segmentation + 1e-6) +
            (1-gt_instance) * torch.log(1-segmentation + 1e-6))             # (k, gt_plane_num, k*w)

        ce_loss = torch.mean(ce_loss, dim=2)                                # (k, gt_plane_num)
        
        matching = torch.argmin(ce_loss, dim=1, keepdim=True)

        return matching 
示例30
def choose_best_match_batch(Rrois, gt_rois):
    """
    choose best match representation of gt_rois for a Rrois
    :param Rrois: (x_ctr, y_ctr, w, h, angle)
            shape: (n, 5)
    :param gt_rois: (x_ctr, y_ctr, w, h, angle)
            shape: (n, 5)
    :return: gt_roi_news: gt_roi with new representation
            shape: (n, 5)
    """
    # TODO: check the dimensions
    Rroi_angles = Rrois[:, 4].unsqueeze(1)

    gt_xs, gt_ys, gt_ws, gt_hs, gt_angles = copy.deepcopy(gt_rois[:, 0]), copy.deepcopy(gt_rois[:, 1]), \
                                            copy.deepcopy(gt_rois[:, 2]), copy.deepcopy(gt_rois[:, 3]), \
                                            copy.deepcopy(gt_rois[:, 4])

    gt_angle_extent = torch.cat((gt_angles[:, np.newaxis], (gt_angles + np.pi/2.)[:, np.newaxis],
                                      (gt_angles + np.pi)[:, np.newaxis], (gt_angles + np.pi * 3/2.)[:, np.newaxis]), 1)
    dist = (Rroi_angles - gt_angle_extent) % (2 * np.pi)
    dist = torch.min(dist, np.pi * 2 - dist)
    min_index = torch.argmin(dist, 1)

    gt_rois_extent0 = copy.deepcopy(gt_rois)
    gt_rois_extent1 = torch.cat((gt_xs.unsqueeze(1), gt_ys.unsqueeze(1), \
                                 gt_hs.unsqueeze(1), gt_ws.unsqueeze(1), gt_angles.unsqueeze(1) + np.pi/2.), 1)
    gt_rois_extent2 = torch.cat((gt_xs.unsqueeze(1), gt_ys.unsqueeze(1), \
                                 gt_ws.unsqueeze(1), gt_hs.unsqueeze(1), gt_angles.unsqueeze(1) + np.pi), 1)
    gt_rois_extent3 = torch.cat((gt_xs.unsqueeze(1), gt_ys.unsqueeze(1), \
                                 gt_hs.unsqueeze(1), gt_ws.unsqueeze(1), gt_angles.unsqueeze(1) + np.pi * 3/2.), 1)
    gt_rois_extent = torch.cat((gt_rois_extent0.unsqueeze(1),
                                     gt_rois_extent1.unsqueeze(1),
                                     gt_rois_extent2.unsqueeze(1),
                                     gt_rois_extent3.unsqueeze(1)), 1)

    gt_rois_new = torch.zeros_like(gt_rois)
    # TODO: add pool.map here
    for curiter, index in enumerate(min_index):
        gt_rois_new[curiter, :] = gt_rois_extent[curiter, index, :]

    gt_rois_new[:, 4] = gt_rois_new[:, 4] % (2 * np.pi)

    return gt_rois_new