From: Rob Suderman Date: Tue, 13 Sep 2022 00:12:33 +0000 (-0700) Subject: [mlir][tosa] Added tosa.reverse folder X-Git-Tag: upstream/17.0.6~33724 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3e49a3e89dbd44bfeb7611d76c23e137e1676abe;p=platform%2Fupstream%2Fllvm.git [mlir][tosa] Added tosa.reverse folder Fold cases where a tosa.reverse is a splat or reversing a dim of length-1. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D133144 --- diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index c50fee2..8518d6b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1464,6 +1464,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index abe33f8..b346b11 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -846,6 +846,21 @@ OpFoldResult PadOp::fold(ArrayRef operands) { return {}; } +OpFoldResult ReverseOp::fold(ArrayRef operands) { + auto operand = getInput(); + auto operandTy = operand.getType().cast(); + auto axis = getAxis(); + auto operandAttr = operands[0].dyn_cast_or_null(); + if (operandAttr) + return operandAttr; + + // If the dim-length is 1, tosa.reverse is a no-op. + if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1) + return operand; + + return {}; +} + OpFoldResult SliceOp::fold(ArrayRef operands) { auto inputTy = getInput().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 28be1ff..0811578 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -540,3 +540,25 @@ func.func @cast_int_to_int_sign() -> tensor { // CHECK: return %[[SPLAT]] return %cast : tensor } + +// ----- + +// CHECK-LABEL: @reverse_splat +func.func @reverse_splat() -> tensor<10xi32> { + // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<10xi32>} + %splat = "tosa.const"() {value = dense<42> : tensor<10xi32>} : () -> tensor<10xi32> + %reverse = "tosa.reverse"(%splat) { axis = 0 : i64 } : (tensor<10xi32>) -> tensor<10xi32> + // CHECK: return %[[SPLAT]] + return %reverse : tensor<10xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_length_one +func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) { + %nofold = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + %fold = "tosa.reverse"(%arg0) { axis = 1 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32> + // CHECK: %[[NOFOLD:.+]] = "tosa.reverse"(%arg0) {axis = 0 : i64} + // CHECK: return %[[NOFOLD]], %arg0 + return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32> +}