From ded988ed0c00e033aa7fa9ea42d7ad19f3dd983e Mon Sep 17 00:00:00 2001 From: Aviad Cohen Date: Tue, 10 Jan 2023 12:35:50 -0800 Subject: [PATCH] [mlir][tosa] Remove redundant "tosa.transpose" operations We can fold redundant Tosa::TransposeOp actions like identity tranpose/transpose(traspose). Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D140466 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 4 ++ mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 83 ++++++++++++++-------- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 14 ++++ mlir/test/IR/transpose-fold.mlir | 44 ++++++++++++ 4 files changed, 114 insertions(+), 31 deletions(-) create mode 100644 mlir/test/IR/transpose-fold.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index b73368f..6609c6b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1542,6 +1542,10 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [ outs Tosa_Tensor1Dto6D:$output ); + let extraClassDeclaration = [{ + LogicalResult getConstantPerms(llvm::SmallVector &perms); + }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index f8b48f1..5f44634 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -131,29 +131,49 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { return success(); } -struct TransposeNoOp : public OpRewritePattern { +struct ConsolidateTransposeOptimization + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tosa::TransposeOp op, + LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp, PatternRewriter &rewriter) const override { - auto perm = op.getPerms(); + // Input is also TransposeOp - transpose(transpose(A)). + auto innerTranspose = + transposeOp.getInput1().getDefiningOp(); + if (!innerTranspose) + return rewriter.notifyMatchFailure(transposeOp, + "input must be transpose operation"); + + SmallVector transposePerms, innerTransposePerms; + if (transposeOp.getConstantPerms(transposePerms).failed()) + return rewriter.notifyMatchFailure(transposeOp, + "transpose perms must be constant"); + if (innerTranspose.getConstantPerms(innerTransposePerms).failed()) + return rewriter.notifyMatchFailure( + transposeOp, "inner transpose perms must be constant"); + if (transposePerms.size() != innerTransposePerms.size()) + return rewriter.notifyMatchFailure( + transposeOp, + "transpose and inner transpose perms sizes must be equal"); + if (transposePerms.empty()) + return rewriter.notifyMatchFailure( + transposeOp, "transpose perms sizes must be positive"); - DenseIntElementsAttr permAttr; - if (!matchPattern(perm, m_Constant(&permAttr))) { - return failure(); - } + // Consolidate transposes into one transpose. + SmallVector perms(transposePerms.size()); + for (int i = 0, s = transposePerms.size(); i < s; ++i) + perms[i] = innerTransposePerms[transposePerms[i]]; - SmallVector permValues = llvm::to_vector<6>( - llvm::map_range(permAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); + auto permsTy = + RankedTensorType::get(transposePerms.size(), rewriter.getI32Type()); + auto permsAttr = DenseIntElementsAttr::get(permsTy, perms); + Value permsValue = + rewriter.create(transposeOp.getLoc(), permsAttr); - for (int i = 0, s = permValues.size(); i < s; i++) { - if (i != permValues[i]) { - return failure(); - } - } + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResult().getType(), + innerTranspose.getInput1(), permsValue); - rewriter.replaceOp(op, op.getInput1()); return success(); } }; @@ -212,7 +232,7 @@ struct TransposeIsReshape : public OpRewritePattern { void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } struct AddZeroOptimization : public OpRewritePattern { @@ -997,26 +1017,27 @@ OpFoldResult TileOp::fold(ArrayRef operands) { } OpFoldResult TransposeOp::fold(ArrayRef operands) { - if (!operands[1]) - return {}; - auto inputTy = getInput1().getType().cast(); auto resultTy = getType().cast(); - if (inputTy.getElementType() != resultTy.getElementType()) - return {}; // Transposing splat values just means reshaping. if (auto input = operands[0].dyn_cast_or_null()) { - if (input.isSplat()) - return input.reshape(getType().cast()); + if (input.isSplat() && resultTy.hasStaticShape() && + inputTy.getElementType() == resultTy.getElementType()) + return input.reshape(resultTy); } - auto perms = llvm::to_vector<6>(llvm::map_range( - operands[1].cast().getValues(), - [](const APInt &val) { return val.getSExtValue(); })); + // Transpose does not change the input type. + if (getInput1().getType() != getType()) + return {}; - if (llvm::equal(llvm::seq(0, perms.size()), perms) && - getInput1().getType() == getType()) - return getInput1(); - return {}; + // Transpose is not the identity transpose. + SmallVector perms; + if (getConstantPerms(perms).failed()) + return {}; + + if (!llvm::equal(llvm::seq(0, perms.size()), perms)) + return {}; + + return getInput1(); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 7ce0081..82f34f9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -688,6 +688,20 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { return mlir::success(); } +LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector &perms) { + // Perms must be constants. + DenseIntElementsAttr permsAttr; + if (!matchPattern(getPerms(), m_Constant(&permsAttr))) + return failure(); + + // Transpose is not the identity transpose. + perms = llvm::to_vector( + llvm::map_range(permsAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + return success(); +} + LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/mlir/test/IR/transpose-fold.mlir b/mlir/test/IR/transpose-fold.mlir new file mode 100644 index 0000000..1079bf3e --- /dev/null +++ b/mlir/test/IR/transpose-fold.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @test_cancel_transpose_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { +// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32> +// CHECK: } + +func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) { + %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>) + %2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32> + %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32> + return %3 : tensor<1x2x3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_remove_identity_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> { +// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32> +// CHECK: } + +func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) { + %0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<1x2x3xi32>) + return %1 : tensor<1x2x3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> { +// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32> +// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32> +// CHECK: } + +func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) { + %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> (tensor<3x4x2x5xi32>) + %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32> + %3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32> + return %3 : tensor<5x4x3x2xi32> +} -- 2.7.4