detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
public:
+ const TransformOptions &getOptions() const { return options; }
+
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
return DiagnosedSilenceableFailure::success();
}
+/// Reports an error and returns failure if `targets` contains an ancestor
+/// operation before its descendant (or a copy of itself). Implementation detail
+/// for expensive checks during `TransformEachOpTrait::apply`.
+LogicalResult checkNestedConsumption(Location loc,
+ ArrayRef<Operation *> targets);
+
} // namespace detail
} // namespace transform
} // namespace mlir
mlir::DiagnosedSilenceableFailure
mlir::transform::TransformEachOpTrait<OpTy>::apply(
TransformResults &transformResults, TransformState &state) {
- auto targets = state.getPayloadOps(this->getOperation()->getOperand(0));
+ Value handle = this->getOperation()->getOperand(0);
+ auto targets = state.getPayloadOps(handle);
+
+ // If the operand is consumed, check if it is associated with operations that
+ // may be erased before their nested operations are.
+ if (state.getOptions().getExpensiveChecksEnabled() &&
+ isHandleConsumed(handle, cast<transform::TransformOpInterface>(
+ this->getOperation())) &&
+ failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
+ llvm::to_vector(targets)))) {
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
// Step 1. Handle the corner case where no target is specified.
// This is typically the case when the matcher fails to apply and we need to
//===----------------------------------------------------------------------===//
LogicalResult
+transform::detail::checkNestedConsumption(Location loc,
+ ArrayRef<Operation *> targets) {
+ for (auto &&[position, parent] : llvm::enumerate(targets)) {
+ for (Operation *child : targets.drop_front(position + 1)) {
+ if (parent->isAncestor(child)) {
+ InFlightDiagnostic diag =
+ emitError(loc)
+ << "transform operation consumes a handle pointing to an ancestor "
+ "payload operation before its descendant";
+ diag.attachNote()
+ << "the ancestor is likely erased or rewritten before the "
+ "descendant is accessed, leading to undefined behavior";
+ diag.attachNote(parent->getLoc()) << "ancestor payload op";
+ diag.attachNote(child->getLoc()) << "descendant payload op";
+ return diag;
+ }
+ }
+ }
+ return success();
+}
+
+LogicalResult
transform::detail::checkApplyToOne(Operation *transformOp,
Location payloadOpLoc,
const ApplyToEachResultList &partialResult) {
--- /dev/null
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics \
+// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{enable-expensive-checks=1 bind-first-extra-to-ops=scf.for})"
+
+func.func private @bar()
+
+func.func @foo() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ // expected-note @below {{ancestor payload op}}
+ scf.for %i = %c0 to %c1 step %c10 {
+ // expected-note @below {{descendant payload op}}
+ scf.for %j = %c0 to %c1 step %c10 {
+ func.call @bar() : () -> ()
+ }
+ }
+ return
+}
+
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">):
+ %1 = transform.test_reverse_payload_ops %arg1 : (!transform.op<"scf.for">) -> !transform.op<"scf.for">
+ // expected-error @below {{transform operation consumes a handle pointing to an ancestor payload operation before its descendant}}
+ // expected-note @below {{the ancestor is likely erased or rewritten before the descendant is accessed, leading to undefined behavior}}
+ transform.test_consume_operand_each %1 : !transform.op<"scf.for">
+}
+
+// -----
+
+func.func private @bar()
+
+func.func @foo() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ scf.for %i = %c0 to %c1 step %c10 {
+ scf.for %j = %c0 to %c1 step %c10 {
+ func.call @bar() : () -> ()
+ }
+ }
+ return
+}
+
+// No error here, processing ancestors before descendants.
+transform.sequence failures(suppress) {
+^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">):
+ transform.test_consume_operand_each %arg1 : !transform.op<"scf.for">
+}
let cppNamespace = "::mlir::test";
}
+def TestConsumeOperandEach : Op<Transform_Dialect, "test_consume_operand_each",
+ [TransformOpInterface, TransformEachOpTrait,
+ MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> {
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+ let cppNamespace = "::mlir::test";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state) {
+ return ::mlir::DiagnosedSilenceableFailure::success();
+ }
+ }];
+}
+
def TestConsumeOperandOfOpKindOrFail
: Op<Transform_Dialect, "test_consume_operand_of_op_kind_or_fail",
[DeclareOpInterfaceMethods<TransformOpInterface>,