From 95f495c7b3a6bc2ed47bb6f9977256ec0f841e52 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 9 Jun 2023 09:49:35 +0000 Subject: [PATCH] [mlir][transform] add a check for nested consumption in ApplyEachOpTrait ApplyEachOpTrait applies to payload ops associated with its operand handle one-by-one in order. If a handle is consumed, this usually indicates that the associated payload ops are erased or rewritten. Add a check that we don't consume an ancestor payload operation before consuming its descendant, as the latter is likely to be a dangling pointer. Transform operations for which this is a legitimate behavior (i.e., they consume the handle but don't actually erase or rewrite the payload operation) should implement the interface directly and allow for repeated handles. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D152510 --- .../Dialect/Transform/IR/TransformInterfaces.h | 21 +++++++++- .../Dialect/Transform/IR/TransformInterfaces.cpp | 22 ++++++++++ .../Dialect/Transform/apply-foreach-nested.mlir | 48 ++++++++++++++++++++++ .../Transform/TestTransformDialectExtension.td | 16 ++++++++ 4 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Transform/apply-foreach-nested.mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 28972f1..1afac0a 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -187,6 +187,8 @@ private: detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot); public: + const TransformOptions &getOptions() const { return options; } + /// Returns the op at which the transformation state is rooted. This is /// typically helpful for transformations that apply globally. Operation *getTopLevel() const; @@ -1352,6 +1354,12 @@ applyTransformToEach(TransformOpTy transformOp, Range &&targets, return DiagnosedSilenceableFailure::success(); } +/// Reports an error and returns failure if `targets` contains an ancestor +/// operation before its descendant (or a copy of itself). Implementation detail +/// for expensive checks during `TransformEachOpTrait::apply`. +LogicalResult checkNestedConsumption(Location loc, + ArrayRef targets); + } // namespace detail } // namespace transform } // namespace mlir @@ -1360,7 +1368,18 @@ template mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( TransformResults &transformResults, TransformState &state) { - auto targets = state.getPayloadOps(this->getOperation()->getOperand(0)); + Value handle = this->getOperation()->getOperand(0); + auto targets = state.getPayloadOps(handle); + + // If the operand is consumed, check if it is associated with operations that + // may be erased before their nested operations are. + if (state.getOptions().getExpensiveChecksEnabled() && + isHandleConsumed(handle, cast( + this->getOperation())) && + failed(detail::checkNestedConsumption(this->getOperation()->getLoc(), + llvm::to_vector(targets)))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } // Step 1. Handle the corner case where no target is specified. // This is typically the case when the matcher fails to apply and we need to diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 068fa0b..a6c5629 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1331,6 +1331,28 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( //===----------------------------------------------------------------------===// LogicalResult +transform::detail::checkNestedConsumption(Location loc, + ArrayRef targets) { + for (auto &&[position, parent] : llvm::enumerate(targets)) { + for (Operation *child : targets.drop_front(position + 1)) { + if (parent->isAncestor(child)) { + InFlightDiagnostic diag = + emitError(loc) + << "transform operation consumes a handle pointing to an ancestor " + "payload operation before its descendant"; + diag.attachNote() + << "the ancestor is likely erased or rewritten before the " + "descendant is accessed, leading to undefined behavior"; + diag.attachNote(parent->getLoc()) << "ancestor payload op"; + diag.attachNote(child->getLoc()) << "descendant payload op"; + return diag; + } + } + } + return success(); +} + +LogicalResult transform::detail::checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult) { diff --git a/mlir/test/Dialect/Transform/apply-foreach-nested.mlir b/mlir/test/Dialect/Transform/apply-foreach-nested.mlir new file mode 100644 index 0000000..d2f7164 --- /dev/null +++ b/mlir/test/Dialect/Transform/apply-foreach-nested.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s --split-input-file --verify-diagnostics \ +// RUN: --pass-pipeline="builtin.module(test-transform-dialect-interpreter{enable-expensive-checks=1 bind-first-extra-to-ops=scf.for})" + +func.func private @bar() + +func.func @foo() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + // expected-note @below {{ancestor payload op}} + scf.for %i = %c0 to %c1 step %c10 { + // expected-note @below {{descendant payload op}} + scf.for %j = %c0 to %c1 step %c10 { + func.call @bar() : () -> () + } + } + return +} + +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">): + %1 = transform.test_reverse_payload_ops %arg1 : (!transform.op<"scf.for">) -> !transform.op<"scf.for"> + // expected-error @below {{transform operation consumes a handle pointing to an ancestor payload operation before its descendant}} + // expected-note @below {{the ancestor is likely erased or rewritten before the descendant is accessed, leading to undefined behavior}} + transform.test_consume_operand_each %1 : !transform.op<"scf.for"> +} + +// ----- + +func.func private @bar() + +func.func @foo() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + scf.for %i = %c0 to %c1 step %c10 { + scf.for %j = %c0 to %c1 step %c10 { + func.call @bar() : () -> () + } + } + return +} + +// No error here, processing ancestors before descendants. +transform.sequence failures(suppress) { +^bb0(%arg0: !transform.any_op, %arg1: !transform.op<"scf.for">): + transform.test_consume_operand_each %arg1 : !transform.op<"scf.for"> +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index 85b0440..7611165 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -108,6 +108,22 @@ def TestConsumeOperand : Op { + let arguments = (ins TransformHandleTypeInterface:$target); + let assemblyFormat = "$target attr-dict `:` type($target)"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + return ::mlir::DiagnosedSilenceableFailure::success(); + } + }]; +} + def TestConsumeOperandOfOpKindOrFail : Op, -- 2.7.4