[mlir][transform] Remove redundant handle check in `replacePayload...`
authorMatthias Springer <me@m-sp.org>
Mon, 26 Jun 2023 15:49:31 +0000 (17:49 +0200)
committerMatthias Springer <me@m-sp.org>
Mon, 26 Jun 2023 15:59:06 +0000 (17:59 +0200)
Differential Revision: https://reviews.llvm.org/D153766

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

index 21ddd71..20f9b21 100644 (file)
@@ -555,7 +555,8 @@ private:
                           ArrayRef<Operation *> payloadOperations);
 
   /// Replaces the given payload op with another op. If the replacement op is
-  /// null, removes the association of the payload op with its handle.
+  /// null, removes the association of the payload op with its handle. Returns
+  /// failure if the op is not associated with any handle.
   ///
   /// Note: This function does not update value handles. None of the original
   /// op's results are allowed to be mapped to any value handle.
@@ -563,7 +564,7 @@ private:
 
   /// Replaces the given payload value with another value. If the replacement
   /// value is null, removes the association of the payload value with its
-  /// handle.
+  /// handle. Returns failure if the value is not associated with any handle.
   LogicalResult replacePayloadValue(Value value, Value replacement);
 
   /// Records handle invalidation reporters into `newlyInvalidated`.
index ce6ec5a..5bd9516 100644 (file)
@@ -338,14 +338,7 @@ void transform::TransformState::forgetValueMapping(
 LogicalResult
 transform::TransformState::replacePayloadOp(Operation *op,
                                             Operation *replacement) {
-  // Drop the mapping between the op and all handles that point to it. Don't
-  // care if there are on such handles.
-  SmallVector<Value> opHandles;
-  (void)getHandlesForPayloadOp(op, opHandles);
-  for (Value handle : opHandles) {
-    Mappings &mappings = getMapping(handle);
-    dropMappingEntry(mappings.reverse, op, handle);
-  }
+  // TODO: consider invalidating the handles to nested objects here.
 
 #ifndef NDEBUG
   for (Value opResult : op->getResults()) {
@@ -355,23 +348,29 @@ transform::TransformState::replacePayloadOp(Operation *op,
   }
 #endif // NDEBUG
 
+  // Drop the mapping between the op and all handles that point to it. Fail if
+  // there are no handles.
+  SmallVector<Value> opHandles;
+  if (failed(getHandlesForPayloadOp(op, opHandles)))
+    return failure();
+  for (Value handle : opHandles) {
+    Mappings &mappings = getMapping(handle);
+    dropMappingEntry(mappings.reverse, op, handle);
+  }
+
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   if (options.getExpensiveChecksEnabled()) {
     auto it = cachedNames.find(op);
     assert(it != cachedNames.end() && "entry not found");
     assert(it->second == op->getName() && "operation name mismatch");
     cachedNames.erase(it);
-  }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
-  // TODO: consider invalidating the handles to nested objects here.
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  if (replacement && options.getExpensiveChecksEnabled()) {
-    auto insertion = cachedNames.insert({replacement, replacement->getName()});
-    if (!insertion.second) {
-      assert(insertion.first->second == replacement->getName() &&
-             "operation is already cached with a different name");
+    if (replacement) {
+      auto insertion =
+          cachedNames.insert({replacement, replacement->getName()});
+      if (!insertion.second) {
+        assert(insertion.first->second == replacement->getName() &&
+               "operation is already cached with a different name");
+      }
     }
   }
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -411,7 +410,8 @@ transform::TransformState::replacePayloadOp(Operation *op,
 LogicalResult
 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
   SmallVector<Value> valueHandles;
-  (void)getHandlesForPayloadValue(value, valueHandles);
+  if (failed(getHandlesForPayloadValue(value, valueHandles)))
+    return failure();
 
   for (Value handle : valueHandles) {
     Mappings &mappings = getMapping(handle);
@@ -537,30 +537,30 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
     Location ancestorLoc = ancestor->getLoc();
     Location opLoc = definingOp->getLoc();
     Location valueLoc = payloadValue.getLoc();
-    newlyInvalidated[valueHandle] =
-        [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo,
-         ancestorLoc, opLoc, valueLoc](Location currentLoc) {
-          InFlightDiagnostic diag = emitError(currentLoc)
-                                    << "op uses a handle invalidated by a "
-                                       "previously executed transform op";
-          diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
-          diag.attachNote(owner->getLoc())
-              << "invalidated by this transform op that consumes its operand #"
-              << operandNo
-              << " and invalidates all handles to payload IR entities "
-                 "associated with this operand and entities nested in them";
-          diag.attachNote(ancestorLoc)
-              << "ancestor op associated with the consumed handle";
-          if (resultNo) {
-            diag.attachNote(opLoc)
-                << "op defining the value as result #" << *resultNo;
-          } else {
-            diag.attachNote(opLoc)
-                << "op defining the value as block argument #" << argumentNo
-                << " of block #" << blockNo << " in region #" << regionNo;
-          }
-          diag.attachNote(valueLoc) << "payload value";
-        };
+    newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
+                                     argumentNo, blockNo, regionNo, ancestorLoc,
+                                     opLoc, valueLoc](Location currentLoc) {
+      InFlightDiagnostic diag = emitError(currentLoc)
+                                << "op uses a handle invalidated by a "
+                                   "previously executed transform op";
+      diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
+      diag.attachNote(owner->getLoc())
+          << "invalidated by this transform op that consumes its operand #"
+          << operandNo
+          << " and invalidates all handles to payload IR entities "
+             "associated with this operand and entities nested in them";
+      diag.attachNote(ancestorLoc)
+          << "ancestor op associated with the consumed handle";
+      if (resultNo) {
+        diag.attachNote(opLoc)
+            << "op defining the value as result #" << *resultNo;
+      } else {
+        diag.attachNote(opLoc)
+            << "op defining the value as block argument #" << argumentNo
+            << " of block #" << blockNo << " in region #" << regionNo;
+      }
+      diag.attachNote(valueLoc) << "payload value";
+    };
   }
 }
 
@@ -1064,10 +1064,6 @@ transform::TransformState::Extension::~Extension() = default;
 LogicalResult
 transform::TransformState::Extension::replacePayloadOp(Operation *op,
                                                        Operation *replacement) {
-  SmallVector<Value> handles;
-  if (failed(state.getHandlesForPayloadOp(op, handles)))
-    return failure();
-
   // TODO: we may need to invalidate handles to operations and values nested in
   // the operation being replaced.
   return state.replacePayloadOp(op, replacement);
@@ -1076,10 +1072,6 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op,
 LogicalResult
 transform::TransformState::Extension::replacePayloadValue(Value value,
                                                           Value replacement) {
-  SmallVector<Value> handles;
-  if (failed(state.getHandlesForPayloadValue(value, handles)))
-    return failure();
-
   return state.replacePayloadValue(value, replacement);
 }