[mlir][tosa] Added tosa.reverse folder
authorRob Suderman <suderman@google.com>
Tue, 13 Sep 2022 00:12:33 +0000 (17:12 -0700)
committerRob Suderman <suderman@google.com>
Tue, 13 Sep 2022 00:15:17 +0000 (17:15 -0700)
Fold cases where a tosa.reverse is a splat or reversing a dim
of length-1.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D133144

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir

index c50fee2..8518d6b 100644 (file)
@@ -1464,6 +1464,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   let results = (outs
     Tosa_Tensor1Dto4D:$output
   );
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index abe33f8..b346b11 100644 (file)
@@ -846,6 +846,21 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
   return {};
 }
 
+OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
+  auto operand = getInput();
+  auto operandTy = operand.getType().cast<ShapedType>();
+  auto axis = getAxis();
+  auto operandAttr = operands[0].dyn_cast_or_null<SplatElementsAttr>();
+  if (operandAttr)
+    return operandAttr;
+
+  // If the dim-length is 1, tosa.reverse is a no-op.
+  if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
+    return operand;
+
+  return {};
+}
+
 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
   auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
index 28be1ff..0811578 100644 (file)
@@ -540,3 +540,25 @@ func.func @cast_int_to_int_sign() -> tensor<i32> {
   // CHECK: return %[[SPLAT]]
   return %cast : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: @reverse_splat
+func.func @reverse_splat() -> tensor<10xi32> {
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<10xi32>}
+  %splat = "tosa.const"() {value = dense<42> : tensor<10xi32>} : () -> tensor<10xi32>
+  %reverse = "tosa.reverse"(%splat) { axis = 0 : i64 } : (tensor<10xi32>) -> tensor<10xi32>
+  // CHECK: return %[[SPLAT]]
+  return %reverse : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reverse_length_one
+func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) {
+  %nofold = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+  %fold = "tosa.reverse"(%arg0) { axis = 1 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+  // CHECK: %[[NOFOLD:.+]] = "tosa.reverse"(%arg0) {axis = 0 : i64}
+  // CHECK: return %[[NOFOLD]], %arg0
+  return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32>
+}