#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/IR/Action.h"
+
namespace mlir {
class PatternRewriter;
class PDLByteCodeMutableState;
} // namespace detail
+/// This is the type of Action that is dispatched when a pattern is applied.
+/// It captures the pattern to apply on top of the usual context.
+class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
+public:
+ using Base = tracing::ActionImpl<ApplyPatternAction>;
+ ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
+ : Base(irUnits), pattern(pattern) {}
+ static constexpr StringLiteral tag = "apply-pattern-action";
+ static constexpr StringLiteral desc =
+ "Encapsulate the application of rewrite patterns";
+
+ void print(raw_ostream &os) const override {
+ os << "`" << tag << "`\n"
+ << " pattern: " << pattern.getDebugName() << '\n';
+ }
+
+private:
+ const Pattern &pattern;
+};
+
/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite. For PDL patterns, the
// match has already been performed, we just need to rewrite.
- rewriter.setInsertionPoint(op);
+ bool matched = false;
+ op->getContext()->executeAction<ApplyPatternAction>(
+ [&]() {
+ rewriter.setInsertionPoint(op);
#ifndef NDEBUG
- // Operation `op` may be invalidated after applying the rewrite pattern.
- Operation *dumpRootOp = getDumpRootOp(op);
+ // Operation `op` may be invalidated after applying the rewrite
+ // pattern.
+ Operation *dumpRootOp = getDumpRootOp(op);
#endif
- if (pdlMatch) {
- result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
- } else {
- LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
- << bestPattern->getDebugName() << "\"\n");
-
- const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
- result = pattern->matchAndRewrite(op, rewriter);
-
- LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
- << "\" result " << succeeded(result) << "\n");
- }
-
- // Process the result of the pattern application.
- if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
- result = failure();
- if (succeeded(result)) {
- LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+ if (pdlMatch) {
+ result =
+ bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+ } else {
+ LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
+ << bestPattern->getDebugName() << "\"\n");
+
+ const auto *pattern =
+ static_cast<const RewritePattern *>(bestPattern);
+ result = pattern->matchAndRewrite(op, rewriter);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "\"" << bestPattern->getDebugName() << "\" result "
+ << succeeded(result) << "\n");
+ }
+
+ // Process the result of the pattern application.
+ if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
+ result = failure();
+ if (succeeded(result)) {
+ LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+ matched = true;
+ return;
+ }
+
+ // Perform any necessary cleanups.
+ if (onFailure)
+ onFailure(*bestPattern);
+ },
+ {op}, *bestPattern);
+ if (matched)
break;
- }
-
- // Perform any necessary cleanups.
- if (onFailure)
- onFailure(*bestPattern);
} while (true);
if (mutableByteCodeState)