[mlir] Add RewriterBase::replaceAllUsesWith for Blocks.
authorIngo Müller <ingomueller@google.com>
Tue, 24 Jan 2023 15:42:38 +0000 (15:42 +0000)
committerIngo Müller <ingomueller@google.com>
Wed, 15 Feb 2023 07:23:21 +0000 (07:23 +0000)
When changing IR in a RewriterPattern, all changes must go through the
rewriter. There are several convenience functions in RewriterBase that
help with high-level modifications, such as replaceAllUsesWith for
Values, but there is currently none to do the same task for Blocks.

Reviewed By: mehdi_amini, ingomueller-net

Differential Revision: https://reviews.llvm.org/D142525

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
mlir/test/Transforms/test-strict-pattern-driver.mlir
mlir/test/lib/Dialect/Test/TestPatterns.cpp

index 64eb66b..187ce06 100644 (file)
@@ -505,7 +505,16 @@ public:
   /// Find uses of `from` and replace them with `to`. It also marks every
   /// modified uses and notifies the rewriter that an in-place operation
   /// modification is about to happen.
-  void replaceAllUsesWith(Value from, Value to);
+  void replaceAllUsesWith(Value from, Value to) {
+    return replaceAllUsesWith(from.getImpl(), to);
+  }
+  template <typename OperandType, typename ValueT>
+  void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
+    for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
+      Operation *op = operand.getOwner();
+      updateRootInPlace(op, [&]() { operand.set(to); });
+    }
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
   /// true. It also marks every modified uses and notifies the rewriter that an
index b082b0d..1ca86cd 100644 (file)
@@ -309,14 +309,6 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
   source->erase();
 }
 
-/// Find uses of `from` and replace it with `to`
-void RewriterBase::replaceAllUsesWith(Value from, Value to) {
-  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
-    Operation *op = operand.getOwner();
-    updateRootInPlace(op, [&]() { operand.set(to); });
-  }
-}
-
 /// Find uses of `from` and replace them with `to` if the `functor` returns
 /// true. It also marks every modified uses and notifies the rewriter that an
 /// in-place operation modification is about to happen.
index 9dbaea1..5df2d6d 100644 (file)
@@ -1,4 +1,8 @@
 // RUN: mlir-opt \
+// RUN:     -test-strict-pattern-driver="strictness=AnyOp" \
+// RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-AN
+
+// RUN: mlir-opt \
 // RUN:     -test-strict-pattern-driver="strictness=ExistingAndNewOps" \
 // RUN:     --split-input-file %s | FileCheck %s --check-prefix=CHECK-EN
 
@@ -58,3 +62,24 @@ func.func @test_replace_with_erase_op() {
   "test.replace_with_new_op"() {create_erase_op} : () -> ()
   return
 }
+
+// -----
+
+// CHECK-AN-LABEL: func @test_trigger_rewrite_through_block
+//       CHECK-AN: "test.change_block_op"()[^[[BB0:.*]], ^[[BB0]]]
+//       CHECK-AN: return
+//       CHECK-AN: ^[[BB1:[^:]*]]:
+//       CHECK-AN: "test.implicit_change_op"()[^[[BB1]]]
+func.func @test_trigger_rewrite_through_block() {
+  return
+^bb1:
+  // Uses bb1. ChangeBlockOp replaces that and all other usages of bb1 with bb2.
+  "test.change_block_op"() [^bb1, ^bb2] : () -> ()
+^bb2:
+  return
+^bb3:
+  // Also uses bb1. ChangeBlockOp replaces that usage with bb2. This triggers
+  // this op being put on the worklist, which triggers ImplicitChangeOp, which,
+  // in turn, replaces the successor with bb3.
+  "test.implicit_change_op"() [^bb1] : () -> ()
+}
index a15d307..4bfbb34 100644 (file)
@@ -256,11 +256,19 @@ public:
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
     mlir::RewritePatternSet patterns(ctx);
-    patterns.add<InsertSameOp, ReplaceWithNewOp, EraseOp>(ctx);
+    patterns.add<
+        // clang-format off
+        InsertSameOp,
+        ReplaceWithNewOp,
+        EraseOp,
+        ChangeBlockOp,
+        ImplicitChangeOp
+        // clang-format on
+        >(ctx);
     SmallVector<Operation *> ops;
     getOperation()->walk([&](Operation *op) {
       StringRef opName = op->getName().getStringRef();
-      if (opName == "test.insert_same_op" ||
+      if (opName == "test.insert_same_op" || opName == "test.change_block_op" ||
           opName == "test.replace_with_new_op" || opName == "test.erase_op") {
         ops.push_back(op);
       }
@@ -342,7 +350,7 @@ private:
     }
   };
 
-  // Remove an operation may introduce the re-visiting of its opreands.
+  // Remove an operation may introduce the re-visiting of its operands.
   class EraseOp : public RewritePattern {
   public:
     EraseOp(MLIRContext *context)
@@ -353,6 +361,55 @@ private:
       return success();
     }
   };
+
+  // The following two patterns test RewriterBase::replaceAllUsesWith.
+  //
+  // That function replaces all usages of a Block (or a Value) with another one
+  // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver
+  // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its
+  // worklist: when an op is modified, it is added to the worklist. The two
+  // patterns below make the tracking observable: ChangeBlockOp replaces all
+  // usages of a block and that pattern is applied because the corresponding ops
+  // are put on the initial worklist (see above). ImplicitChangeOp does an
+  // unrelated change but ops of the corresponding type are *not* on the initial
+  // worklist, so the effect of the second pattern is only visible if the
+  // tracking and subsequent adding to the worklist actually works.
+
+  // Replace all usages of the first successor with the second successor.
+  class ChangeBlockOp : public RewritePattern {
+  public:
+    ChangeBlockOp(MLIRContext *context)
+        : RewritePattern("test.change_block_op", /*benefit=*/1, context) {}
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      if (op->getNumSuccessors() < 2)
+        return failure();
+      Block *firstSuccessor = op->getSuccessor(0);
+      Block *secondSuccessor = op->getSuccessor(1);
+      if (firstSuccessor == secondSuccessor)
+        return failure();
+      // This is the function being tested:
+      rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor);
+      // Using the following line instead would make the test fail:
+      // firstSuccessor->replaceAllUsesWith(secondSuccessor);
+      return success();
+    }
+  };
+
+  // Changes the successor to the parent block.
+  class ImplicitChangeOp : public RewritePattern {
+  public:
+    ImplicitChangeOp(MLIRContext *context)
+        : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {}
+    LogicalResult matchAndRewrite(Operation *op,
+                                  PatternRewriter &rewriter) const override {
+      if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
+        return failure();
+      rewriter.updateRootInPlace(
+          op, [&]() { op->setSuccessor(op->getBlock(), 0); });
+      return success();
+    }
+  };
 };
 
 } // namespace