Python源码示例:torch.stft()

示例1
def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None) -> None:
        super(Spectrogram, self).__init__()
        self.n_fft = n_fft
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequecies due to onesided=True in torch.stft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.pad = pad
        self.power = power
        self.normalized = normalized 
示例2
def test_istft_requires_nola(self):
        stft = torch.zeros((3, 5, 2))
        kwargs_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.ones(4),
        }

        kwargs_not_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.zeros(4),
        }

        # A window of ones meets NOLA but a window of zeros does not. This should
        # throw an error.
        torchaudio.functional.istft(stft, **kwargs_ok)
        self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok) 
示例3
def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
                                               window=torch.ones(L), center=False, normalized=False)
        # There is a larger error due to the scaling of amplitude
        _compare_estimate(sound, estimate, atol=1e-3) 
示例4
def compute_torch_stft(audio, descriptor):

    name, *args = descriptor.split("_")

    n_fft, hop_size, *rest = args
    n_fft = int(n_fft)
    hop_size = int(hop_size)

    stft = torch.stft(
        audio,
        n_fft=n_fft,
        hop_length=hop_size,
        window=torch.hann_window(n_fft, device=audio.device)
    )

    stft = torch.sqrt((stft ** 2).sum(-1))

    return stft 
示例5
def forward(self, x):
        if self.preemp is not None:
            x = x.unsqueeze(1)
            x = self.preemp(x)
            x = x.squeeze(1)
        stft = torch.stft(x,
                          self.win_length,
                          self.hop_length,
                          fft_size=self.n_fft,
                          window=self.win)
        real = stft[:, :, :, 0]
        im = stft[:, :, :, 1]
        spec = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))

        # convert linear spec to mel
        mel = torch.matmul(spec, self.mel_basis)
        # convert to db
        mel = _amp_to_db(mel) - hparams.ref_level_db
        return _normalize(mel) 
示例6
def __call__(self, wav):
        with torch.no_grad():
            # STFT
            data = torch.stft(wav, n_fft=self.nfft, hop_length=self.window_shift,
                              win_length=self.window_size, window=self.window)
            data /= self.window.pow(2).sum().sqrt_()
            #mag = data.pow(2).sum(-1).log1p_()
            #ang = torch.atan2(data[:, :, 1], data[:, :, 0])
            ## {mag, phase} x n_freq_bin x n_frame
            #data = torch.cat([mag.unsqueeze_(0), ang.unsqueeze_(0)], dim=0)
            ## FxTx2 -> 2xFxT
            data = data.transpose(1, 2).transpose(0, 1)
            return data


# transformer: frame splitter 
示例7
def forward(self, audio):
        p = (self.n_fft - self.hop_length) // 2
        audio = F.pad(audio, (p, p), "reflect").squeeze(1)
        fft = torch.stft(
            audio,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window,
            center=False,
        )
        real_part, imag_part = fft.unbind(-1)
        magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
        mel_output = torch.matmul(self.mel_basis, magnitude)
        log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
        return log_mel_spec 
示例8
def forward(self, x):
        """
        Input: (nb_samples, nb_channels, nb_timesteps)
        Output:(nb_samples, nb_channels, nb_bins, nb_frames, 2)
        """

        nb_samples, nb_channels, nb_timesteps = x.size()

        # merge nb_samples and nb_channels for multichannel stft
        x = x.reshape(nb_samples*nb_channels, -1)

        # compute stft with parameters as close as possible scipy settings
        stft_f = torch.stft(
            x,
            n_fft=self.n_fft, hop_length=self.n_hop,
            window=self.window, center=self.center,
            normalized=False, onesided=True,
            pad_mode='reflect'
        )

        # reshape back to channel dimension
        stft_f = stft_f.contiguous().view(
            nb_samples, nb_channels, self.n_fft // 2 + 1, -1, 2
        )
        return stft_f 
示例9
def test_istft(self):
        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        ])
        self.assert_batch_consistencies(F.istft, stft, n_fft=4, length=4) 
示例10
def test_batch_TimeStretch(self):
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5) 
示例11
def test_istft_of_ones(self):
        # stft = torch.stft(torch.ones(4), 4)
        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        ])

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        _compare_estimate(torch.ones(4), estimate) 
示例12
def test_istft_of_zeros(self):
        # stft = torch.stft(torch.zeros(4), 4)
        stft = torch.zeros((3, 5, 2))

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        _compare_estimate(torch.zeros(4), estimate) 
示例13
def test_istft_requires_overlap_windows(self):
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        stft = torch.zeros((3, 5, 2))
        self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
                          hop_length=20, win_length=1, window=torch.ones(1)) 
示例14
def forward(self, x):
        x = torch.stft(x, self.n_fft, **self.stft_kwargs).norm(dim=-1, p=2)
        x = self.f2m(x.permute(0, 2, 1))
        if self.use_cuda_kernel:
            x, ls = pcen_cuda_kernel(x, self.eps, self.s, self.alpha, self.delta, self.r, self.trainable, self.last_state, self.empty)
        else:
            x, ls = pcen(x, self.eps, self.s, self.alpha, self.delta, self.r, self.training and self.trainable, self.last_state, self.empty)
        self.last_state = ls.detach()
        self.empty = False
        return x 
示例15
def __call__(self, pkg, cached_file=None):
        pkg = format_package(pkg)
        wav = pkg['chunk']
        max_frames = wav.size(0) // self.hop
        if cached_file is not None:
            # load pre-computed data
            X = torch.load(cached_file)
            beg_i = pkg['chunk_beg_i'] // self.hop
            end_i = pkg['chunk_end_i'] // self.hop
            X = X[:, beg_i:end_i]
            pkg['lps'] = X
        else:
            #print ('Chunks wav shape is {}'.format(wav.shape))
            wav = wav.to(self.device)
            X = torch.stft(wav, self.n_fft,
                           self.hop, self.win)
            X = torch.norm(X, 2, dim=2).cpu()[:, :max_frames]
            X = 10 * torch.log10(X ** 2 + 10e-20).cpu()
            if self.der_order > 0 :
                deltas=[X]
                for n in range(1,self.der_order+1):
                    deltas.append(librosa.feature.delta(X.numpy(),order=n))
                X=torch.from_numpy(np.concatenate(deltas))
     
            pkg[self.name] = X
        # Overwrite resolution to hop length
        pkg['dec_resolution'] = self.hop
        return pkg 
示例16
def __call__(self, pkg):
        pkg = format_package(pkg)
        wav = pkg['chunk']
        max_frames = wav.size(0) // self.hop
        wav = wav.to(self.device)
        X = torch.stft(wav, self.n_fft,
                       self.hop, self.win)
        X = torch.norm(X, 2, dim=2).cpu()[:, :max_frames]
        pkg['lps'] = 10 * torch.log10(X ** 2 + 10e-20).cpu()
        return pkg 
示例17
def is_stft(descriptor):
    return descriptor.startswith("stft") 
示例18
def get_mel(self, x):
        stft = torch.stft(
            input=x,
            n_fft=self.hp.audio.n_fft,
            hop_length=self.hp.audio.hop_length,
            win_length=self.hp.audio.win_length,
            window=self.window
        )
        mag = torch.norm(stft, p=2, dim=-1)
        melspectrogram = torch.matmul(self.mel_basis, mag)
        return melspectrogram 
示例19
def stft(y, scale='linear'):
    D = torch.stft(y, n_fft=1024, hop_length=256, win_length=1024, window=torch.hann_window(1024).cuda())
    D = torch.sqrt(D.pow(2).sum(-1) + 1e-10)
    # D = torch.sqrt(torch.clamp(D.pow(2).sum(-1), min=1e-10))
    if scale == 'linear':
        return D
    elif scale == 'log':
        S = 2 * torch.log(torch.clamp(D, 1e-10, float("inf")))
        return S
    else:
        pass

# STFT code is adapted from: https://github.com/pseeth/pytorch-stft 
示例20
def _power_loss(self, p_y, t_y):
        fft_orig = torch.stft(t_y.reshape(t_y.shape[0]), n_fft=512,
                              window=torch.hann_window(window_length=512).to(device))
        fft_pred = torch.stft(p_y.reshape(p_y.shape[0]), n_fft=512,
                              window=torch.hann_window(window_length=512).to(device))
        real_orig = fft_orig[:, :, 0]
        im_org = fft_orig[:, :, 1]
        power_orig = torch.sqrt(torch.pow(real_orig, 2) + torch.pow(im_org, 2))
        real_pred = fft_pred[:, :, 0]
        im_pred = fft_pred[:, :, 1]
        power_pred = torch.sqrt(torch.pow(real_pred, 2) + torch.pow(im_pred, 2))
        return torch.sum(torch.pow(torch.norm(torch.abs(power_pred) - torch.abs(power_orig), p=2, dim=1), 2)) / (
                power_pred.shape[0] * power_pred.shape[1]) 
示例21
def forward(self, x, stfts_orig):
        stfts = []
        # First compute multiple STFT for x
        for i, scale in enumerate(self.scales):
            cur_fft = torch.stft(x, n_fft=scale, window=self.windows[i], hop_length=int((1-self.overlap)*scale), center=False)
            stfts.append(amp(cur_fft))
        # Compute loss 
        lin_loss = sum([torch.mean(abs(stfts_orig[i][j] - stfts[i][j])) for j in range(len(stfts[i])) for i in range(len(stfts))])
        log_loss = sum([torch.mean(abs(torch.log(stfts_orig[i][j] + 1e-4) - torch.log(stfts[i][j] + 1e-4)))  for j in range(len(stfts[i])) for i in range(len(stfts))])
        return lin_loss + log_loss 
示例22
def forward(self, x):
        stfts = []
        for i, scale in enumerate(self.scales):
            cur_fft = torch.stft(x, n_fft=scale, window=self.windows[i], hop_length=int((1-self.overlap)*scale), center=False)
            stfts.append(amp(cur_fft))
        if (self.reshape):
            stft_tab = []
            for b in range(x.shape[0]):
                cur_fft = []
                for s, _ in enumerate(self.scales):
                    cur_fft.append(stfts[s][b])
                stft_tab.append(cur_fft)
            stfts = stft_tab
        return stfts 
示例23
def compute_stft(audio, n_fft=1024, win_length=1024, hop_length=256):
    """
    Computes STFT transformation of given audio

    Args:
        audio (Tensor): B x T, batch of audio

    Returns:
        mag (Tensor): STFT magnitudes
        real (Tensor): Real part of STFT transformation result
        im (Tensor): Imagine part of STFT transformation result
    """
    win = torch.hann_window(win_length).cuda()

    # add some padding because torch 4.0 doesn't
    signal_dim = audio.dim()
    extended_shape = [1] * (3 - signal_dim) + list(audio.size())
    # pad = int(self.n_fft // 2)
    pad = win_length
    audio = F.pad(audio.view(extended_shape), (pad, pad), 'constant')
    audio = audio.view(audio.shape[-signal_dim:])

    stft = torch.stft(audio, win_length, hop_length, fft_size=n_fft, window=win)
    real = stft[:, :, :, 0]
    im = stft[:, :, :, 1]
    power = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))
    return power, real, im 
示例24
def forward(self, x):
        if self.preemp is not None:
            x = x.unsqueeze(1)
            # conv and remove last padding
            x = self.preemp(x)[:, :, :-1]
            x = x.squeeze(1)

        # center=True
        # torch 0.4 doesnt support like librosa
        signal_dim = x.dim()
        extended_shape = [1] * (3 - signal_dim) + list(x.size())
        # pad = int(self.n_fft // 2)
        pad = self.win_length
        x = F.pad(x.view(extended_shape), (pad, pad), 'constant')
        x = x.view(x.shape[-signal_dim:])
        stft = torch.stft(x,
                          self.win_length,
                          self.hop_length,
                          window=self.win,
                          fft_size=self.n_fft)
        real = stft[:, :, :, 0]
        im = stft[:, :, :, 1]
        p = torch.sqrt(torch.pow(real, 2) + torch.pow(im, 2))

        # convert volume to db
        spec = _amp_to_db(p) - hparams.ref_level_db
        return spec, p 
示例25
def __call__(self, signal):
        spectrogram = torch.stft(
            torch.FloatTensor(signal),
            self.n_fft,
            hop_length=self.hop_length,
            win_length=self.n_fft,
            window=torch.hamming_window(self.n_fft),
            center=False,
            normalized=False,
            onesided=True
        )
        spectrogram = (spectrogram[:, :, 0].pow(2) + spectrogram[:, :, 1].pow(2)).pow(0.5)
        spectrogram = np.log1p(spectrogram.numpy())

        return spectrogram 
示例26
def _stft(self, data: torch.Tensor, n_fft: int, hop_length: int):
        win_length = n_fft
        window = self._stft_window

        stft = torch.stft(
            data,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=True,
            pad_mode='reflect',
            normalized=False,
        )
        return stft 
示例27
def stft(y):
    D = torch.stft(y, n_fft=1024, hop_length=256, win_length=1024, window=torch.hann_window(1024).cuda())
    D = torch.sqrt(D.pow(2).sum(-1) + 1e-10)
    S = 2 * torch.log(torch.clamp(D, 1e-10, float("inf")))
    return D, S 
示例28
def stft(y, scale='linear'):
    D = torch.stft(y, n_fft=1024, hop_length=256, win_length=1024)#, window=torch.hann_window(1024).cuda())
    D = torch.sqrt(D.pow(2).sum(-1) + 1e-10)
    # D = torch.sqrt(torch.clamp(D.pow(2).sum(-1), min=1e-10))
    if scale == 'linear':
        return D
    elif scale == 'log':
        S = 2 * torch.log(torch.clamp(D, 1e-10, float("inf")))
        return S
    else:
        pass 
示例29
def spectrogram(
        waveform: Tensor,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: Optional[float],
        normalized: bool
) -> Tensor:
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
    The spectrogram can be either magnitude-only or complex.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., time)
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
        power (float or None): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead.
        normalized (bool): Whether to normalize by magnitude after stft

    Returns:
        Tensor: Dimension (..., freq, time), freq is
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
        Fourier bins, and time is the number of window hops (n_frame).
    """

    if pad > 0:
        # TODO add "with torch.no_grad():" back when JIT supports it
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")

    # pack batch
    shape = waveform.size()
    waveform = waveform.reshape(-1, shape[-1])

    # default values are consistent with librosa.core.spectrum._spectrogram
    spec_f = torch.stft(
        waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
    )

    # unpack batch
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])

    if normalized:
        spec_f /= window.pow(2.).sum().sqrt()
    if power is not None:
        spec_f = complex_norm(spec_f, power=power)

    return spec_f 
示例30
def create_fb_matrix(
        n_freqs: int,
        f_min: float,
        f_max: float,
        n_mels: int,
        sample_rate: int,
        norm: Optional[str] = None
) -> Tensor:
    r"""Create a frequency bin conversion matrix.

    Args:
        n_freqs (int): Number of frequencies to highlight/apply
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
        n_mels (int): Number of mel filterbanks
        sample_rate (int): Sample rate of the audio waveform
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)

    Returns:
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
    """

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

    # freq bins
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

    # calculate mel freq bins
    # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
    m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
    m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
    # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
    f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
    # create overlapping triangles
    zero = torch.zeros(1)
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))

    if norm is not None and norm == "slaney":
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

    return fb