From 4ada6c2aafffd90c87900cab0adbb4d43c874b9b Mon Sep 17 00:00:00 2001 From: not-jenni Date: Mon, 18 Oct 2021 16:22:01 -0700 Subject: [PATCH] [mlir][tosa] Adds a canonicalization to the transpose op if the perms are a no op Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D112037 --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 28 ++++++++++++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 13 ++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 9c8d4ac..2ad14f5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -222,9 +222,37 @@ struct ConstantTransposeOptimization } }; +struct NoOpOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto perm = op.perms(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(perm, m_Constant(&permAttr))) { + return failure(); + } + + SmallVector permValues = llvm::to_vector<6>( + llvm::map_range(permAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + for (int i = 0, s = permValues.size(); i < s; i++) { + if (i != permValues[i]) { + return failure(); + } + } + + rewriter.replaceOp(op, op.input1()); + return success(); + } +}; + void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5fe5bd4..983ce58 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -233,7 +233,7 @@ func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { // CHECK-LABEL: @transpose_nofold_shape func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { // CHECK: "tosa.transpose" - %0 = arith.constant dense<[0, 1]> : tensor<2xi32> + %0 = arith.constant dense<[1, 0]> : tensor<2xi32> %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor return %1 : tensor } @@ -325,3 +325,14 @@ func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> } + +// ----- + +// CHECK-LABEL: @transpose_no_op +func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.transpose + %perms = "tosa.const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tosa.transpose"(%arg0, %perms) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x4x5x6xf32> + return %1 : tensor<3x4x5x6xf32> +} -- 2.7.4