[mlir][tosa] Canonicalize tosa.transpose to tosa.reshape
authorRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 19:06:04 +0000 (11:06 -0800)
committerRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 19:19:55 +0000 (11:19 -0800)
Added tosa.transpose canonicalization for case where a tosa.transpose is
equivalent to a tosa.reshape. This occurs when the permutation does not
permutate non-unary dimensions.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D140356

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
mlir/test/Dialect/Tosa/constant-op-fold.mlir

index 215a4cc..1f91d43 100644 (file)
@@ -88,16 +88,22 @@ struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
   LogicalResult matchAndRewrite(tosa::ReshapeOp op,
                                 PatternRewriter &rewriter) const override {
     Value input = op.getInput1();
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType resultTy = op.getType().cast<ShapedType>();
     ArrayAttr newShape = op.getNewShape();
 
+    if (inputTy.getElementType() != resultTy.getElementType())
+      return rewriter.notifyMatchFailure(op, "element type does not match.");
+
     // Check if input is constant
     DenseElementsAttr inputAttr;
     if (!matchPattern(input, m_Constant(&inputAttr)))
-      return failure();
+      return rewriter.notifyMatchFailure(op, "Non-constant input.");
 
     // Check if has >1 consumer and is not splat
     if (!input.hasOneUse() && !inputAttr.isSplat())
-      return failure();
+      return rewriter.notifyMatchFailure(op,
+                                         "Used more than once or not-splat");
 
     // Grab the new shape
     SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
@@ -132,7 +138,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
   return success();
 }
 
-struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
+struct TransposeNoOp : public OpRewritePattern<tosa::TransposeOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
@@ -159,9 +165,60 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
   }
 };
 
+// Determines the case when tosa.transpose is a tosa.reshape operation.
+struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    DenseIntElementsAttr permAttr;
+    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
+      return rewriter.notifyMatchFailure(op, "Non-constant permutation");
+
+    auto input = op.getInput1();
+    auto inputTy = input.getType().cast<ShapedType>();
+    if (!inputTy.hasRank())
+      return rewriter.notifyMatchFailure(op, "Unranked input.");
+
+    int64_t numDynDims = 0;
+    for (int i = 0; i < inputTy.getRank(); ++i)
+      if (inputTy.isDynamicDim(i))
+        numDynDims++;
+
+    if (numDynDims > 1)
+      return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
+
+    SmallVector<int64_t> permValues = llvm::to_vector<6>(
+        llvm::map_range(permAttr.getValues<APInt>(),
+                        [](const APInt &val) { return val.getSExtValue(); }));
+
+    SmallVector<int64_t> nonZeroPerms;
+    nonZeroPerms.reserve(permValues.size());
+    for (auto idx : permValues) {
+      auto sz = inputTy.getDimSize(idx);
+      if (sz != 1)
+        nonZeroPerms.push_back(idx);
+    }
+
+    for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
+      if (nonZeroPerms[i - 1] > nonZeroPerms[i])
+        return rewriter.notifyMatchFailure(op,
+                                           "Transpose changes memeory layout.");
+
+    SmallVector<int64_t> newShape;
+    newShape.reserve(inputTy.getRank());
+    for (int i = 0, s = inputTy.getRank(); i < s; ++i)
+      newShape.push_back(inputTy.getDimSize(permValues[i]));
+
+    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+        op, op.getType(), op.getInput1(), rewriter.getI64ArrayAttr(newShape));
+    return success();
+  }
+};
+
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<NoOpOptimization>(context);
+  results.add<TransposeNoOp, TransposeIsReshape>(context);
 }
 
 struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
@@ -958,6 +1015,11 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[1])
     return {};
 
+  auto inputTy = getInput1().getType().cast<ShapedType>();
+  auto resultTy = getType().cast<ShapedType>();
+  if (inputTy.getElementType() != resultTy.getElementType())
+    return {};
+
   // Transposing splat values just means reshaping.
   if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
     if (input.isSplat())
index ca25100..7eea232 100644 (file)
@@ -400,6 +400,14 @@ func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
   return %1 : tensor<3x4x5x6xf32>
 }
 
+// CHECK-LABEL: @transpose_is_reshape
+func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> {
+  // CHECK: "tosa.reshape"(%arg0) {new_shape = [1, 4, 1, 5]} : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32>
+  %perms = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %0 = "tosa.transpose"(%arg0, %perms) : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32>
+  return %0 : tensor<1x4x1x5xf32>
+}
+
 // CHECK-LABEL: @single_bit_reshape
 // https://github.com/llvm/llvm-project/issues/55440
 func.func @single_bit_reshape() -> tensor<1xi1> {
index 0811578..1ca93fe 100644 (file)
@@ -90,12 +90,12 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
 }
 
 // CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> {
+func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
   %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8>
+  %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
   // CHECK: tosa.transpose
-  %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
-  return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
+  %0 = "tosa.transpose"(%input, %perms) : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
+  return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
 }
 
 // -----