From 1d9a1139fd2c29189f2e2b9b149dfbd1a6b931bb Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 26 May 2023 15:50:59 +0000 Subject: [PATCH] [mlir] harden expensive-checks mode against ops with repeated operands Transform operations may indicate that they may accept and consume several handles pointing to the same or nested payload entities. The initial implementation of the expensive-checks mode was simply ignoring such cases as consuming the second handle would fail the check after the first handle invalidated it by consuming the same payload. Additional checks had been added since then, which could now trigger assertions in the expensive-checks module itself (instead of or in addition to use-after-free assertions down the road), specifically because the payload associations for invalidated handles is removed from the state to enable other kinds of checking. Rework the handling of transform operations with repeated handles so use-after-consume is still reported properly if the consumption happened by a preceding operation, as opposed to the a preceding operand of the same operation that is still (corretly) ignored if the op requests that. Depends on: D151560 Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D151569 --- .../Dialect/Transform/IR/TransformInterfaces.h | 104 +++++++++++++++---- .../Dialect/Transform/IR/TransformInterfaces.cpp | 111 ++++++++++++++------- mlir/test/Dialect/Transform/expensive-checks.mlir | 22 ++++ .../Transform/TestTransformDialectExtension.cpp | 4 + .../Transform/TestTransformDialectExtension.td | 5 +- 5 files changed, 192 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 4c07791..fc1ffeb 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -153,6 +153,10 @@ private: /// values in the payload IR. Also works for reverse mappings. using ValueMapping = DenseMap>; + /// Mapping between a Value in the transform IR and an error message that + /// should be emitted when the value is used. + using InvalidatedHandleMap = DenseMap>; + /// The bidirectional mappings between transform IR values and payload IR /// operations, and the mapping between transform IR values and parameters. struct Mappings { @@ -567,26 +571,85 @@ private: /// handle. LogicalResult replacePayloadValue(Value value, Value replacement); - /// If the operand is a handle consumed by the operation, i.e. has the "free" - /// memory effect associated with it, identifies other handles that are - /// pointing to payload IR operations nested in the operations pointed to by - /// the consumed handle. Marks all such handles as invalidated to trigger - /// errors if they are used. If `throughValue` is passed, record the fact that - /// an op handle was invalidated because a value handle associated with - /// results of the payload op or its block arguments was invalidated. + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `handle` is the op operand that consumes the handle, + /// - `potentialAncestors` is a list of ancestors of the payload operation + /// that the consumed handle is associated with, including itself, + /// - `throughValue` is the payload value the handle to which is consumed, + /// when it is the case, null when the operation handle is consumed + /// directly. + /// Iterates over all known operation and value handles and records reporters + /// for any potential future use of `handle` or any other handle that is + /// invalidated by its consumption, i.e., any handle pointing to any payload + /// IR entity (operation or value) associated with the same payload IR entity + /// as the consumed handle, or any nested payload IR entity. If + /// `potentialAncestors` is empty, records the reporter anyway. Does not + /// override existing reporters. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. void recordOpHandleInvalidation(OpOperand &consumingHandle, ArrayRef potentialAncestors, - Value throughValue = nullptr); - void recordOpHandleInvalidationOne(OpOperand &handle, - ArrayRef potentialAncestors, - Operation *payloadOp, Value otherHandle, - Value throughValue = nullptr); - + Value throughValue, + InvalidatedHandleMap &newlyInvalidated) const; + + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `consumingHandle` is the op operand that consumes the handle, + /// - `potentialAncestors` is a list of ancestors of the payload operation + /// that the consumed handle is associated with, including itself, + /// - `payloadOp` is the operation itself, + /// - `otherHandle` is another that may be associated with the affected + /// payload operations + /// - `throughValue` is the payload value the handle to which is consumed, + /// when it is the case, null when the operation handle is consumed + /// directly. + /// Looks at the payload opreations associated with `otherHandle` and if any + /// of these operations has an ancestor (or is itself) listed in + /// `potentialAncestors`, records the error message describing the use of the + /// invalidated handle. Does nothing if `otherHandle` already has a reporter + /// associated with it. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. + void recordOpHandleInvalidationOne( + OpOperand &consumingHandle, ArrayRef potentialAncestors, + Operation *payloadOp, Value otherHandle, Value throughValue, + InvalidatedHandleMap &newlyInvalidated) const; + + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `opHandle` is the op operand that consumes the handle; + /// - `potentialAncestors` is a list of ancestors of the payload operation + /// that the consumed handle is associated with, including itself; + /// - `payloadValue` is the value defined by the operation associated with + /// the consuming handle as either op result or block argument; + /// - `valueHandle` is another that may be associated with the payload value. + /// Looks at the payload values associated with `valueHandle` and if any of + /// these values is defined, as op result or block argument, by an operation + /// whose ancestor (or the operation itself) is listed in + /// `potentialAncestors`, records the error message describing the use of the + /// invalidated handle. Does nothing if `valueHandle` already has a reporter + /// associated with it. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. void recordValueHandleInvalidationByOpHandleOne( OpOperand &opHandle, ArrayRef potentialAncestors, - Value payloadValue, Value valueHandle); - - void recordValueHandleInvalidation(OpOperand &valueHandle); + Value payloadValue, Value valueHandle, + InvalidatedHandleMap &newlyInvalidated) const; + + /// Records handle invalidation reporters into `newlyInvalidated`. + /// Specifically, + /// - `valueHandle` is the op operand that consumes the handle, + /// - `throughValue` is the payload value the handle to which is consumed, + /// when it is the case, null when the operation handle is consumed + /// directly. + /// Iterates over all known operation and value handles and records reporters + /// for any potential future use of `handle` or any other handle that is + /// invalidated by its consumption, i.e., any handle pointing to any payload + /// IR entity (operation or value) associated with the same payload IR entity + /// as the consumed handle, or any nested payload IR entity. Does not override + /// existing reporters. This must remain a const method so it doesn't + /// inadvertently mutate `invalidatedHandles` too early. + void + recordValueHandleInvalidation(OpOperand &valueHandle, + InvalidatedHandleMap &newlyInvalidated) const; /// Checks that the operation does not use invalidated handles as operands. /// Reports errors and returns failure if it does. Otherwise, invalidates the @@ -596,6 +659,13 @@ private: LogicalResult checkAndRecordHandleInvalidation(TransformOpInterface transform); + /// Implementation of the checkAndRecordHandleInvalidation. This must remain a + /// const method so it doesn't inadvertently mutate `invalidatedHandles` too + /// early. + LogicalResult checkAndRecordHandleInvalidationImpl( + transform::TransformOpInterface transform, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const; + /// Remove all nullptrs from op handles that were added by `replacePayloadOp`. void compactOpHandles(); @@ -628,7 +698,7 @@ private: /// describe when the handles were invalidated. Calling such a function emits /// a user-visible diagnostic with an additional note pointing to the given /// location. - DenseMap> invalidatedHandles; + InvalidatedHandleMap invalidatedHandles; #if LLVM_ENABLE_ABI_BREAKING_CHECKS /// A stack of nested regions that are being processed in the transform IR. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 85535c7..b1dc668 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -431,10 +431,13 @@ transform::TransformState::replacePayloadValue(Value value, Value replacement) { void transform::TransformState::recordOpHandleInvalidationOne( OpOperand &consumingHandle, ArrayRef potentialAncestors, - Operation *payloadOp, Value otherHandle, Value throughValue) { + Operation *payloadOp, Value otherHandle, Value throughValue, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // If the op is associated with invalidated handle, skip the check as it - // may be reading invalid IR. - if (invalidatedHandles.count(otherHandle)) + // may be reading invalid IR. This also ensures we report the first + // invalidation and not the last one. + if (invalidatedHandles.count(otherHandle) || + newlyInvalidated.count(otherHandle)) return; FULL_LDBG("--recordOpHandleInvalidationOne\n"); @@ -467,9 +470,9 @@ void transform::TransformState::recordOpHandleInvalidationOne( Location opLoc = payloadOp->getLoc(); std::optional throughValueLoc = throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; - invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, - otherHandle, - throughValueLoc](Location currentLoc) { + newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, + otherHandle, + throughValueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; @@ -490,11 +493,14 @@ void transform::TransformState::recordOpHandleInvalidationOne( } void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( - OpOperand &consumingHandle, ArrayRef potentialAncestors, - Value payloadValue, Value valueHandle) { + OpOperand &opHandle, ArrayRef potentialAncestors, + Value payloadValue, Value valueHandle, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // If the op is associated with invalidated handle, skip the check as it - // may be reading invalid IR. - if (invalidatedHandles.count(valueHandle)) + // may be reading invalid IR. This also ensures we report the first + // invalidation and not the last one. + if (invalidatedHandles.count(valueHandle) || + newlyInvalidated.count(valueHandle)) return; for (Operation *ancestor : potentialAncestors) { @@ -517,12 +523,12 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( if (!ancestor->isAncestor(definingOp)) continue; - Operation *owner = consumingHandle.getOwner(); - unsigned operandNo = consumingHandle.getOperandNumber(); + Operation *owner = opHandle.getOwner(); + unsigned operandNo = opHandle.getOperandNumber(); Location ancestorLoc = ancestor->getLoc(); Location opLoc = definingOp->getLoc(); Location valueLoc = payloadValue.getLoc(); - invalidatedHandles[valueHandle] = + newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo, ancestorLoc, opLoc, valueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) @@ -551,7 +557,8 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( void transform::TransformState::recordOpHandleInvalidation( OpOperand &handle, ArrayRef potentialAncestors, - Value throughValue) { + Value throughValue, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { if (potentialAncestors.empty()) { DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { @@ -561,7 +568,7 @@ void transform::TransformState::recordOpHandleInvalidation( Operation *owner = handle.getOwner(); unsigned operandNo = handle.getOperandNumber(); - invalidatedHandles[handle.get()] = [owner, operandNo](Location currentLoc) { + newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle associated with empty " "payload and invalidated by a " @@ -580,14 +587,16 @@ void transform::TransformState::recordOpHandleInvalidation( // number of IR objects (operations and values). Alternatively, we could walk // the IR nested in each payload op associated with the given handle and look // for handles associated with each operation and value. - for (const Mappings &mapping : llvm::make_second_range(mappings)) { + for (const transform::TransformState::Mappings &mapping : + llvm::make_second_range(mappings)) { // Go over all op handle mappings and mark as invalidated any handle // pointing to any of the payload ops associated with the given handle or // any op nested in them. for (const auto &[payloadOp, otherHandles] : mapping.reverse) { for (Value otherHandle : otherHandles) recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, - otherHandle, throughValue); + otherHandle, throughValue, + newlyInvalidated); } // Go over all value handle mappings and mark as invalidated any handle // pointing to any result of the payload op associated with the given handle @@ -597,13 +606,15 @@ void transform::TransformState::recordOpHandleInvalidation( for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) { for (Value valueHandle : valueHandles) recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, - payloadValue, valueHandle); + payloadValue, valueHandle, + newlyInvalidated); } } } void transform::TransformState::recordValueHandleInvalidation( - OpOperand &valueHandle) { + OpOperand &valueHandle, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { // Invalidate other handles to the same value. for (Value payloadValue : getPayloadValues(valueHandle.get())) { SmallVector otherValueHandles; @@ -612,8 +623,8 @@ void transform::TransformState::recordValueHandleInvalidation( Operation *owner = valueHandle.getOwner(); unsigned operandNo = valueHandle.getOperandNumber(); Location valueLoc = payloadValue.getLoc(); - invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo, - valueLoc](Location currentLoc) { + newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo, + valueLoc](Location currentLoc) { InFlightDiagnostic diag = emitError(currentLoc) << "op uses a handle invalidated by a " "previously executed transform op"; @@ -629,17 +640,24 @@ void transform::TransformState::recordValueHandleInvalidation( if (auto opResult = llvm::dyn_cast(payloadValue)) { Operation *payloadOp = opResult.getOwner(); - recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue); + recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue, + newlyInvalidated); } else { auto arg = llvm::dyn_cast(payloadValue); for (Operation &payloadOp : *arg.getOwner()) - recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue); + recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue, + newlyInvalidated); } } } -LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( - TransformOpInterface transform) { +/// Checks that the operation does not use invalidated handles as operands. +/// Reports errors and returns failure if it does. Otherwise, invalidates the +/// handles consumed by the operation as well as any handles pointing to payload +/// IR operations nested in the operations associated with the consumed handles. +LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( + transform::TransformOpInterface transform, + transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { FULL_LDBG("--Start checkAndRecordHandleInvalidation\n"); auto memoryEffectsIface = cast(transform.getOperation()); @@ -651,13 +669,23 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { (DBGS() << "----iterate on handle: " << target.get() << "\n"); }); - // If the operand uses an invalidated handle, report it. + // If the operand uses an invalidated handle, report it. If the operation + // allows handles to point to repeated payload operations, only report + // pre-existing invalidation errors. Otherwise, also report invalidations + // caused by the current transform operation affecting its other operands. auto it = invalidatedHandles.find(target.get()); - if (!transform.allowsRepeatedHandleOperands() && - it != invalidatedHandles.end()) { - FULL_LDBG("--End checkAndRecordHandleInvalidation -> FAILURE\n"); + auto nit = newlyInvalidated.find(target.get()); + if (it != invalidatedHandles.end()) { + FULL_LDBG("--End checkAndRecordHandleInvalidation, found already " + "invalidated -> FAILURE\n"); return it->getSecond()(transform->getLoc()), failure(); } + if (!transform.allowsRepeatedHandleOperands() && + nit != newlyInvalidated.end()) { + FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly " + "invalidated (by this op) -> FAILURE\n"); + return nit->getSecond()(transform->getLoc()), failure(); + } // Invalidate handles pointing to the operations nested in the operation // associated with the handle consumed by this operation. @@ -666,15 +694,18 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( effect.getValue() == target.get(); }; if (llvm::any_of(effects, consumesTarget)) { - FULL_LDBG("----found consume effect -> SKIP\n"); - if (llvm::isa(target.get().getType())) { + FULL_LDBG("----found consume effect\n"); + if (llvm::isa( + target.get().getType())) { FULL_LDBG("----recordOpHandleInvalidation\n"); - ArrayRef payloadOps = getPayloadOpsView(target.get()); - recordOpHandleInvalidation(target, payloadOps); - } else if (llvm::isa( + SmallVector payloadOps = + llvm::to_vector(getPayloadOps(target.get())); + recordOpHandleInvalidation(target, payloadOps, nullptr, + newlyInvalidated); + } else if (llvm::isa( target.get().getType())) { FULL_LDBG("----recordValueHandleInvalidation\n"); - recordValueHandleInvalidation(target); + recordValueHandleInvalidation(target, newlyInvalidated); } else { FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); } @@ -687,6 +718,16 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( return success(); } +LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( + transform::TransformOpInterface transform) { + InvalidatedHandleMap newlyInvalidated; + LogicalResult checkResult = + checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated); + invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()), + std::make_move_iterator(newlyInvalidated.end())); + return checkResult; +} + template DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef payload, diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir index 4cbaad8..e35c179 100644 --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -342,3 +342,25 @@ transform.sequence failures(propagate) { // expected-error @below {{uses a handle associated with empty payload and invalidated by a previously executed transform op}} transform.test_print_remark_at_operand %0, "remark" : !transform.any_op } + +// ----- + +// Make sure we properly report a use-after-consume error when repeated handles +// are allowed in the consuming op. We still want to report handles consumed by +// _previous_ operations, just not by this one. To bypass the quick static check +// of repeated consumption, create a handle to the transform operation and +// invalidate the handle to the root module thus invalidating all other handles. + +// expected-note @below {{ancestor payload op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-note @below {{handle to invalidated ops}} + // expected-note @below {{nested payload op}} + %0 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates all handles to payload IR entities associated with this operand and entities nested in them}} + transform.test_consume_operand %arg0 : !transform.any_op + // expected-error @below {{uses a handle invalidated by a previously executed transform op}} + transform.test_consume_operand %0 { allow_repeated_handles } : !transform.any_op + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 5bf488e..f3b6c19 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -178,6 +178,10 @@ void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( transform::onlyReadsPayload(effects); } +bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { + return getAllowRepeatedHandles(); +} + DiagnosedSilenceableFailure mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, transform::TransformState &state) { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index b1129ea..c02e2d9 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -97,11 +97,12 @@ def TestProduceValueHandleToArgumentOfParentBlock } def TestConsumeOperand : Op, + [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let arguments = (ins Transform_AnyHandleOrParamType:$operand, - Optional:$second_operand); + Optional:$second_operand, + UnitAttr:$allow_repeated_handles); let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict `:` type($operand)" "(`,` type($second_operand)^)?"; -- 2.7.4