[mlir] teach expensive-checks transform mode about empty handle
authorAlex Zinenko <zinenko@google.com>
Fri, 26 May 2023 13:07:13 +0000 (13:07 +0000)
committerAlex Zinenko <zinenko@google.com>
Fri, 26 May 2023 16:01:09 +0000 (16:01 +0000)
The transform dialect interpreter features the expensive-checks mode
that acts as an embedded sanitizer to track use-after-consume of
transform handles. Its logic is based on the relations between payload
operations, which made it silently ignore empty handles that are
consumed. Also catch and report this case because the remaining code may
hit an assertion on attempting to access a consumed handle (that is
removed from the mapping).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D151560

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Transform/expensive-checks.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

index 695b4a3..85535c7 100644 (file)
@@ -443,6 +443,9 @@ void transform::TransformState::recordOpHandleInvalidationOne(
       llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
                             [](Operation *op) { llvm::dbgs() << *op; });
       llvm::dbgs() << "\n");
+
+  Operation *owner = consumingHandle.getOwner();
+  unsigned operandNo = consumingHandle.getOperandNumber();
   for (Operation *ancestor : potentialAncestors) {
     // clang-format off
     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 
@@ -462,8 +465,6 @@ void transform::TransformState::recordOpHandleInvalidationOne(
     // deleted before the lambda gets called.
     Location ancestorLoc = ancestor->getLoc();
     Location opLoc = payloadOp->getLoc();
-    Operation *owner = consumingHandle.getOwner();
-    unsigned operandNo = consumingHandle.getOperandNumber();
     std::optional<Location> throughValueLoc =
         throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
     invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
@@ -551,6 +552,27 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
 void transform::TransformState::recordOpHandleInvalidation(
     OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
     Value throughValue) {
+
+  if (potentialAncestors.empty()) {
+    DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
+      (DBGS() << "----recording invalidation for empty handle: " << handle.get()
+              << "\n");
+    });
+
+    Operation *owner = handle.getOwner();
+    unsigned operandNo = handle.getOperandNumber();
+    invalidatedHandles[handle.get()] = [owner, operandNo](Location currentLoc) {
+      InFlightDiagnostic diag = emitError(currentLoc)
+                                << "op uses a handle associated with empty "
+                                   "payload and invalidated by a "
+                                   "previously executed transform op";
+      diag.attachNote(owner->getLoc())
+          << "invalidated by this transform op that consumes its operand #"
+          << operandNo;
+    };
+    return;
+  }
+
   // Iterate over the mapping and invalidate aliasing handles. This is quite
   // expensive and only necessary for error reporting in case of transform
   // dialect misuse with dangling handles. Iteration over the handles is based
index 3a7a3df..4cbaad8 100644 (file)
@@ -331,3 +331,14 @@ transform.sequence failures(propagate) {
   test_consume_operand %3 : !transform.any_value
   test_consume_operand %2 : !transform.any_op
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.test_produce_empty_payload : !transform.any_op
+  // expected-note @below {{invalidated by this transform op that consumes its operand #0}}
+  transform.test_consume_operand %0 : !transform.any_op
+  // expected-error @below {{uses a handle associated with empty payload and invalidated by a previously executed transform op}}
+  transform.test_print_remark_at_operand %0, "remark" : !transform.any_op
+}
index 2b23b88..5bf488e 100644 (file)
@@ -627,6 +627,12 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
   return DiagnosedSilenceableFailure::success();
 }
 
+DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  results.set(cast<OpResult>(getOut()), {});
+  return DiagnosedSilenceableFailure::success();
+}
+
 void mlir::test::TestProduceNullParamOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   transform::producesHandle(getOut(), effects);
index 3e81125..b1129ea 100644 (file)
@@ -427,6 +427,15 @@ def TestProduceNullPayloadOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestProduceEmptyPayloadOp
+  : Op<Transform_Dialect, "test_produce_empty_payload",
+      [DeclareOpInterfaceMethods<TransformOpInterface>,
+       MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> {
+  let results = (outs TransformHandleTypeInterface:$out);
+  let assemblyFormat = "attr-dict `:` type($out)";
+  let cppNamespace = "::mlir::test";
+}
+
 def TestProduceNullParamOp
   : Op<Transform_Dialect, "test_produce_null_param",
       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,