From 198d1d99769700d0136ac90275a8d6fa1871accf Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 18 Jan 2023 08:59:02 -0800 Subject: [PATCH] [mlir][tosa] Prefer tosa.transpose composition canonicalization to reshape It is preferred to merge tosa.transpose operations together rather than convert one to a tosa.reshape. This is to leverage the tosa.transpose -> tosa.transpose merging canonicalization. Reviewed By: AviadCo Differential Revision: https://reviews.llvm.org/D141434 --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 11 +++++++++++ mlir/test/{IR => Dialect/Tosa}/transpose-fold.mlir | 17 +++++++++++++++++ 2 files changed, 28 insertions(+) rename mlir/test/{IR => Dialect/Tosa}/transpose-fold.mlir (70%) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 625c855..74325c8 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -188,6 +188,17 @@ struct TransposeIsReshape : public OpRewritePattern { if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) return rewriter.notifyMatchFailure(op, "Non-constant permutation"); + if (op.getInput1().getDefiningOp()) + return rewriter.notifyMatchFailure( + op, "Src is from transpose, can compose transposes"); + + Value result = op.getResult(); + for (Operation *subop : result.getUsers()) { + if (dyn_cast_or_null(subop)) + return rewriter.notifyMatchFailure( + op, "Dest is used by transpose, can compose transposes"); + } + auto input = op.getInput1(); auto inputTy = input.getType().cast(); if (!inputTy.hasRank()) diff --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/Dialect/Tosa/transpose-fold.mlir similarity index 70% rename from mlir/test/IR/transpose-fold.mlir rename to mlir/test/Dialect/Tosa/transpose-fold.mlir index 1079bf3e..df49b79 100644 --- a/mlir/test/IR/transpose-fold.mlir +++ b/mlir/test/Dialect/Tosa/transpose-fold.mlir @@ -42,3 +42,20 @@ func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> %3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32> return %3 : tensor<5x4x3x2xi32> } + +// ----- + +// CHECK-LABEL: func.func @test_prefer_compose_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32> +// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32> +// CHECK: } + +func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) { + %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> (tensor<2x3x1x4xi32>) + %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32> + %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32> + return %3 : tensor<4x3x2x1xi32> +} -- 2.7.4