[mlir][tosa] Add FFT2d operation
authorLuke Hutton <luke.hutton@arm.com>
Tue, 14 Mar 2023 18:55:39 +0000 (18:55 +0000)
committerRobert Suderman <suderman@google.com>
Tue, 14 Mar 2023 19:04:52 +0000 (19:04 +0000)
Adds the FFT2d TOSA operation and supporting
shape inference function.

Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Reviewed By: rsuderman, eric-k256

Differential Revision: https://reviews.llvm.org/D144784

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

index fc4cf7d..be5720c 100644 (file)
@@ -184,6 +184,36 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
 }
 
 //===----------------------------------------------------------------------===//
+// Operator: fft2d
+//===----------------------------------------------------------------------===//
+def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    Pure]> {
+  let summary = "Performs FFT2D operation on the input.";
+
+  let description = [{
+    Performs a batched complex 2D Fast Fourier Transform over the input. The
+    complex input values are constructed from the corresponding values in the
+    input_real and input_imag tensors. The resulting values in the output are
+    split into the output_real and output_imag tensors. No normalization is
+    applied on either the forward or inverse versions of the operation.
+  }];
+
+  let arguments = (ins
+    Tosa_Tensor3D:$input_real,
+    Tosa_Tensor3D:$input_imag,
+
+    BoolAttr:$inverse
+  );
+
+  let results = (outs
+    Tosa_Tensor3D:$output_real,
+    Tosa_Tensor3D:$output_imag
+  );
+}
+
+//===----------------------------------------------------------------------===//
 // Operator: fully_connected
 //===----------------------------------------------------------------------===//
 def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
index c3b161a..d7bb6d0 100644 (file)
@@ -409,6 +409,16 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+
+  return success();
+}
+
+LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
+  inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
   return success();
 }
 
index fa82577..68eca32 100644 (file)
@@ -52,6 +52,13 @@ func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4
 }
 
 // -----
+// CHECK-LABEL: fft2d
+func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  %0, %1 = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
 // CHECK-LABEL: fully_connected
 func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
   %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32>
index c955d57..94eea3b 100644 (file)
@@ -1216,3 +1216,21 @@ func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () {
   %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_static_fft2d
+func.func @test_static_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // CHECK: -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %output_real, %output_imag : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_batch_fft2d
+func.func @test_dynamic_batch_fft2d(%arg0: tensor<?x4x8xf32>, %arg1: tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>) {
+  // CHECK: -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
+  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<?x4x8xf32>, tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
+  return %output_real, %output_imag : tensor<?x4x8xf32>, tensor<?x4x8xf32>
+}