Python源码示例:torch.rfft()

示例1
def p2o(psf, shape):
    '''
    # psf: NxCxhxw
    # shape: [H,W]
    # otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf



# otf2psf: not sure where I got this one from. Maybe translated from Octave source code or whatever. It's just math. 
示例2
def p2o(psf, shape):
    '''
    Args:
        psf: NxCxhxw
        shape: [H,W]

    Returns:
        otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[...,1][torch.abs(otf[...,1])<n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf 
示例3
def p2o(psf, shape):
    '''
    Convert point-spread function to optical transfer function.
    otf = p2o(psf) computes the Fast Fourier Transform (FFT) of the
    point-spread function (PSF) array and creates the optical transfer
    function (OTF) array that is not influenced by the PSF off-centering.

    Args:
        psf: NxCxhxw
        shape: [H, W]

    Returns:
        otf: NxCxHxWx2
    '''
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[...,:psf.shape[2],:psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis+2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(torch.tensor(psf.shape).type_as(psf) * torch.log2(torch.tensor(psf.shape).type_as(psf)))
    otf[..., 1][torch.abs(otf[..., 1]) < n_ops*2.22e-16] = torch.tensor(0).type_as(psf)
    return otf 
示例4
def pad_rfft3(f, onesided=True):
    """
    padded batch real fft
    :param f: tensor of shape [..., res0, res1, res2]
    """
    n0, n1, n2 = f.shape[-3:]
    h0, h1, h2 = int(n0/2), int(n1/2), int(n2/2)

    F2 = torch.rfft(f, signal_ndim=1, onesided=onesided) # [..., res0, res1, res2/2+1, 2]
    F2[..., h2, :] = 0

    F1 = torch.fft(F2.transpose(-3,-2), signal_ndim=1)
    F1[..., h1,:] = 0
    F1 = F1.transpose(-2,-3)

    F0 = torch.fft(F1.transpose(-4,-2), signal_ndim=1)
    F0[..., h0,:] = 0
    F0 = F0.transpose(-2,-4)
    return F0 
示例5
def __init__(self, shape, sd=0.01, decay_power=1, init_img=None):
        super(SpectralImage, self).__init__()
        self.shape = shape
        ch, h, w = shape
        freqs = _rfft2d_freqs(h, w)
        fh, fw = freqs.shape
        self.decay_power = decay_power

        init_val = sd * torch.randn(ch, fh, fw, 2)
        spectrum_var = torch.nn.Parameter(init_val)
        self.spectrum_var = spectrum_var

        spertum_scale = 1.0 / np.maximum(freqs,
                                         1.0 / max(h, w))**self.decay_power
        spertum_scale *= np.sqrt(w * h)
        spertum_scale = torch.FloatTensor(spertum_scale).unsqueeze(-1)
        self.register_buffer('spertum_scale', spertum_scale)

        if init_img is not None:
            if init_img.shape[2] % 2 == 1:
                init_img = nn.functional.pad(init_img, (1, 0, 0, 0))
            fft = torch.rfft(init_img * 4, 2, onesided=True, normalized=False)
            self.spectrum_var.data.copy_(fft / spertum_scale) 
示例6
def forward(self, z):
        z, cond = z
        # Reshape filter coefficients to complex form
        filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous()
        filter_coef[:,:,1] = 0
        # Compute filter windowed impulse response
        h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,))
        h_w = self.filter_window.unsqueeze(0) * h
        h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0)
        # Compute the spectral transform 
        S_sig = torch.rfft(z, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2)
        # Compute the spectral mask
        H = torch.rfft(h_w, 1).reshape(z.shape[0], -1, self.block_size // 2 + 1, 2)
        # Filter the original noise
        S_filtered          = torch.zeros_like(H)
        S_filtered[:,:,:,0] = H[:,:,:,0] * S_sig[:,:,:,0] - H[:,:,:,1] * S_sig[:,:,:,1]
        S_filtered[:,:,:,1] = H[:,:,:,0] * S_sig[:,:,:,1] + H[:,:,:,1] * S_sig[:,:,:,0]
        S_filtered          = S_filtered.reshape(-1, self.block_size // 2 + 1, 2)
        # Inverse the spectral noise back to signal
        filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(z.shape[0], -1)
        return filtered_noise 
示例7
def forward(self, z):
        z, conditions = z
        # Pad the input sequence
        y = nn.functional.pad(z, (0, self.size), "constant", 0)
        # Compute STFT
        Y_S = torch.rfft(y, 1)
        # Compute the current impulse response
        idx = torch.sigmoid(self.wetdry) * self.identity
        imp = torch.sigmoid(1 - self.wetdry) * self.impulse
        dcy = torch.exp(-(torch.exp(self.decay) + 2) * torch.linspace(0,1, self.size).to(z.device))
        final_impulse = idx + imp * dcy
        # Pad the impulse response
        impulse = nn.functional.pad(final_impulse, (0, self.size), "constant", 0)
        if y.shape[-1] > self.size:
            impulse = nn.functional.pad(impulse, (0, y.shape[-1] - impulse.shape[-1]), "constant", 0)
        IR_S = torch.rfft(impulse.detach(),1).expand_as(Y_S)
        # Apply the reverb
        Y_S_CONV = torch.zeros_like(IR_S)
        Y_S_CONV[:,:,0] = Y_S[:,:,0] * IR_S[:,:,0] - Y_S[:,:,1] * IR_S[:,:,1]
        Y_S_CONV[:,:,1] = Y_S[:,:,0] * IR_S[:,:,1] + Y_S[:,:,1] * IR_S[:,:,0]
        # Invert the reverberated signal
        y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1],))
        return y 
示例8
def forward(self, z):
        sig, conditions = z
        # Create noise source
        noise = torch.randn([sig.shape[0], sig.shape[1], self.block_size]).detach().to(sig.device).reshape(-1, self.block_size) * self.noise_att
        S_noise = torch.rfft(noise, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2)
        # Reshape filter coefficients to complex form
        filter_coef = self.filter_coef.reshape([-1, self.filter_size // 2 + 1, 1]).expand([-1, self.filter_size // 2 + 1, 2]).contiguous()
        filter_coef[:,:,1] = 0
        # Compute filter windowed impulse response
        h = torch.irfft(filter_coef, 1, signal_sizes=(self.filter_size,))
        h_w = self.filter_window.unsqueeze(0) * h
        h_w = nn.functional.pad(h_w, (0, self.block_size - self.filter_size), "constant", 0)
        # Compute the spectral mask
        H = torch.rfft(h_w, 1).reshape(sig.shape[0], -1, self.block_size // 2 + 1, 2)
        # Filter the original noise
        S_filtered          = torch.zeros_like(H)
        S_filtered[:,:,:,0] = H[:,:,:,0] * S_noise[:,:,:,0] - H[:,:,:,1] * S_noise[:,:,:,1]
        S_filtered[:,:,:,1] = H[:,:,:,0] * S_noise[:,:,:,1] + H[:,:,:,1] * S_noise[:,:,:,0]
        S_filtered          = S_filtered.reshape(-1, self.block_size // 2 + 1, 2)
        # Inverse the spectral noise back to signal
        filtered_noise = torch.irfft(S_filtered, 1)[:,:self.block_size].reshape(sig.shape[0], -1)
        return filtered_noise 
示例9
def forward(self, z, x):
        z = self.feature(z)
        x = self.feature(x)

        zf = torch.rfft(z, signal_ndim=2)
        xf = torch.rfft(x, signal_ndim=2)

        kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True)

        kzyf = tensor_complex_mulconj(zf, self.yf.to(device=z.device))

        solution =  tensor_complex_division(kzyf, kzzf + self.config.lambda0)

        response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2)

        return response 
示例10
def normalize_var(orig):
    batch_size = orig.size(0)

    # Spectral variance
    mean = torch.mean(orig.view(batch_size, -1), 1).view(batch_size, 1, 1, 1)
    spec_var = torch.rfft(torch.pow(orig -  mean, 2), 2)

    # Normalization
    imC = torch.sqrt(torch.irfft(spec_var, 2, signal_sizes=orig.size()[2:]).abs())
    imC /= torch.max(imC.view(batch_size, -1), 1)[0].view(batch_size, 1, 1, 1)
    minC = 0.001
    imK =  (minC + 1) / (minC + imC)

    mean, imK = mean.detach(), imK.detach()
    img = mean + (orig -  mean) * imK
    return normalize(img) 
示例11
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1,s1,h2,s2,x,y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch


        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,))

        return re 
示例12
def get_uperleft_denominator_pytorch(img, kernel):
    '''
    img: NxCxHxW
    kernel: Nx1xhxw
    denominator: Nx1xHxW
    upperleft: NxCxHxWx2
    '''
    V = p2o(kernel, img.shape[-2:])  # Nx1xHxWx2
    denominator = V[..., 0]**2+V[..., 1]**2  # Nx1xHxW
    upperleft = cmul(cconj(V), rfft(img))  # Nx1xHxWx2 * NxCxHxWx2
    return upperleft, denominator 
示例13
def rfft(t):
    return torch.rfft(t, 2, onesided=False) 
示例14
def rfft(t):
    return torch.rfft(t, 2, onesided=False) 
示例15
def forward(self, x, FB, FBC, F2B, FBFy, alpha, sf):
        FR = FBFy + torch.rfft(alpha*x, 2, onesided=False)
        x1 = cmul(FB, FR)
        FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
        invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
        invWBR = cdiv(FBR, csum(invW, alpha))
        FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1))
        FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1)
        Xest = torch.irfft(FX, 2, onesided=False)

        return Xest 
示例16
def forward(self, x, k, sf, sigma):
        '''
        x: tensor, NxCxWxH
        k: tensor, Nx(1,3)xwxh
        sf: integer, 1
        sigma: tensor, Nx1x1x1
        '''

        # initialization & pre-calculation
        w, h = x.shape[-2:]
        FB = p2o(k, (w*sf, h*sf))
        FBC = cconj(FB, inplace=False)
        F2B = r2c(cabs2(FB))
        STy = upsample(x, sf=sf)
        FBFy = cmul(FBC, torch.rfft(STy, 2, onesided=False))
        x = nn.functional.interpolate(x, scale_factor=sf, mode='nearest')

        # hyper-parameter, alpha & beta
        ab = self.h(torch.cat((sigma, torch.tensor(sf).type_as(sigma).expand_as(sigma)), dim=1))

        # unfolding
        for i in range(self.n):
            
            x = self.d(x, FB, FBC, F2B, FBFy, ab[:, i:i+1, ...], sf)
            x = self.p(torch.cat((x, ab[:, i+self.n:i+self.n+1, ...].repeat(1, 1, x.size(2), x.size(3))), dim=1))

        return x 
示例17
def _get_fft_basis(self):
        fourier_basis = torch.rfft(
            torch.eye(self.filter_length), 
            1, onesided=True
        )
        cutoff = 1 + self.filter_length // 2
        fourier_basis = torch.cat([
            fourier_basis[:, :cutoff, 0],
            fourier_basis[:, :cutoff, 1]
        ], dim=1)
        return fourier_basis.float() 
示例18
def rfft2(data):
    data = ifftshift(data, dim=(-2, -1))
    data = torch.rfft(data, 2, normalized=True, onesided=False)
    data = fftshift(data, dim=(-3, -2))
    return data 
示例19
def forward(self, x):
        x = self.feature(x)

        x =  x * self.config.cos_window

        xf = torch.rfft(x, signal_ndim=2)
        solution =  tensor_complex_division(self.model_alphaf, self.model_betaf + self.config.lambda0)
        response = torch.irfft(torch.sum(tensor_complex_mulconj(xf, solution), dim=1, keepdim=True), signal_ndim=2)
        r_max = torch.max(response)
        return response 
示例20
def update(self, z, lr=1.):
        z = self.feature(z)

        z = z * self.config.cos_window
        zf = torch.rfft(z, signal_ndim=2)

        kzzf = torch.sum(tensor_complex_mulconj(zf,zf), dim=1, keepdim=True)
        kzyf = tensor_complex_mulconj(zf,self.config.yf_online.to(device=z.device))

        if lr > 0.99:
            self.model_alphaf = kzyf
            self.model_betaf = kzzf
        else:
            self.model_alphaf = (1 - lr) * self.model_alphaf.data + lr * kzyf.data
            self.model_betaf = (1 - lr) * self.model_betaf.data + lr * kzzf.data 
示例21
def parse_args(self, **kargs):
        # default branch is AlexNetV1
        self.cfg = {
            'crop_sz': 125,
            'output_sz': 121,
            'lambda0': 1e-4,
            'padding': 2.0,
            'output_sigma_factor': 0.1,
            'initial_lr': 1e-2,
            'final_lr': 1e-5,
            'epoch_num': 50,
            'weight_decay': 5e-4,
            'batch_size': 32,

            'interp_factor': 0.01,
            'num_scale': 3,
            'scale_step': 1.0275,
            'min_scale_factor': 0.2,
            'max_scale_factor': 5,
            'scale_penalty': 0.9925,
            }

        for key, val in kargs.items():
            self.cfg.update({key: val})

        self.cfg['output_sigma'] = self.cfg['crop_sz'] / (1 + self.cfg['padding']) * self.cfg['output_sigma_factor']
        self.cfg['y'] = gaussian_shaped_labels(self.cfg['output_sigma'], [self.cfg['output_sz'], self.cfg['output_sz']])
        self.cfg['yf'] = torch.rfft(torch.Tensor(self.cfg['y']).view(1, 1, self.cfg['output_sz'], self.cfg['output_sz']).cuda(), signal_ndim=2)
        self.cfg['net_average_image'] = np.array([104, 117, 123]).reshape(1, 1, -1).astype(np.float32)

        self.cfg['scale_factor'] = self.cfg['scale_step'] ** (np.arange(self.cfg['num_scale']) - self.cfg['num_scale'] / 2)
        self.cfg['scale_penalties'] = self.cfg['scale_penalty'] ** (np.abs((np.arange(self.cfg['num_scale']) - self.cfg['num_scale'] / 2)))
        self.cfg['net_input_size'] = [self.cfg['crop_sz'], self.cfg['crop_sz']]
        self.cfg['cos_window'] = torch.Tensor(np.outer(np.hanning(self.cfg['crop_sz']), np.hanning(self.cfg['crop_sz']))).cuda()
        self.cfg['y_online'] = gaussian_shaped_labels(self.cfg['output_sigma'], self.cfg['net_input_size'])
        self.cfg['yf_online'] = torch.rfft(torch.Tensor(self.cfg['y_online']).view(1, 1, self.cfg['crop_sz'], self.cfg['crop_sz']).cuda(), signal_ndim=2)

        self.cfg = dict2tuple(self.cfg) 
示例22
def cfft2(a):
    """Do FFT and center the low frequency component.
    Always produces odd (full) output sizes."""

    return rfftshift2(torch.rfft(a, 2)) 
示例23
def dct1(x):
    """
    Discrete Cosine Transform, Type I

    :param x: the input signal
    :return: the DCT-I of the signal over the last dimension
    """
    x_shape = x.shape
    x = x.view(-1, x_shape[-1])

    return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) 
示例24
def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = torch.rfft(v, 1, onesided=False)

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V 
示例25
def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1,s1,h2,s2,x,y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch


        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx,im_fx,re_fy,im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()), 1, signal_sizes=(output_size,))

        return re 
示例26
def filters(self):
        ft_f = torch.rfft(self._filters, 1, normalized=True)
        hft_f = torch.stack([ft_f[:, :, :, 1], - ft_f[:, :, :, 0]], dim=-1)
        hft_f = torch.irfft(hft_f, 1, normalized=True,
                            signal_sizes=(self.kernel_size, ))
        return torch.cat([self._filters, hft_f], dim=0) 
示例27
def backward(ctx,grad_output):
        h1,s1,h2,s2,x,y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_prod = torch.rfft(grad_output, 1)
        grad_re_prod = grad_prod.select(-1, 0)
        grad_im_prod = grad_prod.select(-1, 1)

        # Compute the gradient of x first then y

        # Gradient of x
        # Recompute fy
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy,  1, grad_im_prod, im_fy)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy)
        grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx)
        del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx

        # Gradient of y
        # Recompute fx
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx,  1, grad_im_prod, im_fx)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx)
        grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy)
        del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy

        return None, None, None, None, None, grad_x, grad_y, None 
示例28
def so3_rfft(x, for_grad=False, b_out=None):
    '''
    :param x: [..., beta, alpha, gamma]
    :return: [l * m * n, ..., complex]
    '''
    b_in = x.size(-1) // 2
    assert x.size(-1) == 2 * b_in
    assert x.size(-2) == 2 * b_in
    assert x.size(-3) == 2 * b_in
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[:-3]

    x = x.contiguous().view(-1, 2 * b_in, 2 * b_in, 2 * b_in)  # [batch, beta, alpha, gamma]

    '''
    :param x: [batch, beta, alpha, gamma] (nbatch, 2 b_in, 2 b_in, 2 b_in)
    :return: [l * m * n, batch, complex] (b_out (4 b_out**2 - 1) // 3, nbatch, 2)
    '''
    nspec = b_out * (4 * b_out ** 2 - 1) // 3
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device)

    output = x.new_empty((nspec, nbatch, 2))
    if x.is_cuda and x.dtype == torch.float32:
        x = torch.rfft(x, 2)  # [batch, beta, m, n, complex]
        cuda_kernel = _setup_so3fft_cuda_kernel(b_in=b_in, b_out=b_out, nbatch=nbatch, real_input=True, device=x.device.index)
        cuda_kernel(x, wigner, output)
    else:
        # TODO use torch.rfft
        x = torch.fft(torch.stack((x, torch.zeros_like(x)), dim=-1), 2)
        if b_in < b_out:
            output.fill_(0)
        for l in range(b_out):
            s = slice(l * (4 * l**2 - 1) // 3, l * (4 * l**2 - 1) // 3 + (2 * l + 1) ** 2)
            l1 = min(l, b_in - 1)  # if b_out > b_in, consider high frequencies as null

            xx = x.new_zeros((x.size(0), x.size(1), 2 * l + 1, 2 * l + 1, 2))
            xx[:, :, l: l + l1 + 1, l: l + l1 + 1] = x[:, :, :l1 + 1, :l1 + 1]
            if l1 > 0:
                xx[:, :, l - l1:l, l: l + l1 + 1] = x[:, :, -l1:, :l1 + 1]
                xx[:, :, l: l + l1 + 1, l - l1:l] = x[:, :, :l1 + 1, -l1:]
                xx[:, :, l - l1:l, l - l1:l] = x[:, :, -l1:, -l1:]

            out = torch.einsum("bmn,zbmnc->mnzc", (wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1), xx))
            output[s] = out.view((2 * l + 1) ** 2, -1, 2)
    
    output = output.view(-1, *batch_size, 2)  # [l * m * n, ..., complex]
    return output 
示例29
def get_reg_filter(sz: torch.Tensor, target_sz: torch.Tensor, params):
    """Computes regularization filter in CCOT and ECO."""

    if not params.use_reg_window:
        return params.reg_window_min * torch.ones(1,1,1,1)

    if getattr(params, 'reg_window_square', False):
        target_sz = target_sz.prod().sqrt() * torch.ones(2)

    # Normalization factor
    reg_scale = 0.5 * target_sz

    # Construct grid
    if getattr(params, 'reg_window_centered', True):
        wrg = torch.arange(-int((sz[0]-1)/2), int(sz[0]/2+1), dtype=torch.float32).view(1,1,-1,1)
        wcg = torch.arange(-int((sz[1]-1)/2), int(sz[1]/2+1), dtype=torch.float32).view(1,1,1,-1)
    else:
        wrg = torch.cat([torch.arange(0, int(sz[0]/2+1), dtype=torch.float32),
                         torch.arange(-int((sz[0] - 1) / 2), 0, dtype=torch.float32)]).view(1,1,-1,1)
        wcg = torch.cat([torch.arange(0, int(sz[1]/2+1), dtype=torch.float32),
                         torch.arange(-int((sz[1] - 1) / 2), 0, dtype=torch.float32)]).view(1,1,1,-1)

    # Construct regularization window
    reg_window = (params.reg_window_edge - params.reg_window_min) * \
                 (torch.abs(wrg/reg_scale[0])**params.reg_window_power +
                  torch.abs(wcg/reg_scale[1])**params.reg_window_power) + params.reg_window_min

    # Compute DFT and enforce sparsity
    reg_window_dft = torch.rfft(reg_window, 2) / sz.prod()
    reg_window_dft_abs = complex.abs(reg_window_dft)
    reg_window_dft[reg_window_dft_abs < params.reg_sparsity_threshold * reg_window_dft_abs.max(), :] = 0

    # Do the inverse transform to correct for the window minimum
    reg_window_sparse = torch.irfft(reg_window_dft, 2, signal_sizes=sz.long().tolist())
    reg_window_dft[0,0,0,0,0] += params.reg_window_min - sz.prod() * reg_window_sparse.min()
    reg_window_dft = complex.real(fourier.rfftshift2(reg_window_dft))

    # Remove zeros
    max_inds,_ = reg_window_dft.nonzero().max(dim=0)
    mid_ind = int((reg_window_dft.shape[2]-1)/2)
    top = max_inds[-2].item() + 1
    bottom = 2*mid_ind - max_inds[-2].item()
    right = max_inds[-1].item() + 1
    reg_window_dft = reg_window_dft[..., bottom:top, :right]
    if reg_window_dft.shape[-1] > 1:
        reg_window_dft = torch.cat([reg_window_dft[..., 1:].flip((2, 3)), reg_window_dft], -1)

    return reg_window_dft 
示例30
def backward(ctx,grad_output):
        h1,s1,h2,s2,x,y = ctx.saved_tensors

        # Recompute part of the forward pass to get the input to the complex product
        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, ctx.output_size, x, ctx.force_cpu_scatter_add)
        py = CountSketchFn_forward(h2, s2, ctx.output_size, y, ctx.force_cpu_scatter_add)

        # Then convert the output to Fourier domain
        grad_output = grad_output.contiguous()
        grad_prod = torch.rfft(grad_output, 1)
        grad_re_prod = grad_prod.select(-1, 0)
        grad_im_prod = grad_prod.select(-1, 1)

        # Compute the gradient of x first then y
        
        # Gradient of x
        # Recompute fy
        fy = torch.rfft(py,1)
        re_fy = fy.select(-1,0)
        im_fy = fy.select(-1,1)
        del py
        # Compute the gradient of fx, then back to temporal space
        grad_re_fx = torch.addcmul(grad_re_prod * re_fy,  1, grad_im_prod, im_fy)
        grad_im_fx = torch.addcmul(grad_im_prod * re_fy, -1, grad_re_prod, im_fy)
        grad_fx = torch.irfft(torch.stack((grad_re_fx,grad_im_fx), grad_re_fx.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of x
        grad_x = CountSketchFn_backward(h1, s1, ctx.x_size, grad_fx)
        del re_fy,im_fy,grad_re_fx,grad_im_fx,grad_fx

        # Gradient of y
        # Recompute fx
        fx = torch.rfft(px,1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        # Compute the gradient of fy, then back to temporal space
        grad_re_fy = torch.addcmul(grad_re_prod * re_fx,  1, grad_im_prod, im_fx)
        grad_im_fy = torch.addcmul(grad_im_prod * re_fx, -1, grad_re_prod, im_fx)
        grad_fy = torch.irfft(torch.stack((grad_re_fy,grad_im_fy), grad_re_fy.dim()), 1, signal_sizes=(ctx.output_size,))
        # Finally compute the gradient of y
        grad_y = CountSketchFn_backward(h2, s2, ctx.y_size, grad_fy)
        del re_fx,im_fx,grad_re_fy,grad_im_fy,grad_fy

        return None, None, None, None, None, grad_x, grad_y, None