return success();
}
-struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
+struct ConsolidateTransposeOptimization
+ : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- auto perm = op.getPerms();
+ // Input is also TransposeOp - transpose(transpose(A)).
+ auto innerTranspose =
+ transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
+ if (!innerTranspose)
+ return rewriter.notifyMatchFailure(transposeOp,
+ "input must be transpose operation");
+
+ SmallVector<int64_t> transposePerms, innerTransposePerms;
+ if (transposeOp.getConstantPerms(transposePerms).failed())
+ return rewriter.notifyMatchFailure(transposeOp,
+ "transpose perms must be constant");
+ if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
+ return rewriter.notifyMatchFailure(
+ transposeOp, "inner transpose perms must be constant");
+ if (transposePerms.size() != innerTransposePerms.size())
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "transpose and inner transpose perms sizes must be equal");
+ if (transposePerms.empty())
+ return rewriter.notifyMatchFailure(
+ transposeOp, "transpose perms sizes must be positive");
- DenseIntElementsAttr permAttr;
- if (!matchPattern(perm, m_Constant(&permAttr))) {
- return failure();
- }
+ // Consolidate transposes into one transpose.
+ SmallVector<int32_t> perms(transposePerms.size());
+ for (int i = 0, s = transposePerms.size(); i < s; ++i)
+ perms[i] = innerTransposePerms[transposePerms[i]];
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ auto permsTy =
+ RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
+ auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
+ Value permsValue =
+ rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
- for (int i = 0, s = permValues.size(); i < s; i++) {
- if (i != permValues[i]) {
- return failure();
- }
- }
+ rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
+ transposeOp, transposeOp.getResult().getType(),
+ innerTranspose.getInput1(), permsValue);
- rewriter.replaceOp(op, op.getInput1());
return success();
}
};
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<TransposeNoOp, TransposeIsReshape>(context);
+ results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[1])
- return {};
-
auto inputTy = getInput1().getType().cast<ShapedType>();
auto resultTy = getType().cast<ShapedType>();
- if (inputTy.getElementType() != resultTy.getElementType())
- return {};
// Transposing splat values just means reshaping.
if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
- if (input.isSplat())
- return input.reshape(getType().cast<ShapedType>());
+ if (input.isSplat() && resultTy.hasStaticShape() &&
+ inputTy.getElementType() == resultTy.getElementType())
+ return input.reshape(resultTy);
}
- auto perms = llvm::to_vector<6>(llvm::map_range(
- operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ // Transpose does not change the input type.
+ if (getInput1().getType() != getType())
+ return {};
- if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
- getInput1().getType() == getType())
- return getInput1();
- return {};
+ // Transpose is not the identity transpose.
+ SmallVector<int64_t> perms;
+ if (getConstantPerms(perms).failed())
+ return {};
+
+ if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
+ return {};
+
+ return getInput1();
}
--- /dev/null
+// RUN: mlir-opt %s --canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @test_cancel_transpose_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
+// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32>
+// CHECK: }
+
+func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
+ %0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
+ %2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
+ %3 = "tosa.transpose"(%1, %2) : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
+ return %3 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_remove_identity_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3xi32>) -> tensor<1x2x3xi32> {
+// CHECK: return %[[VAL_0]] : tensor<1x2x3xi32>
+// CHECK: }
+
+func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
+ %0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<1x2x3xi32>)
+ return %1 : tensor<1x2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
+// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
+// CHECK: }
+
+func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
+ %0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> (tensor<3x4x2x5xi32>)
+ %2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
+ %3 = "tosa.transpose"(%1, %2) : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
+ return %3 : tensor<5x4x3x2xi32>
+}