[mlir][tosa] Add RFFT2d operation
authorLuke Hutton <luke.hutton@arm.com>
Tue, 24 Jan 2023 23:42:02 +0000 (15:42 -0800)
committerJacques Pienaar <jpienaar@google.com>
Tue, 24 Jan 2023 23:42:02 +0000 (15:42 -0800)
Adds the RFFT2d TOSA operation and supporting
shape inference function.

Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: I7e49c47cdd846cdc1b187545ef76d5cda2d5d9ad

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D142336

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

index afdd801..35a0996 100644 (file)
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
 
 #include "mlir/Dialect/Traits.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
index 6609c6b..f922166 100644 (file)
@@ -271,6 +271,34 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
 }
 
 //===----------------------------------------------------------------------===//
+// Operator: rfft2d
+//===----------------------------------------------------------------------===//
+def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    Pure]> {
+  let summary = "Performs RFFT2D operation on the input.";
+
+  let description = [{
+    Performs a batched 2D real-valued Fast Fourier Transform over the input where
+    the input tensor consists of real values producing complex valued output. The
+    complex output values will be split into the output_real and output_imag
+    tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only
+    calculate the first half of the final output axis. Imaginary values with
+    locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero.
+  }];
+
+  let arguments = (ins
+    Tosa_Tensor3D:$input
+  );
+
+  let results = (outs
+    Tosa_Tensor3D:$output_real,
+    Tosa_Tensor3D:$output_imag
+  );
+}
+
+//===----------------------------------------------------------------------===//
 // Operator: transpose_conv2d
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
index 338f530..e78a0b1 100644 (file)
@@ -387,6 +387,31 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape = operands.getShape(0);
+
+  if (!inputShape.hasRank())
+    return failure();
+
+  llvm::SmallVector<int64_t> outputShape;
+  outputShape.resize(3, ShapedType::kDynamic);
+  outputShape[0] = inputShape.getDimSize(0);
+  outputShape[1] = inputShape.getDimSize(1);
+  int64_t inWidth = inputShape.getDimSize(2);
+
+  // Note that we can support this calculation symbolically
+  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+  if (inWidth != ShapedType::kDynamic)
+    outputShape[2] = inWidth / 2 + 1;
+
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
index 4599ca8..fa82577 100644 (file)
@@ -73,6 +73,13 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32
 }
 
 // -----
+// CHECK-LABEL: rfft2d
+func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
+  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+}
+
+// -----
 // CHECK-LABEL: transpose_conv2d
 func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
   %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
index f6740db..c955d57 100644 (file)
@@ -1189,3 +1189,30 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
   }) : (tensor<i32>, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_static_rfft2d
+func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () {
+  // CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>)
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x8xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_batch_rfft2d
+func.func @test_dynamic_batch_rfft2d(%arg0 : tensor<?x2x4xf32>) -> () {
+  // CHECK: -> (tensor<?x2x3xf32>, tensor<?x2x3xf32>)
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x2x4xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_width_rfft2d
+func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () {
+  // CHECK: -> (tensor<5x2x?xf32>, tensor<5x2x?xf32>)
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return
+}