From 3c196f1658f3c5fd368fdaa3c2fb165ed6d7fefa Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 26 Jun 2023 17:49:31 +0200 Subject: [PATCH] [mlir][transform] Remove redundant handle check in `replacePayload...` Differential Revision: https://reviews.llvm.org/D153766 --- .../Dialect/Transform/IR/TransformInterfaces.h | 5 +- .../Dialect/Transform/IR/TransformInterfaces.cpp | 96 ++++++++++------------ 2 files changed, 47 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 21ddd71..20f9b21 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -555,7 +555,8 @@ private: ArrayRef payloadOperations); /// Replaces the given payload op with another op. If the replacement op is - /// null, removes the association of the payload op with its handle. + /// null, removes the association of the payload op with its handle. Returns + /// failure if the op is not associated with any handle. /// /// Note: This function does not update value handles. None of the original /// op's results are allowed to be mapped to any value handle. @@ -563,7 +564,7 @@ private: /// Replaces the given payload value with another value. If the replacement /// value is null, removes the association of the payload value with its - /// handle. + /// handle. Returns failure if the value is not associated with any handle. LogicalResult replacePayloadValue(Value value, Value replacement); /// Records handle invalidation reporters into `newlyInvalidated`. diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index ce6ec5a..5bd9516 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -338,14 +338,7 @@ void transform::TransformState::forgetValueMapping( LogicalResult transform::TransformState::replacePayloadOp(Operation *op, Operation *replacement) { - // Drop the mapping between the op and all handles that point to it. Don't - // care if there are on such handles. - SmallVector opHandles; - (void)getHandlesForPayloadOp(op, opHandles); - for (Value handle : opHandles) { - Mappings &mappings = getMapping(handle); - dropMappingEntry(mappings.reverse, op, handle); - } + // TODO: consider invalidating the handles to nested objects here. #ifndef NDEBUG for (Value opResult : op->getResults()) { @@ -355,23 +348,29 @@ transform::TransformState::replacePayloadOp(Operation *op, } #endif // NDEBUG + // Drop the mapping between the op and all handles that point to it. Fail if + // there are no handles. + SmallVector opHandles; + if (failed(getHandlesForPayloadOp(op, opHandles))) + return failure(); + for (Value handle : opHandles) { + Mappings &mappings = getMapping(handle); + dropMappingEntry(mappings.reverse, op, handle); + } + #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { auto it = cachedNames.find(op); assert(it != cachedNames.end() && "entry not found"); assert(it->second == op->getName() && "operation name mismatch"); cachedNames.erase(it); - } -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - - // TODO: consider invalidating the handles to nested objects here. - -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (replacement && options.getExpensiveChecksEnabled()) { - auto insertion = cachedNames.insert({replacement, replacement->getName()}); - if (!insertion.second) { - assert(insertion.first->second == replacement->getName() && - "operation is already cached with a different name"); + if (replacement) { + auto insertion = + cachedNames.insert({replacement, replacement->getName()}); + if (!insertion.second) { + assert(insertion.first->second == replacement->getName() && + "operation is already cached with a different name"); + } } } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -411,7 +410,8 @@ transform::TransformState::replacePayloadOp(Operation *op, LogicalResult transform::TransformState::replacePayloadValue(Value value, Value replacement) { SmallVector valueHandles; - (void)getHandlesForPayloadValue(value, valueHandles); + if (failed(getHandlesForPayloadValue(value, valueHandles))) + return failure(); for (Value handle : valueHandles) { Mappings &mappings = getMapping(handle); @@ -537,30 +537,30 @@ void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( Location ancestorLoc = ancestor->getLoc(); Location opLoc = definingOp->getLoc(); Location valueLoc = payloadValue.getLoc(); - newlyInvalidated[valueHandle] = - [valueHandle, owner, operandNo, resultNo, argumentNo, blockNo, regionNo, - ancestorLoc, opLoc, valueLoc](Location currentLoc) { - InFlightDiagnostic diag = emitError(currentLoc) - << "op uses a handle invalidated by a " - "previously executed transform op"; - diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; - diag.attachNote(owner->getLoc()) - << "invalidated by this transform op that consumes its operand #" - << operandNo - << " and invalidates all handles to payload IR entities " - "associated with this operand and entities nested in them"; - diag.attachNote(ancestorLoc) - << "ancestor op associated with the consumed handle"; - if (resultNo) { - diag.attachNote(opLoc) - << "op defining the value as result #" << *resultNo; - } else { - diag.attachNote(opLoc) - << "op defining the value as block argument #" << argumentNo - << " of block #" << blockNo << " in region #" << regionNo; - } - diag.attachNote(valueLoc) << "payload value"; - }; + newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo, + argumentNo, blockNo, regionNo, ancestorLoc, + opLoc, valueLoc](Location currentLoc) { + InFlightDiagnostic diag = emitError(currentLoc) + << "op uses a handle invalidated by a " + "previously executed transform op"; + diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; + diag.attachNote(owner->getLoc()) + << "invalidated by this transform op that consumes its operand #" + << operandNo + << " and invalidates all handles to payload IR entities " + "associated with this operand and entities nested in them"; + diag.attachNote(ancestorLoc) + << "ancestor op associated with the consumed handle"; + if (resultNo) { + diag.attachNote(opLoc) + << "op defining the value as result #" << *resultNo; + } else { + diag.attachNote(opLoc) + << "op defining the value as block argument #" << argumentNo + << " of block #" << blockNo << " in region #" << regionNo; + } + diag.attachNote(valueLoc) << "payload value"; + }; } } @@ -1064,10 +1064,6 @@ transform::TransformState::Extension::~Extension() = default; LogicalResult transform::TransformState::Extension::replacePayloadOp(Operation *op, Operation *replacement) { - SmallVector handles; - if (failed(state.getHandlesForPayloadOp(op, handles))) - return failure(); - // TODO: we may need to invalidate handles to operations and values nested in // the operation being replaced. return state.replacePayloadOp(op, replacement); @@ -1076,10 +1072,6 @@ transform::TransformState::Extension::replacePayloadOp(Operation *op, LogicalResult transform::TransformState::Extension::replacePayloadValue(Value value, Value replacement) { - SmallVector handles; - if (failed(state.getHandlesForPayloadValue(value, handles))) - return failure(); - return state.replacePayloadValue(value, replacement); } -- 2.7.4