Python源码示例:torch.repeat_interleave()

示例1
def hard_k_hot(logits, k, temperature=0.1):
  r"""Returns a hard k-hot sample given a categorical
  distribution defined by a tensor of unnormalized
  log-likelihoods.

  This is useful for example to sample a set of pixels in an
  image to move from a grid-structured data representation to a
  set- or graph-structured representation within a network.

  Args:
    logits (torch.Tensor): unnormalized log-likelihood tensor.
    k (int): number of items to sample without replacement.
    temperature (float): temparature of the soft distribution.

  Returns:
    Hard k-hot vector from the relaxed k-hot distribution
    defined by logits and temperature.
  """
  soft = soft_k_hot(logits, k, temperature=temperature)
  hard = torch.zeros_like(soft)
  _, top_k = torch.topk(logits, k)
  index = torch.repeat_interleave(torch.arange(0, hard.size(0)), k)
  hard[index, top_k.view(-1)] = 1.0
  return replace_gradient(hard, soft) 
示例2
def pairwise_no_pad(op, data, indices):
  unique, counts = indices.unique(return_counts=True)
  expansion = torch.cumsum(counts, dim=0)
  expansion = torch.repeat_interleave(expansion, counts)
  offset = torch.arange(0, counts.sum(), device=data.device)
  expansion = expansion - offset - 1
  expanded = torch.repeat_interleave(data, expansion.to(data.device), dim=0)

  expansion_offset = counts.roll(1)
  expansion_offset[0] = 0
  expansion_offset = expansion_offset.cumsum(dim=0)
  expansion_offset = torch.repeat_interleave(expansion_offset, counts)
  expansion_offset = torch.repeat_interleave(expansion_offset, expansion)
  off_start = torch.repeat_interleave(torch.repeat_interleave(counts, counts) - expansion, expansion)
  access = torch.arange(expansion.sum(), device=data.device)
  access = access - torch.repeat_interleave(expansion.roll(1).cumsum(dim=0), expansion) + off_start + expansion_offset

  result = op(expanded, data[access.to(data.device)])
  return result, torch.repeat_interleave(indices, expansion, dim=0) 
示例3
def __patched_conv_ops(op, x, y, *args, **kwargs):
        x_encoded = CUDALongTensor.__encode_as_fp64(x).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y).data

        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(3, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)

        bs, c, *img = x.size()
        c_out, c_in, *ks = y.size()

        x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img)
        y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks)

        c_z = c_out if op in ["conv1d", "conv2d"] else c_in

        z_encoded = getattr(torch, op)(
            x_enc_span, y_enc_span, *args, **kwargs, groups=9
        )
        z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_(
            0, 1
        )

        return CUDALongTensor.__decode_as_int64(z_encoded) 
示例4
def get_tiled_batch(self, num_tiles: int):
        assert (
            self.has_float_features_only
        ), f"only works for float features now: {self}"
        """
        tiled_feature should be (batch_size * num_tiles, feature_dim)
        forall i in [batch_size],
        tiled_feature[i*num_tiles:(i+1)*num_tiles] should be feat[i]
        """
        feat = self.float_features
        assert (
            len(feat.shape) == 2
        ), f"Need feat shape to be (batch_size, feature_dim), got {feat.shape}."
        batch_size, _ = feat.shape
        # pyre-fixme[16]: `Tensor` has no attribute `repeat_interleave`.
        tiled_feat = feat.repeat_interleave(repeats=num_tiles, dim=0)
        return FeatureData(float_features=tiled_feat) 
示例5
def get_output(encoder_output, duration_predictor_output, alpha, mel_max_length=None):
        output = list()
        dec_pos = list()

        for i in range(encoder_output.size(0)):
            repeats = duration_predictor_output[i].float() * alpha
            repeats = torch.round(repeats).long()
            output.append(torch.repeat_interleave(encoder_output[i], repeats, dim=0))
            dec_pos.append(torch.from_numpy(np.indices((output[i].shape[0],))[0] + 1))

        output = torch.nn.utils.rnn.pad_sequence(output, batch_first=True)
        dec_pos = torch.nn.utils.rnn.pad_sequence(dec_pos, batch_first=True)

        dec_pos = dec_pos.to(output.device, non_blocking=True)

        if mel_max_length:
            output = output[:, :mel_max_length]
            dec_pos = dec_pos[:, :mel_max_length]

        return output, dec_pos 
示例6
def test_instance_norm():
    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))

    norm = InstanceNorm(16)
    assert norm.__repr__() == (
        'InstanceNorm(16, eps=1e-05, momentum=0.1, affine=False, '
        'track_running_stats=False)')
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16)

    norm = InstanceNorm(16, affine=True, track_running_stats=True)
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16)

    # Should behave equally to `BatchNorm` for mini-batches of size 1.
    x = torch.randn(100, 16)
    norm1 = InstanceNorm(16, affine=False, track_running_stats=False)
    norm2 = BatchNorm(16, affine=False, track_running_stats=False)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)

    norm1 = InstanceNorm(16, affine=False, track_running_stats=True)
    norm2 = BatchNorm(16, affine=False, track_running_stats=True)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
    assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
    assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6)
    assert torch.allclose(norm1.running_mean, norm2.running_mean, atol=1e-6)
    assert torch.allclose(norm1.running_var, norm2.running_var, atol=1e-6)
    norm1.eval()
    norm2.eval()
    assert torch.allclose(norm1(x), norm2(x), atol=1e-6) 
示例7
def test_graph_size_norm():
    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))
    norm = GraphSizeNorm()
    out = norm(torch.randn(100, 16), batch)
    assert out.size() == (100, 16) 
示例8
def forward(self, seq_value_len_list):
        if self.supports_masking:
            uiseq_embed_list, mask = seq_value_len_list  # [B, T, E], [B, 1]
            mask = mask.float()
            user_behavior_length = torch.sum(mask, dim=-1, keepdim=True)
            mask = mask.unsqueeze(2)
        else:
            uiseq_embed_list, user_behavior_length = seq_value_len_list  # [B, T, E], [B, 1]
            mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1],
                                       dtype=torch.float32)  # [B, 1, maxlen]
            mask = torch.transpose(mask, 1, 2)  # [B, maxlen, 1]

        embedding_size = uiseq_embed_list.shape[-1]

        mask = torch.repeat_interleave(mask, embedding_size, dim=2)  # [B, maxlen, E]

        if self.mode == 'max':
            hist = uiseq_embed_list - (1 - mask) * 1e9
            hist = torch.max(hist, dim=1, keepdim=True)[0]
            return hist
        hist = uiseq_embed_list * mask.float()
        hist = torch.sum(hist, dim=1, keepdim=False)

        if self.mode == 'mean':
            hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps)

        hist = torch.unsqueeze(hist, dim=1)
        return hist 
示例9
def repeat(input, repeats, dim):
    # return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
    if dim < 0:
        dim += input.dim()
    return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1) 
示例10
def sample(self, data):
    support, values = data
    mean, logvar = self.condition(support)
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    latent_sample = distribution.rsample()
    latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
    local_samples = torch.randn(support.size(0) * self.size, 16)
    sample = torch.cat((latent_sample, local_samples), dim=1)
    return (support, sample), (mean, logvar) 
示例11
def forward(self, data):
    support, values = data
    mean, logvar = self.encoder(support)
    distribution = Normal(mean, torch.exp(0.5 * logvar))
    latent_sample = distribution.rsample()
    latent_sample = torch.repeat_interleave(latent_sample, self.size, dim=0)
    combined = torch.cat((values.view(-1, 28 * 28), latent_sample), dim=1)
    return self.verdict(combined) 
示例12
def forward(self, image, condition):
    image = image.view(-1, 3, 64, 64)
    out = self.input_process(self.input(image))
    mean, logvar = self.condition(condition)
    #distribution = Normal(mean, torch.exp(0.5 * logvar))
    sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
    cond = self.postprocess(sample)
    cond = torch.repeat_interleave(cond, 5, dim=0)
    result = self.combine(torch.cat((out, cond), dim=1))
    return result, (mean, logvar) 
示例13
def forward(self, image, condition):
    image = image.view(-1, 28 * 28)
    out = self.input_process(self.input(image))
    mean, logvar = self.condition(condition)
    #distribution = Normal(mean, torch.exp(0.5 * logvar))
    sample = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)#distribution.rsample()
    cond = self.postprocess(sample)
    cond = torch.repeat_interleave(cond, 5, dim=0)
    result = self.combine(torch.cat((out, cond), dim=1))
    return result, (mean, logvar) 
示例14
def repack(data, indices, target_indices):
  out = torch.zeros(
    target_indices.size(0), *data.shape[1:],
    dtype=data.dtype, device=data.device
  )
  unique, lengths = indices.unique(return_counts=True)
  unique, target_lengths = target_indices.unique(return_counts=True)
  offset = target_lengths - lengths
  offset = offset.roll(1, 0)
  offset[0] = 0
  offset = torch.repeat_interleave(offset.cumsum(dim=0), lengths, dim=0)
  index = offset + torch.arange(len(indices)).to(data.device)

  out[index] = data
  return data, target_indices 
示例15
def pairwise(op, data, indices, padding_value=0):
  padded, _, _, counts = pad(data, indices, value=padding_value)
  padded = padded.transpose(1, 2)
  reference = padded.unsqueeze(-1)
  padded = padded.unsqueeze(-2)
  op_result = op(padded, reference)

  # batch indices into pairwise tensor:
  batch_indices = torch.arange(counts.size(0))
  batch_indices = torch.repeat_interleave(batch_indices, counts ** 2)

  # first dimension indices:
  first_offset = counts.roll(1)
  first_offset[0] = 0
  first_offset = torch.cumsum(first_offset, dim=0)
  first_offset = torch.repeat_interleave(first_offset, counts)
  first_indices = torch.arange(counts.sum()) - first_offset
  first_indices = torch.repeat_interleave(
    first_indices,
    torch.repeat_interleave(counts, counts)
  )

  # second dimension indices:
  second_offset = torch.repeat_interleave(counts, counts).roll(1)
  second_offset[0] = 0
  second_offset = torch.cumsum(second_offset, dim=0)
  second_offset = torch.repeat_interleave(second_offset, torch.repeat_interleave(counts, counts))
  second_indices = torch.arange((counts ** 2).sum()) - second_offset

  # extract tensor from padded result using indices:
  result = op_result[batch_indices, first_indices, second_indices]

  # access: cumsum(counts ** 2)[idx] + counts[idx] * idy + idz
  access_batch = (counts ** 2).roll(1)
  access_batch[0] = 0
  access_batch = torch.cumsum(access_batch, dim=0)
  access_first = counts

  access = (access_batch, access_first)

  return result, batch_indices, first_indices, second_indices, access 
示例16
def __init__(self, indices):
    unique, counts = indices.unique(return_counts=True)
    structure_indices = torch.arange(counts.sum(), device=indices.device)
    structure_indices = torch.repeat_interleave(
      structure_indices, torch.repeat_interleave(
        counts, counts
      )
    )

    # prepare offsets of connections:
    repeated_counts = torch.repeat_interleave(counts, counts)
    other_counts = repeated_counts.roll(1)
    other_counts[0] = 0
    other_counts = other_counts.cumsum(dim=0)
    offset_factors = other_counts
    offset = torch.repeat_interleave(offset_factors, repeated_counts)
    base = counts.roll(1)
    base[0] = 0
    base = base.cumsum(dim=0)
    base = torch.repeat_interleave(torch.repeat_interleave(base, counts), repeated_counts)

    structure_connections = torch.arange((counts * counts).sum(), device=indices.device)
    structure_connections = structure_connections - offset + base

    super(FullyConnectedScatter, self).__init__(
      0, 0,
      structure_indices,
      structure_connections
    ) 
示例17
def __init__(self, batch, width):
    # prepare offsets of connections:
    offset_factors = torch.arange(width * batch) * width
    offset = torch.repeat_interleave(offset_factors, width)

    structure_connections = torch.arange(width * width * batch) - offset
    structure_connections = structure_connections.reshape(batch * width, width)
    super(FullyConnectedConstant, self).__init__(
      0, 0,
      structure_connections
    ) 
示例18
def matmul(x, y, *args, **kwargs):
        # Prepend 1 to the dimension of x or y if it is 1-dimensional
        remove_x, remove_y = False, False
        if x.dim() == 1:
            x = x.view(1, x.shape[0])
            remove_x = True
        if y.dim() == 1:
            y = y.view(y.shape[0], 1)
            remove_y = True

        x_encoded = CUDALongTensor.__encode_as_fp64(x).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y).data

        # Span x and y for cross multiplication
        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(3, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)

        # Broadcasting
        for _ in range(abs(x_enc_span.ndim - y_enc_span.ndim)):
            if x_enc_span.ndim > y_enc_span.ndim:
                y_enc_span.unsqueeze_(1)
            else:
                x_enc_span.unsqueeze_(1)

        z_encoded = torch.matmul(x_enc_span, y_enc_span, *args, **kwargs)

        if remove_x:
            z_encoded.squeeze_(-2)
        if remove_y:
            z_encoded.squeeze_(-1)

        return CUDALongTensor.__decode_as_int64(z_encoded) 
示例19
def extract_video(self, img):
        buffer = self.transform(img)
        buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 1), self.clip_len, 1)
        buffer = torch.repeat_interleave(torch.unsqueeze(buffer, 0), self.n_clips, 0)
        return buffer 
示例20
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True):
	assert batch_sorted_labels.size(1) > k

	batch_max = torch.max(batch_sorted_labels, dim=1)

	batch_labels = batch_sorted_labels[:, 0:k]
	batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(2.0, batch_max)

	batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros
	batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1)

	positions = torch.arange(k) + 1.0
	positions = positions.view(1, -1)
	positions = torch.repeat_interleave(positions, batch_sorted_labels.size(0), dim=0)

	batch_expt_ranks = 1.0 / positions

	cascad_unsatis_pros = positions
	cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k-1]

	expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros  # w.r.t. all rank positions

	if point:
		batch_errs = torch.sum(expt_satis_ranks, dim=1)
		return batch_errs
	else:
		batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1)
		return batch_err_at_ks 
示例21
def __init__(
        self, data_encoder=None, data_transform=None, lbl_transform=None, repeats=1
    ):
        self.data_encoder = data_encoder
        self.data_transform = data_transform
        self.lbl_transform = lbl_transform
        self.data = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.float)
        self.data = torch.repeat_interleave(self.data, int(repeats), dim=1) 
示例22
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """Input: indices for pairs of atoms that are close to each other.
    each pair only appear once, i.e. only one of the pairs (1, 2) and
    (2, 1) exists.

    Output: indices for all central atoms and it pairs of neighbors. For
    example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
    (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
    central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
    are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
    """
    # convert representation from pair to central-others
    ai1 = atom_index12.view(-1)
    sorted_ai1, rev_indices = ai1.sort()

    # sort and compute unique key
    uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_inverse=False, return_counts=True)

    # compute central_atom_index
    pair_sizes = counts * (counts - 1) // 2
    pair_indices = torch.repeat_interleave(pair_sizes)
    central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)

    # do local combinations within unique key, assuming sorted
    m = counts.max().item() if counts.numel() > 0 else 0
    n = pair_sizes.shape[0]
    intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
    mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
    sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
    sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)

    # unsort result from last part
    local_index12 = rev_indices[sorted_local_index12]

    # compute mapping between representation of central-other to pair
    n = atom_index12.shape[1]
    sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
    return central_atom_index, local_index12 % n, sign12 
示例23
def sample_weights(self) -> Tensor:
        if self._sample_weights is None:
            samples = self.queries[:, 2]
            self._sample_weights = torch.repeat_interleave(
                samples.to(dtype=torch.float).reciprocal(), samples.to(dtype=torch.long)
            )
        return self._sample_weights 
示例24
def select_slate(self, action: torch.Tensor):
        row_idx = torch.repeat_interleave(
            torch.arange(action.shape[0]).unsqueeze(1), action.shape[1], dim=1
        )
        mask = self.mask[row_idx, action]
        # Make sure the indices are in the right range
        assert mask.to(torch.bool).all()
        float_features = self.float_features[row_idx, action]
        value = self.value[row_idx, action]
        return DocList(float_features, mask, value) 
示例25
def _generate_text_rep(text, dur):
        text_rep = []
        for t, d in zip(text, dur):
            text_rep.append(torch.repeat_interleave(t, d))

        text_rep = Ops.merge(text_rep)

        return text_rep 
示例26
def forward(self, inputs):
        feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
        for i in range(self.n_blocks-1):
            feats.append(self.backbone[i](feats[-1]))
        feats = torch.cat(feats, 1)
        fusion = self.fusion_block(feats)

        x1 = F.adaptive_max_pool2d(fusion, 1)
        x2 = F.adaptive_avg_pool2d(fusion, 1)
        feat_global_pool = torch.cat((x1, x2), dim=1)
        feat_global_pool = torch.repeat_interleave(feat_global_pool, repeats=fusion.shape[2], dim=2)
        cat_pooled = torch.cat((feat_global_pool, fusion), dim=1)
        out = self.prediction(cat_pooled).squeeze(-1)
        return F.log_softmax(out, dim=1) 
示例27
def forward(self, inputs):
        feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
        for i in range(self.n_blocks-1):
            feats.append(self.backbone[i](feats[-1]))
        feats = torch.cat(feats, dim=1)

        fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]])
        fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
        return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1) 
示例28
def forward(self, data):
        corr, color, batch = data.pos, data.x, data.batch
        x = torch.cat((corr, color), dim=1)
        feats = [self.head(x, self.knn(x[:, 0:3], batch))]
        for i in range(self.n_blocks-1):
            feats.append(self.backbone[i](feats[-1], batch)[0])
        feats = torch.cat(feats, dim=1)

        fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
        fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
        return self.prediction(torch.cat((fusion, feats), dim=1)) 
示例29
def train(self, replay_buffer, iterations, batch_size=100):

		for it in range(iterations):
			# Sample replay buffer / batch
			state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)

			# Variational Auto-Encoder Training
			recon, mean, std = self.vae(state, action)
			recon_loss = F.mse_loss(recon, action)
			KL_loss	= -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
			vae_loss = recon_loss + 0.5 * KL_loss

			self.vae_optimizer.zero_grad()
			vae_loss.backward()
			self.vae_optimizer.step()


			# Critic Training
			with torch.no_grad():
				# Duplicate next state 10 times
				next_state = torch.repeat_interleave(next_state, 10, 0)

				# Compute value of perturbed actions sampled from the VAE
				target_Q1, target_Q2 = self.critic_target(next_state, self.actor_target(next_state, self.vae.decode(next_state)))

				# Soft Clipped Double Q-learning 
				target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
				# Take max over each action sampled from the VAE
				target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)

				target_Q = reward + not_done * self.discount * target_Q

			current_Q1, current_Q2 = self.critic(state, action)
			critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

			self.critic_optimizer.zero_grad()
			critic_loss.backward()
			self.critic_optimizer.step()


			# Pertubation Model / Action Training
			sampled_actions = self.vae.decode(state)
			perturbed_actions = self.actor(state, sampled_actions)

			# Update through DPG
			actor_loss = -self.critic.q1(state, perturbed_actions).mean()
		 	 
			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()


			# Update Target Networks 
			for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

			for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
				target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 
示例30
def __call__(self, img):
        """
        Args:
            img: shape must be (batch_size, C, spatial_dim1[, spatial_dim2, ...]).

        Returns:
            A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]).
        """
        channel_dim = 1
        if img.shape[channel_dim] == 1:

            img = torch.squeeze(img, dim=channel_dim)

            if self.independent:
                for i in self.applied_labels:
                    foreground = (img == i).type(torch.uint8)
                    mask = get_largest_connected_component_mask(foreground, self.connectivity)
                    img[foreground != mask] = 0
            else:
                foreground = torch.zeros_like(img)
                for i in self.applied_labels:
                    foreground += (img == i).type(torch.uint8)
                mask = get_largest_connected_component_mask(foreground, self.connectivity)
                img[foreground != mask] = 0
            output = torch.unsqueeze(img, dim=channel_dim)
        else:
            # one-hot data is assumed to have binary value in each channel
            if self.independent:
                for i in self.applied_labels:
                    foreground = img[:, i, ...].type(torch.uint8)
                    mask = get_largest_connected_component_mask(foreground, self.connectivity)
                    img[:, i, ...][foreground != mask] = 0
            else:
                applied_img = img[:, self.applied_labels, ...].type(torch.uint8)
                foreground = torch.any(applied_img, dim=channel_dim)
                mask = get_largest_connected_component_mask(foreground, self.connectivity)
                background_mask = torch.unsqueeze(foreground != mask, dim=channel_dim)
                background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=channel_dim)
                applied_img[background_mask] = 0
                img[:, self.applied_labels, ...] = applied_img.type(img.type())
            output = img

        return output