[mlir][transform][NFC] Store all Mappings in region stack
authorMatthias Springer <me@m-sp.org>
Mon, 26 Jun 2023 15:29:57 +0000 (17:29 +0200)
committerMatthias Springer <me@m-sp.org>
Mon, 26 Jun 2023 15:35:26 +0000 (17:35 +0200)
Do not swap the Mappings when entering a region that is isolated from above. Simply push another Mappings struct to the stack and prevent invalid accesses during lookups.

Differential Revision: https://reviews.llvm.org/D153765

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp

index d54f9c4..21ddd71 100644 (file)
@@ -262,12 +262,6 @@ public:
   // class body to comply with visibility and full-declaration requirements.
   inline RegionScope make_region_scope(Region &region);
 
-  /// Creates a new region scope for the given isolated-from-above region.
-  /// Unlike the non-isolated counterpart, there is no nesting expectation.
-  // Implementation note: this method is inline but implemented outside of the
-  // class body to comply with visibility and full-declaration requirements
-  inline RegionScope make_isolated_region_scope(Region &region);
-
   /// A RAII object maintaining a "stack frame" for a transform IR region. When
   /// applying a transform IR operation that contains a region, the caller is
   /// expected to create a RegionScope before applying the ops contained in the
@@ -282,51 +276,25 @@ public:
     ~RegionScope();
 
   private:
-    /// Tag structure for differentiating the constructor for isolated regions.
-    struct Isolated {};
-
     /// Creates a new scope for mappings between values defined in the given
     /// transform IR region and payload IR objects.
     RegionScope(TransformState &state, Region &region)
         : state(state), region(&region) {
-      auto res = state.mappings.try_emplace(this->region);
+      auto res = state.mappings.insert(std::make_pair(&region, Mappings()));
       assert(res.second && "the region scope is already present");
       (void)res;
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-      assert(((state.regionStack.size() == 1 && !state.regionStack.back()) ||
-              state.regionStack.back()->isProperAncestor(&region)) &&
-             "scope started at a non-nested region");
       state.regionStack.push_back(&region);
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
     }
 
-    /// Creates a new scope for mappings between values defined in the given
-    /// isolated-from-above transform IR region and payload IR objects.
-    RegionScope(TransformState &state, Region &region, Isolated)
-        : state(state), region(&region) {
-      // Store the previous mapping stack locally.
-      storedMappings = llvm::SmallDenseMap<Region *, Mappings>();
-      storedMappings->swap(state.mappings);
-      state.mappings.try_emplace(this->region);
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-      state.regionStack.push_back(this->region);
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-    }
-
     /// Back-reference to the transform state.
     TransformState &state;
 
     /// The region this scope is associated with.
     Region *region;
 
-    /// Local copy of the mappings that existed before entering the current
-    /// region. Used only when the current region is isolated so we don't
-    /// accidentally look up the values defined outside the isolated region.
-    std::optional<llvm::SmallDenseMap<Region *, Mappings>> storedMappings =
-        std::nullopt;
-
     friend RegionScope TransformState::make_region_scope(Region &);
-    friend RegionScope TransformState::make_isolated_region_scope(Region &);
   };
   friend class RegionScope;
 
@@ -446,9 +414,19 @@ private:
     return const_cast<TransformState *>(this)->getMapping(value);
   }
   Mappings &getMapping(Value value) {
-    auto it = mappings.find(value.getParentRegion());
+    Region *region = value.getParentRegion();
+    auto it = mappings.find(region);
     assert(it != mappings.end() &&
            "trying to find a mapping for a value from an unmapped region");
+#ifndef NDEBUG
+    for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+      if (r == region)
+        break;
+      if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+        llvm_unreachable(
+            "trying to get mapping beyond region that is isolated from above");
+    }
+#endif // NDEBUG
     return it->second;
   }
 
@@ -457,9 +435,19 @@ private:
     return const_cast<TransformState *>(this)->getMapping(operation);
   }
   Mappings &getMapping(Operation *operation) {
-    auto it = mappings.find(operation->getParentRegion());
+    Region *region = operation->getParentRegion();
+    auto it = mappings.find(region);
     assert(it != mappings.end() &&
            "trying to find a mapping for an operation from an unmapped region");
+#ifndef NDEBUG
+    for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+      if (r == region)
+        break;
+      if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+        llvm_unreachable(
+            "trying to get mapping beyond region that is isolated from above");
+    }
+#endif // NDEBUG
     return it->second;
   }
 
@@ -676,9 +664,9 @@ private:
   /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
   void compactOpHandles();
 
-  /// The mappings between transform IR values and payload IR ops, aggregated by
-  /// the region in which the transform IR values are defined.
-  llvm::SmallDenseMap<Region *, Mappings> mappings;
+  /// A stack of mappings between transform IR values and payload IR ops,
+  /// aggregated by the region in which the transform IR values are defined.
+  llvm::MapVector<Region *, Mappings> mappings;
 
   /// Op handles may be temporarily mapped to nullptr to avoid invalidating
   /// payload op iterators. This set contains all op handles with nullptrs.
@@ -834,14 +822,6 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
   return RegionScope(*this, region);
 }
 
-/// Creates a RAII object the lifetime of which corresponds to the new mapping
-/// for transform IR values defined in the given isolated-from-above region.
-/// Values defined in surrounding regions cannot be accessed.
-TransformState::RegionScope
-TransformState::make_isolated_region_scope(Region &region) {
-  return RegionScope(*this, region, RegionScope::Isolated());
-}
-
 /// A listener that updates a TransformState based on IR modifications. This
 /// listener can be used during a greedy pattern rewrite to keep the transform
 /// state up-to-date.
index 27794e6..ce6ec5a 100644 (file)
@@ -45,7 +45,7 @@ transform::TransformState::TransformState(
   for (ArrayRef<MappedValue> mapping : extraMappings)
     topLevelMappedValues.push_back(mapping);
 
-  auto result = mappings.try_emplace(region);
+  auto result = mappings.insert(std::make_pair(region, Mappings()));
   assert(result.second && "the region scope is already present");
   (void)result;
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -85,12 +85,15 @@ transform::TransformState::getPayloadValues(Value handleValue) const {
 LogicalResult transform::TransformState::getHandlesForPayloadOp(
     Operation *op, SmallVectorImpl<Value> &handles) const {
   bool found = false;
-  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
     auto iterator = mapping.reverse.find(op);
     if (iterator != mapping.reverse.end()) {
       llvm::append_range(handles, iterator->getSecond());
       found = true;
     }
+    // Stop looking when reaching a region that is isolated from above.
+    if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+      break;
   }
 
   return success(found);
@@ -99,12 +102,15 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
 LogicalResult transform::TransformState::getHandlesForPayloadValue(
     Value payloadValue, SmallVectorImpl<Value> &handles) const {
   bool found = false;
-  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
     auto iterator = mapping.reverseValues.find(payloadValue);
     if (iterator != mapping.reverseValues.end()) {
       llvm::append_range(handles, iterator->getSecond());
       found = true;
     }
+    // Stop looking when reaching a region that is isolated from above.
+    if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+      break;
   }
 
   return success(found);
@@ -590,8 +596,10 @@ void transform::TransformState::recordOpHandleInvalidation(
   // number of IR objects (operations and values). Alternatively, we could walk
   // the IR nested in each payload op associated with the given handle and look
   // for handles associated with each operation and value.
-  for (const transform::TransformState::Mappings &mapping :
-       llvm::make_second_range(mappings)) {
+  for (const auto &[region, mapping] : llvm::reverse(mappings)) {
+    // Stop lookup when reaching a region that is isolated from above.
+    if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+      break;
     // Go over all op handle mappings and mark as invalidated any handle
     // pointing to any of the payload ops associated with the given handle or
     // any op nested in them.
@@ -1102,8 +1110,6 @@ transform::TransformState::RegionScope::~RegionScope() {
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
   state.mappings.erase(region);
-  if (storedMappings.has_value())
-    state.mappings.swap(*storedMappings);
 
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   // If the last handle to a payload op has gone out of scope, we no longer
index 8742489..1d89ff4 100644 (file)
@@ -446,7 +446,7 @@ static DiagnosedSilenceableFailure
 matchBlock(Block &block, Operation *op, transform::TransformState &state,
            SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
   assert(block.getParent() && "cannot match using a detached block");
-  auto matchScope = state.make_isolated_region_scope(*block.getParent());
+  auto matchScope = state.make_region_scope(*block.getParent());
   if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
     return DiagnosedSilenceableFailure::definiteFailure();
 
@@ -524,7 +524,7 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
           continue;
         }
 
-        auto scope = state.make_isolated_region_scope(action.getFunctionBody());
+        auto scope = state.make_region_scope(action.getFunctionBody());
         for (auto &&[arg, map] : llvm::zip_equal(
                  action.getFunctionBody().front().getArguments(), mappings)) {
           if (failed(state.mapBlockArgument(arg, map)))
@@ -1029,7 +1029,7 @@ transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
   // Map operands to block arguments.
   SmallVector<SmallVector<MappedValue>> mappings;
   detail::prepareValueMappings(mappings, getOperands(), state);
-  auto scope = state.make_isolated_region_scope(callee.getBody());
+  auto scope = state.make_region_scope(callee.getBody());
   for (auto &&[arg, map] :
        llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
     if (failed(state.mapBlockArgument(arg, map)))