Add tracing for pattern application in a ApplyPatternAction
authorMehdi Amini <joker.eph@gmail.com>
Tue, 11 Apr 2023 00:39:34 +0000 (18:39 -0600)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 11 Apr 2023 00:42:45 +0000 (18:42 -0600)
Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D144816

mlir/include/mlir/Rewrite/PatternApplicator.h
mlir/lib/Rewrite/PatternApplicator.cpp

index a2e2286..41ec95a 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 
+#include "mlir/IR/Action.h"
+
 namespace mlir {
 class PatternRewriter;
 
@@ -23,6 +25,26 @@ namespace detail {
 class PDLByteCodeMutableState;
 } // namespace detail
 
+/// This is the type of Action that is dispatched when a pattern is applied.
+/// It captures the pattern to apply on top of the usual context.
+class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
+public:
+  using Base = tracing::ActionImpl<ApplyPatternAction>;
+  ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
+      : Base(irUnits), pattern(pattern) {}
+  static constexpr StringLiteral tag = "apply-pattern-action";
+  static constexpr StringLiteral desc =
+      "Encapsulate the application of rewrite patterns";
+
+  void print(raw_ostream &os) const override {
+    os << "`" << tag << "`\n"
+       << " pattern: " << pattern.getDebugName() << '\n';
+  }
+
+private:
+  const Pattern &pattern;
+};
+
 /// This class manages the application of a group of rewrite patterns, with a
 /// user-provided cost model.
 class PatternApplicator {
index 499a850..08d6ee6 100644 (file)
@@ -185,35 +185,47 @@ LogicalResult PatternApplicator::matchAndRewrite(
     // Try to match and rewrite this pattern. The patterns are sorted by
     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
     // match has already been performed, we just need to rewrite.
-    rewriter.setInsertionPoint(op);
+    bool matched = false;
+    op->getContext()->executeAction<ApplyPatternAction>(
+        [&]() {
+          rewriter.setInsertionPoint(op);
 #ifndef NDEBUG
-    // Operation `op` may be invalidated after applying the rewrite pattern.
-    Operation *dumpRootOp = getDumpRootOp(op);
+          // Operation `op` may be invalidated after applying the rewrite
+          // pattern.
+          Operation *dumpRootOp = getDumpRootOp(op);
 #endif
-    if (pdlMatch) {
-      result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
-    } else {
-      LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
-                              << bestPattern->getDebugName() << "\"\n");
-
-      const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
-      result = pattern->matchAndRewrite(op, rewriter);
-
-      LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
-                              << "\" result " << succeeded(result) << "\n");
-    }
-
-    // Process the result of the pattern application.
-    if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
-      result = failure();
-    if (succeeded(result)) {
-      LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+          if (pdlMatch) {
+            result =
+                bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+          } else {
+            LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
+                                    << bestPattern->getDebugName() << "\"\n");
+
+            const auto *pattern =
+                static_cast<const RewritePattern *>(bestPattern);
+            result = pattern->matchAndRewrite(op, rewriter);
+
+            LLVM_DEBUG(llvm::dbgs()
+                       << "\"" << bestPattern->getDebugName() << "\" result "
+                       << succeeded(result) << "\n");
+          }
+
+          // Process the result of the pattern application.
+          if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
+            result = failure();
+          if (succeeded(result)) {
+            LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+            matched = true;
+            return;
+          }
+
+          // Perform any necessary cleanups.
+          if (onFailure)
+            onFailure(*bestPattern);
+        },
+        {op}, *bestPattern);
+    if (matched)
       break;
-    }
-
-    // Perform any necessary cleanups.
-    if (onFailure)
-      onFailure(*bestPattern);
   } while (true);
 
   if (mutableByteCodeState)