namespace {
constexpr int64_t cufft_max_ndim = 3;
+// "Large" here means a prime factor not special-cased by cuFFT
+// Ref: https://docs.nvidia.com/cuda/cufft/index.html#accuracy-and-performance
+bool has_large_prime_factor(int64_t n) {
+ constexpr int64_t first_large_prime = 11;
+ const std::array<int64_t, 4> prime_radices{{2, 3, 5, 7}};
+ for (auto prime : prime_radices) {
+ if (n < first_large_prime) {
+ return false;
+ }
+
+ while (n % prime == 0) {
+ n /= prime;
+ }
+ }
+ return n != 1;
+}
+
// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
static const Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
IntArrayRef dim, bool forward) {
c10::optional<CuFFTConfig> uncached_plan;
const CuFFTConfig * config = nullptr;
- if (plan_cache.max_size() > 0) {
+ // Workaround for gh-63152, gh-58724
+ // Bluestein plans in CUDA 11.1 (cufft 10.3) cannot be re-used
+ // Bluestein's algorithm is only used when a size has large prime factors,
+ // sizes with only small prime factors can still be cached
+ bool use_caching = true;
+#ifdef CUFFT_VERSION
+ if (10300 <= CUFFT_VERSION && CUFFT_VERSION < 10400) {
+ // Only cache plans for transforms with small prime factors
+ use_caching = std::none_of(
+ signal_size.begin() + 1, signal_size.end(), [](int64_t dim_size) {
+ return has_large_prime_factor(dim_size);
+ });
+ }
+#endif
+
+ if (use_caching && plan_cache.max_size() > 0) {
guard.lock();
if (plan_cache.max_size() > 0) { // check again after acquiring the lock
config = &plan_cache.lookup(Params);
@onlyOnCPUAndCUDA
@skipCPUIfNoFFT
+ def test_fft_plan_repeatable(self, device):
+ # Regression test for gh-58724 and gh-63152
+ for n in [2048, 3199, 5999]:
+ a = torch.randn(n, device=device, dtype=torch.complex64)
+ res1 = torch.fft.fftn(a)
+ res2 = torch.fft.fftn(a.clone())
+ self.assertEqual(res1, res2)
+
+ a = torch.randn(n, device=device, dtype=torch.float64)
+ res1 = torch.fft.rfft(a)
+ res2 = torch.fft.rfft(a.clone())
+ self.assertEqual(res1, res2)
+
+ @onlyOnCPUAndCUDA
+ @skipCPUIfNoFFT
@dtypes(torch.double)
def test_istft_round_trip_simple_cases(self, device, dtype):
"""stft -> istft should recover the original signale"""