Python源码示例:torch.ifft()

示例1
def ifft2(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.ifft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
示例2
def ifft2(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.

    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.

    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.ifft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
示例3
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
示例4
def ifft(t):
    return torch.ifft(t, 2) 
示例5
def ifft(t):
    return torch.ifft(t, 2) 
示例6
def ifft(t):
    # Complex-to-complex Inverse Discrete Fourier Transform
    return torch.ifft(t, 2) 
示例7
def pad_irfft3(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    res = F.shape[-3]
    f0 = torch.ifft(F.transpose(-4,-2), signal_ndim=1).transpose(-2,-4)
    f1 = torch.ifft(f0.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f2 = torch.irfft(f1, signal_ndim=1, signal_sizes=[res]) # [..., res0, res1, res2]
    return f2 
示例8
def pad_ifft2(F):
    """
    padded batch inverse real fft
    :param f: tensor of shape [..., res0, res1, res2/2+1, 2]
    """
    f0 = torch.ifft(F.transpose(-3,-2), signal_ndim=1).transpose(-2,-3)
    f1 = torch.ifft(f0, signal_ndim=1)
    return f2 
示例9
def so3_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round((3 / 4 * nspec) ** (1 / 3))
    assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m * n, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
    :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)  # [beta, l * m * n] (2 * b_out, nspec)

    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
    if x.is_cuda and x.dtype == torch.float32:
        cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=False, device=x.device.index)
        cuda_kernel(x, wigner, output)  # [batch, beta, m, n, complex]
    else:
        output.fill_(0)
        for l in range(min(b_in, b_out)):
            s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
            out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
            l1 = min(l, b_out - 1)  # if b_out < b_in
            output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
            if l > 0:
                output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
                output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
                output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]

    output = torch.ifft(output, 2) * output.size(-2) ** 2  # [batch, beta, alpha, gamma, complex]    
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2)
    return output 
示例10
def backward(self, grad_output):  # pylint: disable=W
        # ifft of grad_output is not necessarily real, therefore we cannot use rifft
        return so3_ifft(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None 
示例11
def forward(self, bottom):

        batch_size, _, height, width = bottom.size()

        bottom_flat = bottom.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)

        sketch_1 = bottom_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = bottom_flat.mm(self.sparse_sketch_matrix2)

        im_zeros_1 = torch.zeros(sketch_1.size()).to(sketch_1.device)
        im_zeros_2 = torch.zeros(sketch_2.size()).to(sketch_2.device)
        fft1 = torch.fft(torch.cat([sketch_1.unsqueeze(-1), im_zeros_1.unsqueeze(-1)], dim=-1), 1)
        fft2 = torch.fft(torch.cat([sketch_2.unsqueeze(-1), im_zeros_2.unsqueeze(-1)], dim=-1), 1)

        fft_product_real = fft1[..., 0].mul(fft2[..., 0]) - fft1[..., 1].mul(fft2[..., 1])
        fft_product_imag = fft1[..., 0].mul(fft2[..., 1]) + fft1[..., 1].mul(fft2[..., 0])

        cbp_flat = torch.ifft(torch.cat([
            fft_product_real.unsqueeze(-1),
            fft_product_imag.unsqueeze(-1)],
            dim=-1), 1)[..., 0]

        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=[1, 2])

        return cbp 
示例12
def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = nn.Sequential(
            BlockPermProduct(size=config['size'], complex=True, share_logit=False),
            Block2x2DiagProduct(size=config['size'], complex=True)
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        self.target_matrix = torch.fft(real_to_complex(torch.eye(size)))
        # self.target_matrix = size * torch.ifft(real_to_complex(torch.eye(size)))
        self.input = real_to_complex(torch.eye(size)) 
示例13
def forward(self, x):
         bsn = 1
         batchSize, dim, h, w = x.data.shape
         x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim)  # batchsize,h, w, dim,
         y = torch.ones(batchSize, self.output_dim, device=x.device)

         for img in range(batchSize // bsn):
             segLen = bsn * h * w
             upper = batchSize * h * w
             interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
             interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
             batch_x = x_flat[interLarge, :]

             sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
             sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

             sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
             sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

             Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
             Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

             tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

             y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

         y = self._signed_sqrt(y)
         y = self._l2norm(y)
         return y 
示例14
def fft(input, inverse=False):
    """Interface with torch FFT routines for 3D signals.
        fft of a 3d signal
        Example
        -------
        x = torch.randn(128, 32, 32, 32, 2)

        x_fft = fft(x)
        x_ifft = fft(x, inverse=True)
        Parameters
        ----------
        x : tensor
            Complex input for the FFT.
        inverse : bool
            True for computing the inverse FFT.

        Raises
        ------
        TypeError
            In the event that x does not have a final dimension 2 i.e. not
            complex.

        Returns
        -------
        output : tensor
            Result of FFT or IFFT.
    """
    if not _is_complex(input):
        raise TypeError('The input should be complex (e.g. last dimension is 2)')
    if inverse:
        return torch.ifft(input, 3)
    return torch.fft(input, 3) 
示例15
def torch_ifft2(k, normalized=True):
    """ ifft on last 2 dim """
    kt = numpy_to_torch(k)
    xt = torch.ifft(kt, 2, normalized)
    return torch_to_complex_numpy(xt) 
示例16
def torch_ifft2c(x, normalized=True):
    """ ifft2 on last 2 dim """
    x = np.fft.ifftshift(x, axes=(-2,-1))
    xt = numpy_to_torch(x)
    kt = torch.ifft(xt, 2, normalized=True)
    k = torch_to_complex_numpy(kt)
    return np.fft.fftshift(k, axes=(-2,-1)) 
示例17
def ifft2(data):
    assert data.size(-1) == 2
    data = torch.ifft(data, 2, normalized=True)
    return data 
示例18
def ifft2c(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.
    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.
    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = ifftshift(data, dim=(-3, -2))
    data = torch.ifft(data, 2, normalized=True)
    data = fftshift(data, dim=(-3, -2))
    return data 
示例19
def fft_adj(x, ndim=2):
    return torch.ifft(x, signal_ndim=ndim, normalized=True) 
示例20
def fft_filter(x, kern, norm=None):
    """FFT-based filtering on a 2-size oversampled grid.
    """
    im_size = torch.tensor(x.shape).to(torch.long)[3:]
    grid_size = im_size * 2

    # set up n-dimensional zero pad
    pad_sizes = []
    permute_dims = [0, 1]
    inv_permute_dims = [0, 1, 2 + grid_size.shape[0]]
    for i in range(grid_size.shape[0]):
        pad_sizes.append(0)
        pad_sizes.append(int(grid_size[-1 - i] - im_size[-1 - i]))
        permute_dims.append(3 + i)
        inv_permute_dims.append(2 + i)
    permute_dims.append(2)
    pad_sizes = tuple(pad_sizes)
    permute_dims = tuple(permute_dims)
    inv_permute_dims = tuple(inv_permute_dims)

    # zero pad and fft
    x = F.pad(x, pad_sizes)
    x = x.permute(permute_dims)
    x = torch.fft(x, grid_size.numel())
    if norm == 'ortho':
        x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
    x = x.permute(inv_permute_dims)

    # apply the filter
    x = complex_mult(x, kern, dim=2)

    # inverse fft
    x = x.permute(permute_dims)
    x = torch.ifft(x, grid_size.numel())
    x = x.permute(inv_permute_dims)

    # crop to input size
    crop_starts = tuple(np.array(x.shape).astype(np.int) * 0)
    crop_ends = [x.shape[0], x.shape[1], x.shape[2]]
    for dim in im_size:
        crop_ends.append(int(dim))
    x = x[tuple(map(slice, crop_starts, crop_ends))]

    # scaling, assume user handled adjoint scaling with their kernel
    if norm == 'ortho':
        x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))

    return x 
示例21
def s2_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round(nspec ** 0.5)
    assert nspec == b_in ** 2
    if b_out is None:
        b_out = b_in
    assert b_out >= b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
    wigner = wigner.view(2 * b_out, -1)  # [beta, l * m] (2 * b_out, nspec)

    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
                    args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
                    stream=stream)
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    else:
        output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
        for l in range(b_in):
            s = slice(l ** 2, l ** 2 + 2 * l + 1)
            out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
            output[:, :, :l + 1] += out[:, :, -l - 1:]
            if l > 0:
                output[:, :, -l:] += out[:, :, :l]

    output = torch.ifft(output, 1) * output.size(-2)  # [batch, beta, alpha, complex]
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
    return output 
示例22
def so3_rifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round((3 / 4 * nspec) ** (1 / 3))
    assert nspec == b_in * (4 * b_in ** 2 - 1) // 3
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m * n, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
    :return: [batch, beta, alpha, gamma] (nbatch, 2 b_out, 2 b_out, 2 b_out)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)  # [beta, l * m * n] (2 * b_out, nspec)

    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
    if x.is_cuda and x.dtype == torch.float32:
        cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_output=True, device=x.device.index)
        cuda_kernel(x, wigner, output)  # [batch, beta, m, n, complex]
    else:
        # TODO can be optimized knowing that the output is real, like in _setup_so3ifft_cuda_kernel(real_output=True)
        output.fill_(0)
        for l in range(min(b_in, b_out)):
            s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
            out = torch.einsum("mnzc,bmn->zbmnc", (x[s].view(2 * l + 1, 2 * l + 1, -1, 2), wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
            l1 = min(l, b_out - 1)  # if b_out < b_in
            output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l: l + l1 + 1, l: l + l1 + 1]
            if l > 0:
                output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1: l, l: l + l1 + 1]
                output[:, :, :l1 + 1, -l1:] += out[:, :, l: l + l1 + 1, l - l1: l]
                output[:, :, -l1:, -l1:] += out[:, :, l - l1: l, l - l1: l]

    output = torch.ifft(output, 2) * output.size(-2) ** 2  # [batch, beta, alpha, gamma, complex]
    output = output[..., 0]  # [batch, beta, alpha, gamma]
    output = output.contiguous()

    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out)
    return output