[mlir] Fix worklist bug in MultiOpPatternRewriteDriver
authorMatthias Springer <springerm@google.com>
Tue, 10 Jan 2023 14:30:49 +0000 (15:30 +0100)
committerMatthias Springer <springerm@google.com>
Tue, 10 Jan 2023 14:33:22 +0000 (15:33 +0100)
When `strict = true`, only pre-existing and newly-created ops are rewritten and/or folded. Such ops are stored in `strictModeFilteredOps`.

Newly-created ops were previously added to `strictModeFilteredOps` after calling `addToWorklist` (via `GreedyPatternRewriteDriver::notifyOperationInserted`). Therefore, newly-created ops were never added to the worklist.

Also fix a test case that should have gone into an infinite loop (`test.replace_with_new_op` was replaced with itself, which should have caused the op to be rewritten over and over), but did not due to this bug.

Differential Revision: https://reviews.llvm.org/D141141

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index 5005a08..cdb0b78 100644 (file)
@@ -558,9 +558,9 @@ public:
 
 private:
   void notifyOperationInserted(Operation *op) override {
-    GreedyPatternRewriteDriver::notifyOperationInserted(op);
     if (strictMode)
       strictModeFilteredOps.insert(op);
+    GreedyPatternRewriteDriver::notifyOperationInserted(op);
   }
 
   void notifyOperationRemoved(Operation *op) override {
index 51d2969..8c6eaf3 100644 (file)
@@ -1,6 +1,9 @@
 // RUN: mlir-opt -allow-unregistered-dialect -test-strict-pattern-driver %s | FileCheck %s
 
-// CHECK-LABEL: @test_erase
+// CHECK-LABEL: func @test_erase
+//       CHECK:   test.arg0
+//       CHECK:   test.arg1
+//   CHECK-NOT:   test.erase_op
 func.func @test_erase() {
   %0 = "test.arg0"() : () -> (i32)
   %1 = "test.arg1"() : () -> (i32)
@@ -8,16 +11,29 @@ func.func @test_erase() {
   return
 }
 
-// CHECK-LABEL: @test_insert_same_op
+// CHECK-LABEL: func @test_insert_same_op
+//       CHECK:   "test.insert_same_op"() {skip = true}
+//       CHECK:   "test.insert_same_op"() {skip = true}
 func.func @test_insert_same_op() {
   %0 = "test.insert_same_op"() : () -> (i32)
   return
 }
 
-// CHECK-LABEL: @test_replace_with_same_op
-func.func @test_replace_with_same_op() {
-  %0 = "test.replace_with_same_op"() : () -> (i32)
+// CHECK-LABEL: func @test_replace_with_new_op
+//       CHECK:   %[[n:.*]] = "test.new_op"
+//       CHECK:   "test.dummy_user"(%[[n]])
+//       CHECK:   "test.dummy_user"(%[[n]])
+func.func @test_replace_with_new_op() {
+  %0 = "test.replace_with_new_op"() : () -> (i32)
   %1 = "test.dummy_user"(%0) : (i32) -> (i32)
   %2 = "test.dummy_user"(%0) : (i32) -> (i32)
   return
 }
+
+// CHECK-LABEL: func @test_replace_with_erase_op
+//   CHECK-NOT:   test.replace_with_new_op
+//   CHECK-NOT:   test.erase_op
+func.func @test_replace_with_erase_op() {
+  "test.replace_with_new_op"() {create_erase_op} : () -> ()
+  return
+}
index 9b74e80..2573f76 100644 (file)
@@ -220,12 +220,12 @@ public:
 
   void runOnOperation() override {
     mlir::RewritePatternSet patterns(&getContext());
-    patterns.add<InsertSameOp, ReplaceWithSameOp, EraseOp>(&getContext());
+    patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(&getContext());
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
       if (opName == "test.insert_same_op" ||
-          opName == "test.replace_with_same_op" || opName == "test.erase_op") {
+          opName == "test.replace_with_new_op" || opName == "test.erase_op") {
         ops.push_back(op);
       }
     });
@@ -260,16 +260,25 @@ private:
   };
 
   // Replace an operation may introduce the re-visiting of its users.
-  class ReplaceWithSameOp : public RewritePattern {
+  class ReplaceWithNewOp : public RewritePattern {
   public:
-    ReplaceWithSameOp(MLIRContext *context)
-        : RewritePattern("test.replace_with_same_op", /*benefit=*/1, context) {}
+    ReplaceWithNewOp(MLIRContext *context)
+        : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
 
     LogicalResult matchAndRewrite(Operation *op,
                                   PatternRewriter &rewriter) const override {
-      Operation *newOp =
-          rewriter.create(op->getLoc(), op->getName().getIdentifier(),
-                          op->getOperands(), op->getResultTypes());
+      Operation *newOp;
+      if (op->hasAttr("create_erase_op")) {
+        newOp = rewriter.create(
+            op->getLoc(),
+            OperationName("test.erase_op", op->getContext()).getIdentifier(),
+            ValueRange(), TypeRange());
+      } else {
+        newOp = rewriter.create(
+            op->getLoc(),
+            OperationName("test.new_op", op->getContext()).getIdentifier(),
+            op->getOperands(), op->getResultTypes());
+      }
       rewriter.replaceOp(op, newOp->getResults());
       return success();
     }