[mlir] RewriterBase::Listener: Add notifyOperationModified callback
authorMatthias Springer <springerm@google.com>
Wed, 22 Feb 2023 09:41:22 +0000 (10:41 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 22 Feb 2023 09:41:57 +0000 (10:41 +0100)
This callback is triggered by `finalizeRootUpdate`. This allows listeners to listen for in-place op modifications without creating a new RewriterBase subclass.

Differential Revision: https://reviews.llvm.org/D143380

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

index 7845cb1..daade92 100644 (file)
@@ -402,6 +402,9 @@ public:
     Listener()
         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
 
+    /// Notify the listener that the specified operation was modified in-place.
+    virtual void notifyOperationModified(Operation *op) {}
+
     /// Notify the listener that the specified operation is about to be replaced
     /// with the set of values potentially produced by new operations. This is
     /// called before the uses of the operation have been changed.
@@ -514,7 +517,7 @@ public:
   /// This method is used to signal the end of a root update on the given
   /// operation. This can only be called on operations that were provided to a
   /// call to `startRootUpdate`.
-  virtual void finalizeRootUpdate(Operation *op) {}
+  virtual void finalizeRootUpdate(Operation *op);
 
   /// This method cancels a pending root update. This can only be called on
   /// operations that were provided to a call to `startRootUpdate`.
index 10baea6..958d82a 100644 (file)
@@ -294,6 +294,12 @@ void RewriterBase::eraseBlock(Block *block) {
   block->erase();
 }
 
+void RewriterBase::finalizeRootUpdate(Operation *op) {
+  // Notify the listener that the operation was modified.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyOperationModified(op);
+}
+
 /// Merge the operations of block 'source' into the end of block 'dest'.
 /// 'source's predecessors must be empty or only contain 'dest`.
 /// 'argValues' is used to replace the block arguments of 'source' after
index 0d78362..52ebd54 100644 (file)
@@ -1654,6 +1654,7 @@ void ConversionPatternRewriter::startRootUpdate(Operation *op) {
 }
 
 void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
+  PatternRewriter::finalizeRootUpdate(op);
   // There is nothing to do here, we only need to track the operation at the
   // start of the update.
 #ifndef NDEBUG
index 89088d5..15495ef 100644 (file)
@@ -54,7 +54,7 @@ protected:
 
   /// Notify the driver that the specified operation may have been modified
   /// in-place. The operation is added to the worklist.
-  void finalizeRootUpdate(Operation *op) override;
+  void notifyOperationModified(Operation *op) override;
 
   /// Notify the driver that the specified operation was inserted. Update the
   /// worklist as needed: The operation is enqueued depending on scope and
@@ -335,7 +335,7 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   addToWorklist(op);
 }
 
-void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) {
+void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) {
   LLVM_DEBUG({
     logger.startLine() << "** Modified: '" << op->getName() << "'(" << op
                        << ")\n";