[mlir] specify the values when notifying about op replacement
authorAlex Zinenko <zinenko@google.com>
Tue, 27 Sep 2022 16:08:44 +0000 (16:08 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 27 Sep 2022 16:22:35 +0000 (16:22 +0000)
It is useful for PatternRewriter listeners to know the values that are
replacing the op in addition to only the fact of the op being replaced
for being able to keep track of changes or for debugging.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D134748

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 8d62d31..fa21d0c 100644 (file)
@@ -546,9 +546,9 @@ protected:
   /// they would like to be notified about certain types of mutations.
 
   /// Notify the rewriter that the specified operation is about to be replaced
-  /// with another set of operations. This is called before the uses of the
-  /// operation have been changed.
-  virtual void notifyRootReplaced(Operation *op) {}
+  /// with the set of values potentially produced by new operations. This is
+  /// called before the uses of the operation have been changed.
+  virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {}
 
   /// This is called on an operation that a rewrite is removing, right before
   /// the operation is deleted. At this point, the operation has zero uses.
index 56063d0..494d90f 100644 (file)
@@ -216,7 +216,7 @@ void RewriterBase::replaceOpWithIf(
          "incorrect number of values to replace operation");
 
   // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op);
+  notifyRootReplaced(op, newValues);
 
   // Replace each use of the results when the functor is true.
   bool replacedAllUses = true;
@@ -244,7 +244,7 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
 /// the operation.
 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   // Notify the rewriter subclass that we're about to replace this root.
-  notifyRootReplaced(op);
+  notifyRootReplaced(op, newValues);
 
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
index 7305a37..9c62d61 100644 (file)
@@ -69,7 +69,7 @@ protected:
   // When the root of a pattern is about to be replaced, it can trigger
   // simplifications to its users - make sure to add them to the worklist
   // before the root is changed.
-  void notifyRootReplaced(Operation *op) override;
+  void notifyRootReplaced(Operation *op, ValueRange replacement) override;
 
   /// PatternRewriter hook for erasing a dead operation.
   void eraseOp(Operation *op) override;
@@ -348,7 +348,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
   });
 }
 
-void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
+void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
+                                                    ValueRange replacement) {
   LLVM_DEBUG({
     logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
                        << ")\n";
@@ -437,7 +438,7 @@ protected:
 
   // When a root is going to be replaced, its removal will be notified as well.
   // So there is nothing to do here.
-  void notifyRootReplaced(Operation *op) override {}
+  void notifyRootReplaced(Operation *op, ValueRange replacement) override {}
 
 private:
   /// The low-level pattern applicator.