/// values in the payload IR. Also works for reverse mappings.
using ValueMapping = DenseMap<Value, SmallVector<Value>>;
+ /// Mapping between a Value in the transform IR and an error message that
+ /// should be emitted when the value is used.
+ using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
+
/// The bidirectional mappings between transform IR values and payload IR
/// operations, and the mapping between transform IR values and parameters.
struct Mappings {
/// handle.
LogicalResult replacePayloadValue(Value value, Value replacement);
- /// If the operand is a handle consumed by the operation, i.e. has the "free"
- /// memory effect associated with it, identifies other handles that are
- /// pointing to payload IR operations nested in the operations pointed to by
- /// the consumed handle. Marks all such handles as invalidated to trigger
- /// errors if they are used. If `throughValue` is passed, record the fact that
- /// an op handle was invalidated because a value handle associated with
- /// results of the payload op or its block arguments was invalidated.
+ /// Records handle invalidation reporters into `newlyInvalidated`.
+ /// Specifically,
+ /// - `handle` is the op operand that consumes the handle,
+ /// - `potentialAncestors` is a list of ancestors of the payload operation
+ /// that the consumed handle is associated with, including itself,
+ /// - `throughValue` is the payload value the handle to which is consumed,
+ /// when it is the case, null when the operation handle is consumed
+ /// directly.
+ /// Iterates over all known operation and value handles and records reporters
+ /// for any potential future use of `handle` or any other handle that is
+ /// invalidated by its consumption, i.e., any handle pointing to any payload
+ /// IR entity (operation or value) associated with the same payload IR entity
+ /// as the consumed handle, or any nested payload IR entity. If
+ /// `potentialAncestors` is empty, records the reporter anyway. Does not
+ /// override existing reporters. This must remain a const method so it doesn't
+ /// inadvertently mutate `invalidatedHandles` too early.
void recordOpHandleInvalidation(OpOperand &consumingHandle,
ArrayRef<Operation *> potentialAncestors,
- Value throughValue = nullptr);
- void recordOpHandleInvalidationOne(OpOperand &handle,
- ArrayRef<Operation *> potentialAncestors,
- Operation *payloadOp, Value otherHandle,
- Value throughValue = nullptr);
-
+ Value throughValue,
+ InvalidatedHandleMap &newlyInvalidated) const;
+
+ /// Records handle invalidation reporters into `newlyInvalidated`.
+ /// Specifically,
+ /// - `consumingHandle` is the op operand that consumes the handle,
+ /// - `potentialAncestors` is a list of ancestors of the payload operation
+ /// that the consumed handle is associated with, including itself,
+ /// - `payloadOp` is the operation itself,
+ /// - `otherHandle` is another that may be associated with the affected
+ /// payload operations
+ /// - `throughValue` is the payload value the handle to which is consumed,
+ /// when it is the case, null when the operation handle is consumed
+ /// directly.
+ /// Looks at the payload opreations associated with `otherHandle` and if any
+ /// of these operations has an ancestor (or is itself) listed in
+ /// `potentialAncestors`, records the error message describing the use of the
+ /// invalidated handle. Does nothing if `otherHandle` already has a reporter
+ /// associated with it. This must remain a const method so it doesn't
+ /// inadvertently mutate `invalidatedHandles` too early.
+ void recordOpHandleInvalidationOne(
+ OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
+ Operation *payloadOp, Value otherHandle, Value throughValue,
+ InvalidatedHandleMap &newlyInvalidated) const;
+
+ /// Records handle invalidation reporters into `newlyInvalidated`.
+ /// Specifically,
+ /// - `opHandle` is the op operand that consumes the handle;
+ /// - `potentialAncestors` is a list of ancestors of the payload operation
+ /// that the consumed handle is associated with, including itself;
+ /// - `payloadValue` is the value defined by the operation associated with
+ /// the consuming handle as either op result or block argument;
+ /// - `valueHandle` is another that may be associated with the payload value.
+ /// Looks at the payload values associated with `valueHandle` and if any of
+ /// these values is defined, as op result or block argument, by an operation
+ /// whose ancestor (or the operation itself) is listed in
+ /// `potentialAncestors`, records the error message describing the use of the
+ /// invalidated handle. Does nothing if `valueHandle` already has a reporter
+ /// associated with it. This must remain a const method so it doesn't
+ /// inadvertently mutate `invalidatedHandles` too early.
void recordValueHandleInvalidationByOpHandleOne(
OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
- Value payloadValue, Value valueHandle);
-
- void recordValueHandleInvalidation(OpOperand &valueHandle);
+ Value payloadValue, Value valueHandle,
+ InvalidatedHandleMap &newlyInvalidated) const;
+
+ /// Records handle invalidation reporters into `newlyInvalidated`.
+ /// Specifically,
+ /// - `valueHandle` is the op operand that consumes the handle,
+ /// - `throughValue` is the payload value the handle to which is consumed,
+ /// when it is the case, null when the operation handle is consumed
+ /// directly.
+ /// Iterates over all known operation and value handles and records reporters
+ /// for any potential future use of `handle` or any other handle that is
+ /// invalidated by its consumption, i.e., any handle pointing to any payload
+ /// IR entity (operation or value) associated with the same payload IR entity
+ /// as the consumed handle, or any nested payload IR entity. Does not override
+ /// existing reporters. This must remain a const method so it doesn't
+ /// inadvertently mutate `invalidatedHandles` too early.
+ void
+ recordValueHandleInvalidation(OpOperand &valueHandle,
+ InvalidatedHandleMap &newlyInvalidated) const;
/// Checks that the operation does not use invalidated handles as operands.
/// Reports errors and returns failure if it does. Otherwise, invalidates the
LogicalResult
checkAndRecordHandleInvalidation(TransformOpInterface transform);
+ /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
+ /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
+ /// early.
+ LogicalResult checkAndRecordHandleInvalidationImpl(
+ transform::TransformOpInterface transform,
+ transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
+
/// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
void compactOpHandles();
/// describe when the handles were invalidated. Calling such a function emits
/// a user-visible diagnostic with an additional note pointing to the given
/// location.
- DenseMap<Value, std::function<void(Location)>> invalidatedHandles;
+ InvalidatedHandleMap invalidatedHandles;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// A stack of nested regions that are being processed in the transform IR.
void transform::TransformState::recordOpHandleInvalidationOne(
OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
- Operation *payloadOp, Value otherHandle, Value throughValue) {
+ Operation *payloadOp, Value otherHandle, Value throughValue,
+ transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// If the op is associated with invalidated handle, skip the check as it
- // may be reading invalid IR.
- if (invalidatedHandles.count(otherHandle))
+ // may be reading invalid IR. This also ensures we report the first
+ // invalidation and not the last one.
+ if (invalidatedHandles.count(otherHandle) ||
+ newlyInvalidated.count(otherHandle))
return;
FULL_LDBG("--recordOpHandleInvalidationOne\n");
Location opLoc = payloadOp->getLoc();
std::optional<Location> throughValueLoc =
throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
- invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
- otherHandle,
- throughValueLoc](Location currentLoc) {
+ newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
+ otherHandle,
+ throughValueLoc](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle invalidated by a "
"previously executed transform op";
}
void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
- OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
- Value payloadValue, Value valueHandle) {
+ OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
+ Value payloadValue, Value valueHandle,
+ transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// If the op is associated with invalidated handle, skip the check as it
- // may be reading invalid IR.
- if (invalidatedHandles.count(valueHandle))
+ // may be reading invalid IR. This also ensures we report the first
+ // invalidation and not the last one.
+ if (invalidatedHandles.count(valueHandle) ||
+ newlyInvalidated.count(valueHandle))
return;
for (Operation *ancestor : potentialAncestors) {
if (!ancestor->isAncestor(definingOp))
continue;
- Operation *owner = consumingHandle.getOwner();
- unsigned operandNo = consumingHandle.getOperandNumber();
+ Operation *owner = opHandle.getOwner();
+ unsigned operandNo = opHandle.getOperandNumber();
Location ancestorLoc = ancestor->getLoc();
Location opLoc = definingOp->getLoc();
Location valueLoc = payloadValue.getLoc();
- invalidatedHandles[valueHandle] =
+ newlyInvalidated[valueHandle] =
[valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo,
ancestorLoc, opLoc, valueLoc](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
void transform::TransformState::recordOpHandleInvalidation(
OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
- Value throughValue) {
+ Value throughValue,
+ transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
if (potentialAncestors.empty()) {
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
Operation *owner = handle.getOwner();
unsigned operandNo = handle.getOperandNumber();
- invalidatedHandles[handle.get()] = [owner, operandNo](Location currentLoc) {
+ newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle associated with empty "
"payload and invalidated by a "
// number of IR objects (operations and values). Alternatively, we could walk
// the IR nested in each payload op associated with the given handle and look
// for handles associated with each operation and value.
- for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+ for (const transform::TransformState::Mappings &mapping :
+ llvm::make_second_range(mappings)) {
// Go over all op handle mappings and mark as invalidated any handle
// pointing to any of the payload ops associated with the given handle or
// any op nested in them.
for (const auto &[payloadOp, otherHandles] : mapping.reverse) {
for (Value otherHandle : otherHandles)
recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
- otherHandle, throughValue);
+ otherHandle, throughValue,
+ newlyInvalidated);
}
// Go over all value handle mappings and mark as invalidated any handle
// pointing to any result of the payload op associated with the given handle
for (const auto &[payloadValue, valueHandles] : mapping.reverseValues) {
for (Value valueHandle : valueHandles)
recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
- payloadValue, valueHandle);
+ payloadValue, valueHandle,
+ newlyInvalidated);
}
}
}
void transform::TransformState::recordValueHandleInvalidation(
- OpOperand &valueHandle) {
+ OpOperand &valueHandle,
+ transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
// Invalidate other handles to the same value.
for (Value payloadValue : getPayloadValues(valueHandle.get())) {
SmallVector<Value> otherValueHandles;
Operation *owner = valueHandle.getOwner();
unsigned operandNo = valueHandle.getOperandNumber();
Location valueLoc = payloadValue.getLoc();
- invalidatedHandles[otherHandle] = [otherHandle, owner, operandNo,
- valueLoc](Location currentLoc) {
+ newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
+ valueLoc](Location currentLoc) {
InFlightDiagnostic diag = emitError(currentLoc)
<< "op uses a handle invalidated by a "
"previously executed transform op";
if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
Operation *payloadOp = opResult.getOwner();
- recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue);
+ recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
+ newlyInvalidated);
} else {
auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
for (Operation &payloadOp : *arg.getOwner())
- recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue);
+ recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
+ newlyInvalidated);
}
}
}
-LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
- TransformOpInterface transform) {
+/// Checks that the operation does not use invalidated handles as operands.
+/// Reports errors and returns failure if it does. Otherwise, invalidates the
+/// handles consumed by the operation as well as any handles pointing to payload
+/// IR operations nested in the operations associated with the consumed handles.
+LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
+ transform::TransformOpInterface transform,
+ transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
auto memoryEffectsIface =
cast<MemoryEffectOpInterface>(transform.getOperation());
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
(DBGS() << "----iterate on handle: " << target.get() << "\n");
});
- // If the operand uses an invalidated handle, report it.
+ // If the operand uses an invalidated handle, report it. If the operation
+ // allows handles to point to repeated payload operations, only report
+ // pre-existing invalidation errors. Otherwise, also report invalidations
+ // caused by the current transform operation affecting its other operands.
auto it = invalidatedHandles.find(target.get());
- if (!transform.allowsRepeatedHandleOperands() &&
- it != invalidatedHandles.end()) {
- FULL_LDBG("--End checkAndRecordHandleInvalidation -> FAILURE\n");
+ auto nit = newlyInvalidated.find(target.get());
+ if (it != invalidatedHandles.end()) {
+ FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
+ "invalidated -> FAILURE\n");
return it->getSecond()(transform->getLoc()), failure();
}
+ if (!transform.allowsRepeatedHandleOperands() &&
+ nit != newlyInvalidated.end()) {
+ FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
+ "invalidated (by this op) -> FAILURE\n");
+ return nit->getSecond()(transform->getLoc()), failure();
+ }
// Invalidate handles pointing to the operations nested in the operation
// associated with the handle consumed by this operation.
effect.getValue() == target.get();
};
if (llvm::any_of(effects, consumesTarget)) {
- FULL_LDBG("----found consume effect -> SKIP\n");
- if (llvm::isa<TransformHandleTypeInterface>(target.get().getType())) {
+ FULL_LDBG("----found consume effect\n");
+ if (llvm::isa<transform::TransformHandleTypeInterface>(
+ target.get().getType())) {
FULL_LDBG("----recordOpHandleInvalidation\n");
- ArrayRef<Operation *> payloadOps = getPayloadOpsView(target.get());
- recordOpHandleInvalidation(target, payloadOps);
- } else if (llvm::isa<TransformValueHandleTypeInterface>(
+ SmallVector<Operation *> payloadOps =
+ llvm::to_vector(getPayloadOps(target.get()));
+ recordOpHandleInvalidation(target, payloadOps, nullptr,
+ newlyInvalidated);
+ } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
target.get().getType())) {
FULL_LDBG("----recordValueHandleInvalidation\n");
- recordValueHandleInvalidation(target);
+ recordValueHandleInvalidation(target, newlyInvalidated);
} else {
FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
}
return success();
}
+LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
+ transform::TransformOpInterface transform) {
+ InvalidatedHandleMap newlyInvalidated;
+ LogicalResult checkResult =
+ checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
+ invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
+ std::make_move_iterator(newlyInvalidated.end()));
+ return checkResult;
+}
+
template <typename T>
DiagnosedSilenceableFailure
checkRepeatedConsumptionInOperand(ArrayRef<T> payload,