[mlir] GreedyPatternRewriteDriver: Fix termination criteria in OpPatternRewriteDriver
authorMatthias Springer <springerm@google.com>
Wed, 18 Jan 2023 14:10:14 +0000 (15:10 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 18 Jan 2023 14:11:06 +0000 (15:11 +0100)
This driver should iterate until convergence or until the specified op was erased. However, it used to stop when any op was erased.

Differential Revision: https://reviews.llvm.org/D141921

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 6bd3994..b7ea592 100644 (file)
@@ -481,7 +481,8 @@ protected:
   /// If an operation is about to be removed, mark it so that we can let clients
   /// know.
   void notifyOperationRemoved(Operation *op) override {
-    opErasedViaPatternRewrites = true;
+    if (this->op == op)
+      opErasedViaPatternRewrites = true;
   }
 
   // When a root is going to be replaced, its removal will be notified as well.
@@ -495,6 +496,9 @@ private:
   /// Non-pattern based folder for operations.
   OperationFolder folder;
 
+  /// Op that is being processed.
+  Operation *op = nullptr;
+
   /// Set to true if the operation has been erased via pattern rewrites.
   bool opErasedViaPatternRewrites = false;
 };
@@ -509,6 +513,7 @@ private:
 LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
                                                       int64_t maxNumRewrites,
                                                       bool &erased) {
+  this->op = op;
   bool changed = false;
   erased = false;
   opErasedViaPatternRewrites = false;