Python源码示例:torch.conv2d()

示例1
def __patched_conv_ops(op, x, y, *args, **kwargs):
        x_encoded = CUDALongTensor.__encode_as_fp64(x).data
        y_encoded = CUDALongTensor.__encode_as_fp64(y).data

        repeat_idx = [1] * (x_encoded.dim() - 1)
        x_enc_span = x_encoded.repeat(3, *repeat_idx)
        y_enc_span = torch.repeat_interleave(y_encoded, repeats=3, dim=0)

        bs, c, *img = x.size()
        c_out, c_in, *ks = y.size()

        x_enc_span = x_enc_span.transpose_(0, 1).reshape(bs, 9 * c, *img)
        y_enc_span = y_enc_span.reshape(9 * c_out, c_in, *ks)

        c_z = c_out if op in ["conv1d", "conv2d"] else c_in

        z_encoded = getattr(torch, op)(
            x_enc_span, y_enc_span, *args, **kwargs, groups=9
        )
        z_encoded = z_encoded.reshape(bs, 9, c_z, *z_encoded.size()[2:]).transpose_(
            0, 1
        )

        return CUDALongTensor.__decode_as_int64(z_encoded) 
示例2
def cross_correlation_loss(I, J, n):
    I = I.permute(0, 3, 1, 2)
    J = J.permute(0, 3, 1, 2)
    batch_size, channels, xdim, ydim = I.shape
    I2 = torch.mul(I, I)
    J2 = torch.mul(J, J)
    IJ = torch.mul(I, J)
    sum_filter = torch.ones((1, channels, n, n))
    if use_gpu:
        sum_filter = sum_filter.cuda()
    I_sum = torch.conv2d(I, sum_filter, padding=1, stride=(1,1))
    J_sum = torch.conv2d(J, sum_filter,  padding=1 ,stride=(1,1))
    I2_sum = torch.conv2d(I2, sum_filter, padding=1, stride=(1,1))
    J2_sum = torch.conv2d(J2, sum_filter, padding=1, stride=(1,1))
    IJ_sum = torch.conv2d(IJ, sum_filter, padding=1, stride=(1,1))
    win_size = n**2
    u_I = I_sum / win_size
    u_J = J_sum / win_size
    cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
    I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
    J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
    cc = cross*cross / (I_var*J_var + np.finfo(float).eps)
    return torch.mean(cc) 
示例3
def cross_correlation_loss(I, J, n):
    I = I.permute(0, 3, 1, 2)
    J = J.permute(0, 3, 1, 2)
    batch_size, channels, xdim, ydim = I.shape
    I2 = torch.mul(I, I)
    J2 = torch.mul(J, J)
    IJ = torch.mul(I, J)
    sum_filter = torch.ones((1, channels, n, n))
    if use_gpu:
        sum_filter = sum_filter.cuda()
    I_sum = torch.conv2d(I, sum_filter, padding=1, stride=(1,1))
    J_sum = torch.conv2d(J, sum_filter,  padding=1 ,stride=(1,1))
    I2_sum = torch.conv2d(I2, sum_filter, padding=1, stride=(1,1))
    J2_sum = torch.conv2d(J2, sum_filter, padding=1, stride=(1,1))
    IJ_sum = torch.conv2d(IJ, sum_filter, padding=1, stride=(1,1))
    win_size = n**2
    u_I = I_sum / win_size
    u_J = J_sum / win_size
    cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
    I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
    J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size
    cc = cross*cross / (I_var*J_var + np.finfo(float).eps)
    return torch.mean(cc) 
示例4
def region_cohesion(classes):
  """Computes a measure of region cohesion, that is, the ratio between
  class edges and surface.
  """
  dx = torch.conv2d(classes, torch.tensor([[1, -2, 1]]))
  dy = torch.conv2d(classes, torch.tensor([[1], [-2], [1]]))
  mag = torch.sqrt(dx ** 2 + dy ** 2)
  total_mag = mag.sum(dim=-2)
  total_class = classes.sum(dim=-2)
  return total_mag / total_class 
示例5
def conv2d(input, weight, *args, **kwargs):
        return CUDALongTensor.__patched_conv_ops(
            "conv2d", input, weight, *args, **kwargs
        ) 
示例6
def conv2d(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, stride=1, padding=0, dilation=1, groups=1, mode=None):
    """Standard conv2d. Returns the input if weight=None."""

    if weight is None:
        return input

    ind = None
    if mode is not None:
        if padding != 0:
            raise ValueError('Cannot input both padding and mode.')
        if mode == 'same':
            padding = (weight.shape[2]//2, weight.shape[3]//2)
            if weight.shape[2] % 2 == 0 or weight.shape[3] % 2 == 0:
                ind = (slice(-1) if weight.shape[2] % 2 == 0 else slice(None),
                       slice(-1) if weight.shape[3] % 2 == 0 else slice(None))
        elif mode == 'valid':
            padding = (0, 0)
        elif mode == 'full':
            padding = (weight.shape[2]-1, weight.shape[3]-1)
        else:
            raise ValueError('Unknown mode for padding.')

    out = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
    if ind is None:
        return out
    return out[:,:,ind[0],ind[1]] 
示例7
def conv1x1(input: torch.Tensor, weight: torch.Tensor):
    """Do a convolution with a 1x1 kernel weights. Implemented with matmul, which can be faster than using conv."""

    if weight is None:
        return input

    return torch.conv2d(input, weight) 
示例8
def _downsample(x, f, direction, shift):
  """Downsample by a factor of 2 using reflecting boundary conditions.

  This function convolves `x` with filter `f` with reflecting boundary
  conditions, and then decimates by a factor of 2. This is usually done to
  downsample `x`, assuming `f` is some smoothing filter, but will also be used
  for wavelet transformations in which `f` is not a smoothing filter.

  Args:
    x: The input tensor (numpy or TF), of size (num_channels, width, height).
    f: The input filter, which must be an odd-length 1D numpy array.
    direction: The spatial direction in [0, 1] along which `x` will be convolved
      with `f` and then decimated. Because `x` has a batch/channels dimension,
      `direction` == 0 corresponds to downsampling along axis 1 in `x`, and
      `direction` == 1 corresponds to downsampling along axis 2 in `x`.
    shift: A shift amount in [0, 1] by which `x` will be shifted along the axis
      specified by `direction` before filtering.

  Returns:
    `x` convolved with `f` along the spatial dimension `direction` with
    reflection boundary conditions with an offset of `shift`.
  """
  _check_resample_inputs(x, f, direction, shift)
  # The above and below padding amounts are different so as to support odd
  # and even length filters. An odd-length filter of length n causes a padding
  # of (n-1)/2 on both sides, while an even-length filter will pad by one less
  # below than above.
  x = torch.as_tensor(x)
  f = torch.tensor(f).to(x)
  x_padded = pad_reflecting(x, (len(f) - 1) // 2, len(f) // 2, direction + 1)
  if direction == 0:
    x_padded = x_padded[:, shift:, :]
    f_ex = torch.as_tensor(f)[:, np.newaxis]
    stride = [2, 1]
  elif direction == 1:
    x_padded = x_padded[:, :, shift:]
    f_ex = torch.as_tensor(f)[np.newaxis, :]
    stride = [1, 2]
  y = torch.conv2d(
      x_padded[:, np.newaxis, :, :],
      f_ex[np.newaxis, np.newaxis].type(x.dtype),
      stride=stride)[:, 0, :, :]
  return y