[mlir][mem2reg] Add mem2reg rewrite pattern.
authorThéo Degioanni <theo.degioanni@nextsilicon.com>
Tue, 9 May 2023 14:01:31 +0000 (14:01 +0000)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Tue, 9 May 2023 14:01:45 +0000 (14:01 +0000)
This revision introduces the ability to invoke mem2reg as a rewrite pattern. This also modified the canonical mem2reg pass to use the rewrite pattern approach.

Depends on D149825

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D149958

mlir/include/mlir/Transforms/Mem2Reg.h
mlir/lib/Transforms/Mem2Reg.cpp

index 1593b12..a34ea68 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 
 namespace mlir {
@@ -117,6 +118,19 @@ private:
   MemorySlotPromotionInfo info;
 };
 
+/// Pattern applying mem2reg to the regions of the operations on which it
+/// matches.
+class Mem2RegPattern : public RewritePattern {
+public:
+  using RewritePattern::RewritePattern;
+
+  Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1)
+      : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+};
+
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult
index 633813f..a4bf97c 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
@@ -22,6 +23,8 @@ namespace mlir {
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "mem2reg"
+
 using namespace mlir;
 
 /// mem2reg
@@ -422,6 +425,9 @@ void MemorySlotPromoter::promoteSlot() {
     }
   }
 
+  LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
+                          << "\n");
+
   allocator.handlePromotionComplete(slot, defaultValue);
 }
 
@@ -450,39 +456,49 @@ LogicalResult mlir::tryToPromoteMemorySlots(
   return success(!toPromote.empty());
 }
 
-namespace {
+LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
+                                              PatternRewriter &rewriter) const {
+  hasBoundedRewriteRecursion();
 
-struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
-  void runOnOperation() override {
-    Operation *scopeOp = getOperation();
-    bool changed = false;
+  if (op->getNumRegions() == 0)
+    return failure();
 
-    for (Region &region : scopeOp->getRegions()) {
-      if (region.getBlocks().empty())
-        continue;
+  DominanceInfo dominance;
 
-      OpBuilder builder(&region.front(), region.front().begin());
+  SmallVector<PromotableAllocationOpInterface> allocators;
+  // Build a list of allocators to attempt to promote the slots of.
+  for (Region &region : op->getRegions())
+    for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
+      allocators.emplace_back(allocator);
 
-      // Promoting a slot can allow for further promotion of other slots,
-      // promotion is tried until no promotion succeeds.
-      while (true) {
-        DominanceInfo &dominance = getAnalysis<DominanceInfo>();
+  // Because pattern rewriters are normally not expressive enough to support a
+  // transformation like mem2reg, this uses an escape hatch to mark modified
+  // operations manually and operate outside of its context.
+  rewriter.startRootUpdate(op);
 
-        SmallVector<PromotableAllocationOpInterface> allocators;
-        // Build a list of allocators to attempt to promote the slots of.
-        for (Block &block : region)
-          for (Operation &op : block.getOperations())
-            if (auto allocator = dyn_cast<PromotableAllocationOpInterface>(op))
-              allocators.emplace_back(allocator);
+  OpBuilder builder(rewriter.getContext());
 
-        // Attempt promoting until no promotion succeeds.
-        if (failed(tryToPromoteMemorySlots(allocators, builder, dominance)))
-          break;
+  if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
+    rewriter.cancelRootUpdate(op);
+    return failure();
+  }
 
-        changed = true;
-        getAnalysisManager().invalidate({});
-      }
-    }
+  rewriter.finalizeRootUpdate(op);
+  return success();
+}
+
+namespace {
+
+struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+  void runOnOperation() override {
+    Operation *scopeOp = getOperation();
+    bool changed = false;
+
+    RewritePatternSet rewritePatterns(&getContext());
+    rewritePatterns.add<Mem2RegPattern>(&getContext());
+    FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+    (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
+                                 &changed);
 
     if (!changed)
       markAllAnalysesPreserved();