--- /dev/null
+//
+// This implementation of 2D convolution
+// uses Fast Fourier Transform to speed up
+// computation on larger matrices.
+//
+// For principles of work refer to:
+// - Convolution theorem
+// - Discrete Fourier Transform
+// - Fast Fourier Transform
+// - https://arxiv.org/abs/1312.5851
+//
+// This implementation is for testing purposes and
+// speeding up the interpreter. After we decide on
+// CG IR, FG IR and code generation this will be
+// implemented as optimization pass.
+//
+//
+// No implementation yet, so the interfaces
+// of some methods are subject to change.
+//
+
+
+#ifndef _NNC_CORE_BACKEND_INTERPRETER_CONV2D_FFT_IMPL_
+#define _NNC_CORE_BACKEND_INTERPRETER_CONV2D_FFT_IMPL_
+
+#include <complex>
+
+#include "interpreter/ops/OperationImpl.h"
+#include "nnc/core/IR/model/operations/conv_2d_op.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace backend
+{
+namespace interpreter
+{
+namespace impl
+{
+using nncc::contrib::core::IR::model::ops::Conv2DOp;
+using nncc::contrib::core::IR::model::ops::PaddingType;
+using nncc::contrib::core::data::Tensor;
+
+typedef std::complex<float> FFT_complex;
+
+
+class Conv2D_FFT : public OperationImpl<float>
+{
+public:
+ explicit Conv2D_FFT(const TensorVariant &input, const Conv2DOp &op);
+ std::vector<TensorVariant> operator()() override;
+
+protected:
+ ///
+ /// Perform Fast Fourier transform on tensor and return the result.
+ /// No support for complex tensors yet, so we have to use something else
+ ///
+ std::vector<FFT_complex> fft(const Tensor<float> &tensor);
+
+ ///
+ /// Perform Inverse Fast Fourier transform on tensor and return the result.
+ ///
+ TensorVariant ifft(const std::vector<FFT_complex> &spectre,
+ const Shape &out_shape,
+ const Shape &strides,
+ const Index &paddings);
+
+private:
+ const Tensor<float> _input;
+ Tensor<float> _kernel;
+ const Shape _strides;
+ const PaddingType _padding;
+ const Shape &_out_shape;
+ const Conv2DOp &_op;
+};
+
+} // namespace impl
+} // namespace interpreter
+} // namespace backend
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_BACKEND_INTERPRETER_CONV2D_FFT_IMPL_
--- /dev/null
+#include <cmath>
+
+#include "nnc/core/linalg/ShapeRange.h"
+
+#include "interpreter/ops/conv_FFT.h"
+#include "interpreter/ops/common.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace backend
+{
+namespace interpreter
+{
+namespace impl
+{
+
+using namespace nncc::core::ADT::tensor;
+
+Index reduce(const Index &idx)
+{
+ Index res = idx;
+ res.resize(idx.rank() - 1);
+ return res;
+}
+
+// Mostly compatible with tensorflow implementation
+// Assuming input is in NHWC format with batch omitted( [in_height, in_width, in_channels] )
+// Kernel is in [filter_height, filter_width, in_channels, out_channels]
+// Refer to https://www.tensorflow.org/api_docs/python/tf/nn/conv2d for info
+std::vector<TensorVariant> Conv2D_FFT::operator()()
+{
+ auto res = allocate_tensor(_out_shape);
+
+ // TODO: implement
+ //
+ // 1. Pad input (clamp to zero, clamp to edge, wrap)
+ // 2. Pad kernel with zeroes to match padded input shape
+ // 3. FFT input and kernel
+ // 4. Elementwise production
+ // 5. IFFT input
+ // 6. Match output shape
+
+ return {res};
+}
+
+Conv2D_FFT::Conv2D_FFT(const TensorVariant &input, const Conv2DOp &op)
+ : _input(input), _kernel(op.getKernel()), _strides(op.getStrides()),
+ _padding(op.getPaddingType()),
+ _out_shape(op.getOutputShape(0)), _op(op)
+{
+ // Same assertions as in Conv2D
+ assert(_op.getInputShape(0).rank() == 3);
+ assert(input.getShape().rank() == 3);
+ assert(_kernel.getShape().rank() == 4);
+ assert(_strides.dim(2) == 1);
+ assert(_op.getPadding(2) == 0);
+}
+
+std::vector<FFT_complex> Conv2D_FFT::fft(const Tensor<float> &tensor)
+{
+ std::vector<FFT_complex> res;
+
+ // TODO: implement
+
+ return res;
+}
+
+TensorVariant Conv2D_FFT::ifft(const std::vector<FFT_complex> &spectre,
+ const Shape &out_shape,
+ const Shape &strides,
+ const Index &paddings)
+{
+ TensorVariant res = allocate_tensor(out_shape);
+
+ // TODO: implement
+
+ return res;
+}
+
+} // namespace impl
+} // namespace interpreter
+} // namespace backend
+} // namespace contrib
+} // namespace nncc