results.insert<ReshapeReshapeOptimization>(context);
}
+struct ConstantTransposeOptimization
+ : public OpRewritePattern<tosa::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ DenseElementsAttr inputValues;
+ if (!matchPattern(op.input1(), m_Constant(&inputValues)))
+ return failure();
+ // Make sure the input is a constant that has a single user.
+ if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
+ return failure();
+
+ DenseIntElementsAttr permAttr;
+ if (!matchPattern(op.perms(), m_Constant(&permAttr)))
+ return failure();
+ auto permValues = llvm::to_vector<6>(llvm::map_range(
+ // TOSA allows both 32- and 64-bit integer tensors here.
+ permAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getZExtValue(); }));
+
+ auto inputType = op.input1().getType().cast<ShapedType>();
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ int64_t numElements = inputType.getNumElements();
+
+ auto outputType = op.getType().cast<ShapedType>();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+
+ SmallVector<Attribute, 4> outputValues;
+ outputValues.resize(numElements);
+
+ // Transpose the input constant. Because we don't know its rank in advance,
+ // we need to loop over the range [0, element count) and delinearize the
+ // index.
+ for (int srcLinearIndex = 0; srcLinearIndex < numElements;
+ ++srcLinearIndex) {
+ SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
+ int totalCount = srcLinearIndex;
+ for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
+ srcIndices[dim] = totalCount % inputShape[dim];
+ totalCount /= inputShape[dim];
+ }
+
+ SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
+ for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
+ dstIndices[dim] = srcIndices[permValues[dim]];
+
+ uint64_t dstLinearIndex = dstIndices.front();
+ for (int dim = 1; dim < outputType.getRank(); ++dim)
+ dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
+
+ outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(
+ op, outputType, DenseElementsAttr::get(outputType, outputValues));
+ return success();
+ }
+};
+
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<ConstantTransposeOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
if (!operands[1])
return {};
- DenseIntElementsAttr perms = operands[1].cast<DenseIntElementsAttr>();
-
- bool isRange = true;
- for (auto it : llvm::enumerate(perms)) {
- isRange = isRange &&
- it.value().getSExtValue() == static_cast<int64_t>(it.index());
+ // Transposing splat values just means reshaping.
+ if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+ if (input.isSplat())
+ return input.reshape(getType().cast<ShapedType>());
}
- if (isRange && input1().getType() == getType())
+ auto perms = llvm::to_vector<6>(llvm::map_range(
+ operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); }));
+
+ if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
+ input1().getType() == getType())
return input1();
return {};
}
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
%1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_splat
+func @transpose_fold_splat() -> tensor<3x2xf32> {
+ %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: %[[CST:.+]] = "tosa.const"()
+ // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32>
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_2d_float
+func @transpose_fold_2d_float() -> tensor<3x2xf32> {
+ %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: %[[CST:.+]] = "tosa.const"()
+ // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_fold_4d_int
+func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> {
+ %input = "tosa.const"() {value = dense<[[
+ [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
+ [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
+ ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32>
+ %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+ // CHECK: %[[CST:.+]] = "tosa.const"()
+ // CHECK-SAME{LITERAL}: value = dense<[
+ // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
+ // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
+ // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
+ // CHECK-SAME{LITERAL}: ]>
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<3x1x4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_input
+func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> {
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: tosa.transpose
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_non_cst_perms
+func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
+ %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ // CHECK: tosa.transpose
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1 : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_nofold_multi_users
+func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) {
+ %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
+ %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: tosa.transpose
+ %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
+ return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
+}