[mlir] [mem2reg] Adapt to be pattern-friendly.
authorThéo Degioanni <theo.degioanni@nextsilicon.com>
Tue, 16 May 2023 08:35:00 +0000 (08:35 +0000)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Tue, 16 May 2023 08:35:13 +0000 (08:35 +0000)
This revision modifies the mem2reg interfaces and algorithm to be more
omfortable to use as a pattern. The motivation behind this is that
currently the pattern needs to be applied to the scope op of the region
in which allocators should be promoted. However, a more natural way to
apply the pattern would be to apply it on the allocator directly. This
is not only clearer but easier to parallelize.

This revision changes the mem2reg pattern to operate this way. This
required restraining the interfaces to only mutate IR using
RewriterBase, as the previously used escape hatch is not granular enough
to match on the region that is modified only. This has the unfortunate
cost of preventing batching allocator promotion and making the block
argument adding logic more complex. Because batching no longer made any
sense, I made the internal analyzer/promoter decoupling private again.

This also adds statistics to the mem2reg infrastructure.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D150432

mlir/include/mlir/Interfaces/MemorySlotInterfaces.h
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
mlir/include/mlir/Transforms/Mem2Reg.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
mlir/lib/Transforms/Mem2Reg.cpp
mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
mlir/test/Dialect/LLVMIR/mem2reg.mlir
mlir/test/Dialect/MemRef/mem2reg-statistics.mlir [new file with mode: 0644]
mlir/test/Dialect/MemRef/mem2reg.mlir

index be761ae..c0f8b2f 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 
index f98bdba..73061f7 100644 (file)
@@ -31,6 +31,8 @@ def PromotableAllocationOpInterface
 
         Promotion of the slot will lead to the slot pointer no longer being
         used, leaving the content of the memory slot unreachable.
+
+        No IR mutation is allowed in this method.
       }], "::llvm::SmallVector<::mlir::MemorySlot>", "getPromotableSlots",
       (ins)
     >,
@@ -38,34 +40,42 @@ def PromotableAllocationOpInterface
         Provides the default Value of this memory slot. The provided Value
         will be used as the reaching definition of loads done before any store.
         This Value must outlive the promotion and dominate all the uses of this
-        slot's pointer. The provided builder can be used to create the default
+        slot's pointer. The provided rewriter can be used to create the default
         value on the fly.
 
-        The builder is located at the beginning of the block where the slot
-        pointer is defined.
+        The rewriter is located at the beginning of the block where the slot
+        pointer is defined. All IR mutations must happen through the rewriter.
       }], "::mlir::Value", "getDefaultValue",
-      (ins "const ::mlir::MemorySlot &":$slot, "::mlir::OpBuilder &":$builder)
+      (ins
+        "const ::mlir::MemorySlot &":$slot,
+        "::mlir::RewriterBase &":$rewriter)
     >,
     InterfaceMethod<[{
         Hook triggered for every new block argument added to a block.
         This will only be called for slots declared by this operation.
 
-        The builder is located at the beginning of the block on call.
+        The rewriter is located at the beginning of the block on call. All IR
+        mutations must happen through the rewriter.
       }],
       "void", "handleBlockArgument",
       (ins
         "const ::mlir::MemorySlot &":$slot,
         "::mlir::BlockArgument":$argument,
-        "::mlir::OpBuilder &":$builder
+        "::mlir::RewriterBase &":$rewriter
       )
     >,
     InterfaceMethod<[{
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
+
+        All IR mutations must happen through the rewriter.
       }],
       "void", "handlePromotionComplete",
-      (ins "const ::mlir::MemorySlot &":$slot, "::mlir::Value":$defaultValue)
+      (ins
+        "const ::mlir::MemorySlot &":$slot, 
+        "::mlir::Value":$defaultValue,
+        "::mlir::RewriterBase &":$rewriter)
     >,
   ];
 }
@@ -87,6 +97,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Gets whether this operation loads from the specified slot.
+
+        No IR mutation is allowed in this method.
       }],
       "bool", "loadsFrom",
       (ins "const ::mlir::MemorySlot &":$slot)
@@ -96,6 +108,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         value if this operation does not store to this slot. An operation
         storing a value to a slot must always be able to provide the value it
         stores. This method is only called on operations that use the slot.
+
+        No IR mutation is allowed in this method.
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot)
@@ -107,6 +121,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         If the removal procedure of the use will require that other uses get
         removed, that dependency should be added to the `newBlockingUses`
         argument. Dependent uses must only be uses of results of this operation.
+
+        No IR mutation is allowed in this method.
       }], "bool", "canUsesBeRemoved",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
@@ -132,13 +148,14 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The builder is located after the promotable operation on call.
+        The rewriter is located after the promotable operation on call. All IR
+        mutations must happen through the rewriter.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::OpBuilder &":$builder,
+           "::mlir::RewriterBase &":$rewriter,
            "::mlir::Value":$reachingDefinition)
     >,
   ];
@@ -160,6 +177,8 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         If the removal procedure of the use will require that other uses get
         removed, that dependency should be added to the `newBlockingUses`
         argument. Dependent uses must only be uses of results of this operation.
+
+        No IR mutation is allowed in this method.
       }], "bool", "canUsesBeRemoved",
       (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
            "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses)
@@ -185,12 +204,13 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The builder is located after the promotable operation on call.
+        The rewriter is located after the promotable operation on call. All IR
+        mutations must happen through the rewriter.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::OpBuilder &":$builder)
+           "::mlir::RewriterBase &":$rewriter)
     >,
   ];
 }
index a34ea68..46b2a1f 100644 (file)
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
+#include "llvm/ADT/Statistic.h"
 
 namespace mlir {
 
-/// Information computed during promotion analysis used to perform actual
-/// promotion.
-struct MemorySlotPromotionInfo {
-  /// Blocks for which at least two definitions of the slot values clash.
-  SmallPtrSet<Block *, 8> mergePoints;
-  /// Contains, for each operation, which uses must be eliminated by promotion.
-  /// This is a DAG structure because if an operation must eliminate some of
-  /// its uses, it is because the defining ops of the blocking uses requested
-  /// it. The defining ops therefore must also have blocking uses or be the
-  /// starting point of the bloccking uses.
-  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
-};
-
-/// Computes information for basic slot promotion. This will check that direct
-/// slot promotion can be performed, and provide the information to execute the
-/// promotion. This does not mutate IR.
-class MemorySlotPromotionAnalyzer {
-public:
-  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
-      : slot(slot), dominance(dominance) {}
-
-  /// Computes the information for slot promotion if promotion is possible,
-  /// returns nothing otherwise.
-  std::optional<MemorySlotPromotionInfo> computeInfo();
-
-private:
-  /// Computes the transitive uses of the slot that block promotion. This finds
-  /// uses that would block the promotion, checks that the operation has a
-  /// solution to remove the blocking use, and potentially forwards the analysis
-  /// if the operation needs further blocking uses resolved to resolve its own
-  /// uses (typically, removing its users because it will delete itself to
-  /// resolve its own blocking uses). This will fail if one of the transitive
-  /// users cannot remove a requested use, and should prevent promotion.
-  LogicalResult computeBlockingUses(
-      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
-
-  /// Computes in which blocks the value stored in the slot is actually used,
-  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
-  /// set of blocks containing a store to the slot (defining the value of the
-  /// slot).
-  SmallPtrSet<Block *, 16>
-  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
-
-  /// Computes the points in which multiple re-definitions of the slot's value
-  /// (stores) may conflict.
-  void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
-
-  /// Ensures predecessors of merge points can properly provide their current
-  /// definition of the value stored in the slot to the merge point. This can
-  /// notably be an issue if the terminator used does not have the ability to
-  /// forward values through block operands.
-  bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
-
-  MemorySlot slot;
-  DominanceInfo &dominance;
-};
-
-/// The MemorySlotPromoter handles the state of promoting a memory slot. It
-/// wraps a slot and its associated allocator. This will perform the mutation of
-/// IR.
-class MemorySlotPromoter {
-public:
-  MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
-                     OpBuilder &builder, DominanceInfo &dominance,
-                     MemorySlotPromotionInfo info);
-
-  /// Actually promotes the slot by mutating IR. Promoting a slot does not
-  /// invalidate the MemorySlotPromotionInfo of other slots.
-  void promoteSlot();
-
-private:
-  /// Computes the reaching definition for all the operations that require
-  /// promotion. `reachingDef` is the value the slot should contain at the
-  /// beginning of the block. This method returns the reached definition at the
-  /// end of the block.
-  Value computeReachingDefInBlock(Block *block, Value reachingDef);
-
-  /// Computes the reaching definition for all the operations that require
-  /// promotion. `reachingDef` corresponds to the initial value the
-  /// slot will contain before any write, typically a poison value.
-  void computeReachingDefInRegion(Region *region, Value reachingDef);
-
-  /// Removes the blocking uses of the slot, in topological order.
-  void removeBlockingUses();
-
-  /// Lazily-constructed default value representing the content of the slot when
-  /// no store has been executed. This function may mutate IR.
-  Value getLazyDefaultValue();
-
-  MemorySlot slot;
-  PromotableAllocationOpInterface allocator;
-  OpBuilder &builder;
-  /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
-  /// initialize it on demand.
-  Value defaultValue;
-  /// Contains the reaching definition at this operation. Reaching definitions
-  /// are only computed for promotable memory operations with blocking uses.
-  DenseMap<PromotableMemOpInterface, Value> reachingDefs;
-  DominanceInfo &dominance;
-  MemorySlotPromotionInfo info;
+struct Mem2RegStatistics {
+  llvm::Statistic *promotedAmount = nullptr;
+  llvm::Statistic *newBlockArgumentAmount = nullptr;
 };
 
 /// Pattern applying mem2reg to the regions of the operations on which it
 /// matches.
-class Mem2RegPattern : public RewritePattern {
+class Mem2RegPattern
+    : public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
 public:
-  using RewritePattern::RewritePattern;
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
 
-  Mem2RegPattern(MLIRContext *ctx, PatternBenefit benefit = 1)
-      : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {}
+  Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
+                 PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
 
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
                                 PatternRewriter &rewriter) const override;
+
+private:
+  Mem2RegStatistics statistics;
 };
 
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
-                        OpBuilder &builder, DominanceInfo &dominance);
+                        RewriterBase &rewriter,
+                        Mem2RegStatistics statistics = {});
 
 } // namespace mlir
 
index 1cc357c..62b8dd0 100644 (file)
@@ -189,6 +189,21 @@ def Mem2Reg : Pass<"mem2reg"> {
     This pass only supports unstructured control-flow. Promotion of operations
     within subregions will not happen.
   }];
+
+  let options = [
+    Option<"enableRegionSimplification", "region-simplify", "bool",
+       /*default=*/"true",
+       "Perform control flow optimizations to the region tree">,
+  ];
+
+  let statistics = [
+    Statistic<"promotedAmount",
+              "promoted slots",
+              "Number of promoted memory slot">,
+    Statistic<"newBlockArgumentAmount",
+              "new block args",
+              "Total number of block arguments added">,
+  ];
 }
 
 def PrintOpStats : Pass<"print-op-stats"> {
index e4fd2a7..51c4989 100644 (file)
@@ -35,24 +35,25 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                      OpBuilder &builder) {
-  return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
+                                      RewriterBase &rewriter) {
+  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
 }
 
 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                          BlockArgument argument,
-                                         OpBuilder &builder) {
+                                         RewriterBase &rewriter) {
   for (Operation *user : getOperation()->getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
-                                       declareOp.getVarInfo());
+      rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
+                                        declareOp.getVarInfo());
 }
 
 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                             Value defaultValue) {
+                                             Value defaultValue,
+                                             RewriterBase &rewriter) {
   if (defaultValue && defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
-  erase();
+    rewriter.eraseOp(defaultValue.getDefiningOp());
+  rewriter.eraseOp(*this);
 }
 
 //===----------------------------------------------------------------------===//
@@ -87,10 +88,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  getResult().replaceAllUsesWith(reachingDefinition);
+  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
   return DeletionKind::Delete;
 }
 
@@ -110,13 +111,13 @@ bool LLVM::StoreOp::canUsesBeRemoved(
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the stored slot
   // pointer.
   for (Operation *user : slot.ptr.getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      builder.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
-                                       declareOp.getVarInfo());
+      rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
+                                        declareOp.getVarInfo());
   return DeletionKind::Delete;
 }
 
@@ -140,7 +141,7 @@ bool LLVM::BitcastOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::BitcastOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -151,7 +152,7 @@ bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -162,7 +163,7 @@ bool LLVM::LifetimeStartOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -173,7 +174,7 @@ bool LLVM::LifetimeEndOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -184,7 +185,7 @@ bool LLVM::DbgDeclareOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
 
@@ -209,6 +210,6 @@ bool LLVM::GEPOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::GEPOp::removeBlockingUses(
-    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
   return DeletionKind::Delete;
 }
index b5f5272..12d9ebd 100644 (file)
@@ -40,29 +40,30 @@ SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
 }
 
 Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                        OpBuilder &builder) {
+                                        RewriterBase &rewriter) {
   assert(isSupportedElementType(slot.elemType));
   // TODO: support more types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](MemRefType t) {
-        return builder.create<memref::AllocaOp>(getLoc(), t);
+        return rewriter.create<memref::AllocaOp>(getLoc(), t);
       })
       .Default([&](Type t) {
-        return builder.create<arith::ConstantOp>(getLoc(), t,
-                                                 builder.getZeroAttr(t));
+        return rewriter.create<arith::ConstantOp>(getLoc(), t,
+                                                  rewriter.getZeroAttr(t));
       });
 }
 
 void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
-                                               Value defaultValue) {
+                                               Value defaultValue,
+                                               RewriterBase &rewriter) {
   if (defaultValue.use_empty())
-    defaultValue.getDefiningOp()->erase();
-  erase();
+    rewriter.eraseOp(defaultValue.getDefiningOp());
+  rewriter.eraseOp(*this);
 }
 
 void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                            BlockArgument argument,
-                                           OpBuilder &builder) {}
+                                           RewriterBase &rewriter) {}
 
 //===----------------------------------------------------------------------===//
 //  LoadOp/StoreOp interfaces
@@ -86,10 +87,10 @@ bool memref::LoadOp::canUsesBeRemoved(
 
 DeletionKind memref::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  getResult().replaceAllUsesWith(reachingDefinition);
+  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
   return DeletionKind::Delete;
 }
 
@@ -113,6 +114,6 @@ bool memref::StoreOp::canUsesBeRemoved(
 
 DeletionKind memref::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition) {
   return DeletionKind::Delete;
 }
index 45d6f7d..3b303f9 100644 (file)
@@ -10,6 +10,8 @@
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -92,11 +94,121 @@ using namespace mlir;
 /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
 ///      Springer.
 
+namespace {
+
+/// Information computed during promotion analysis used to perform actual
+/// promotion.
+struct MemorySlotPromotionInfo {
+  /// Blocks for which at least two definitions of the slot values clash.
+  SmallPtrSet<Block *, 8> mergePoints;
+  /// Contains, for each operation, which uses must be eliminated by promotion.
+  /// This is a DAG structure because if an operation must eliminate some of
+  /// its uses, it is because the defining ops of the blocking uses requested
+  /// it. The defining ops therefore must also have blocking uses or be the
+  /// starting point of the bloccking uses.
+  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+};
+
+/// Computes information for basic slot promotion. This will check that direct
+/// slot promotion can be performed, and provide the information to execute the
+/// promotion. This does not mutate IR.
+class MemorySlotPromotionAnalyzer {
+public:
+  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
+      : slot(slot), dominance(dominance) {}
+
+  /// Computes the information for slot promotion if promotion is possible,
+  /// returns nothing otherwise.
+  std::optional<MemorySlotPromotionInfo> computeInfo();
+
+private:
+  /// Computes the transitive uses of the slot that block promotion. This finds
+  /// uses that would block the promotion, checks that the operation has a
+  /// solution to remove the blocking use, and potentially forwards the analysis
+  /// if the operation needs further blocking uses resolved to resolve its own
+  /// uses (typically, removing its users because it will delete itself to
+  /// resolve its own blocking uses). This will fail if one of the transitive
+  /// users cannot remove a requested use, and should prevent promotion.
+  LogicalResult computeBlockingUses(
+      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+
+  /// Computes in which blocks the value stored in the slot is actually used,
+  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
+  /// set of blocks containing a store to the slot (defining the value of the
+  /// slot).
+  SmallPtrSet<Block *, 16>
+  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
+
+  /// Computes the points in which multiple re-definitions of the slot's value
+  /// (stores) may conflict.
+  void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
+
+  /// Ensures predecessors of merge points can properly provide their current
+  /// definition of the value stored in the slot to the merge point. This can
+  /// notably be an issue if the terminator used does not have the ability to
+  /// forward values through block operands.
+  bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
+
+  MemorySlot slot;
+  DominanceInfo &dominance;
+};
+
+/// The MemorySlotPromoter handles the state of promoting a memory slot. It
+/// wraps a slot and its associated allocator. This will perform the mutation of
+/// IR.
+class MemorySlotPromoter {
+public:
+  MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
+                     RewriterBase &rewriter, DominanceInfo &dominance,
+                     MemorySlotPromotionInfo info,
+                     const Mem2RegStatistics &statistics);
+
+  /// Actually promotes the slot by mutating IR. Promoting a slot DOES
+  /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
+  /// promotion info should NOT be performed in batches.
+  void promoteSlot();
+
+private:
+  /// Computes the reaching definition for all the operations that require
+  /// promotion. `reachingDef` is the value the slot should contain at the
+  /// beginning of the block. This method returns the reached definition at the
+  /// end of the block.
+  Value computeReachingDefInBlock(Block *block, Value reachingDef);
+
+  /// Computes the reaching definition for all the operations that require
+  /// promotion. `reachingDef` corresponds to the initial value the
+  /// slot will contain before any write, typically a poison value.
+  void computeReachingDefInRegion(Region *region, Value reachingDef);
+
+  /// Removes the blocking uses of the slot, in topological order.
+  void removeBlockingUses();
+
+  /// Lazily-constructed default value representing the content of the slot when
+  /// no store has been executed. This function may mutate IR.
+  Value getLazyDefaultValue();
+
+  MemorySlot slot;
+  PromotableAllocationOpInterface allocator;
+  RewriterBase &rewriter;
+  /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
+  /// initialize it on demand.
+  Value defaultValue;
+  /// Contains the reaching definition at this operation. Reaching definitions
+  /// are only computed for promotable memory operations with blocking uses.
+  DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+  DominanceInfo &dominance;
+  MemorySlotPromotionInfo info;
+  const Mem2RegStatistics &statistics;
+};
+
+} // namespace
+
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
-    OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info)
-    : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
-      info(std::move(info)) {
+    RewriterBase &rewriter, DominanceInfo &dominance,
+    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    : slot(slot), allocator(allocator), rewriter(rewriter),
+      dominance(dominance), info(std::move(info)), statistics(statistics) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -114,9 +226,9 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
   if (defaultValue)
     return defaultValue;
 
-  OpBuilder::InsertionGuard guard(builder);
-  builder.setInsertionPointToStart(slot.ptr.getParentBlock());
-  return defaultValue = allocator.getDefaultValue(slot, builder);
+  RewriterBase::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
+  return defaultValue = allocator.getDefaultValue(slot, rewriter);
 }
 
 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
@@ -341,11 +453,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
     Block *block = job.block->getBlock();
 
     if (info.mergePoints.contains(block)) {
-      BlockArgument blockArgument =
-          block->addArgument(slot.elemType, slot.ptr.getLoc());
-      builder.setInsertionPointToStart(block);
-      allocator.handleBlockArgument(slot, blockArgument, builder);
+      // If the block is a merge point, we need to add a block argument to hold
+      // the selected reaching definition. This has to be a bit complicated
+      // because of RewriterBase limitations: we need to create a new block with
+      // the extra block argument, move the content of the block to the new
+      // block, and replace the block with the new block in the merge point set.
+      SmallVector<Type> argTypes;
+      SmallVector<Location> argLocs;
+      for (BlockArgument arg : block->getArguments()) {
+        argTypes.push_back(arg.getType());
+        argLocs.push_back(arg.getLoc());
+      }
+      argTypes.push_back(slot.elemType);
+      argLocs.push_back(slot.ptr.getLoc());
+      Block *newBlock = rewriter.createBlock(block, argTypes, argLocs);
+
+      info.mergePoints.erase(block);
+      info.mergePoints.insert(newBlock);
+
+      rewriter.replaceAllUsesWith(block, newBlock);
+      rewriter.mergeBlocks(block, newBlock,
+                           newBlock->getArguments().drop_back());
+
+      block = newBlock;
+
+      BlockArgument blockArgument = block->getArguments().back();
+      rewriter.setInsertionPointToStart(block);
+      allocator.handleBlockArgument(slot, blockArgument, rewriter);
       job.reachingDef = blockArgument;
+
+      if (statistics.newBlockArgumentAmount)
+        (*statistics.newBlockArgumentAmount)++;
     }
 
     job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
@@ -355,8 +493,10 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
         if (info.mergePoints.contains(blockOperand.get())) {
           if (!job.reachingDef)
             job.reachingDef = getLazyDefaultValue();
-          terminator.getSuccessorOperands(blockOperand.getOperandNumber())
-              .append(job.reachingDef);
+          rewriter.updateRootInPlace(terminator, [&]() {
+            terminator.getSuccessorOperands(blockOperand.getOperandNumber())
+                .append(job.reachingDef);
+          });
         }
       }
     }
@@ -382,24 +522,24 @@ void MemorySlotPromoter::removeBlockingUses() {
       if (!reachingDef)
         reachingDef = getLazyDefaultValue();
 
-      builder.setInsertionPointAfter(toPromote);
+      rewriter.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
-              slot, info.userToBlockingUses[toPromote], builder, reachingDef) ==
-          DeletionKind::Delete)
+              slot, info.userToBlockingUses[toPromote], rewriter,
+              reachingDef) == DeletionKind::Delete)
         toErase.push_back(toPromote);
 
       continue;
     }
 
     auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
-    builder.setInsertionPointAfter(toPromote);
+    rewriter.setInsertionPointAfter(toPromote);
     if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
-                                          builder) == DeletionKind::Delete)
+                                          rewriter) == DeletionKind::Delete)
       toErase.push_back(toPromote);
   }
 
   for (Operation *toEraseOp : toErase)
-    toEraseOp->erase();
+    rewriter.eraseOp(toEraseOp);
 
   assert(slot.ptr.use_empty() &&
          "after promotion, the slot pointer should not be used anymore");
@@ -421,87 +561,73 @@ void MemorySlotPromoter::promoteSlot() {
       assert(succOperands.size() == mergePoint->getNumArguments() ||
              succOperands.size() + 1 == mergePoint->getNumArguments());
       if (succOperands.size() + 1 == mergePoint->getNumArguments())
-        succOperands.append(getLazyDefaultValue());
+        rewriter.updateRootInPlace(
+            user, [&]() { succOperands.append(getLazyDefaultValue()); });
     }
   }
 
   LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
                           << "\n");
 
-  allocator.handlePromotionComplete(slot, defaultValue);
+  if (statistics.promotedAmount)
+    (*statistics.promotedAmount)++;
+
+  allocator.handlePromotionComplete(slot, defaultValue, rewriter);
 }
 
 LogicalResult mlir::tryToPromoteMemorySlots(
-    ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
-    DominanceInfo &dominance) {
-  // Actual promotion may invalidate the dominance analysis, so slot promotion
-  // is prepated in batches.
-  SmallVector<MemorySlotPromoter> toPromote;
+    ArrayRef<PromotableAllocationOpInterface> allocators,
+    RewriterBase &rewriter, Mem2RegStatistics statistics) {
+  DominanceInfo dominance;
+
+  bool promotedAny = false;
+
   for (PromotableAllocationOpInterface allocator : allocators) {
     for (MemorySlot slot : allocator.getPromotableSlots()) {
       if (slot.ptr.use_empty())
         continue;
 
+      DominanceInfo dominance;
       MemorySlotPromotionAnalyzer analyzer(slot, dominance);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
-      if (info)
-        toPromote.emplace_back(slot, allocator, builder, dominance,
-                               std::move(*info));
+      if (info) {
+        MemorySlotPromoter(slot, allocator, rewriter, dominance,
+                           std::move(*info), statistics)
+            .promoteSlot();
+        promotedAny = true;
+      }
     }
   }
 
-  for (MemorySlotPromoter &promoter : toPromote)
-    promoter.promoteSlot();
-
-  return success(!toPromote.empty());
+  return success(promotedAny);
 }
 
-LogicalResult Mem2RegPattern::matchAndRewrite(Operation *op,
-                                              PatternRewriter &rewriter) const {
+LogicalResult
+Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
+                                PatternRewriter &rewriter) const {
   hasBoundedRewriteRecursion();
-
-  if (op->getNumRegions() == 0)
-    return failure();
-
-  DominanceInfo dominance;
-
-  SmallVector<PromotableAllocationOpInterface> allocators;
-  // Build a list of allocators to attempt to promote the slots of.
-  for (Region &region : op->getRegions())
-    for (auto allocator : region.getOps<PromotableAllocationOpInterface>())
-      allocators.emplace_back(allocator);
-
-  // Because pattern rewriters are normally not expressive enough to support a
-  // transformation like mem2reg, this uses an escape hatch to mark modified
-  // operations manually and operate outside of its context.
-  rewriter.startRootUpdate(op);
-
-  OpBuilder builder(rewriter.getContext());
-
-  if (failed(tryToPromoteMemorySlots(allocators, builder, dominance))) {
-    rewriter.cancelRootUpdate(op);
-    return failure();
-  }
-
-  rewriter.finalizeRootUpdate(op);
-  return success();
+  return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
 }
 
 namespace {
 
 struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
+  using impl::Mem2RegBase<Mem2Reg>::Mem2RegBase;
+
   void runOnOperation() override {
     Operation *scopeOp = getOperation();
-    bool changed = false;
+
+    Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
+
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = enableRegionSimplification;
 
     RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<Mem2RegPattern>(&getContext());
+    rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
     FrozenRewritePatternSet frozen(std::move(rewritePatterns));
-    (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(),
-                                 &changed);
 
-    if (!changed)
-      markAllAnalysesPreserved();
+    if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
+      signalPassFailure();
   }
 };
 
index d8d04df..0c1908e 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(llvm.func(mem2reg{region-simplify=false}))' | FileCheck %s
 
 llvm.func @use(i64)
 llvm.func @use_ptr(!llvm.ptr)
index 090f913..fc696c5 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @default_value
 llvm.func @default_value() -> i32 {
diff --git a/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir b/mlir/test/Dialect/MemRef/mem2reg-statistics.mlir
new file mode 100644 (file)
index 0000000..29ca511
--- /dev/null
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file --mlir-pass-statistics 2>&1 >/dev/null | FileCheck %s
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 1 promoted slots
+func.func @basic() -> i32 {
+  %0 = arith.constant 5 : i32
+  %1 = memref.alloca() : memref<i32>
+  memref.store %0, %1[] : memref<i32>
+  %2 = memref.load %1[] : memref<i32>
+  return %2 : i32
+}
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 0 promoted slots
+func.func @no_alloca() -> i32 {
+  %0 = arith.constant 5 : i32
+  return %0 : i32
+}
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 2 new block args
+// CHECK-NEXT: (S) 1 promoted slots
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+  %alloca = memref.alloca() : memref<i64>
+  memref.store %arg2, %alloca[] : memref<i64>
+  cf.cond_br %arg1, ^bb1, ^bb2
+^bb1:
+  %use = memref.load %alloca[] : memref<i64>
+  call @use(%use) : (i64) -> ()
+  memref.store %arg0, %alloca[] : memref<i64>
+  cf.br ^bb2
+^bb2:
+  cf.br ^bb1
+}
+
+func.func @use(%arg: i64) { return }
+
+// -----
+
+// CHECK: Mem2Reg
+// CHECK-NEXT: (S) 0 new block args
+// CHECK-NEXT: (S) 3 promoted slots
+func.func @recursive(%arg: i64) -> i64 {
+  %alloca0 = memref.alloca() : memref<memref<memref<i64>>>
+  %alloca1 = memref.alloca() : memref<memref<i64>>
+  %alloca2 = memref.alloca() : memref<i64>
+  memref.store %arg, %alloca2[] : memref<i64>
+  memref.store %alloca2, %alloca1[] : memref<memref<i64>>
+  memref.store %alloca1, %alloca0[] : memref<memref<memref<i64>>>
+  %load0 = memref.load %alloca0[] : memref<memref<memref<i64>>>
+  %load1 = memref.load %load0[] : memref<memref<i64>>
+  %load2 = memref.load %load1[] : memref<i64>
+  return %load2 : i64
+}
index 86707ac..d300699 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg{region-simplify=false}))' --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func.func @basic
 func.func @basic() -> i32 {
@@ -148,20 +148,18 @@ func.func @deny_store_of_alloca(%arg: memref<memref<i32>>) -> i32 {
 
 // CHECK-LABEL: func.func @promotable_nonpromotable_intertwined
 func.func @promotable_nonpromotable_intertwined() -> i32 {
-  // CHECK: %[[VAL:.*]] = arith.constant 5 : i32
-  %0 = arith.constant 5 : i32
   // CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref<i32>
-  %1 = memref.alloca() : memref<i32>
+  %0 = memref.alloca() : memref<i32>
   // CHECK-NOT: = memref.alloca() : memref<memref<i32>>
-  %2 = memref.alloca() : memref<memref<i32>>
-  memref.store %1, %2[] : memref<memref<i32>>
-  %3 = memref.load %2[] : memref<memref<i32>>
+  %1 = memref.alloca() : memref<memref<i32>>
+  memref.store %0, %1[] : memref<memref<i32>>
+  %2 = memref.load %1[] : memref<memref<i32>>
   // CHECK: call @use(%[[NON_PROMOTED]])
-  call @use(%1) : (memref<i32>) -> ()
+  call @use(%0) : (memref<i32>) -> ()
   // CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][]
-  %4 = memref.load %1[] : memref<i32>
+  %3 = memref.load %0[] : memref<i32>
   // CHECK: return %[[RES]] : i32
-  return %4 : i32
+  return %3 : i32
 }
 
 func.func @use(%arg: memref<i32>) { return }