[mlir][tosa] Set explicit benefit for tosa.reshape to linalg
authorRob Suderman <suderman@google.com>
Wed, 18 Jan 2023 16:37:05 +0000 (08:37 -0800)
committerRob Suderman <suderman@google.com>
Wed, 18 Jan 2023 16:38:02 +0000 (08:38 -0800)
The patterns used to lower tosa.reshape to linalg are order
dependent which varies depending on platform. Setting the
benefit appropriately guarantees compilation executes the
rewriters in the correct ordering.

Reviewed By: benvanik

Differential Revision: https://reviews.llvm.org/D141981

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

index bb4072e..aa41c96 100644 (file)
@@ -963,11 +963,6 @@ public:
           reshape, "Cannot collapse dynamic dims to more than one dimension");
     }
 
-    if (operandTy == resultTy) {
-      rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
-      return success();
-    }
-
     SmallVector<ReassociationExprs, 4> reassociationMap;
     if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
                                             resultTy.getShape(),
@@ -1001,11 +996,6 @@ public:
     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
     bool isDynamic = !operandTy.hasStaticShape();
 
-    if (operandTy == resultTy) {
-      rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
-      return success();
-    }
-
     if (isDynamic && operandTy.getRank() != 1) {
       return rewriter.notifyMatchFailure(
           reshape, "Cannot expand dynamic dims from more than one dimension");
@@ -1045,11 +1035,6 @@ public:
     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
     bool isDynamic = !operandTy.hasStaticShape();
 
-    if (operandTy == resultTy) {
-      rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
-      return success();
-    }
-
     SmallVector<int64_t> intermediateShape;
     if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
                                intermediateShape, isDynamic)) {
@@ -2310,6 +2295,13 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
   patterns->add<MaterializeResizeBroadcast>(patterns->getContext(),
                                             /*benefit=*/300);
 
+  patterns->add<ReshapeConverterCollapse>(patterns->getContext(),
+                                          /*benefit=*/100);
+  patterns->add<ReshapeConverterExpand>(patterns->getContext(),
+                                        /*benefit=*/200);
+  patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext(),
+                                                /*benefit=*/300);
+
   patterns->add<
       // clang-format off
       PointwiseConverter<tosa::AddOp>,
@@ -2357,9 +2349,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ArgMaxConverter,
       ConcatConverter,
       GatherConverter,
-      ReshapeConverterCollapse,
-      ReshapeConverterExpand,
-      ReshapeConverterCollapseExpand,
       RescaleConverter,
       ReverseConverter,
       TableConverter,