#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
namespace mlir {
MemorySlotPromotionInfo info;
};
+/// Pattern applying mem2reg to the regions of the operations on which it
+/// matches.
+class Mem2RegPattern : public RewritePattern {
+public:
+ using RewritePattern::RewritePattern;
+
+ Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
/// Attempts to promote the memory slots of the provided allocators. Succeeds if
/// at least one memory slot was promoted.
LogicalResult
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "mem2reg"
+
using namespace mlir;
/// mem2reg
}
}
+ LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
+ << "\n");
+
allocator.handlePromotionComplete(slot, defaultValue);
}
return success(!toPromote.empty());
}
-namespace {
+LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ hasBoundedRewriteRecursion();
-struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
- void runOnOperation() override {
- Operation *scopeOp = getOperation();
- bool changed = false;
+ if (op->getNumRegions() == 0)
+ return failure();
- for (Region ®ion : scopeOp->getRegions()) {
- if (region.getBlocks().empty())
- continue;
+ DominanceInfo dominance;
- OpBuilder builder(®ion.front(), region.front().begin());
+ SmallVector<PromotableAllocationOpInterface> allocators;
+ // Build a list of allocators to attempt to promote the slots of.
+ for (Region ®ion : op->getRegions())
+ for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
+ allocators.emplace_back(allocator);
- // Promoting a slot can allow for further promotion of other slots,
- // promotion is tried until no promotion succeeds.
- while (true) {
- DominanceInfo &dominance = getAnalysis<DominanceInfo>();
+ // Because pattern rewriters are normally not expressive enough to support a
+ // transformation like mem2reg, this uses an escape hatch to mark modified
+ // operations manually and operate outside of its context.
+ rewriter.startRootUpdate(op);
- SmallVector<PromotableAllocationOpInterface> allocators;
- // Build a list of allocators to attempt to promote the slots of.
- for (Block &block : region)
- for (Operation &op : block.getOperations())
- if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
- allocators.emplace_back(allocator);
+ OpBuilder builder(rewriter.getContext());
- // Attempt promoting until no promotion succeeds.
- if (failed(tryToPromoteMemorySlots(allocators, builder, dominance)))
- break;
+ if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
+ rewriter.cancelRootUpdate(op);
+ return failure();
+ }
- changed = true;
- getAnalysisManager().invalidate({});
- }
- }
+ rewriter.finalizeRootUpdate(op);
+ return success();
+}
+
+namespace {
+
+struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+ void runOnOperation() override {
+ Operation *scopeOp = getOperation();
+ bool changed = false;
+
+ RewritePatternSet rewritePatterns(&getContext());
+ rewritePatterns.add<Mem2RegPattern>(&getContext());
+ FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+ (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
+ &changed);
if (!changed)
markAllAnalysesPreserved();