From b0bf7ffffc3a2a65a22da4afb13feb855baf2042 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 17 Oct 2022 12:47:03 +0000 Subject: [PATCH] [mlir] add utilites for DiagnosedSilenceableFailure This class adds helper functions similar to `emitError` for the DiagnosedSilenceableFailure class in both the silenceable and definite failure cases. These helpers simplify the use of said class and make tranfsorm op application code idiomatic. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D136072 --- .../Dialect/Transform/IR/TransformInterfaces.h | 93 ++++++++++++++++++++++ .../Dialect/Transform/IR/TransformInterfaces.td | 21 +++-- .../Dialect/GPU/TransformOps/GPUTransformOps.cpp | 3 +- .../Linalg/TransformOps/LinalgTransformOps.cpp | 17 ++-- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 17 ++-- .../Transform/transform-state-extension.mlir | 10 +++ .../Transform/TestTransformDialectExtension.cpp | 6 +- 7 files changed, 137 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 2c81985..b56a5dd 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -181,6 +181,99 @@ private: #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; +class DiagnosedDefiniteFailure; + +DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, + const Twine &message = {}); + +/// A compatibility class connecting `InFlightDiagnostic` to +/// `DiagnosedSilenceableFailure` while providing an interface similar to the +/// former. Implicitly convertible to `DiagnosticSilenceableFailure` in definite +/// failure state and to `LogicalResult` failure. Reports the error on +/// conversion or on destruction. Instances of this class can be created by +/// `emitDefiniteFailure()`. +class DiagnosedDefiniteFailure { + friend DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, + const Twine &message); + +public: + /// Only move-constructible because it carries an in-flight diagnostic. + DiagnosedDefiniteFailure(DiagnosedDefiniteFailure &&) = default; + + /// Forward the message to the diagnostic. + template + DiagnosedDefiniteFailure &operator<<(T &&value) & { + diag << std::forward(value); + return *this; + } + template + DiagnosedDefiniteFailure &&operator<<(T &&value) && { + return std::move(this->operator<<(std::forward(value))); + } + + /// Attaches a note to the error. + Diagnostic &attachNote(Optional loc = llvm::None) { + return diag.attachNote(loc); + } + + /// Implicit conversion to DiagnosedSilenceableFailure in the definite failure + /// state. Reports the error. + operator DiagnosedSilenceableFailure() { + diag.report(); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + /// Implicit conversion to LogicalResult in the failure state. Reports the + /// error. + operator LogicalResult() { + diag.report(); + return failure(); + } + +private: + /// Constructs a definite failure at the given location with the given + /// message. + explicit DiagnosedDefiniteFailure(Location loc, const Twine &message) + : diag(emitError(loc, message)) {} + + /// Copy-construction and any assignment is disallowed to prevent repeated + /// error reporting. + DiagnosedDefiniteFailure(const DiagnosedDefiniteFailure &) = delete; + DiagnosedDefiniteFailure & + operator=(const DiagnosedDefiniteFailure &) = delete; + DiagnosedDefiniteFailure &operator=(DiagnosedDefiniteFailure &&) = delete; + + /// The error message. + InFlightDiagnostic diag; +}; + +/// Emits a definite failure with the given message. The returned object allows +/// for last-minute modification to the error message, such as attaching notes +/// and completing the message. It will be reported when the object is +/// destructed or converted. +inline DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, + const Twine &message) { + return DiagnosedDefiniteFailure(loc, message); +} +inline DiagnosedDefiniteFailure emitDefiniteFailure(Operation *op, + const Twine &message = {}) { + return emitDefiniteFailure(op->getLoc(), message); +} + +/// Emits a silenceable failure with the given message. A silenceable failure +/// must be either suppressed or converted into a definite failure and reported +/// to the user. +inline DiagnosedSilenceableFailure +emitSilenceableFailure(Location loc, const Twine &message = {}) { + Diagnostic diag(loc, DiagnosticSeverity::Error); + diag << message; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); +} +inline DiagnosedSilenceableFailure +emitSilenceableFailure(Operation *op, const Twine &message = {}) { + return emitSilenceableFailure(op->getLoc(), message); +} + namespace transform { class TransformOpInterface; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index 238825d..fe29f30 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -63,20 +63,29 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> { } /// Creates the silenceable failure object with a diagnostic located at the - /// current operation. - DiagnosedSilenceableFailure emitSilenceableError() { - Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + /// current operation. Silenceable failure must be suppressed or reported + /// explicitly at some later time. + DiagnosedSilenceableFailure + emitSilenceableError(const ::llvm::Twine &message = {}) { + return ::mlir::emitSilenceableFailure($_op); + } + + /// Creates the definite failure object with a diagnostic located at the + /// current operation. Definite failure will be reported when the object + /// is destroyed or converted. + DiagnosedDefiniteFailure + emitDefiniteFailure(const ::llvm::Twine &message = {}) { + return ::mlir::emitDefiniteFailure($_op, message); } /// Creates the default silenceable failure for a transform op that failed /// to properly apply to a target. DiagnosedSilenceableFailure emitDefaultSilenceableFailure( Operation *target) { - Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); + DiagnosedSilenceableFailure diag = emitSilenceableFailure($_op->getLoc()); diag << $_op->getName() << " failed to apply"; diag.attachNote(target->getLoc()) << "when applied to this op"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + return diag; } }]; } diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 92b0b5e..280dbf0 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -324,8 +324,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( if (transformOp.has_value()) { return transformOp->emitSilenceableError() << message; } - foreachThreadOp->emitError() << message; - return DiagnosedSilenceableFailure::definiteFailure(); + return emitDefiniteFailure(foreachThreadOp, message); }; if (foreachThreadOp.getNumResults() > 0) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index e47e8e5..1030f0a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -470,10 +470,9 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); if (containingOps.size() != 1) { - // Definite failure. - return DiagnosedSilenceableFailure( - this->emitOpError("requires exactly one containing_op handle (got ") - << containingOps.size() << ")"); + return emitDefiniteFailure() + << "requires exactly one containing_op handle (got " + << containingOps.size() << ")"; } Operation *containingOp = containingOps.front(); @@ -925,11 +924,11 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, } if (splitPoints.size() != payload.size()) { - emitError() << "expected the dynamic split point handle to point to as " - "many operations (" - << splitPoints.size() << ") as the target handle (" - << payload.size() << ")"; - return DiagnosedSilenceableFailure::definiteFailure(); + return emitDefiniteFailure() + << "expected the dynamic split point handle to point to as " + "many operations (" + << splitPoints.size() << ") as the target handle (" + << payload.size() << ")"; } } else { splitPoints.resize(payload.size(), diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 6d32c40..e632be0 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -177,17 +177,16 @@ transform::AlternativesOp::apply(transform::TransformResults &results, for (Operation *original : originals) { if (original->isAncestor(getOperation())) { - InFlightDiagnostic diag = - emitError() << "scope must not contain the transforms being applied"; + auto diag = emitDefiniteFailure() + << "scope must not contain the transforms being applied"; diag.attachNote(original->getLoc()) << "scope"; - return DiagnosedSilenceableFailure::definiteFailure(); + return diag; } if (!original->hasTrait()) { - InFlightDiagnostic diag = - emitError() - << "only isolated-from-above ops can be alternative scopes"; + auto diag = emitDefiniteFailure() + << "only isolated-from-above ops can be alternative scopes"; diag.attachNote(original->getLoc()) << "scope"; - return DiagnosedSilenceableFailure(std::move(diag)); + return diag; } } @@ -523,8 +522,8 @@ transform::PDLMatchOp::apply(transform::TransformResults &results, for (Operation *root : state.getPayloadOps(getRoot())) { if (failed(extension->findAllMatches( getPatternName().getLeafReference().getValue(), root, targets))) { - emitOpError() << "could not find pattern '" << getPatternName() << "'"; - return DiagnosedSilenceableFailure::definiteFailure(); + emitDefiniteFailure() + << "could not find pattern '" << getPatternName() << "'"; } } results.set(getResult().cast(), targets); diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir index f7f678e..1f29684 100644 --- a/mlir/test/Dialect/Transform/transform-state-extension.mlir +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -44,3 +44,13 @@ module { test_check_if_test_extension_present %arg0 } } + +// ----- + +module { + transform.sequence failures(suppress) { + ^bb0(%arg0: !pdl.operation): + // expected-error @below {{TestTransformStateExtension missing}} + test_remap_operand_to_self %arg0 + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index b890af5..483d4fe 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -188,10 +188,8 @@ mlir::test::TestCheckIfTestExtensionPresentOp::apply( DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); - if (!extension) { - emitError() << "TestTransformStateExtension missing"; - return DiagnosedSilenceableFailure::definiteFailure(); - } + if (!extension) + return emitDefiniteFailure("TestTransformStateExtension missing"); if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), getOperation()))) -- 2.7.4