mapping between transform IR values and payload IR operations.
- An `Allocate` effect from this resource means creating a new mapping
- entry, it is always accompanied by a `Write` effet.
+ entry, it is always accompanied by a `Write` effect.
- A `Read` effect from this resource means accessing the mapping.
transform operations can return _new_ handles that can be read or consumed
by subsequent operations.
+ ## Handle Invalidation
+
+ The execution model of the transform dialect expects that a payload IR
+ operation is associated with _at most one_ transform IR handle. This avoids
+ the situation when a handle to an operation outlives the operation itself
+ that can be erased during a transformation triggered through another handle.
+
+ Handles pointing to operations nested in each other are allowed to co-exist
+ in the transform IR. However, a transform IR operation that consumes such a
+ handle automatically _invalidates_ all the other handles that are associated
+ with operations nested in the operations associated with the consumed
+ handle. Any use of the invalidated handle results in undefined behavior
+ since the payload IR operations associated with it are likely to have been
+ mutated or erased. The mere fact of the handle being invalidated does _not_
+ trigger undefined behavior, only its appearance as an operand does.
+ Invalidation applies to the entire handle, even if some of the payload IR
+ operations associated with it are not nested in payload IR operations
+ associated with another, consumed handle.
+
+ Note: the restriction on two handles not pointing to the same operation may
+ be relaxed in the future to follow the invalidation model for nested
+ operation.
+
+ The Transform dialect infrastructure has the capability of checking whether
+ the transform IR op operand is invalidated before applying the
+ transformation. However, such a check is computationally expensive and
+ must be enabled explicitly through `TransformOptions`. Additionally, the
+ `transform-dialect-check-uses` pass emits warnings when a handle may be used
+ after it has been consumed, but does so abstractly, without processing the
+ payload IR.
+
## Intended Use and Integrations
The transformation control infrastructure provided by this dialect is
differentiate between the parts of the loop produced by the previous pass
(both are the same operation, and it is likely undesirable to pollute the
operation with pass-specific information). Implementing passes that run the
- combined transfomration would have run into the combinatorial explosion
+ combined transformation would have run into the combinatorial explosion
issue due to multiple possible transform compositions or into the need for
deep pass parameterization, the ultimate form of which is an ad-hoc dialect
to specify which transformations the pass should run. The transform dialect
takes care of bookkeeping. As such, the transform dialect does not provide
the interpreter pass. Instead, it provides a set of utilities that can be
used by clients to define their own interpreter passes or as part of a more
- complex pass. For example, the mapping between values in the tranfsorm IR
+ complex pass. For example, the mapping between values in the transform IR
and operations in the payload IR, or the function that applies the
transformations specified by ops in the given block sequentially. Note that
a transform op may have regions with further transform ops in them, with
class TransformOpInterface;
+/// Options controlling the application of transform operations by the
+/// TransformState.
+class TransformOptions {
+public:
+ TransformOptions() {}
+
+ /// Requests computationally expensive checks of the transform and payload IR
+ /// well-formedness to be performed before each transformation. In particular,
+ /// these ensure that the handles still point to valid operations when used.
+ TransformOptions &enableExpensiveChecks(bool enable = true) {
+ expensiveChecksEnabled = enable;
+ return *this;
+ }
+
+ /// Returns true if the expensive checks are requested.
+ bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
+
+private:
+ bool expensiveChecksEnabled = true;
+};
+
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
/// surrounding structure are referred to as transform IR. The operations to
/// Creates a state for transform ops living in the given region. The parent
/// operation of the region. The second argument points to the root operation
/// in the payload IR beind transformed, which may or may not contain the
- /// region with transform ops.
- TransformState(Region ®ion, Operation *root);
+ /// region with transform ops. Additional options can be provided through the
+ /// trailing configuration object.
+ TransformState(Region ®ion, Operation *root,
+ const TransformOptions &options = TransformOptions());
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
Value handle);
+ /// If the operand is a handle consumed by the operation, i.e. has the "free"
+ /// memory effect associated with it, identifies other handles that are
+ /// pointing to payload IR operations nested in the operations pointed to by
+ /// the consumed handle. Marks all such handles as invalidated so trigger
+ /// errors if they are used.
+ void recordHandleInvalidation(OpOperand &handle);
+
+ /// Checks that the operation does not use invalidated handles as operands.
+ /// Reports errors and returns failure if it does. Otherwise, invalidates the
+ /// handles consumed by the operation as well as any handles pointing to
+ /// payload IR operations nested in the operations associated with the
+ /// consumed handles.
+ LogicalResult
+ checkAndRecordHandleInvalidation(TransformOpInterface transform);
+
/// The mappings between transform IR values and payload IR ops, aggregated by
/// the region in which the transform IR values are defined.
llvm::SmallDenseMap<Region *, Mappings> mappings;
/// The top-level operation that contains all payload IR, typically a module.
Operation *topLevel;
+ /// Additional options controlling the transformation state behavior.
+ TransformOptions options;
+
+ /// The mapping from invalidated handles to the error-reporting functions that
+ /// describe when the handles were invalidated. Calling such a function emits
+ /// a user-visible diagnostic.
+ DenseMap<Value, std::function<void()>> invalidatedHandles;
+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// A stack of nested regions that are being processed in the transform IR.
/// Each region must be an ancestor of the following regions in this list.
constexpr const Value transform::TransformState::kTopLevelValue;
-transform::TransformState::TransformState(Region ®ion, Operation *root)
- : topLevel(root) {
+transform::TransformState::TransformState(Region ®ion, Operation *root,
+ const TransformOptions &options)
+ : topLevel(root), options(options) {
auto result = mappings.try_emplace(®ion);
assert(result.second && "the region scope is already present");
(void)result;
return success();
}
+void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
+ ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
+ for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+ for (const auto &kvp : mapping.reverse) {
+ // If the op is associated with invalidated handle, skip the check as it
+ // may be reading invalid IR.
+ Operation *op = kvp.first;
+ Value otherHandle = kvp.second;
+ if (invalidatedHandles.count(otherHandle))
+ continue;
+
+ for (Operation *ancestor : potentialAncestors) {
+ if (!ancestor->isProperAncestor(op))
+ continue;
+
+ // Make sure the error-reporting lambda doesn't capture anything
+ // by-reference because it will go out of scope. Additionally, extract
+ // location from Payload IR ops because the ops themselves may be
+ // deleted before the lambda gets called.
+ Location ancestorLoc = ancestor->getLoc();
+ Location opLoc = op->getLoc();
+ Operation *owner = handle.getOwner();
+ unsigned operandNo = handle.getOperandNumber();
+ invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
+ otherHandle]() {
+ InFlightDiagnostic diag =
+ owner->emitOpError()
+ << "invalidated the handle to payload operations nested in the "
+ "payload operation associated with its operand #"
+ << operandNo;
+ diag.attachNote(ancestorLoc) << "ancestor op";
+ diag.attachNote(opLoc) << "nested op";
+ diag.attachNote(otherHandle.getLoc()) << "other handle";
+ };
+ }
+ }
+ }
+}
+
+LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
+ TransformOpInterface transform) {
+ auto memoryEffectsIface =
+ cast<MemoryEffectOpInterface>(transform.getOperation());
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memoryEffectsIface.getEffectsOnResource(
+ transform::TransformMappingResource::get(), effects);
+
+ for (OpOperand &target : transform->getOpOperands()) {
+ // If the operand uses an invalidated handle, report it.
+ auto it = invalidatedHandles.find(target.get());
+ if (it != invalidatedHandles.end())
+ return it->getSecond()(), failure();
+
+ // Invalidate handles pointing to the operations nested in the operation
+ // associated with the handle consumed by this operation.
+ auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
+ return isa<MemoryEffects::Free>(effect.getEffect()) &&
+ effect.getValue() == target.get();
+ };
+ if (llvm::find_if(effects, consumesTarget) != effects.end())
+ recordHandleInvalidation(target);
+ }
+ return success();
+}
+
LogicalResult
transform::TransformState::applyTransform(TransformOpInterface transform) {
+ if (options.getExpensiveChecksEnabled() &&
+ failed(checkAndRecordHandleInvalidation(transform))) {
+ return failure();
+ }
+
transform::TransformResults results(transform->getNumResults());
if (failed(transform.apply(results, *this)))
return failure();
auto memEffectInterface =
cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance, 2> effects;
- for (Value target : transform->getOperands()) {
+ for (OpOperand &target : transform->getOpOperands()) {
effects.clear();
- memEffectInterface.getEffectsOnValue(target, effects);
+ memEffectInterface.getEffectsOnValue(target.get(), effects);
if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<transform::TransformMappingResource>(
effect.getResource()) &&
isa<MemoryEffects::Free>(effect.getEffect());
})) {
- removePayloadOps(target);
+ removePayloadOps(target.get());
}
}
- for (auto &en : llvm::enumerate(transform->getResults())) {
- assert(en.value().getDefiningOp() == transform.getOperation() &&
+ for (OpResult result : transform->getResults()) {
+ assert(result.getDefiningOp() == transform.getOperation() &&
"payload IR association for a value other than the result of the "
"current transform op");
- if (failed(setPayloadOps(en.value(), results.get(en.index()))))
+ if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
return failure();
}
--- /dev/null
+// RUN: mlir-opt --test-transform-dialect-interpreter='enable-expensive-checks=1' --split-input-file --verify-diagnostics %s
+
+// expected-note @below {{ancestor op}}
+func.func @func() {
+ // expected-note @below {{nested op}}
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @return : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ // expected-note @below {{other handle}}
+ %0 = pdl_match @return in %arg1
+ %1 = get_closest_isolated_parent %0
+ // expected-error @below {{invalidated the handle to payload operations nested in the payload operation associated with its operand #0}}
+ test_consume_operand %1
+ test_print_remark_at_operand %0, "remark"
+ }
+}
return success();
}
+LogicalResult
+mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ return success();
+}
+
LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
let hasVerifier = 1;
}
+def TestConsumeOperand : Op<Transform_Dialect, "test_consume_operand",
+ [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins
+ Arg<PDL_Operation, "",
+ [TransformMappingRead, TransformMappingFree]>:$operand);
+ let assemblyFormat = "$operand attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
def TestConsumeOperandIfMatchesParamOrFail
: Op<Transform_Dialect, "test_consume_operand_if_matches_param_or_fail",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTransformDialectInterpreterPass)
+ TestTransformDialectInterpreterPass() = default;
+ TestTransformDialectInterpreterPass(
+ const TestTransformDialectInterpreterPass &) {}
+
StringRef getArgument() const override {
return "test-transform-dialect-interpreter";
}
void runOnOperation() override {
ModuleOp module = getOperation();
- transform::TransformState state(module.getBodyRegion(), module);
+ transform::TransformState state(
+ module.getBodyRegion(), module,
+ transform::TransformOptions().enableExpensiveChecks(
+ enableExpensiveChecks));
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(state.applyTransform(op)))
return signalPassFailure();
}
}
+
+ Option<bool> enableExpensiveChecks{
+ *this, "enable-expensive-checks", llvm::cl::init(false),
+ llvm::cl::desc("perform expensive checks to better report errors in the "
+ "transform IR")};
};
} // namespace