[mlir:MultiOpDriver] Don't add ops which are not in the allowed list
authorChia-hung Duan <chiahungduan@google.com>
Thu, 2 Jun 2022 18:27:36 +0000 (18:27 +0000)
committerChia-hung Duan <chiahungduan@google.com>
Thu, 2 Jun 2022 18:27:37 +0000 (18:27 +0000)
In strict mode, only the new inserted operation is allowed to add to the
worklist. Before this change, it would add the users of a replaced op
and it didn't check if the users are allowed to be pushed into the
worklist

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D126899

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index a08657c..be7924e 100644 (file)
@@ -539,12 +539,27 @@ private:
     }
   }
 
+  void notifyOperationInserted(Operation *op) override {
+    GreedyPatternRewriteDriver::notifyOperationInserted(op);
+    if (strictMode)
+      strictModeFilteredOps.insert(op);
+  }
+
   void notifyOperationRemoved(Operation *op) override {
     GreedyPatternRewriteDriver::notifyOperationRemoved(op);
     if (strictMode)
       strictModeFilteredOps.erase(op);
   }
 
+  void notifyRootReplaced(Operation *op) override {
+    for (auto result : op->getResults()) {
+      for (auto *user : result.getUsers()) {
+        if (!strictMode || strictModeFilteredOps.contains(user))
+          addToWorklist(user);
+      }
+    }
+  }
+
   /// If `strictMode` is true, any pre-existing ops outside of
   /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
   /// If `strictMode` is false, operations that use results of (or supply
@@ -592,6 +607,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
   SmallVector<Value, 8> originalOperands, resultValues;
   while (!worklist.empty()) {
     Operation *op = popFromWorklist();
+    assert((!strictMode || strictModeFilteredOps.contains(op)) &&
+           "unexpected op was inserted under strict mode");
 
     // Nulls get added to the worklist when operations are removed, ignore
     // them.