// tensors being contiguous, and that the strides at the innermost signal
// dimension being unit (1) w.r.t. the corresponding data type.
-#pragma push
-#pragma diag_suppress 177 // Function was declared but never referenced
-static inline Tensor _run_cufft(
- const CuFFTConfig &config, Tensor& input, int64_t signal_ndim,
- bool complex_input, bool complex_output, bool inverse,
- IntArrayRef checked_signal_sizes, fft_norm_mode norm, bool onesided,
- IntArrayRef output_sizes, bool input_was_cloned
-) {
- if (config.should_clone_input() && !input_was_cloned) {
- input = input.clone(at::MemoryFormat::Contiguous);
- }
-
- auto& plan = config.plan();
-
- // set output
- auto output = at::empty(output_sizes, input.options());
-
- // set to current stream
- CUFFT_CHECK(cufftSetStream(plan, at::cuda::getCurrentCUDAStream()));
-
- auto ws = at::empty({ config.workspace_size() }, at::device(at::kCUDA).dtype(at::kByte));
- CUFFT_CHECK(cufftSetWorkArea(plan, ws.data_ptr()));
-
- // run
- exec_cufft_plan(config, input.data_ptr(), output.data_ptr(), !inverse);
-
- // rescale if requested
- auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
- if (norm != fft_norm_mode::none) {
- auto signal_numel = c10::multiply_integers(checked_signal_sizes);
- double scale_denom;
- if (norm == fft_norm_mode::by_root_n) {
- scale_denom = std::sqrt(static_cast<double>(signal_numel));
- } else {
- scale_denom = static_cast<double>(signal_numel);
- }
- if (!complex_input && complex_output && !onesided) {
- auto end_data_slice = infer_ft_real_to_complex_onesided_size(size_last_signal_dim);
- output.narrow(signal_ndim, 0, end_data_slice).div_(scale_denom);
- } else {
- output.div_(scale_denom);
- }
- }
-
- // if needed, fill out the other half using conjugate symmetry
- if (!complex_input && complex_output && !onesided) {
- DimVector signal_dims(signal_ndim);
- std::iota(signal_dims.begin(), signal_dims.end(), 1);
- auto out_as_complex = at::view_as_complex(output);
- at::native::_fft_fill_with_conjugate_symmetry_(out_as_complex, signal_dims);
- }
- return output;
-}
-#pragma pop
-
// The cuFFT plan cache
// unique_ptr for nullability and to avoid reference invalidation on vector resize
static std::vector<std::unique_ptr<CuFFTParamsLRUCache>> plan_caches;