From 3dfea727a4bb50bbc9fc1b5534bb067e34ee1fea Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 30 Sep 2022 14:11:34 +0000 Subject: [PATCH] [mlir] relax transform dialect multi-handle restriction Relax the restriction in the transform dialect interpreter utilities that expected a payload IR op to be assocaited with at most one transform IR handle value. This was useful during the initial bootstrapping to avoid use-after-free error equivalents when a payload IR op could be erased through one of the handles associated with it and then accessed through another. It was, however, possible to erase an ancestor of the payload IR operation in question. The expensive-checks mode of interpretation is able to detect both cases and has proven sufficiently robust in debugging use-after-free errors. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D134964 --- .../mlir/Dialect/Transform/IR/TransformDialect.td | 63 ++++---- .../Dialect/Transform/IR/TransformInterfaces.h | 99 +++++++------ .../Dialect/Transform/IR/TransformInterfaces.cpp | 163 +++++++++++---------- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 12 +- mlir/test/Dialect/Transform/expensive-checks.mlir | 39 +++++ mlir/test/Dialect/Transform/test-interpreter.mlir | 15 +- .../Transform/TestTransformDialectExtension.cpp | 13 +- .../Transform/TestTransformDialectExtension.td | 16 +- 8 files changed, 240 insertions(+), 180 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 54ff0f0..fbceffb 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -66,11 +66,11 @@ def Transform_Dialect : Dialect { A Transform IR value such as `%0` may be associated with multiple payload operations. This is conceptually a set of operations and no assumptions - should be made about the order of ops. Most Transform IR ops support - operand values that are mapped to multiple operations. They usually apply - the respective transformation for every mapped op ("batched execution"). - Deviations from this convention are described in the documentation of - Transform IR ops. + should be made about the order of ops unless specified otherwise by the + operation. Most Transform IR ops support operand values that are mapped to + multiple operations. They usually apply the respective transformation for + every mapped op ("batched execution"). Deviations from this convention are + described in the documentation of Transform IR ops. Overall, Transform IR ops are expected to be contained in a single top-level op. Such top-level ops specify how to apply the transformations described @@ -161,17 +161,18 @@ def Transform_Dialect : Dialect { ## Execution Model - The transformation starts at the specifed top-level transform IR operation - and applies to some payload IR scope, identified by the payload IR op that - contains the IR to transform. It is the responsibility of the user to - properly select the scope and/or to avoid the transformations to modify the - IR outside of the given scope. The top-level transform IR operation may - contain further transform operations and execute them in the desired order. + The transformation starts at the user-specified top-level transform IR + operation and applies to some user-specified payload IR scope, identified by + the payload IR op that contains the IR to transform. It is the + responsibility of the user to properly select the scope and/or to avoid the + transformations to modify the IR outside of the given scope. The top-level + transform IR operation may contain further transform operations and execute + them in the desired order. Transformation application functions produce a tri-state status: - success; - - recoverable (silencable) failure; + - recoverable (silenceable) failure; - irrecoverable failure. Transformation container operations may intercept recoverable failures and @@ -180,9 +181,9 @@ def Transform_Dialect : Dialect { failures, the diagnostics are emitted immediately whereas their emission is postponed for recoverable failures. Transformation container operations may also fail to recover from a theoretically recoverable failure, in which case - they are expected to emit the diagnostic and turn the failure into an - irrecoverable one. A recoverable failure produced by applying the top-level - transform IR operation is considered irrecoverable. + they can either propagate it to their parent or emit the diagnostic and turn + the failure into an irrecoverable one. A recoverable failure produced by + applying the top-level transform IR operation is considered irrecoverable. Transformation container operations are allowed to "step over" some nested operations if the application of some previous operation produced a failure. @@ -193,26 +194,18 @@ def Transform_Dialect : Dialect { ## Handle Invalidation - The execution model of the transform dialect expects that a payload IR - operation is associated with _at most one_ transform IR handle. This avoids - the situation when a handle to an operation outlives the operation itself - that can be erased during a transformation triggered through another handle. - - Handles pointing to operations nested in each other are allowed to co-exist - in the transform IR. However, a transform IR operation that consumes such a - handle automatically _invalidates_ all the other handles that are associated - with operations nested in the operations associated with the consumed - handle. Any use of the invalidated handle results in undefined behavior - since the payload IR operations associated with it are likely to have been - mutated or erased. The mere fact of the handle being invalidated does _not_ - trigger undefined behavior, only its appearance as an operand does. - Invalidation applies to the entire handle, even if some of the payload IR - operations associated with it are not nested in payload IR operations - associated with another, consumed handle. - - Note: the restriction on two handles not pointing to the same operation may - be relaxed in the future to follow the invalidation model for nested - operation. + The execution model of the transform dialect allows a payload IR operation + to be associated with _multiple_ handles as well as nested payload IR + operations to be associated with different handles. A transform IR operation + that consumes a handle automatically _invalidates_ all the other handles + associated with the same payload IR operations, or with any of their + descendants, as the consumed handle. Note that the _entire_ handle is + invalidated, even if some of the payload IR operations associated with it + or their ancestors were not associated with the consumed handle. Any use of + the invalidated handle results in undefined behavior since the payload IR + operations associated with it are likely to have been mutated or erased. The + mere fact of the handle being invalidated does _not_ trigger undefined + behavior, only its appearance as an operand does. The Transform dialect infrastructure has the capability of checking whether the transform IR op operand is invalidated before applying the diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 76fe4fa..abd10fe 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -209,11 +209,13 @@ private: /// TransformOpInterface. The operations implementing this interface and the /// surrounding structure are referred to as transform IR. The operations to /// which transformations apply are referred to as payload IR. The state thus -/// contains the mapping between values defined in the transform IR ops and -/// payload IR ops. It assumes that each value in the transform IR can be used -/// at most once (since transformations are likely to change the payload IR ops -/// the value corresponds to). Checks that transform IR values correspond to -/// disjoint sets of payload IR ops throughout the transformation. +/// contains the many-to-many mapping between values defined in the transform IR +/// ops and payload IR ops. The "expensive-checks" option can be passed to +/// the constructor at transformation execution time that transform IR values +/// used as operands by a transform IR operation are not associated with +/// dangling pointers to payload IR operations that are known to have been +/// erased by previous transformation through the same or a different transform +/// IR value. /// /// A reference to this class is passed as an argument to "apply" methods of the /// transform op interface. Thus the "apply" method can call @@ -235,9 +237,10 @@ class TransformState { /// operations in the payload IR. using TransformOpMapping = DenseMap>; - /// Mapping between a payload IR operation and the transform IR value it is - /// currently associated with. - using TransformOpReverseMapping = DenseMap; + /// Mapping between a payload IR operation and the transform IR values it is + /// associated with. + using TransformOpReverseMapping = + DenseMap>; /// Bidirectional mappings between transform IR values and payload IR /// operations. @@ -249,7 +252,7 @@ class TransformState { public: /// Creates a state for transform ops living in the given region. The parent /// operation of the region. The second argument points to the root operation - /// in the payload IR beind transformed, which may or may not contain the + /// in the payload IR being transformed, which may or may not contain the /// region with transform ops. Additional options can be provided through the /// trailing configuration object. TransformState(Region ®ion, Operation *root, @@ -263,9 +266,10 @@ public: /// This is helpful for transformations that apply to a particular handle. ArrayRef getPayloadOps(Value value) const; - /// Returns the Transform IR handle for the given Payload IR op if it exists - /// in the state, null otherwise. - Value getHandleForPayloadOp(Operation *op) const; + /// Populates `handles` with all handles pointing to the given Payload IR op. + /// Returns success if such handles exist, failure otherwise. + LogicalResult getHandlesForPayloadOp(Operation *op, + SmallVectorImpl &handles) const; /// Applies the transformation specified by the given transform op and updates /// the state accordingly. @@ -275,13 +279,13 @@ public: /// list of operations in the payload IR. The arguments must be defined in /// blocks of the currently processed transform IR region, typically after a /// region scope is defined. - LogicalResult mapBlockArguments(BlockArgument argument, - ArrayRef operations) { + void mapBlockArguments(BlockArgument argument, + ArrayRef operations) { #if LLVM_ENABLE_ABI_BREAKING_CHECKS assert(argument.getParentRegion() == regionStack.back() && "mapping block arguments from a region other than the active one"); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - return setPayloadOps(argument, operations); + setPayloadOps(argument, operations); } // Forward declarations to support limited visibility. @@ -379,7 +383,8 @@ public: const TransformState &getTransformState() const { return state; } /// Replaces the given payload op with another op. If the replacement op is - /// null, removes the association of the payload op with its handle. + /// null, removes the association of the payload op with its handle. Returns + /// failure if the op is not associated with any handle. LogicalResult replacePayloadOp(Operation *op, Operation *replacement); private: @@ -451,20 +456,29 @@ private: return it->second; } - /// Sets the payload IR ops associated with the given transform IR value. - /// Fails if this would result in multiple transform IR values with uses - /// corresponding to the same payload IR ops. For example, a hypothetical - /// "find function by name" transform op would (indirectly) call this - /// function for its result. Having two such calls in a row with for different - /// values, e.g. coming from different ops: + /// Removes the mapping between the given payload IR operation and the given + /// transform IR value. + void dropReverseMapping(Mappings &mappings, Operation *op, Value value); + + /// Sets the payload IR ops associated with the given transform IR value + /// (handle). A payload op may be associated multiple handles as long as + /// at most one of them gets consumed by further transformations. + /// For example, a hypothetical "find function by name" may be called twice in + /// a row to produce two handles pointing to the same function: /// /// %0 = transform.find_func_by_name { name = "myfunc" } /// %1 = transform.find_func_by_name { name = "myfunc" } /// - /// would lead to both values pointing to the same operation. The second call - /// to setPayloadOps will fail, unless the association with the %0 value is - /// removed first by calling update/removePayloadOps. - LogicalResult setPayloadOps(Value value, ArrayRef targets); + /// which is valid by itself. However, calling a hypothetical "rewrite and + /// rename function" transform on both handles: + /// + /// transform.rewrite_and_rename %0 { new_name = "func" } + /// transform.rewrite_and_rename %1 { new_name = "func" } + /// + /// is invalid given the transformation "consumes" the handle as expressed + /// by side effects. Practically, a transformation consuming a handle means + /// that the associated payload operation may no longer exist. + void setPayloadOps(Value value, ArrayRef targets); /// Forgets the payload IR ops associated with the given transform IR value. void removePayloadOps(Value value); @@ -473,24 +487,18 @@ private: /// The callback function is called once per associated operation and is /// expected to return the modified operation or nullptr. In the latter case, /// the corresponding operation is no longer associated with the transform IR - /// value. May fail if the operation produced by the update callback is - /// already associated with a different Transform IR handle value. - LogicalResult - updatePayloadOps(Value value, - function_ref callback); - - /// Attempts to record the mapping between the given Payload IR operation and - /// the given Transform IR handle. Fails and reports an error if the operation - /// is already tracked by another handle. - static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op, - Value handle); + /// value. + void updatePayloadOps(Value value, + function_ref callback); /// If the operand is a handle consumed by the operation, i.e. has the "free" /// memory effect associated with it, identifies other handles that are /// pointing to payload IR operations nested in the operations pointed to by - /// the consumed handle. Marks all such handles as invalidated so trigger + /// the consumed handle. Marks all such handles as invalidated to trigger /// errors if they are used. void recordHandleInvalidation(OpOperand &handle); + void recordHandleInvalidationOne(OpOperand &handle, Operation *payloadOp, + Value otherHandle); /// Checks that the operation does not use invalidated handles as operands. /// Reports errors and returns failure if it does. Otherwise, invalidates the @@ -566,9 +574,9 @@ namespace detail { /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait /// to either the list of operations associated with its operand or the root of /// the payload IR, depending on what is available in the context. -LogicalResult -mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, - Operation *op, Region ®ion); +void mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op, + Region ®ion); /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); @@ -605,18 +613,17 @@ public: /// Sets up the mapping between the entry block of the given region of this op /// and the relevant list of Payload IR operations in the given state. The /// state is expected to be already scoped at the region of this operation. - /// Returns failure if the mapping failed, e.g., the value is already mapped. - LogicalResult mapBlockArguments(TransformState &state, Region ®ion) { + void mapBlockArguments(TransformState &state, Region ®ion) { assert(region.getParentOp() == this->getOperation() && "op comes from the wrong region"); - return detail::mapPossibleTopLevelTransformOpBlockArguments( + detail::mapPossibleTopLevelTransformOpBlockArguments( state, this->getOperation(), region); } - LogicalResult mapBlockArguments(TransformState &state) { + void mapBlockArguments(TransformState &state) { assert( this->getOperation()->getNumRegions() == 1 && "must indicate the region to map if the operation has more than one"); - return mapBlockArguments(state, this->getOperation()->getRegion(0)); + mapBlockArguments(state, this->getOperation()->getRegion(0)); } }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 258420e..1d841f9 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "transform-dialect" @@ -45,35 +46,28 @@ transform::TransformState::getPayloadOps(Value value) const { return iter->getSecond(); } -Value transform::TransformState::getHandleForPayloadOp(Operation *op) const { +LogicalResult transform::TransformState::getHandlesForPayloadOp( + Operation *op, SmallVectorImpl &handles) const { + bool found = false; for (const Mappings &mapping : llvm::make_second_range(mappings)) { - if (Value handle = mapping.reverse.lookup(op)) - return handle; + auto iterator = mapping.reverse.find(op); + if (iterator != mapping.reverse.end()) { + llvm::append_range(handles, iterator->getSecond()); + found = true; + } } - return Value(); -} -LogicalResult transform::TransformState::tryEmplaceReverseMapping( - Mappings &map, Operation *operation, Value handle) { - auto insertionResult = map.reverse.insert({operation, handle}); - if (!insertionResult.second && insertionResult.first->second != handle) { - InFlightDiagnostic diag = operation->emitError() - << "operation tracked by two handles"; - diag.attachNote(handle.getLoc()) << "handle"; - diag.attachNote(insertionResult.first->second.getLoc()) << "handle"; - return diag; - } - return success(); + return success(found); } -LogicalResult -transform::TransformState::setPayloadOps(Value value, - ArrayRef targets) { +void transform::TransformState::setPayloadOps(Value value, + ArrayRef targets) { assert(value != kTopLevelValue && "attempting to reset the transformation root"); + // TODO: this may go now if (value.use_empty()) - return success(); + return; // Setting new payload for the value without cleaning it first is a misuse of // the API, assert here. @@ -84,25 +78,29 @@ transform::TransformState::setPayloadOps(Value value, assert(inserted && "value is already associated with another list"); (void)inserted; - // Having multiple handles to the same operation is an error in the transform - // expressed using the dialect and may be constructed by valid API calls from - // valid IR. Emit an error here. - for (Operation *op : targets) { - if (failed(tryEmplaceReverseMapping(mappings, op, value))) - return failure(); - } + for (Operation *op : targets) + mappings.reverse[op].push_back(value); +} - return success(); +void transform::TransformState::dropReverseMapping(Mappings &mappings, + Operation *op, Value value) { + auto it = mappings.reverse.find(op); + if (it != mappings.reverse.end()) + return; + + llvm::erase_value(it->getSecond(), value); + if (it->getSecond().empty()) + mappings.reverse.erase(it); } void transform::TransformState::removePayloadOps(Value value) { Mappings &mappings = getMapping(value); for (Operation *op : mappings.direct[value]) - mappings.reverse.erase(op); + dropReverseMapping(mappings, op, value); mappings.direct.erase(value); } -LogicalResult transform::TransformState::updatePayloadOps( +void transform::TransformState::updatePayloadOps( Value value, function_ref callback) { Mappings &mappings = getMapping(value); auto it = mappings.direct.find(value); @@ -112,60 +110,60 @@ LogicalResult transform::TransformState::updatePayloadOps( updated.reserve(association.size()); for (Operation *op : association) { - mappings.reverse.erase(op); + dropReverseMapping(mappings, op, value); if (Operation *updatedOp = callback(op)) { updated.push_back(updatedOp); - if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value))) - return failure(); + mappings.reverse[updatedOp].push_back(value); } } std::swap(association, updated); - return success(); } -void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { +void transform::TransformState::recordHandleInvalidationOne( + OpOperand &handle, Operation *payloadOp, Value otherHandle) { ArrayRef potentialAncestors = getPayloadOps(handle.get()); - for (const Mappings &mapping : llvm::make_second_range(mappings)) { - for (const auto &kvp : mapping.reverse) { - // If the op is associated with invalidated handle, skip the check as it - // may be reading invalid IR. - Operation *op = kvp.first; - Value otherHandle = kvp.second; - if (invalidatedHandles.count(otherHandle)) - continue; - - for (Operation *ancestor : potentialAncestors) { - if (!ancestor->isProperAncestor(op)) - continue; - - // Make sure the error-reporting lambda doesn't capture anything - // by-reference because it will go out of scope. Additionally, extract - // location from Payload IR ops because the ops themselves may be - // deleted before the lambda gets called. - Location ancestorLoc = ancestor->getLoc(); - Location opLoc = op->getLoc(); - Operation *owner = handle.getOwner(); - unsigned operandNo = handle.getOperandNumber(); - invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, - otherHandle](Location currentLoc) { - InFlightDiagnostic diag = emitError(currentLoc) - << "op uses a handle invalidated by a " - "previously executed transform op"; - diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; - diag.attachNote(owner->getLoc()) - << "invalidated by this transform op that consumes its operand #" - << operandNo - << " and invalidates handles to payload ops nested in payload " - "ops associated with the consumed handle"; - diag.attachNote(ancestorLoc) << "ancestor payload op"; - diag.attachNote(opLoc) << "nested payload op"; - }; - } - } + // If the op is associated with invalidated handle, skip the check as it + // may be reading invalid IR. + if (invalidatedHandles.count(otherHandle)) + return; + + for (Operation *ancestor : potentialAncestors) { + if (!ancestor->isAncestor(payloadOp)) + continue; + + // Make sure the error-reporting lambda doesn't capture anything + // by-reference because it will go out of scope. Additionally, extract + // location from Payload IR ops because the ops themselves may be + // deleted before the lambda gets called. + Location ancestorLoc = ancestor->getLoc(); + Location opLoc = payloadOp->getLoc(); + Operation *owner = handle.getOwner(); + unsigned operandNo = handle.getOperandNumber(); + invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, + otherHandle](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates handles to payload ops nested in payload " + "ops associated with the consumed handle"; + diag.attachNote(ancestorLoc) << "ancestor payload op"; + diag.attachNote(opLoc) << "nested payload op"; + }; } } +void transform::TransformState::recordHandleInvalidation(OpOperand &handle) { + for (const Mappings &mapping : llvm::make_second_range(mappings)) + for (const auto &[payloadOp, otherHandles] : mapping.reverse) + for (Value otherHandle : otherHandles) + recordHandleInvalidationOne(handle, payloadOp, otherHandle); +} + LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( TransformOpInterface transform) { auto memoryEffectsIface = @@ -252,8 +250,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { assert(result.getDefiningOp() == transform.getOperation() && "payload IR association for a value other than the result of the " "current transform op"); - if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) - return DiagnosedSilenceableFailure::definiteFailure(); + setPayloadOps(result, results.get(result.getResultNumber())); } printOnFailureRAII.release(); @@ -273,10 +270,16 @@ transform::TransformState::Extension::~Extension() = default; LogicalResult transform::TransformState::Extension::replacePayloadOp(Operation *op, Operation *replacement) { - return state.updatePayloadOps(state.getHandleForPayloadOp(op), - [&](Operation *current) { - return current == op ? replacement : current; - }); + SmallVector handles; + if (failed(state.getHandlesForPayloadOp(op, handles))) + return failure(); + + for (Value handle : handles) { + state.updatePayloadOps(handle, [&](Operation *current) { + return current == op ? replacement : current; + }); + } + return success(); } //===----------------------------------------------------------------------===// @@ -311,7 +314,7 @@ transform::TransformResults::get(unsigned resultNumber) const { // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// -LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( +void transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; if (op->getNumOperands() != 0) @@ -319,7 +322,7 @@ LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( else targets.push_back(state.getTopLevel()); - return state.mapBlockArguments(region.front().getArgument(0), targets); + state.mapBlockArguments(region.front().getArgument(0), targets); } LogicalResult diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index b9f9d44..116b6d9 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -211,8 +211,7 @@ transform::AlternativesOp::apply(transform::TransformResults &results, for (Operation *clone : clones) clone->erase(); }); - if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) - return DiagnosedSilenceableFailure::definiteFailure(); + state.mapBlockArguments(reg.front().getArgument(0), clones); bool failed = false; for (Operation &transform : reg.front().without_terminator()) { @@ -285,8 +284,7 @@ transform::ForeachOp::apply(transform::TransformResults &results, for (Operation *op : payloadOps) { auto scope = state.make_region_scope(getBody()); - if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) - return DiagnosedSilenceableFailure::definiteFailure(); + state.mapBlockArguments(getIterationVariable(), {op}); // Execute loop body. for (Operation &transform : getBody().front().without_terminator()) { @@ -512,8 +510,7 @@ transform::SequenceOp::apply(transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); + mapBlockArguments(state); // Apply the sequenced ops one by one. for (Operation &transform : getBodyBlock()->without_terminator()) { @@ -707,8 +704,7 @@ transform::WithPDLPatternsOp::apply(transform::TransformResults &results, [&]() { state.removeExtension(); }); auto scope = state.make_region_scope(getBody()); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); + mapBlockArguments(state); return state.applyTransform(transformOp); } diff --git a/mlir/test/Dialect/Transform/expensive-checks.mlir b/mlir/test/Dialect/Transform/expensive-checks.mlir index ab151979..f4dca69 100644 --- a/mlir/test/Dialect/Transform/expensive-checks.mlir +++ b/mlir/test/Dialect/Transform/expensive-checks.mlir @@ -60,3 +60,42 @@ transform.with_pdl_patterns { test_print_remark_at_operand %0, "remark" } } + + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{nested payload op}} +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !pdl.operation): + %1 = transform.test_copy_payload %0 + // expected-note @below {{handle to invalidated ops}} + %2 = transform.test_copy_payload %0 + // expected-note @below {{invalidated by this transform op that consumes its operand #0}} + transform.test_consume_operand %1 + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + transform.test_consume_operand %2 + } +} + +// ----- + +// expected-note @below {{ancestor payload op}} +// expected-note @below {{nested payload op}} +module { + + transform.sequence failures(propagate) { + ^bb0(%0: !pdl.operation): + %1 = transform.test_copy_payload %0 + // expected-note @below {{handle to invalidated ops}} + %2 = transform.test_copy_payload %0 + // Consuming two handles in the same operation is invalid if they point + // to overlapping sets of payload IR ops. + // + // expected-error @below {{op uses a handle invalidated by a previously executed transform op}} + // expected-note @below {{invalidated by this transform op that consumes its operand #0 and invalidates handles}} + transform.test_consume_operand %1, %2 + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index e458864..7c0deb1 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -17,14 +17,13 @@ transform.test_consume_operand_if_matches_param_or_fail %0[21] // ----- -// expected-error @below {{operation tracked by two handles}} -%0 = transform.test_produce_param_or_forward_operand 42 -// expected-note @below {{handle}} -%1 = transform.test_produce_param_or_forward_operand from %0 -// expected-note @below {{handle}} -%2 = transform.test_produce_param_or_forward_operand from %0 -transform.test_consume_operand_if_matches_param_or_fail %1[42] -transform.test_consume_operand_if_matches_param_or_fail %2[42] +// It is okay to have multiple handles to the same payload op as long +// as only one of them is consumed. The expensive checks mode is necessary +// to detect double-consumption. +%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" } +%1 = transform.test_copy_payload %0 +// expected-remark @below {{succeeded}} +transform.test_consume_operand_if_matches_param_or_fail %0[42] // ----- diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 382b342..752aaff 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -171,9 +171,13 @@ mlir::test::TestCheckIfTestExtensionPresentOp::apply( << "extension present, " << extension->getMessage(); for (Operation *payload : state.getPayloadOps(getOperand())) { diag.attachNote(payload->getLoc()) << "associated payload op"; - assert(state.getHandleForPayloadOp(payload) == getOperand() && +#ifndef NDEBUG + SmallVector handles; + assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); + assert(llvm::is_contained(handles, getOperand()) && "inconsistent mapping between transform IR handles and payload IR " "operations"); +#endif // NDEBUG } return DiagnosedSilenceableFailure::success(); @@ -297,6 +301,13 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects( transform::onlyReadsHandle(getHandle(), effects); } +DiagnosedSilenceableFailure +mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.set(getCopy().cast(), state.getPayloadOps(getHandle())); + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index 4cc6b41..fda71bc 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -38,8 +38,10 @@ def TestConsumeOperand : Op]> { let arguments = (ins Arg:$operand); - let assemblyFormat = "$operand attr-dict"; + [TransformMappingRead, TransformMappingFree]>:$operand, + Arg, "", + [TransformMappingRead, TransformMappingFree]>:$second_operand); + let assemblyFormat = "$operand (`,` $second_operand^)? attr-dict"; let cppNamespace = "::mlir::test"; } @@ -231,4 +233,14 @@ def TestPrintNumberOfAssociatedPayloadIROps let cppNamespace = "::mlir::test"; } +def TestCopyPayloadOp + : Op]> { + let arguments = (ins Arg:$handle); + let results = (outs Res:$copy); + let cppNamespace = "::mlir::test"; + let assemblyFormat = "$handle attr-dict"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD -- 2.7.4