From 984c2c8cb343e9a9d43b085f27f2f2ac3253cae7 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 9 Jan 2023 14:01:25 +0100 Subject: [PATCH] [mlir] verify against nullptr payload in transform dialect When establishing the correspondence between transform values and payload operations or parameters, check that the latter are non-null and report errors. This was previously allowed for exotic cases of partially successfull transformations with "apply each" trait, but was dangerous. The "apply each" implementation was reworked to remove the need for this functionality, so this can now be hardned to avoid null pointer dereferences. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D141142 --- .../Dialect/Transform/IR/TransformInterfaces.h | 18 ++++++------------ .../Dialect/Transform/IR/TransformInterfaces.cpp | 14 ++++++++++++++ mlir/test/Dialect/Transform/test-interpreter.mlir | 16 ++++++++++++++++ .../Transform/TestTransformDialectExtension.cpp | 22 ++++++++++++++++++++++ .../Transform/TestTransformDialectExtension.td | 18 ++++++++++++++++++ 5 files changed, 76 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 0ac2c45..b2c3827 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -832,36 +832,30 @@ applyTransformToEach(TransformOpTy transformOp, ArrayRef targets, SmallVector silenceableStack; unsigned expectedNumResults = transformOp->getNumResults(); for (Operation *target : targets) { - // Emplace back a placeholder for the returned new ops and params. - // This is filled with `expectedNumResults` if the op fails to apply. - ApplyToEachResultList placeholder; - placeholder.reserve(expectedNumResults); - results.push_back(std::move(placeholder)); - auto specificOp = dyn_cast(target); if (!specificOp) { Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error); diag << "transform applied to the wrong op kind"; diag.attachNote(target->getLoc()) << "when applied to this op"; - // Producing `expectedNumResults` nullptr is a silenceableFailure mode. - // TODO: encode this implicit `expectedNumResults` nullptr == - // silenceableFailure with a proper trait. - results.back().assign(expectedNumResults, nullptr); silenceableStack.push_back(std::move(diag)); continue; } + ApplyToEachResultList partialResults; + partialResults.reserve(expectedNumResults); Location specificOpLoc = specificOp->getLoc(); DiagnosedSilenceableFailure res = - transformOp.applyToOne(specificOp, results.back(), state); + transformOp.applyToOne(specificOp, partialResults, state); if (res.isDefiniteFailure() || failed(detail::checkApplyToOne(transformOp, specificOpLoc, - results.back()))) { + partialResults))) { return DiagnosedSilenceableFailure::definiteFailure(); } if (res.isSilenceableFailure()) res.takeDiagnostics(silenceableStack); + else + results.push_back(std::move(partialResults)); } if (!silenceableStack.empty()) { return DiagnosedSilenceableFailure::silenceableFailure( diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 40e4d79..e8ea213 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -80,6 +80,13 @@ transform::TransformState::setPayloadOps(Value value, assert(!value.getType().isa() && "cannot associate payload ops with a value of parameter type"); + for (Operation *target : targets) { + if (target) + continue; + return emitError(value.getLoc()) + << "attempting to assign a null payload op to this transform value"; + } + auto iface = value.getType().cast(); DiagnosedSilenceableFailure result = iface.checkPayload(value.getLoc(), targets); @@ -105,6 +112,13 @@ LogicalResult transform::TransformState::setParams(Value value, ArrayRef params) { assert(value != nullptr && "attempting to set params for a null value"); + for (Attribute attr : params) { + if (attr) + continue; + return emitError(value.getLoc()) + << "attempting to assign a null parameter to this transform value"; + } + auto valueType = value.getType().dyn_cast(); assert(value && "cannot associate parameter with a value of non-parameter type"); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index ca327c6..da48fe2 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1024,3 +1024,19 @@ transform.sequence failures(propagate) { { second_result_is_handle } : (!transform.any_op) -> (!transform.any_op, !transform.param) } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{attempting to assign a null payload op to this transform value}} + %0 = transform.test_produce_null_payload : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{attempting to assign a null parameter to this transform value}} + %0 = transform.test_produce_null_param : !transform.param +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 71bf51d..338d72e 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -458,6 +458,28 @@ mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +void mlir::test::TestProduceNullPayloadOp::getEffects( + SmallVectorImpl &effects) { + transform::producesHandle(getOut(), effects); +} + +DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + SmallVector null({nullptr}); + results.set(getOut().cast(), null); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceNullParamOp::getEffects( + SmallVectorImpl &effects) {} + +DiagnosedSilenceableFailure +mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + results.setParams(getOut().cast(), Attribute()); + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index dbe058c..9ff5e30 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -334,4 +334,22 @@ def TestProduceTransformParamOrForwardOperandOp }]; } +def TestProduceNullPayloadOp + : Op, + DeclareOpInterfaceMethods]> { + let results = (outs TransformHandleTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + +def TestProduceNullParamOp + : Op, + DeclareOpInterfaceMethods]> { + let results = (outs TransformParamTypeInterface:$out); + let assemblyFormat = "attr-dict `:` type($out)"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD -- 2.7.4