[mlir] GreedyPatternRewriter: Reprocess modified ops
authorMatthias Springer <springerm@google.com>
Fri, 18 Nov 2022 10:18:19 +0000 (11:18 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 18 Nov 2022 10:43:44 +0000 (11:43 +0100)
Ops that were modifed in-place (`finalizeRootUpdate` was called) should be reprocessed by the GreedyPatternRewriter. This is currently not happening with `GreedyRewriteConfig::maxIterations = 1`.

Note: If your project goes into an infinite loop because of this change, you likely have one or multiple faulty patterns that modify the same operations in-place (`updateRootInplace`) indefinitely.

Differential Revision: https://reviews.llvm.org/D138038

mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/IR/greedy-pattern-rewriter-driver.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index 386f296..52f2b83 100644 (file)
@@ -525,7 +525,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
   auto inBoundsAttr = b.getBoolArrayAttr(bools);
   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
-    xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    b.updateRootInPlace(xferOp, [&]() {
+      xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    });
     return success();
   }
 
@@ -596,7 +598,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
 
-    xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    b.updateRootInPlace(xferOp, [&]() {
+      xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr);
+    });
 
     return success();
   }
@@ -623,7 +627,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
   else
     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
 
-  xferOp->erase();
+  b.eraseOp(xferOp);
 
   return success();
 }
@@ -634,11 +638,5 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
       failed(filter(xferOp)))
     return failure();
-  rewriter.startRootUpdate(xferOp);
-  if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) {
-    rewriter.finalizeRootUpdate(xferOp);
-    return success();
-  }
-  rewriter.cancelRootUpdate(xferOp);
-  return failure();
+  return splitFullAndPartialTransfer(rewriter, xferOp, options);
 }
index 9c62d61..935ca2e 100644 (file)
@@ -51,6 +51,10 @@ public:
   /// If the specified operation is in the worklist, remove it.
   void removeFromWorklist(Operation *op);
 
+  /// Notifies the driver that the specified operation may have been modified
+  /// in-place.
+  void finalizeRootUpdate(Operation *op) override;
+
 protected:
   // Implement the hook for inserting operations, and make sure that newly
   // inserted ops are added to the worklist for processing.
@@ -326,6 +330,14 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   addToWorklist(op);
 }
 
+void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) {
+  LLVM_DEBUG({
+    logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+  addToWorklist(op);
+}
+
 void GreedyPatternRewriteDriver::addOperandsToWorklist(ValueRange operands) {
   for (Value operand : operands) {
     // If the use count of this operand is now < 2, we re-add the defining
diff --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
new file mode 100644 (file)
index 0000000..4f1a06f
--- /dev/null
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s
+
+// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
+func.func @add_to_worklist_after_inplace_update() {
+  // The following op is updated in-place and should be added back to the
+  // worklist of the GreedyPatternRewriteDriver (regardless of the value of
+  // config.max_iterations).
+
+  // CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> ()
+  "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
+  return
+}
index 12f3747..2d2bf8d 100644 (file)
@@ -147,6 +147,26 @@ public:
   }
 };
 
+/// This pattern matches test.any_attr_of_i32_str ops. In case of an integer
+/// attribute with value smaller than MaxVal, it increments the value by 1.
+template <int MaxVal>
+struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
+  using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AnyAttrOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto intAttr = op.getAttr().dyn_cast<IntegerAttr>();
+    if (!intAttr)
+      return failure();
+    int64_t val = intAttr.getInt();
+    if (val >= MaxVal)
+      return failure();
+    rewriter.updateRootInPlace(
+        op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
+    return success();
+  }
+};
+
 struct TestPatternDriver
     : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -165,8 +185,12 @@ struct TestPatternDriver
                  FolderInsertBeforePreviouslyFoldedConstantPattern,
                  FolderCommutativeOp2WithConstant>(&getContext());
 
+    // Additional patterns for testing the GreedyPatternRewriteDriver.
+    patterns.insert<IncrementIntAttribute<3>>(&getContext());
+
     GreedyRewriteConfig config;
     config.useTopDownTraversal = this->useTopDownTraversal;
+    config.maxIterations = this->maxIterations;
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
                                        config);
   }
@@ -175,6 +199,10 @@ struct TestPatternDriver
       *this, "top-down",
       llvm::cl::desc("Seed the worklist in general top-down order"),
       llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
+  Option<int> maxIterations{
+      *this, "max-iterations",
+      llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"),
+      llvm::cl::init(GreedyRewriteConfig().maxIterations)};
 };
 
 struct TestStrictPatternDriver