From 2c258d91cc1dc11c338e97d6970ac77a4f8978ec Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Thu, 2 Sep 2021 08:59:53 -0700 Subject: [PATCH] Fix torch.istft length mismatch and window runtime error (#63469) Summary: The PR fixes two issues: - See https://github.com/pytorch/pytorch/issues/62747 and https://github.com/pytorch/audio/issues/1409. The length mismatch when the given ``length`` parameter is longer than expected. Add padding logic in consistent with librosa. - See https://github.com/pytorch/pytorch/issues/62323. The current implementations checks if the min value of window_envelop.abs() is greater than zero. In librosa they normalize the signal on non-zero values by indexing. Like ``` approx_nonzero_indices = ifft_window_sum > util.tiny(ifft_window_sum) y[approx_nonzero_indices] /= ifft_window_sum[approx_nonzero_indices] ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/63469 Reviewed By: fmassa Differential Revision: D30695827 Pulled By: nateanl fbshipit-source-id: d034e53f0d65b3fd1dbd150c9c5acf3faf25a164 --- aten/src/ATen/native/SpectralOps.cpp | 10 +++++- test/test_spectral_ops.py | 64 ++++++++++++++++++++++++++++++++---- torch/functional.py | 3 +- 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index cd04207..f9472b1 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -920,7 +920,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho // We need to trim the front padding away if centered const auto start = center ? n_fft / 2 : 0; - const auto end = lengthOpt.has_value()? start + lengthOpt.value() : - n_fft / 2; + const auto end = lengthOpt.has_value() ? start + lengthOpt.value() : (center ? - n_fft / 2 : -1); y = y.slice(2, start, end, 1); window_envelop = window_envelop.slice(2, start, end, 1); @@ -935,6 +935,14 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho if (input_dim == 3) { y = y.squeeze(0); } + // zero padding if the given lengthOpt is longer than expected + if(end > expected_output_signal_len) { + TORCH_WARN_ONCE( + "The length of signal is shorter than the length parameter. Result is being padded with zeros in the tail. " + "Please check your center and hop_length settings." + ); + y = at::constant_pad_nd(y, {0, end - expected_output_signal_len}, 0); + } return y; #undef REPR diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index fdc8c01..f632e95 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -1126,9 +1126,6 @@ class TestFFT(TestCase): original = torch.randn(*sizes, dtype=dtype, device=device) stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) - - # trim the original for case when constructed signal is shorter than original - original = original[..., :inversed.size(-1)] self.assertEqual( inversed, original, msg='istft comparison against original', atol=7e-6, rtol=0, exact_dtype=True) @@ -1167,21 +1164,63 @@ class TestFFT(TestCase): 'normalized': True, 'onesided': False, }, - # hamming_window, not centered, not normalized, onesided + # hamming_window, centered, not normalized, onesided # window same size as n_fft { 'n_fft': 5, 'hop_length': 2, 'win_length': 5, 'window': torch.hamming_window(5, dtype=dtype, device=device), - 'center': False, + 'center': True, 'pad_mode': 'constant', 'normalized': False, 'onesided': True, }, + ] + for i, pattern in enumerate(patterns): + _test_istft_is_inverse_of_stft(pattern) + + @onlyOnCPUAndCUDA + @skipCPUIfNoFFT + @dtypes(torch.double) + def test_istft_round_trip_with_padding(self, device, dtype): + """long hop_length or not centered may cause length mismatch in the inversed signal""" + def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): + # generates a random sound signal for each tril and then does the stft/istft + # operation to check whether we can reconstruct signal + num_trials = 100 + sizes = stft_kwargs['size'] + del stft_kwargs['size'] + istft_kwargs = stft_kwargs.copy() + del istft_kwargs['pad_mode'] + for i in range(num_trials): + original = torch.randn(*sizes, dtype=dtype, device=device) + stft = torch.stft(original, return_complex=True, **stft_kwargs) + with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): + inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs) + n_frames = stft.size(-1) + if stft_kwargs["center"] is True: + len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1) + else: + len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1) + # trim the original for case when constructed signal is shorter than original + padding = inversed[..., len_expected:] + inversed = inversed[..., :len_expected] + original = original[..., :len_expected] + # test the padding points of the inversed signal are all zeros + zeros = torch.zeros_like(padding, device=padding.device) + self.assertEqual( + padding, zeros, msg='istft padding values against zeros', + atol=7e-6, rtol=0, exact_dtype=True) + self.assertEqual( + inversed, original, msg='istft comparison against original', + atol=7e-6, rtol=0, exact_dtype=True) + + patterns = [ # hamming_window, not centered, not normalized, not onesided # window same size as n_fft { + 'size': [2, 20], 'n_fft': 3, 'hop_length': 2, 'win_length': 3, @@ -1191,9 +1230,22 @@ class TestFFT(TestCase): 'normalized': False, 'onesided': False, }, + # hamming_window, centered, not normalized, onesided, long hop_length + # window same size as n_fft + { + 'size': [2, 500], + 'n_fft': 256, + 'hop_length': 254, + 'win_length': 256, + 'window': torch.hamming_window(256, dtype=dtype, device=device), + 'center': True, + 'pad_mode': 'constant', + 'normalized': False, + 'onesided': True, + }, ] for i, pattern in enumerate(patterns): - _test_istft_is_inverse_of_stft(pattern) + _test_istft_is_inverse_of_stft_with_padding(pattern) @onlyOnCPUAndCUDA def test_istft_throws(self, device): diff --git a/torch/functional.py b/torch/functional.py index 81b3de2..63470cf 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -569,7 +569,8 @@ def istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame, ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False - since the signal isn't padded). + since the signal isn't padded). If `length` is given in the arguments and is longer than expected, + ``istft`` will pad zeros to the end of the returned signal. If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc. Left padding can be trimmed off exactly because they can be calculated but right padding cannot be -- 2.7.4