[mlir] add utilites for DiagnosedSilenceableFailure
authorAlex Zinenko <zinenko@google.com>
Mon, 17 Oct 2022 12:47:03 +0000 (12:47 +0000)
committerAlex Zinenko <zinenko@google.com>
Mon, 17 Oct 2022 15:31:28 +0000 (15:31 +0000)
This class adds helper functions similar to `emitError` for the
DiagnosedSilenceableFailure class in both the silenceable and definite
failure cases. These helpers simplify the use of said class and make
tranfsorm op application code idiomatic.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D136072

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/transform-state-extension.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

index 2c81985..b56a5dd 100644 (file)
@@ -181,6 +181,99 @@ private:
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 };
 
+class DiagnosedDefiniteFailure;
+
+DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
+                                             const Twine &message = {});
+
+/// A compatibility class connecting `InFlightDiagnostic` to
+/// `DiagnosedSilenceableFailure` while providing an interface similar to the
+/// former. Implicitly convertible to `DiagnosticSilenceableFailure` in definite
+/// failure state and to `LogicalResult` failure. Reports the error on
+/// conversion or on destruction. Instances of this class can be created by
+/// `emitDefiniteFailure()`.
+class DiagnosedDefiniteFailure {
+  friend DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
+                                                      const Twine &message);
+
+public:
+  /// Only move-constructible because it carries an in-flight diagnostic.
+  DiagnosedDefiniteFailure(DiagnosedDefiniteFailure &&) = default;
+
+  /// Forward the message to the diagnostic.
+  template <typename T>
+  DiagnosedDefiniteFailure &operator<<(T &&value) & {
+    diag << std::forward<T>(value);
+    return *this;
+  }
+  template <typename T>
+  DiagnosedDefiniteFailure &&operator<<(T &&value) && {
+    return std::move(this->operator<<(std::forward<T>(value)));
+  }
+
+  /// Attaches a note to the error.
+  Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
+    return diag.attachNote(loc);
+  }
+
+  /// Implicit conversion to DiagnosedSilenceableFailure in the definite failure
+  /// state. Reports the error.
+  operator DiagnosedSilenceableFailure() {
+    diag.report();
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  /// Implicit conversion to LogicalResult in the failure state. Reports the
+  /// error.
+  operator LogicalResult() {
+    diag.report();
+    return failure();
+  }
+
+private:
+  /// Constructs a definite failure at the given location with the given
+  /// message.
+  explicit DiagnosedDefiniteFailure(Location loc, const Twine &message)
+      : diag(emitError(loc, message)) {}
+
+  /// Copy-construction and any assignment is disallowed to prevent repeated
+  /// error reporting.
+  DiagnosedDefiniteFailure(const DiagnosedDefiniteFailure &) = delete;
+  DiagnosedDefiniteFailure &
+  operator=(const DiagnosedDefiniteFailure &) = delete;
+  DiagnosedDefiniteFailure &operator=(DiagnosedDefiniteFailure &&) = delete;
+
+  /// The error message.
+  InFlightDiagnostic diag;
+};
+
+/// Emits a definite failure with the given message. The returned object allows
+/// for last-minute modification to the error message, such as attaching notes
+/// and completing the message. It will be reported when the object is
+/// destructed or converted.
+inline DiagnosedDefiniteFailure emitDefiniteFailure(Location loc,
+                                                    const Twine &message) {
+  return DiagnosedDefiniteFailure(loc, message);
+}
+inline DiagnosedDefiniteFailure emitDefiniteFailure(Operation *op,
+                                                    const Twine &message = {}) {
+  return emitDefiniteFailure(op->getLoc(), message);
+}
+
+/// Emits a silenceable failure with the given message. A silenceable failure
+/// must be either suppressed or converted into a definite failure and reported
+/// to the user.
+inline DiagnosedSilenceableFailure
+emitSilenceableFailure(Location loc, const Twine &message = {}) {
+  Diagnostic diag(loc, DiagnosticSeverity::Error);
+  diag << message;
+  return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+}
+inline DiagnosedSilenceableFailure
+emitSilenceableFailure(Operation *op, const Twine &message = {}) {
+  return emitSilenceableFailure(op->getLoc(), message);
+}
+
 namespace transform {
 
 class TransformOpInterface;
index 238825d..fe29f30 100644 (file)
@@ -63,20 +63,29 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
     }
 
     /// Creates the silenceable failure object with a diagnostic located at the
-    /// current operation.
-    DiagnosedSilenceableFailure emitSilenceableError() {
-      Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
-      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    /// current operation. Silenceable failure must be suppressed or reported
+    /// explicitly at some later time.
+    DiagnosedSilenceableFailure
+    emitSilenceableError(const ::llvm::Twine &message = {}) {
+      return ::mlir::emitSilenceableFailure($_op);
+    }
+
+    /// Creates the definite failure object with a diagnostic located at the
+    /// current operation. Definite failure will be reported when the object
+    /// is destroyed or converted.
+    DiagnosedDefiniteFailure
+    emitDefiniteFailure(const ::llvm::Twine &message = {}) {
+      return ::mlir::emitDefiniteFailure($_op, message);
     }
 
     /// Creates the default silenceable failure for a transform op that failed
     /// to properly apply to a target.
     DiagnosedSilenceableFailure emitDefaultSilenceableFailure(
         Operation *target) {
-      Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
+      DiagnosedSilenceableFailure diag = emitSilenceableFailure($_op->getLoc());
       diag << $_op->getName() << " failed to apply";
       diag.attachNote(target->getLoc()) << "when applied to this op";
-      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+      return diag;
     }
   }];
 }
index 92b0b5e..280dbf0 100644 (file)
@@ -324,8 +324,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     if (transformOp.has_value()) {
       return transformOp->emitSilenceableError() << message;
     }
-    foreachThreadOp->emitError() << message;
-    return DiagnosedSilenceableFailure::definiteFailure();
+    return emitDefiniteFailure(foreachThreadOp, message);
   };
 
   if (foreachThreadOp.getNumResults() > 0)
index e47e8e5..1030f0a 100644 (file)
@@ -470,10 +470,9 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   }
   ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
   if (containingOps.size() != 1) {
-    // Definite failure.
-    return DiagnosedSilenceableFailure(
-        this->emitOpError("requires exactly one containing_op handle (got ")
-        << containingOps.size() << ")");
+    return emitDefiniteFailure()
+           << "requires exactly one containing_op handle (got "
+           << containingOps.size() << ")";
   }
   Operation *containingOp = containingOps.front();
 
@@ -925,11 +924,11 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
     }
 
     if (splitPoints.size() != payload.size()) {
-      emitError() << "expected the dynamic split point handle to point to as "
-                     "many operations ("
-                  << splitPoints.size() << ") as the target handle ("
-                  << payload.size() << ")";
-      return DiagnosedSilenceableFailure::definiteFailure();
+      return emitDefiniteFailure()
+             << "expected the dynamic split point handle to point to as "
+                "many operations ("
+             << splitPoints.size() << ") as the target handle ("
+             << payload.size() << ")";
     }
   } else {
     splitPoints.resize(payload.size(),
index 6d32c40..e632be0 100644 (file)
@@ -177,17 +177,16 @@ transform::AlternativesOp::apply(transform::TransformResults &results,
 
   for (Operation *original : originals) {
     if (original->isAncestor(getOperation())) {
-      InFlightDiagnostic diag =
-          emitError() << "scope must not contain the transforms being applied";
+      auto diag = emitDefiniteFailure()
+                  << "scope must not contain the transforms being applied";
       diag.attachNote(original->getLoc()) << "scope";
-      return DiagnosedSilenceableFailure::definiteFailure();
+      return diag;
     }
     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
-      InFlightDiagnostic diag =
-          emitError()
-          << "only isolated-from-above ops can be alternative scopes";
+      auto diag = emitDefiniteFailure()
+                  << "only isolated-from-above ops can be alternative scopes";
       diag.attachNote(original->getLoc()) << "scope";
-      return DiagnosedSilenceableFailure(std::move(diag));
+      return diag;
     }
   }
 
@@ -523,8 +522,8 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
   for (Operation *root : state.getPayloadOps(getRoot())) {
     if (failed(extension->findAllMatches(
             getPatternName().getLeafReference().getValue(), root, targets))) {
-      emitOpError() << "could not find pattern '" << getPatternName() << "'";
-      return DiagnosedSilenceableFailure::definiteFailure();
+      emitDefiniteFailure()
+          << "could not find pattern '" << getPatternName() << "'";
     }
   }
   results.set(getResult().cast<OpResult>(), targets);
index f7f678e..1f29684 100644 (file)
@@ -44,3 +44,13 @@ module {
     test_check_if_test_extension_present %arg0
   }
 }
+
+// -----
+
+module {
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !pdl.operation):
+    // expected-error @below {{TestTransformStateExtension missing}}
+    test_remap_operand_to_self %arg0
+  }
+}
index b890af5..483d4fe 100644 (file)
@@ -188,10 +188,8 @@ mlir::test::TestCheckIfTestExtensionPresentOp::apply(
 DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
     transform::TransformResults &results, transform::TransformState &state) {
   auto *extension = state.getExtension<TestTransformStateExtension>();
-  if (!extension) {
-    emitError() << "TestTransformStateExtension missing";
-    return DiagnosedSilenceableFailure::definiteFailure();
-  }
+  if (!extension)
+    return emitDefiniteFailure("TestTransformStateExtension missing");
 
   if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(),
                                       getOperation())))