Reland "[mlir][mem2reg] Expose algorithm internals."
authorThéo Degioanni <theo.degioanni@nextsilicon.com>
Tue, 9 May 2023 09:01:33 +0000 (09:01 +0000)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Tue, 9 May 2023 09:24:49 +0000 (09:24 +0000)
This patch refactors the Mem2Reg infrastructure. It decouples
analysis from promotion, allowing for more control over the execution of
the logic. It also adjusts the interfaces to be less coupled to mem2reg
and more general. This will be useful for an upcoming revision
introducing generic SROA.

This commit reverts f333977eb20a and relands 91cff8a71872.

The original commit was reverted accidentally due to a misinterpretation
of a bazel build bot failure.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D149825

15 files changed:
mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/include/mlir/Interfaces/MemorySlotInterfaces.h [moved from mlir/include/mlir/Interfaces/Mem2RegInterfaces.h with 83% similarity]
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td [moved from mlir/include/mlir/Interfaces/Mem2RegInterfaces.td with 91% similarity]
mlir/include/mlir/Transforms/Mem2Reg.h
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp [moved from mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp with 81% similarity]
mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
mlir/lib/Interfaces/CMakeLists.txt
mlir/lib/Interfaces/MemorySlotInterfaces.cpp [moved from mlir/lib/Interfaces/Mem2RegInterfaces.cpp with 64% similarity]
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/Mem2Reg.cpp

index e8e1712..141ad6d 100644 (file)
@@ -6,7 +6,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
 include "mlir/Dialect/LLVMIR/LLVMEnums.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/Mem2RegInterfaces.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 
 // Operations that correspond to LLVM intrinsics. With MLIR operation set being
 // extendable, there is no reason to introduce a hard boundary between "core"
index f9025e0..fefc602 100644 (file)
@@ -22,7 +22,7 @@ include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/Mem2RegInterfaces.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 class LLVM_Builder<string builder> {
index 41f130e..4c4515c 100644 (file)
@@ -16,7 +16,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
-#include "mlir/Interfaces/Mem2RegInterfaces.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
index 1ea0ef0..b0c8537 100644 (file)
@@ -15,7 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
-include "mlir/Interfaces/Mem2RegInterfaces.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
index 4e0f7ac..0b42dfe 100644 (file)
@@ -7,7 +7,6 @@ add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
-add_mlir_interface(Mem2RegInterfaces)
 add_mlir_interface(ParallelCombiningOpInterface)
 add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
@@ -17,6 +16,12 @@ add_mlir_interface(ValueBoundsOpInterface)
 add_mlir_interface(VectorInterfaces)
 add_mlir_interface(ViewLikeInterface)
 
+set(LLVM_TARGET_DEFINITIONS MemorySlotInterfaces.td)
+mlir_tablegen(MemorySlotOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(MemorySlotOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRMemorySlotInterfacesIncGen)
+add_dependencies(mlir-generic-headers MLIRMemorySlotInterfacesIncGen)
+
 set(LLVM_TARGET_DEFINITIONS DataLayoutInterfaces.td)
 mlir_tablegen(DataLayoutAttrInterface.h.inc -gen-attr-interface-decls)
 mlir_tablegen(DataLayoutAttrInterface.cpp.inc -gen-attr-interface-defs)
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_INTERFACES_MEM2REGINTERFACES_H
-#define MLIR_INTERFACES_MEM2REGINTERFACES_H
+#ifndef MLIR_INTERFACES_MEMORYSLOTINTERFACES_H
+#define MLIR_INTERFACES_MEMORYSLOTINTERFACES_H
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpDefinition.h"
@@ -34,6 +34,6 @@ enum class DeletionKind {
 
 } // namespace mlir
 
-#include "mlir/Interfaces/Mem2RegInterfaces.h.inc"
+#include "mlir/Interfaces/MemorySlotOpInterfaces.h.inc"
 
-#endif // MLIR_INTERFACES_MEM2REGINTERFACES_H
+#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H
@@ -1,4 +1,4 @@
-//===-- Mem2RegInterfaces.td - Mem2Reg interfaces ----------*- tablegen -*-===//
+//===-- MemorySlotInterfaces.td - MemorySlot interfaces ----*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_INTERFACES_MEM2REGINTERFACES
-#define MLIR_INTERFACES_MEM2REGINTERFACES
+#ifndef MLIR_INTERFACES_MEMORYSLOTINTERFACES
+#define MLIR_INTERFACES_MEMORYSLOTINTERFACES
 
 include "mlir/IR/OpBase.td"
 
@@ -76,6 +76,9 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
     to memory slots. Loads and stores must be of whole values of the same
     type as the slot itself.
 
+    For a memory operation on a slot to be valid, it must operate on the slot
+    pointer *only as a pointer to an element of the type of the slot*.
+
     If the same operation does both loads and stores on the same slot, the
     load must semantically happen first.
   }];
@@ -152,21 +155,21 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Checks that this operation can be promoted to no longer use the provided
-        blocking uses, in the context of promoting `slot`.
+        blocking uses, in order to allow optimization.
 
         If the removal procedure of the use will require that other uses get
         removed, that dependency should be added to the `newBlockingUses`
         argument. Dependent uses must only be uses of results of this operation.
       }], "bool", "canUsesBeRemoved",
-      (ins "const ::mlir::MemorySlot &":$slot,
-           "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
+      (ins "const ::llvm::SmallPtrSetImpl<::mlir::OpOperand *> &":$blockingUses,
            "::llvm::SmallVectorImpl<::mlir::OpOperand *> &":$newBlockingUses)
     >,
     InterfaceMethod<[{
         Transforms IR to ensure that the current operation does not use the
-        provided memory slot anymore. In contrast to `PromotableMemOpInterface`,
-        operations implementing this interface must not need access to the
-        reaching definition of the content of the slot.
+        provided blocking uses anymore. In contrast to
+        `PromotableMemOpInterface`, operations implementing this interface
+        must not need access to the reaching definition of the content of the
+        slot.
 
         During the transformation, *no operation should be deleted*.
         The operation can only schedule its own deletion by returning the
@@ -186,11 +189,10 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
-      (ins "const ::mlir::MemorySlot &":$slot,
-           "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
+      (ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
            "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
 
-#endif // MLIR_INTERFACES_MEM2REGINTERFACES
+#endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES
index e2da88f..1593b12 100644 (file)
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/OpDefinition.h"
-#include "mlir/Interfaces/Mem2RegInterfaces.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 
 namespace mlir {
 
+/// Information computed during promotion analysis used to perform actual
+/// promotion.
+struct MemorySlotPromotionInfo {
+  /// Blocks for which at least two definitions of the slot values clash.
+  SmallPtrSet<Block *, 8> mergePoints;
+  /// Contains, for each operation, which uses must be eliminated by promotion.
+  /// This is a DAG structure because if an operation must eliminate some of
+  /// its uses, it is because the defining ops of the blocking uses requested
+  /// it. The defining ops therefore must also have blocking uses or be the
+  /// starting point of the bloccking uses.
+  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+};
+
+/// Computes information for basic slot promotion. This will check that direct
+/// slot promotion can be performed, and provide the information to execute the
+/// promotion. This does not mutate IR.
+class MemorySlotPromotionAnalyzer {
+public:
+  MemorySlotPromotionAnalyzer(MemorySlot slot, DominanceInfo &dominance)
+      : slot(slot), dominance(dominance) {}
+
+  /// Computes the information for slot promotion if promotion is possible,
+  /// returns nothing otherwise.
+  std::optional<MemorySlotPromotionInfo> computeInfo();
+
+private:
+  /// Computes the transitive uses of the slot that block promotion. This finds
+  /// uses that would block the promotion, checks that the operation has a
+  /// solution to remove the blocking use, and potentially forwards the analysis
+  /// if the operation needs further blocking uses resolved to resolve its own
+  /// uses (typically, removing its users because it will delete itself to
+  /// resolve its own blocking uses). This will fail if one of the transitive
+  /// users cannot remove a requested use, and should prevent promotion.
+  LogicalResult computeBlockingUses(
+      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+
+  /// Computes in which blocks the value stored in the slot is actually used,
+  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
+  /// set of blocks containing a store to the slot (defining the value of the
+  /// slot).
+  SmallPtrSet<Block *, 16>
+  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
+
+  /// Computes the points in which multiple re-definitions of the slot's value
+  /// (stores) may conflict.
+  void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
+
+  /// Ensures predecessors of merge points can properly provide their current
+  /// definition of the value stored in the slot to the merge point. This can
+  /// notably be an issue if the terminator used does not have the ability to
+  /// forward values through block operands.
+  bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
+
+  MemorySlot slot;
+  DominanceInfo &dominance;
+};
+
+/// The MemorySlotPromoter handles the state of promoting a memory slot. It
+/// wraps a slot and its associated allocator. This will perform the mutation of
+/// IR.
+class MemorySlotPromoter {
+public:
+  MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
+                     OpBuilder &builder, DominanceInfo &dominance,
+                     MemorySlotPromotionInfo info);
+
+  /// Actually promotes the slot by mutating IR. Promoting a slot does not
+  /// invalidate the MemorySlotPromotionInfo of other slots.
+  void promoteSlot();
+
+private:
+  /// Computes the reaching definition for all the operations that require
+  /// promotion. `reachingDef` is the value the slot should contain at the
+  /// beginning of the block. This method returns the reached definition at the
+  /// end of the block.
+  Value computeReachingDefInBlock(Block *block, Value reachingDef);
+
+  /// Computes the reaching definition for all the operations that require
+  /// promotion. `reachingDef` corresponds to the initial value the
+  /// slot will contain before any write, typically a poison value.
+  void computeReachingDefInRegion(Region *region, Value reachingDef);
+
+  /// Removes the blocking uses of the slot, in topological order.
+  void removeBlockingUses();
+
+  /// Lazily-constructed default value representing the content of the slot when
+  /// no store has been executed. This function may mutate IR.
+  Value getLazyDefaultValue();
+
+  MemorySlot slot;
+  PromotableAllocationOpInterface allocator;
+  OpBuilder &builder;
+  /// Potentially non-initialized default value. Use `getLazyDefaultValue` to
+  /// initialize it on demand.
+  Value defaultValue;
+  /// Contains the reaching definition at this operation. Reaching definitions
+  /// are only computed for promotable memory operations with blocking uses.
+  DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+  DominanceInfo &dominance;
+  MemorySlotPromotionInfo info;
+};
+
 /// Attempts to promote the memory slots of the provided allocators. Succeeds if
 /// at least one memory slot was promoted.
 LogicalResult
index 9ba0315..7bf9595 100644 (file)
@@ -6,7 +6,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
   IR/LLVMDialect.cpp
   IR/LLVMInlining.cpp
   IR/LLVMInterfaces.cpp
-  IR/LLVMMem2Reg.cpp
+  IR/LLVMMemorySlot.cpp
   IR/LLVMTypes.cpp
   IR/LLVMTypeSyntax.cpp
 
@@ -34,7 +34,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
   MLIRDataLayoutInterfaces
   MLIRInferTypeOpInterface
   MLIRIR
-  MLIRMem2RegInterfaces
+  MLIRMemorySlotInterfaces
   MLIRSideEffectInterfaces
   MLIRSupport
   )
similarity index 81%
rename from mlir/lib/Dialect/LLVMIR/IR/LLVMMem2Reg.cpp
rename to mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 7fa8ebc..e4fd2a7 100644 (file)
@@ -1,4 +1,4 @@
-//===- LLVMMem2Reg.cpp - Mem2Reg Interfaces ---------------------*- C++ -*-===//
+//===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements Mem2Reg-related interfaces for LLVM dialect operations.
+// This file implements MemorySlot-related interfaces for LLVM dialect
+// operations.
 //
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Interfaces/Mem2RegInterfaces.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 
 using namespace mlir;
 
@@ -19,13 +20,18 @@ using namespace mlir;
 // Interfaces for AllocaOp
 //===----------------------------------------------------------------------===//
 
+static Type getAllocaElementType(LLVM::AllocaOp alloca) {
+  // This will become trivial once non-opaque pointers are gone.
+  return alloca.getElemType().has_value()
+             ? *alloca.getElemType()
+             : alloca.getResult().getType().getElementType();
+}
+
 llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
   if (!getOperation()->getBlock()->isEntryBlock())
     return {};
 
-  Type elemType =
-      getElemType() ? *getElemType() : getResult().getType().getElementType();
-  return {MemorySlot{getResult(), elemType}};
+  return {MemorySlot{getResult(), getAllocaElementType(*this)}};
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
@@ -128,62 +134,57 @@ static bool forwardToUsers(Operation *op,
 }
 
 bool LLVM::BitcastOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses) {
   return forwardToUsers(*this, newBlockingUses);
 }
 
 DeletionKind LLVM::BitcastOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
 bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses) {
   return forwardToUsers(*this, newBlockingUses);
 }
 
 DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
 bool LLVM::LifetimeStartOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses) {
   return true;
 }
 
 DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
 bool LLVM::LifetimeEndOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses) {
   return true;
 }
 
 DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
 bool LLVM::DbgDeclareOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses) {
   return true;
 }
 
 DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
 
@@ -194,8 +195,12 @@ static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
   });
 }
 
+//===----------------------------------------------------------------------===//
+// Interfaces for GEPOp
+//===----------------------------------------------------------------------===//
+
 bool LLVM::GEPOp::canUsesBeRemoved(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses) {
   // GEP can be removed as long as it is a no-op and its users can be removed.
   if (!hasAllZeroIndices(*this))
@@ -204,7 +209,6 @@ bool LLVM::GEPOp::canUsesBeRemoved(
 }
 
 DeletionKind LLVM::GEPOp::removeBlockingUses(
-    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    OpBuilder &builder) {
+    const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
   return DeletionKind::Delete;
 }
index acc38b5..5b97f7a 100644 (file)
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Interfaces/Mem2RegInterfaces.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
index dbf6e69..665e4c0 100644 (file)
@@ -9,7 +9,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
-  Mem2RegInterfaces.cpp
+  MemorySlotInterfaces.cpp
   ParallelCombiningOpInterface.cpp
   RuntimeVerifiableOpInterface.cpp
   ShapedOpInterfaces.cpp
@@ -46,7 +46,7 @@ add_mlir_interface_library(DestinationStyleOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(LoopLikeInterface)
-add_mlir_interface_library(Mem2RegInterfaces)
+add_mlir_interface_library(MemorySlotInterfaces)
 add_mlir_interface_library(ParallelCombiningOpInterface)
 add_mlir_interface_library(RuntimeVerifiableOpInterface)
 add_mlir_interface_library(ShapedOpInterfaces)
similarity index 64%
rename from mlir/lib/Interfaces/Mem2RegInterfaces.cpp
rename to mlir/lib/Interfaces/MemorySlotInterfaces.cpp
index aadd76b..431febf 100644 (file)
@@ -1,4 +1,4 @@
-//===-- Mem2RegInterfaces.cpp - Mem2Reg interfaces --------------*- C++ -*-===//
+//===-- MemorySlotInterfaces.cpp - MemorySlot interfaces --------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,6 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Interfaces/Mem2RegInterfaces.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 
-#include "mlir/Interfaces/Mem2RegInterfaces.cpp.inc"
+#include "mlir/Interfaces/MemorySlotOpInterfaces.cpp.inc"
index 7b4fb4d..b7e1cd9 100644 (file)
@@ -28,7 +28,7 @@ add_mlir_library(MLIRTransforms
   MLIRAnalysis
   MLIRCopyOpInterface
   MLIRLoopLikeInterface
-  MLIRMem2RegInterfaces
+  MLIRMemorySlotInterfaces
   MLIRPass
   MLIRRuntimeVerifiableOpInterface
   MLIRSideEffectInterfaces
index 5b57283..633813f 100644 (file)
@@ -9,9 +9,12 @@
 #include "mlir/Transforms/Mem2Reg.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
 
 namespace mlir {
@@ -42,7 +45,10 @@ using namespace mlir;
 /// this, the value stored can be well defined at block boundaries, allowing
 /// the propagation of replacement through blocks.
 ///
-/// This pass computes this transformation in four main steps:
+/// This pass computes this transformation in four main steps. The two first
+/// steps are performed during an analysis phase that does not mutate IR.
+///
+/// The two steps of the analysis phase are the following:
 /// - A first step computes the list of operations that transitively use the
 /// memory slot we would like to promote. The purpose of this phase is to
 /// identify which uses must be removed to promote the slot, either by rewiring
@@ -60,6 +66,9 @@ using namespace mlir;
 /// existing. Computing this information in advance allows making sure the
 /// terminators that will forward values are capable of doing so (inability to
 /// do so aborts promotion at this step).
+///
+/// At this point, promotion is guaranteed to happen, and the mutation phase can
+/// begin with the following steps:
 /// - A third step computes the reaching definition of the memory slot at each
 /// blocking user. This is the core of the mem2reg algorithm, also known as
 /// load-store forwarding. This analyses loads and stores and propagates which
@@ -73,10 +82,6 @@ using namespace mlir;
 /// - The final fourth step uses the reaching definition to remove blocking uses
 /// in topological order.
 ///
-/// The two first steps do not mutate IR because promotion can still be aborted
-/// at this point. Once the two last steps are reached, promotion is guaranteed
-/// to succeed, allowing to start mutating IR.
-///
 /// For further reading, chapter three of SSA-based Compiler Design [1]
 /// showcases SSA construction, where mem2reg is an adaptation of the same
 /// process.
@@ -84,100 +89,11 @@ using namespace mlir;
 /// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
 ///      Springer.
 
-namespace {
-
-/// The SlotPromoter handles the state of promoting a memory slot. It wraps a
-/// slot and its associated allocator, along with analysis results related to
-/// the slot.
-class SlotPromoter {
-public:
-  SlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
-               OpBuilder &builder, DominanceInfo &dominance);
-
-  /// Prepare data for the promotion of the slot while checking if it can be
-  /// promoted. Succeeds if the slot can be promoted. This method does not
-  /// mutate IR.
-  LogicalResult prepareSlotPromotion();
-
-  /// Actually promotes the slot by mutating IR. This method must only be
-  /// called after a successful call to `SlotPromoter::prepareSlotPromotion`.
-  /// Promoting a slot does not invalidate the preparation of other slots.
-  void promoteSlot();
-
-private:
-  /// This is the first step of the promotion algorithm.
-  /// Computes the transitive uses of the slot that block promotion. This finds
-  /// uses that would block the promotion, checks that the operation has a
-  /// solution to remove the blocking use, and potentially forwards the analysis
-  /// if the operation needs further blocking uses resolved to resolve its own
-  /// uses (typically, removing its users because it will delete itself to
-  /// resolve its own blocking uses). This will fail if one of the transitive
-  /// users cannot remove a requested use, and should prevent promotion.
-  LogicalResult computeBlockingUses();
-
-  /// Computes in which blocks the value stored in the slot is actually used,
-  /// meaning blocks leading to a load. This method uses `definingBlocks`, the
-  /// set of blocks containing a store to the slot (defining the value of the
-  /// slot).
-  SmallPtrSet<Block *, 16>
-  computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
-
-  /// This is the second step of the promotion algorithm.
-  /// Computes the points in which multiple re-definitions of the slot's value
-  /// (stores) may conflict.
-  void computeMergePoints();
-
-  /// Ensures predecessors of merge points can properly provide their current
-  /// definition of the value stored in the slot to the merge point. This can
-  /// notably be an issue if the terminator used does not have the ability to
-  /// forward values through block operands.
-  bool areMergePointsUsable();
-
-  /// Computes the reaching definition for all the operations that require
-  /// promotion. `reachingDef` is the value the slot should contain at the
-  /// beginning of the block. This method returns the reached definition at the
-  /// end of the block.
-  Value computeReachingDefInBlock(Block *block, Value reachingDef);
-
-  /// This is the third step of the promotion algorithm.
-  /// Computes the reaching definition for all the operations that require
-  /// promotion. `reachingDef` corresponds to the initial value the
-  /// slot will contain before any write, typically a poison value.
-  void computeReachingDefInRegion(Region *region, Value reachingDef);
-
-  /// This is the fourth step of the promotion algorithm.
-  /// Removes the blocking uses of the slot, in topological order.
-  void removeBlockingUses();
-
-  /// Lazily-constructed default value representing the content of the slot when
-  /// no store has been executed. This function may mutate IR.
-  Value getLazyDefaultValue();
-
-  MemorySlot slot;
-  PromotableAllocationOpInterface allocator;
-  OpBuilder &builder;
-  /// Potentially non-initialized default value. Use `lazyDefaultValue` to
-  /// initialize it on demand.
-  Value defaultValue;
-  /// Blocks where multiple definitions of the slot value clash.
-  SmallPtrSet<Block *, 8> mergePoints;
-  /// Contains, for each operation, which uses must be eliminated by promotion.
-  /// This is a DAG structure because an operation that must eliminate some of
-  /// its uses always comes from a request from an operation that must
-  /// eliminate some of its own uses.
-  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
-  /// Contains the reaching definition at this operation. Reaching definitions
-  /// are only computed for promotable memory operations with blocking uses.
-  DenseMap<PromotableMemOpInterface, Value> reachingDefs;
-  DominanceInfo &dominance;
-};
-
-} // namespace
-
-SlotPromoter::SlotPromoter(MemorySlot slot,
-                           PromotableAllocationOpInterface allocator,
-                           OpBuilder &builder, DominanceInfo &dominance)
-    : slot(slot), allocator(allocator), builder(builder), dominance(dominance) {
+MemorySlotPromoter::MemorySlotPromoter(
+    MemorySlot slot, PromotableAllocationOpInterface allocator,
+    OpBuilder &builder, DominanceInfo &dominance, MemorySlotPromotionInfo info)
+    : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
+      info(std::move(info)) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = slot.ptr.dyn_cast<BlockArgument>())
@@ -191,7 +107,7 @@ SlotPromoter::SlotPromoter(MemorySlot slot,
 #endif // NDEBUG
 }
 
-Value SlotPromoter::getLazyDefaultValue() {
+Value MemorySlotPromoter::getLazyDefaultValue() {
   if (defaultValue)
     return defaultValue;
 
@@ -200,7 +116,8 @@ Value SlotPromoter::getLazyDefaultValue() {
   return defaultValue = allocator.getDefaultValue(slot, builder);
 }
 
-LogicalResult SlotPromoter::computeBlockingUses() {
+LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
+    DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses) {
   // The promotion of an operation may require the promotion of further
   // operations (typically, removing operations that use an operation that must
   // delete itself). We thus need to start from the use of the slot pointer and
@@ -216,7 +133,7 @@ LogicalResult SlotPromoter::computeBlockingUses() {
 
   // Then, propagate the requirements for the removal of uses. The
   // topologically-sorted forward slice allows for all blocking uses of an
-  // operation to have been computed before we reach it. Operations are
+  // operation to have been computed before it is reached. Operations are
   // traversed in topological order of their uses, starting from the slot
   // pointer.
   SetVector<Operation *> forwardSlice;
@@ -232,7 +149,7 @@ LogicalResult SlotPromoter::computeBlockingUses() {
     // If the operation decides it cannot deal with removing the blocking uses,
     // promotion must fail.
     if (auto promotable = dyn_cast<PromotableOpInterface>(user)) {
-      if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses))
+      if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses))
         return failure();
     } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
       if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses))
@@ -254,9 +171,9 @@ LogicalResult SlotPromoter::computeBlockingUses() {
   }
 
   // Because this pass currently only supports analysing the parent region of
-  // the slot pointer, if a promotable memory op that needs promotion is
-  // outside of this region, promotion must fail because it will be impossible
-  // to provide a valid `reachingDef` for it.
+  // the slot pointer, if a promotable memory op that needs promotion is outside
+  // of this region, promotion must fail because it will be impossible to
+  // provide a valid `reachingDef` for it.
   for (auto &[toPromote, _] : userToBlockingUses)
     if (isa<PromotableMemOpInterface>(toPromote) &&
         toPromote->getParentRegion() != slot.ptr.getParentRegion())
@@ -265,8 +182,8 @@ LogicalResult SlotPromoter::computeBlockingUses() {
   return success();
 }
 
-SmallPtrSet<Block *, 16>
-SlotPromoter::computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks) {
+SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
+    SmallPtrSetImpl<Block *> &definingBlocks) {
   SmallPtrSet<Block *, 16> liveIn;
 
   // The worklist contains blocks in which it is known that the slot value is
@@ -323,7 +240,8 @@ SlotPromoter::computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks) {
 }
 
 using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
-void SlotPromoter::computeMergePoints() {
+void MemorySlotPromotionAnalyzer::computeMergePoints(
+    SmallPtrSetImpl<Block *> &mergePoints) {
   if (slot.ptr.getParentRegion()->hasOneBlock())
     return;
 
@@ -346,7 +264,8 @@ void SlotPromoter::computeMergePoints() {
   mergePoints.insert(mergePointsVec.begin(), mergePointsVec.end());
 }
 
-bool SlotPromoter::areMergePointsUsable() {
+bool MemorySlotPromotionAnalyzer::areMergePointsUsable(
+    SmallPtrSetImpl<Block *> &mergePoints) {
   for (Block *mergePoint : mergePoints)
     for (Block *pred : mergePoint->getPredecessors())
       if (!isa<BranchOpInterface>(pred->getTerminator()))
@@ -355,10 +274,36 @@ bool SlotPromoter::areMergePointsUsable() {
   return true;
 }
 
-Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) {
+std::optional<MemorySlotPromotionInfo>
+MemorySlotPromotionAnalyzer::computeInfo() {
+  MemorySlotPromotionInfo info;
+
+  // First, find the set of operations that will need to be changed for the
+  // promotion to happen. These operations need to resolve some of their uses,
+  // either by rewiring them or simply deleting themselves. If any of them
+  // cannot find a way to resolve their blocking uses, we abort the promotion.
+  if (failed(computeBlockingUses(info.userToBlockingUses)))
+    return {};
+
+  // Then, compute blocks in which two or more definitions of the allocated
+  // variable may conflict. These blocks will need a new block argument to
+  // accomodate this.
+  computeMergePoints(info.mergePoints);
+
+  // The slot can be promoted if the block arguments to be created can
+  // actually be populated with values, which may not be possible depending
+  // on their predecessors.
+  if (!areMergePointsUsable(info.mergePoints))
+    return {};
+
+  return info;
+}
+
+Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
+                                                    Value reachingDef) {
   for (Operation &op : block->getOperations()) {
     if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
-      if (userToBlockingUses.contains(memOp))
+      if (info.userToBlockingUses.contains(memOp))
         reachingDefs.insert({memOp, reachingDef});
 
       if (Value stored = memOp.getStored(slot))
@@ -369,8 +314,8 @@ Value SlotPromoter::computeReachingDefInBlock(Block *block, Value reachingDef) {
   return reachingDef;
 }
 
-void SlotPromoter::computeReachingDefInRegion(Region *region,
-                                              Value reachingDef) {
+void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
+                                                    Value reachingDef) {
   if (region->hasOneBlock()) {
     computeReachingDefInBlock(&region->front(), reachingDef);
     return;
@@ -392,7 +337,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region,
     DfsJob job = dfsStack.pop_back_val();
     Block *block = job.block->getBlock();
 
-    if (mergePoints.contains(block)) {
+    if (info.mergePoints.contains(block)) {
       BlockArgument blockArgument =
           block->addArgument(slot.elemType, slot.ptr.getLoc());
       builder.setInsertionPointToStart(block);
@@ -404,7 +349,7 @@ void SlotPromoter::computeReachingDefInRegion(Region *region,
 
     if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
       for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
-        if (mergePoints.contains(blockOperand.get())) {
+        if (info.mergePoints.contains(blockOperand.get())) {
           if (!job.reachingDef)
             job.reachingDef = getLazyDefaultValue();
           terminator.getSuccessorOperands(blockOperand.getOperandNumber())
@@ -418,9 +363,9 @@ void SlotPromoter::computeReachingDefInRegion(Region *region,
   }
 }
 
-void SlotPromoter::removeBlockingUses() {
+void MemorySlotPromoter::removeBlockingUses() {
   llvm::SetVector<Operation *> usersToRemoveUses;
-  for (auto &user : llvm::make_first_range(userToBlockingUses))
+  for (auto &user : llvm::make_first_range(info.userToBlockingUses))
     usersToRemoveUses.insert(user);
   SetVector<Operation *> sortedUsersToRemoveUses =
       mlir::topologicalSort(usersToRemoveUses);
@@ -435,8 +380,8 @@ void SlotPromoter::removeBlockingUses() {
         reachingDef = getLazyDefaultValue();
 
       builder.setInsertionPointAfter(toPromote);
-      if (toPromoteMemOp.removeBlockingUses(slot, userToBlockingUses[toPromote],
-                                            builder, reachingDef) ==
+      if (toPromoteMemOp.removeBlockingUses(
+              slot, info.userToBlockingUses[toPromote], builder, reachingDef) ==
           DeletionKind::Delete)
         toErase.push_back(toPromote);
 
@@ -445,7 +390,7 @@ void SlotPromoter::removeBlockingUses() {
 
     auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
     builder.setInsertionPointAfter(toPromote);
-    if (toPromoteBasic.removeBlockingUses(slot, userToBlockingUses[toPromote],
+    if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
                                           builder) == DeletionKind::Delete)
       toErase.push_back(toPromote);
   }
@@ -457,7 +402,7 @@ void SlotPromoter::removeBlockingUses() {
          "after promotion, the slot pointer should not be used anymore");
 }
 
-void SlotPromoter::promoteSlot() {
+void MemorySlotPromoter::promoteSlot() {
   computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
 
   // Now that reaching definitions are known, remove all users.
@@ -465,7 +410,7 @@ void SlotPromoter::promoteSlot() {
 
   // Update terminators in dead branches to forward default if they are
   // succeeded by a merge points.
-  for (Block *mergePoint : mergePoints) {
+  for (Block *mergePoint : info.mergePoints) {
     for (BlockOperand &use : mergePoint->getUses()) {
       auto user = cast<BranchOpInterface>(use.getOwner());
       SuccessorOperands succOperands =
@@ -480,43 +425,26 @@ void SlotPromoter::promoteSlot() {
   allocator.handlePromotionComplete(slot, defaultValue);
 }
 
-LogicalResult SlotPromoter::prepareSlotPromotion() {
-  // First, find the set of operations that will need to be changed for the
-  // promotion to happen. These operations need to resolve some of their uses,
-  // either by rewiring them or simply deleting themselves. If any of them
-  // cannot find a way to resolve their blocking uses, we abort the promotion.
-  if (failed(computeBlockingUses()))
-    return failure();
-
-  // Then, compute blocks in which two or more definitions of the allocated
-  // variable may conflict. These blocks will need a new block argument to
-  // accomodate this.
-  computeMergePoints();
-
-  // The slot can be promoted if the block arguments to be created can
-  // actually be populated with values, which may not be possible depending
-  // on their predecessors.
-  return success(areMergePointsUsable());
-}
-
 LogicalResult mlir::tryToPromoteMemorySlots(
     ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
     DominanceInfo &dominance) {
   // Actual promotion may invalidate the dominance analysis, so slot promotion
   // is prepated in batches.
-  SmallVector<SlotPromoter> toPromote;
+  SmallVector<MemorySlotPromoter> toPromote;
   for (PromotableAllocationOpInterface allocator : allocators) {
     for (MemorySlot slot : allocator.getPromotableSlots()) {
       if (slot.ptr.use_empty())
         continue;
 
-      SlotPromoter promoter(slot, allocator, builder, dominance);
-      if (succeeded(promoter.prepareSlotPromotion()))
-        toPromote.emplace_back(std::move(promoter));
+      MemorySlotPromotionAnalyzer analyzer(slot, dominance);
+      std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
+      if (info)
+        toPromote.emplace_back(slot, allocator, builder, dominance,
+                               std::move(*info));
     }
   }
 
-  for (SlotPromoter &promoter : toPromote)
+  for (MemorySlotPromoter &promoter : toPromote)
     promoter.promoteSlot();
 
   return success(!toPromote.empty());