[mlir] harden expensive-checks mode against ops with repeated operands
authorAlex Zinenko <zinenko@google.com>
Fri, 26 May 2023 15:50:59 +0000 (15:50 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 30 May 2023 07:53:39 +0000 (07:53 +0000)
Transform operations may indicate that they may accept and consume
several handles pointing to the same or nested payload entities. The
initial implementation of the expensive-checks mode was simply ignoring
such cases as consuming the second handle would fail the check after the
first handle invalidated it by consuming the same payload. Additional
checks had been added since then, which could now trigger assertions in
the expensive-checks module itself (instead of or in addition to
use-after-free assertions down the road), specifically because the
payload associations for invalidated handles is removed from the state
to enable other kinds of checking.

Rework the handling of transform operations with repeated handles so
use-after-consume is still reported properly if the consumption happened
by a preceding operation, as opposed to the a preceding operand of the
same operation that is still (corretly) ignored if the op requests that.

Depends on: D151560

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D151569

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/expensive-checks.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

index 4c07791..fc1ffeb 100644 (file)
@@ -153,6 +153,10 @@ private:
   /// values in the payload IR. Also works for reverse mappings.
   using ValueMapping = DenseMap<Value, SmallVector<Value>>;
 
+  /// Mapping between a Value in the transform IR and an error message that
+  /// should be emitted when the value is used.
+  using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
+
   /// The bidirectional mappings between transform IR values and payload IR
   /// operations, and the mapping between transform IR values and parameters.
   struct Mappings {
@@ -567,26 +571,85 @@ private:
   /// handle.
   LogicalResult replacePayloadValue(Value value, Value replacement);
 
-  /// If the operand is a handle consumed by the operation, i.e. has the "free"
-  /// memory effect associated with it, identifies other handles that are
-  /// pointing to payload IR operations nested in the operations pointed to by
-  /// the consumed handle. Marks all such handles as invalidated to trigger
-  /// errors if they are used. If `throughValue` is passed, record the fact that
-  /// an op handle was invalidated because a value handle associated with
-  /// results of the payload op or its block arguments was invalidated.
+  /// Records handle invalidation reporters into `newlyInvalidated`.
+  /// Specifically,
+  ///  - `handle` is the op operand that consumes the handle,
+  ///  - `potentialAncestors` is a list of ancestors of the payload operation
+  ///     that the consumed handle is associated with, including itself,
+  ///  - `throughValue` is the payload value the handle to which is consumed,
+  ///     when it is the case, null when the operation handle is consumed
+  ///     directly.
+  /// Iterates over all known operation and value handles and records reporters
+  /// for any potential future use of `handle` or any other handle that is
+  /// invalidated by its consumption, i.e., any handle pointing to any payload
+  /// IR entity (operation or value) associated with the same payload IR entity
+  /// as the consumed handle, or any nested payload IR entity. If
+  /// `potentialAncestors` is empty, records the reporter anyway. Does not
+  /// override existing reporters. This must remain a const method so it doesn't
+  /// inadvertently mutate `invalidatedHandles` too early.
   void recordOpHandleInvalidation(OpOperand &consumingHandle,
                                   ArrayRef<Operation *> potentialAncestors,
-                                  Value throughValue = nullptr);
-  void recordOpHandleInvalidationOne(OpOperand &handle,
-                                     ArrayRef<Operation *> potentialAncestors,
-                                     Operation *payloadOp, Value otherHandle,
-                                     Value throughValue = nullptr);
-
+                                  Value throughValue,
+                                  InvalidatedHandleMap &newlyInvalidated) const;
+
+  /// Records handle invalidation reporters into `newlyInvalidated`.
+  /// Specifically,
+  ///  - `consumingHandle` is the op operand that consumes the handle,
+  ///  - `potentialAncestors` is a list of ancestors of the payload operation
+  ///     that the consumed handle is associated with, including itself,
+  ///  - `payloadOp` is the operation itself,
+  ///  - `otherHandle` is another that may be associated with the affected
+  ///     payload operations
+  ///  - `throughValue` is the payload value the handle to which is consumed,
+  ///     when it is the case, null when the operation handle is consumed
+  ///     directly.
+  /// Looks at the payload opreations associated with `otherHandle` and if any
+  /// of these operations has an ancestor (or is itself) listed in
+  /// `potentialAncestors`, records the error message describing the use of the
+  /// invalidated handle. Does nothing if `otherHandle` already has a reporter
+  /// associated with it. This must remain a const method so it doesn't
+  /// inadvertently mutate `invalidatedHandles` too early.
+  void recordOpHandleInvalidationOne(
+      OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
+      Operation *payloadOp, Value otherHandle, Value throughValue,
+      InvalidatedHandleMap &newlyInvalidated) const;
+
+  /// Records handle invalidation reporters into `newlyInvalidated`.
+  /// Specifically,
+  ///  - `opHandle` is the op operand that consumes the handle;
+  ///  - `potentialAncestors` is a list of ancestors of the payload operation
+  ///     that the consumed handle is associated with, including itself;
+  ///  - `payloadValue` is the value defined by the operation associated with
+  ///     the consuming handle as either op result or block argument;
+  ///  - `valueHandle` is another that may be associated with the payload value.
+  /// Looks at the payload values associated with `valueHandle` and if any of
+  /// these values is defined, as op result or block argument, by an operation
+  /// whose ancestor (or the operation itself) is listed in
+  /// `potentialAncestors`, records the error message describing the use of the
+  /// invalidated handle. Does nothing if `valueHandle` already has a reporter
+  /// associated with it. This must remain a const method so it doesn't
+  /// inadvertently mutate `invalidatedHandles` too early.
   void recordValueHandleInvalidationByOpHandleOne(
       OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
-      Value payloadValue, Value valueHandle);
-
-  void recordValueHandleInvalidation(OpOperand &valueHandle);
+      Value payloadValue, Value valueHandle,
+      InvalidatedHandleMap &newlyInvalidated) const;
+
+  /// Records handle invalidation reporters into `newlyInvalidated`.
+  /// Specifically,
+  ///  - `valueHandle` is the op operand that consumes the handle,
+  ///  - `throughValue` is the payload value the handle to which is consumed,
+  ///     when it is the case, null when the operation handle is consumed
+  ///     directly.
+  /// Iterates over all known operation and value handles and records reporters
+  /// for any potential future use of `handle` or any other handle that is
+  /// invalidated by its consumption, i.e., any handle pointing to any payload
+  /// IR entity (operation or value) associated with the same payload IR entity
+  /// as the consumed handle, or any nested payload IR entity. Does not override
+  /// existing reporters. This must remain a const method so it doesn't
+  /// inadvertently mutate `invalidatedHandles` too early.
+  void
+  recordValueHandleInvalidation(OpOperand &valueHandle,
+                                InvalidatedHandleMap &newlyInvalidated) const;
 
   /// Checks that the operation does not use invalidated handles as operands.
   /// Reports errors and returns failure if it does. Otherwise, invalidates the
@@ -596,6 +659,13 @@ private:
   LogicalResult
   checkAndRecordHandleInvalidation(TransformOpInterface transform);
 
+  /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
+  /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
+  /// early.
+  LogicalResult checkAndRecordHandleInvalidationImpl(
+      transform::TransformOpInterface transform,
+      transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
+
   /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
   void compactOpHandles();
 
@@ -628,7 +698,7 @@ private:
   /// describe when the handles were invalidated. Calling such a function emits
   /// a user-visible diagnostic with an additional note pointing to the given
   /// location.
-  DenseMap<Value, std::function<void(Location)>> invalidatedHandles;
+  InvalidatedHandleMap invalidatedHandles;
 
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// A stack of nested regions that are being processed in the transform IR.
index 85535c7..b1dc668 100644 (file)
@@ -431,10 +431,13 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) {
 
 void transform::TransformState::recordOpHandleInvalidationOne(
     OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
-    Operation *payloadOp, Value otherHandle, Value throughValue) {
+    Operation *payloadOp, Value otherHandle, Value throughValue,
+    transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
   // If the op is associated with invalidated handle, skip the check as it
-  // may be reading invalid IR.
-  if (invalidatedHandles.count(otherHandle))
+  // may be reading invalid IR. This also ensures we report the first
+  // invalidation and not the last one.
+  if (invalidatedHandles.count(otherHandle) ||
+      newlyInvalidated.count(otherHandle))
     return;
 
   FULL_LDBG("--recordOpHandleInvalidationOne\n");
@@ -467,9 +470,9 @@ void transform::TransformState::recordOpHandleInvalidationOne(
     Location opLoc = payloadOp->getLoc();
     std::optional<Location> throughValueLoc =
         throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
-    invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
-                                       otherHandle,
-                                       throughValueLoc](Location currentLoc) {
+    newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
+                                     otherHandle,
+                                     throughValueLoc](Location currentLoc) {
       InFlightDiagnostic diag = emitError(currentLoc)
                                 << "op uses a handle invalidated by a "
                                    "previously executed transform op";
@@ -490,11 +493,14 @@ void transform::TransformState::recordOpHandleInvalidationOne(
 }
 
 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
-    OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
-    Value payloadValue, Value valueHandle) {
+    OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
+    Value payloadValue, Value valueHandle,
+    transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
   // If the op is associated with invalidated handle, skip the check as it
-  // may be reading invalid IR.
-  if (invalidatedHandles.count(valueHandle))
+  // may be reading invalid IR. This also ensures we report the first
+  // invalidation and not the last one.
+  if (invalidatedHandles.count(valueHandle) ||
+      newlyInvalidated.count(valueHandle))
     return;
 
   for (Operation *ancestor : potentialAncestors) {
@@ -517,12 +523,12 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
     if (!ancestor->isAncestor(definingOp))
       continue;
 
-    Operation *owner = consumingHandle.getOwner();
-    unsigned operandNo = consumingHandle.getOperandNumber();
+    Operation *owner = opHandle.getOwner();
+    unsigned operandNo = opHandle.getOperandNumber();
     Location ancestorLoc = ancestor->getLoc();
     Location opLoc = definingOp->getLoc();
     Location valueLoc = payloadValue.getLoc();
-    invalidatedHandles[valueHandle] =
+    newlyInvalidated[valueHandle] =
         [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo,
          ancestorLoc, opLoc, valueLoc](Location currentLoc) {
           InFlightDiagnostic diag = emitError(currentLoc)
@@ -551,7 +557,8 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
 
 void transform::TransformState::recordOpHandleInvalidation(
     OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
-    Value throughValue) {
+    Value throughValue,
+    transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
 
   if (potentialAncestors.empty()) {
     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
@@ -561,7 +568,7 @@ void transform::TransformState::recordOpHandleInvalidation(
 
     Operation *owner = handle.getOwner();
     unsigned operandNo = handle.getOperandNumber();
-    invalidatedHandles[handle.get()] = [owner, operandNo](Location currentLoc) {
+    newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
       InFlightDiagnostic diag = emitError(currentLoc)
                                 << "op uses a handle associated with empty "
                                    "payload and invalidated by a "
@@ -580,14 +587,16 @@ void transform::TransformState::recordOpHandleInvalidation(
   // number of IR objects (operations and values). Alternatively, we could walk
   // the IR nested in each payload op associated with the given handle and look
   // for handles associated with each operation and value.
-  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+  for (const transform::TransformState::Mappings &mapping :
+       llvm::make_second_range(mappings)) {
     // Go over all op handle mappings and mark as invalidated any handle
     // pointing to any of the payload ops associated with the given handle or
     // any op nested in them.
     for (const auto &[payloadOp, otherHandles] : mapping.reverse) {
       for (Value otherHandle : otherHandles)
         recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
-                                      otherHandle, throughValue);
+                                      otherHandle, throughValue,
+                                      newlyInvalidated);
     }
     // Go over all value handle mappings and mark as invalidated any handle
     // pointing to any result of the payload op associated with the given handle
@@ -597,13 +606,15 @@ void transform::TransformState::recordOpHandleInvalidation(
     for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) {
       for (Value valueHandle : valueHandles)
         recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
-                                                   payloadValue, valueHandle);
+                                                   payloadValue, valueHandle,
+                                                   newlyInvalidated);
     }
   }
 }
 
 void transform::TransformState::recordValueHandleInvalidation(
-    OpOperand &valueHandle) {
+    OpOperand &valueHandle,
+    transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
   // Invalidate other handles to the same value.
   for (Value payloadValue : getPayloadValues(valueHandle.get())) {
     SmallVector<Value> otherValueHandles;
@@ -612,8 +623,8 @@ void transform::TransformState::recordValueHandleInvalidation(
       Operation *owner = valueHandle.getOwner();
       unsigned operandNo = valueHandle.getOperandNumber();
       Location valueLoc = payloadValue.getLoc();
-      invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo,
-                                         valueLoc](Location currentLoc) {
+      newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
+                                       valueLoc](Location currentLoc) {
         InFlightDiagnostic diag = emitError(currentLoc)
                                   << "op uses a handle invalidated by a "
                                      "previously executed transform op";
@@ -629,17 +640,24 @@ void transform::TransformState::recordValueHandleInvalidation(
 
     if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
       Operation *payloadOp = opResult.getOwner();
-      recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue);
+      recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
+                                 newlyInvalidated);
     } else {
       auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
       for (Operation &payloadOp : *arg.getOwner())
-        recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue);
+        recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
+                                   newlyInvalidated);
     }
   }
 }
 
-LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
-    TransformOpInterface transform) {
+/// Checks that the operation does not use invalidated handles as operands.
+/// Reports errors and returns failure if it does. Otherwise, invalidates the
+/// handles consumed by the operation as well as any handles pointing to payload
+/// IR operations nested in the operations associated with the consumed handles.
+LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
+    transform::TransformOpInterface transform,
+    transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
   FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
   auto memoryEffectsIface =
       cast<MemoryEffectOpInterface>(transform.getOperation());
@@ -651,13 +669,23 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
       (DBGS() << "----iterate on handle: " << target.get() << "\n");
     });
-    // If the operand uses an invalidated handle, report it.
+    // If the operand uses an invalidated handle, report it. If the operation
+    // allows handles to point to repeated payload operations, only report
+    // pre-existing invalidation errors. Otherwise, also report invalidations
+    // caused by the current transform operation affecting its other operands.
     auto it = invalidatedHandles.find(target.get());
-    if (!transform.allowsRepeatedHandleOperands() &&
-        it != invalidatedHandles.end()) {
-      FULL_LDBG("--End checkAndRecordHandleInvalidation -> FAILURE\n");
+    auto nit = newlyInvalidated.find(target.get());
+    if (it != invalidatedHandles.end()) {
+      FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
+                "invalidated -> FAILURE\n");
       return it->getSecond()(transform->getLoc()), failure();
     }
+    if (!transform.allowsRepeatedHandleOperands() &&
+        nit != newlyInvalidated.end()) {
+      FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
+                "invalidated (by this op) -> FAILURE\n");
+      return nit->getSecond()(transform->getLoc()), failure();
+    }
 
     // Invalidate handles pointing to the operations nested in the operation
     // associated with the handle consumed by this operation.
@@ -666,15 +694,18 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
              effect.getValue() == target.get();
     };
     if (llvm::any_of(effects, consumesTarget)) {
-      FULL_LDBG("----found consume effect -> SKIP\n");
-      if (llvm::isa<TransformHandleTypeInterface>(target.get().getType())) {
+      FULL_LDBG("----found consume effect\n");
+      if (llvm::isa<transform::TransformHandleTypeInterface>(
+              target.get().getType())) {
         FULL_LDBG("----recordOpHandleInvalidation\n");
-        ArrayRef<Operation *> payloadOps = getPayloadOpsView(target.get());
-        recordOpHandleInvalidation(target, payloadOps);
-      } else if (llvm::isa<TransformValueHandleTypeInterface>(
+        SmallVector<Operation *> payloadOps =
+            llvm::to_vector(getPayloadOps(target.get()));
+        recordOpHandleInvalidation(target, payloadOps, nullptr,
+                                   newlyInvalidated);
+      } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
                      target.get().getType())) {
         FULL_LDBG("----recordValueHandleInvalidation\n");
-        recordValueHandleInvalidation(target);
+        recordValueHandleInvalidation(target, newlyInvalidated);
       } else {
         FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
       }
@@ -687,6 +718,16 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
   return success();
 }
 
+LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
+    transform::TransformOpInterface transform) {
+  InvalidatedHandleMap newlyInvalidated;
+  LogicalResult checkResult =
+      checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
+  invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
+                            std::make_move_iterator(newlyInvalidated.end()));
+  return checkResult;
+}
+
 template <typename T>
 DiagnosedSilenceableFailure
 checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
index 4cbaad8..e35c179 100644 (file)
@@ -342,3 +342,25 @@ transform.sequence failures(propagate) {
   // expected-error @below {{uses a handle associated with empty payload and invalidated by a previously executed transform op}}
   transform.test_print_remark_at_operand %0, "remark" : !transform.any_op
 }
+
+// -----
+
+// Make sure we properly report a use-after-consume error when repeated handles
+// are allowed in the consuming op. We still want to report handles consumed by
+// _previous_ operations, just not by this one. To bypass the quick static check
+// of repeated consumption, create a handle to the transform operation and
+// invalidate the handle to the root module thus invalidating all other handles.
+
+// expected-note @below {{ancestor payload op}}
+module {
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-note @below {{handle to invalidated ops}}
+    // expected-note @below {{nested payload op}}
+    %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+    // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}}
+    transform.test_consume_operand %arg0 : !transform.any_op
+    // expected-error @below {{uses a handle invalidated by a previously executed transform op}}
+    transform.test_consume_operand %0 { allow_repeated_handles } : !transform.any_op
+  }
+}
index 5bf488e..f3b6c19 100644 (file)
@@ -178,6 +178,10 @@ void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
   transform::onlyReadsPayload(effects);
 }
 
+bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
+  return getAllowRepeatedHandles();
+}
+
 DiagnosedSilenceableFailure
 mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
                                       transform::TransformState &state) {
index b1129ea..c02e2d9 100644 (file)
@@ -97,11 +97,12 @@ def TestProduceValueHandleToArgumentOfParentBlock
 }
 
 def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
-     [DeclareOpInterfaceMethods<TransformOpInterface>,
+     [DeclareOpInterfaceMethods<TransformOpInterface, ["allowsRepeatedHandleOperands"]>,
       DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let arguments = (ins
     Transform_AnyHandleOrParamType:$operand,
-    Optional<TransformHandleTypeInterface>:$second_operand);
+    Optional<TransformHandleTypeInterface>:$second_operand,
+    UnitAttr:$allow_repeated_handles);
   let assemblyFormat = 
       "$operand (`,` $second_operand^)? attr-dict `:` type($operand)"
       "(`,` type($second_operand)^)?";