[mlir][transform] Fix TrackingListener in regions that are isolated from above
authorMatthias Springer <me@m-sp.org>
Mon, 26 Jun 2023 16:04:23 +0000 (18:04 +0200)
committerMatthias Springer <me@m-sp.org>
Mon, 26 Jun 2023 16:05:24 +0000 (18:05 +0200)
When an operation is removed/replaced, the TrackingListener updates the internal transform state mapping between handles and payload IR. All handles must be updated, even the ones that are defined in a region that is beyond the most recent region that is isolated from above.

This fixes a bug, where a payload op was erased in a named sequence. Not only handles defined inside of the named region must be updated, but also all other handles such as the ones where the sequence is included.

Differential Revision: https://reviews.llvm.org/D153767

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/test-pattern-application.mlir

index 20f9b21..8411eb0 100644 (file)
@@ -222,14 +222,19 @@ public:
 
   /// Populates `handles` with all handles pointing to the given Payload IR op.
   /// Returns success if such handles exist, failure otherwise.
+  /// If `includeOutOfScope` is set to "true", handles that are defined in
+  /// regions beyond the most recent isolated from above region are included.
   LogicalResult getHandlesForPayloadOp(Operation *op,
-                                       SmallVectorImpl<Value> &handles) const;
+                                       SmallVectorImpl<Value> &handles,
+                                       bool includeOutOfScope = false) const;
 
   /// Populates `handles` with all handles pointing to the given payload IR
   /// value. Returns success if such handles exist, failure otherwise.
-  LogicalResult
-  getHandlesForPayloadValue(Value payloadValue,
-                            SmallVectorImpl<Value> &handles) const;
+  /// If `includeOutOfScope` is set to "true", handles that are defined in
+  /// regions beyond the most recent isolated from above region are included.
+  LogicalResult getHandlesForPayloadValue(Value payloadValue,
+                                          SmallVectorImpl<Value> &handles,
+                                          bool includeOutOfScope = false) const;
 
   /// Applies the transformation specified by the given transform op and updates
   /// the state accordingly.
@@ -410,42 +415,53 @@ private:
                  const TransformOptions &options = TransformOptions());
 
   /// Returns the mappings frame for the region in which the value is defined.
-  const Mappings &getMapping(Value value) const {
-    return const_cast<TransformState *>(this)->getMapping(value);
+  /// If `allowOutOfScope` is set to "false", asserts that the value is in
+  /// scope, based on the current stack of frames.
+  const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
+    return const_cast<TransformState *>(this)->getMapping(value,
+                                                          allowOutOfScope);
   }
-  Mappings &getMapping(Value value) {
+  Mappings &getMapping(Value value, bool allowOutOfScope = false) {
     Region *region = value.getParentRegion();
     auto it = mappings.find(region);
     assert(it != mappings.end() &&
            "trying to find a mapping for a value from an unmapped region");
 #ifndef NDEBUG
-    for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
-      if (r == region)
-        break;
-      if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
-        llvm_unreachable(
-            "trying to get mapping beyond region that is isolated from above");
+    if (!allowOutOfScope) {
+      for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+        if (r == region)
+          break;
+        if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+          llvm_unreachable("trying to get mapping beyond region that is "
+                           "isolated from above");
+      }
     }
 #endif // NDEBUG
     return it->second;
   }
 
   /// Returns the mappings frame for the region in which the operation resides.
-  const Mappings &getMapping(Operation *operation) const {
-    return const_cast<TransformState *>(this)->getMapping(operation);
+  /// If `allowOutOfScope` is set to "false", asserts that the operation is in
+  /// scope, based on the current stack of frames.
+  const Mappings &getMapping(Operation *operation,
+                             bool allowOutOfScope = false) const {
+    return const_cast<TransformState *>(this)->getMapping(operation,
+                                                          allowOutOfScope);
   }
-  Mappings &getMapping(Operation *operation) {
+  Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
     Region *region = operation->getParentRegion();
     auto it = mappings.find(region);
     assert(it != mappings.end() &&
            "trying to find a mapping for an operation from an unmapped region");
 #ifndef NDEBUG
-    for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
-      if (r == region)
-        break;
-      if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
-        llvm_unreachable(
-            "trying to get mapping beyond region that is isolated from above");
+    if (!allowOutOfScope) {
+      for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
+        if (r == region)
+          break;
+        if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+          llvm_unreachable("trying to get mapping beyond region that is "
+                           "isolated from above");
+      }
     }
 #endif // NDEBUG
     return it->second;
index 5bd9516..961cf34 100644 (file)
@@ -83,7 +83,8 @@ transform::TransformState::getPayloadValues(Value handleValue) const {
 }
 
 LogicalResult transform::TransformState::getHandlesForPayloadOp(
-    Operation *op, SmallVectorImpl<Value> &handles) const {
+    Operation *op, SmallVectorImpl<Value> &handles,
+    bool includeOutOfScope) const {
   bool found = false;
   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
     auto iterator = mapping.reverse.find(op);
@@ -92,7 +93,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
       found = true;
     }
     // Stop looking when reaching a region that is isolated from above.
-    if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+    if (!includeOutOfScope &&
+        region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
       break;
   }
 
@@ -100,7 +102,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadOp(
 }
 
 LogicalResult transform::TransformState::getHandlesForPayloadValue(
-    Value payloadValue, SmallVectorImpl<Value> &handles) const {
+    Value payloadValue, SmallVectorImpl<Value> &handles,
+    bool includeOutOfScope) const {
   bool found = false;
   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
     auto iterator = mapping.reverseValues.find(payloadValue);
@@ -109,7 +112,8 @@ LogicalResult transform::TransformState::getHandlesForPayloadValue(
       found = true;
     }
     // Stop looking when reaching a region that is isolated from above.
-    if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
+    if (!includeOutOfScope &&
+        region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
       break;
   }
 
@@ -343,7 +347,8 @@ transform::TransformState::replacePayloadOp(Operation *op,
 #ifndef NDEBUG
   for (Value opResult : op->getResults()) {
     SmallVector<Value> valueHandles;
-    (void)getHandlesForPayloadValue(opResult, valueHandles);
+    (void)getHandlesForPayloadValue(opResult, valueHandles,
+                                    /*includeOutOfScope=*/true);
     assert(valueHandles.empty() && "expected no mapping to old results");
   }
 #endif // NDEBUG
@@ -351,10 +356,10 @@ transform::TransformState::replacePayloadOp(Operation *op,
   // Drop the mapping between the op and all handles that point to it. Fail if
   // there are no handles.
   SmallVector<Value> opHandles;
-  if (failed(getHandlesForPayloadOp(op, opHandles)))
+  if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
     return failure();
   for (Value handle : opHandles) {
-    Mappings &mappings = getMapping(handle);
+    Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
     dropMappingEntry(mappings.reverse, op, handle);
   }
 
@@ -385,7 +390,7 @@ transform::TransformState::replacePayloadOp(Operation *op,
   // element from an array invalidates iterators; merely changing the value of
   // elements does not.
   for (Value handle : opHandles) {
-    Mappings &mappings = getMapping(handle);
+    Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
     auto it = mappings.direct.find(handle);
     if (it == mappings.direct.end())
       continue;
@@ -410,11 +415,12 @@ transform::TransformState::replacePayloadOp(Operation *op,
 LogicalResult
 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
   SmallVector<Value> valueHandles;
-  if (failed(getHandlesForPayloadValue(value, valueHandles)))
+  if (failed(getHandlesForPayloadValue(value, valueHandles,
+                                       /*includeOutOfScope=*/true)))
     return failure();
 
   for (Value handle : valueHandles) {
-    Mappings &mappings = getMapping(handle);
+    Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
     dropMappingEntry(mappings.reverseValues, value, handle);
 
     // If replacing with null, that is erasing the mapping, drop the mapping
@@ -764,7 +770,7 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
 
 void transform::TransformState::compactOpHandles() {
   for (Value handle : opHandlesToCompact) {
-    Mappings &mappings = getMapping(handle);
+    Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
     llvm::erase_value(mappings.direct[handle], nullptr);
   }
   opHandlesToCompact.clear();
@@ -1346,7 +1352,8 @@ void transform::TrackingListener::notifyOperationReplaced(
 
   // Replace op handle.
   SmallVector<Value> opHandles;
-  if (failed(getTransformState().getHandlesForPayloadOp(op, opHandles))) {
+  if (failed(getTransformState().getHandlesForPayloadOp(
+          op, opHandles, /*includeOutOfScope=*/true))) {
     // Op is not tracked.
     return;
   }
index 062ec6e..992f786 100644 (file)
@@ -131,11 +131,47 @@ transform.sequence failures(propagate) {
   transform.apply_patterns to %0 {
     transform.apply_patterns.transform.test_patterns
   } : !transform.any_op
+  // No marker should be printed.
   transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
 }
 
 // -----
 
+// CHECK-LABEL: func @erase_tracked_op_in_named_sequence()
+//       CHECK:   "test.container"() ({
+//  CHECK-NEXT:   ^bb0:
+//  CHECK-NEXT:   }) : () -> ()
+module {
+  func.func @erase_tracked_op_in_named_sequence() {
+    "test.container"() ({
+      // expected-remark @below {{matched op}}
+      %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32)
+    }) : () -> ()
+    return
+  }
+
+  module attributes { transform.with_named_sequence } {
+    transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> () {
+      transform.apply_patterns to %arg0 {
+        transform.apply_patterns.transform.test_patterns
+      } : !transform.any_op
+      transform.yield
+    }
+
+    transform.sequence failures(propagate) {
+    ^bb1(%arg1: !transform.any_op):
+      %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+      transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op
+      include @foo failures(propagate) (%0) : (!transform.any_op) -> ()
+      // No marker should be printed.
+      transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
+    }
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @canonicalization(
 //       CHECK:   %[[c5:.*]] = arith.constant 5 : index
 //       CHECK:   return %[[c5]]