[mlir][python] Capture error diagnostics in exceptions
authorRahul Kayaith <rkayaith@gmail.com>
Tue, 7 Feb 2023 21:07:50 +0000 (16:07 -0500)
committerRahul Kayaith <rkayaith@gmail.com>
Tue, 7 Mar 2023 19:59:22 +0000 (14:59 -0500)
This updates most (all?) error-diagnostic-emitting python APIs to
capture error diagnostics and include them in the raised exception's
message:
```
>>> Operation.parse('"arith.addi"() : () -> ()'))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
mlir._mlir_libs.MLIRError: Unable to parse operation assembly:
error: "-":1:1: 'arith.addi' op requires one result
 note: "-":1:1: see current operation: "arith.addi"() : () -> ()
```

The diagnostic information is available on the exception for users who
may want to customize the error message:
```
>>> try:
...   Operation.parse('"arith.addi"() : () -> ()')
... except MLIRError as e:
...   print(e.message)
...   print(e.error_diagnostics)
...   print(e.error_diagnostics[0].message)
...
Unable to parse operation assembly
[<mlir._mlir_libs._mlir.ir.DiagnosticInfo object at 0x7fed32bd6b70>]
'arith.addi' op requires one result
```

Error diagnostics captured in exceptions aren't propagated to diagnostic
handlers, to avoid double-reporting of errors. The context-level
`emit_error_diagnostics` option can be used to revert to the old
behaviour, causing error diagnostics to be reported to handlers instead
of as part of exceptions.

API changes:
- `Operation.verify` now raises an exception on verification failure,
  instead of returning `false`
- The exception raised by the following methods has been changed to
  `MLIRError`:
  - `PassManager.run`
  - `{Module,Operation,Type,Attribute}.parse`
  - `{RankedTensorType,UnrankedTensorType}.get`
  - `{MemRefType,UnrankedMemRefType}.get`
  - `VectorType.get`
  - `FloatAttr.get`

closes #60595

depends on D144804, D143830

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D143869

13 files changed:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/Bindings/Python/Pass.cpp
mlir/python/mlir/_mlir_libs/__init__.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/ir/diagnostic_handler.py
mlir/test/python/ir/exception.py [new file with mode: 0644]
mlir/test/python/ir/module.py
mlir/test/python/ir/operation.py
mlir/test/python/pass_manager.py

index c8ede8b..b0c35ff 100644 (file)
@@ -344,15 +344,10 @@ public:
     c.def_static(
         "get",
         [](PyType &type, double value, DefaultingPyLocation loc) {
+          PyMlirContext::ErrorCapture errors(loc->getContext());
           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirAttributeIsNull(attr)) {
-            throw SetPyError(PyExc_ValueError,
-                             Twine("invalid '") +
-                                 py::repr(py::cast(type)).cast<std::string>() +
-                                 "' and expected floating point type.");
-          }
+          if (mlirAttributeIsNull(attr))
+            throw MLIRError("Invalid attribute", errors.take());
           return PyFloatAttribute(type.getContext(), attr);
         },
         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
index e03b647..8d637ea 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Debug.h"
+#include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 //#include "mlir-c/Registration.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -38,7 +39,7 @@ using llvm::Twine;
 static const char kContextParseTypeDocstring[] =
     R"(Parses the assembly form of a type.
 
-Returns a Type object or raises a ValueError if the type cannot be parsed.
+Returns a Type object or raises an MLIRError if the type cannot be parsed.
 
 See also: https://mlir.llvm.org/docs/LangRef/#type-system
 )";
@@ -58,7 +59,7 @@ static const char kContextGetNameLocationDocString[] =
 static const char kModuleParseDocstring[] =
     R"(Parses a module's assembly format from a string.
 
-Returns a new MlirModule or raises a ValueError if the parsing fails.
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
 
 See also: https://mlir.llvm.org/docs/LangRef/
 )";
@@ -654,6 +655,20 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
   return pyHandlerObject;
 }
 
+MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
+                                                       void *userData) {
+  auto *self = static_cast<ErrorCapture *>(userData);
+  // Check if the context requested we emit errors instead of capturing them.
+  if (self->ctx->emitErrorDiagnostics)
+    return mlirLogicalResultFailure();
+
+  if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
+    return mlirLogicalResultFailure();
+
+  self->errors.emplace_back(PyDiagnostic(diag).getInfo());
+  return mlirLogicalResultSuccess();
+}
+
 PyMlirContext &DefaultingPyMlirContext::resolve() {
   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
   if (!context) {
@@ -870,6 +885,13 @@ py::tuple PyDiagnostic::getNotes() {
   return *materializedNotes;
 }
 
+PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
+  std::vector<DiagnosticInfo> notes;
+  for (py::handle n : getNotes())
+    notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
+  return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
+}
+
 //------------------------------------------------------------------------------
 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
 //------------------------------------------------------------------------------
@@ -1062,13 +1084,12 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
                                   const std::string &sourceStr,
                                   const std::string &sourceName) {
+  PyMlirContext::ErrorCapture errors(contextRef);
   MlirOperation op =
       mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
                                toMlirStringRef(sourceName));
-  // TODO: Include error diagnostic messages in the exception message
   if (mlirOperationIsNull(op))
-    throw py::value_error(
-        "Unable to parse operation assembly (see diagnostics)");
+    throw MLIRError("Unable to parse operation assembly", errors.take());
   return PyOperation::createDetached(std::move(contextRef), op);
 }
 
@@ -1155,6 +1176,14 @@ void PyOperationBase::moveBefore(PyOperationBase &other) {
   operation.parentKeepAlive = otherOp.parentKeepAlive;
 }
 
+bool PyOperationBase::verify() {
+  PyOperation &op = getOperation();
+  PyMlirContext::ErrorCapture errors(op.getContext());
+  if (!mlirOperationVerify(op.get()))
+    throw MLIRError("Verification failed", errors.take());
+  return true;
+}
+
 std::optional<PyOperationRef> PyOperation::getParentOperation() {
   checkValid();
   if (!isAttached())
@@ -2287,6 +2316,16 @@ void mlir::python::populateIRCore(py::module &m) {
         return self.getMessage();
       });
 
+  py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
+                                           py::module_local())
+      .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
+      .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
+      .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
+      .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
+      .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
+      .def("__str__",
+           [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
+
   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
       .def("detach", &PyDiagnosticHandler::detach)
       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
@@ -2375,6 +2414,11 @@ void mlir::python::populateIRCore(py::module &m) {
             mlirContextAppendDialectRegistry(self.get(), registry);
           },
           py::arg("registry"))
+      .def_property("emit_error_diagnostics", nullptr,
+                    &PyMlirContext::setEmitErrorDiagnostics,
+                    "Emit error diagnostics to diagnostic handlers. By default "
+                    "error diagnostics are captured and reported through "
+                    "MLIRError exceptions.")
       .def("load_all_available_dialects", [](PyMlirContext &self) {
         mlirContextLoadAllAvailableDialects(self.get());
       });
@@ -2566,16 +2610,12 @@ void mlir::python::populateIRCore(py::module &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
       .def_static(
           "parse",
-          [](const std::string moduleAsm, DefaultingPyMlirContext context) {
+          [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParse(
                 context->get(), toMlirStringRef(moduleAsm));
-            // TODO: Rework error reporting once diagnostic engine is exposed
-            // in C API.
-            if (mlirModuleIsNull(module)) {
-              throw SetPyError(
-                  PyExc_ValueError,
-                  "Unable to parse module assembly (see diagnostics)");
-            }
+            if (mlirModuleIsNull(module))
+              throw MLIRError("Unable to parse module assembly", errors.take());
             return PyModule::forModule(module).releaseObject();
           },
           py::arg("asm"), py::arg("context") = py::none(),
@@ -2724,13 +2764,9 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false,
            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
-      .def(
-          "verify",
-          [](PyOperationBase &self) {
-            return mlirOperationVerify(self.getOperation());
-          },
-          "Verify the operation and return true if it passes, false if it "
-          "fails.")
+      .def("verify", &PyOperationBase::verify,
+           "Verify the operation. Raises MLIRError if verification fails, and "
+           "returns true otherwise.")
       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
            "Puts self immediately after the other operation in its parent "
            "block.")
@@ -2833,12 +2869,12 @@ void mlir::python::populateIRCore(py::module &m) {
         // directly.
         std::string clsOpName =
             py::cast<std::string>(cls.attr("OPERATION_NAME"));
-        MlirStringRef parsedOpName =
+        MlirStringRef identifier =
             mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
-        if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName)))
-          throw py::value_error(
-              "Expected a '" + clsOpName + "' op, got: '" +
-              std::string(parsedOpName.data, parsedOpName.length) + "'");
+        std::string_view parsedOpName(identifier.data, identifier.length);
+        if (clsOpName != parsedOpName)
+          throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+                          parsedOpName + "'");
         return PyOpView::constructDerived(cls, *parsed.get());
       },
       py::arg("cls"), py::arg("source"), py::kw_only(),
@@ -3071,19 +3107,16 @@ void mlir::python::populateIRCore(py::module &m) {
       .def_static(
           "parse",
           [](std::string attrSpec, DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
             MlirAttribute type = mlirAttributeParseGet(
                 context->get(), toMlirStringRef(attrSpec));
-            // TODO: Rework error reporting once diagnostic engine is exposed
-            // in C API.
-            if (mlirAttributeIsNull(type)) {
-              throw SetPyError(PyExc_ValueError,
-                               Twine("Unable to parse attribute: '") +
-                                   attrSpec + "'");
-            }
+            if (mlirAttributeIsNull(type))
+              throw MLIRError("Unable to parse attribute", errors.take());
             return PyAttribute(context->getRef(), type);
           },
           py::arg("asm"), py::arg("context") = py::none(),
-          "Parses an attribute from an assembly form")
+          "Parses an attribute from an assembly form. Raises an MLIRError on "
+          "failure.")
       .def_property_readonly(
           "context",
           [](PyAttribute &self) { return self.getContext().getObject(); },
@@ -3182,15 +3215,11 @@ void mlir::python::populateIRCore(py::module &m) {
       .def_static(
           "parse",
           [](std::string typeSpec, DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
             MlirType type =
                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
-            // TODO: Rework error reporting once diagnostic engine is exposed
-            // in C API.
-            if (mlirTypeIsNull(type)) {
-              throw SetPyError(PyExc_ValueError,
-                               Twine("Unable to parse type: '") + typeSpec +
-                                   "'");
-            }
+            if (mlirTypeIsNull(type))
+              throw MLIRError("Unable to parse type", errors.take());
             return PyType(context->getRef(), type);
           },
           py::arg("asm"), py::arg("context") = py::none(),
@@ -3342,4 +3371,17 @@ void mlir::python::populateIRCore(py::module &m) {
 
   // Attribute builder getter.
   PyAttrBuilderMap::bind(m);
+
+  py::register_local_exception_translator([](std::exception_ptr p) {
+    // We can't define exceptions with custom fields through pybind, so instead
+    // the exception class is defined in python and imported here.
+    try {
+      if (p)
+        std::rethrow_exception(p);
+    } catch (const MLIRError &e) {
+      py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+                           .attr("MLIRError")(e.message, e.errorDiagnostics);
+      PyErr_SetObject(PyExc_Exception, obj.ptr());
+    }
+  });
 }
index 4aced36..fc236b1 100644 (file)
@@ -221,6 +221,11 @@ public:
   /// registration object (internally a PyDiagnosticHandler).
   pybind11::object attachDiagnosticHandler(pybind11::object callback);
 
+  /// Controls whether error diagnostics should be propagated to diagnostic
+  /// handlers, instead of being captured by `ErrorCapture`.
+  void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; }
+  struct ErrorCapture;
+
 private:
   PyMlirContext(MlirContext context);
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -248,6 +253,8 @@ private:
       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
   LiveOperationMap liveOperations;
 
+  bool emitErrorDiagnostics = false;
+
   MlirContext context;
   friend class PyModule;
   friend class PyOperation;
@@ -281,6 +288,34 @@ private:
   PyMlirContextRef contextRef;
 };
 
+/// Wrapper around an MlirLocation.
+class PyLocation : public BaseContextObject {
+public:
+  PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
+      : BaseContextObject(std::move(contextRef)), loc(loc) {}
+
+  operator MlirLocation() const { return loc; }
+  MlirLocation get() const { return loc; }
+
+  /// Enter and exit the context manager.
+  pybind11::object contextEnter();
+  void contextExit(const pybind11::object &excType,
+                   const pybind11::object &excVal,
+                   const pybind11::object &excTb);
+
+  /// Gets a capsule wrapping the void* within the MlirLocation.
+  pybind11::object getCapsule();
+
+  /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
+  /// Note that PyLocation instances are uniqued, so the returned object
+  /// may be a pre-existing object. Ownership of the underlying MlirLocation
+  /// is taken by calling this function.
+  static PyLocation createFromCapsule(pybind11::object capsule);
+
+private:
+  MlirLocation loc;
+};
+
 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs
 /// are only valid for the duration of a diagnostic callback and attempting
 /// to access them outside of that will raise an exception. This applies to
@@ -295,6 +330,16 @@ public:
   pybind11::str getMessage();
   pybind11::tuple getNotes();
 
+  /// Materialized diagnostic information. This is safe to access outside the
+  /// diagnostic callback.
+  struct DiagnosticInfo {
+    MlirDiagnosticSeverity severity;
+    PyLocation location;
+    std::string message;
+    std::vector<DiagnosticInfo> notes;
+  };
+  DiagnosticInfo getInfo();
+
 private:
   MlirDiagnostic diagnostic;
 
@@ -351,6 +396,30 @@ private:
   friend class PyMlirContext;
 };
 
+/// RAII object that captures any error diagnostics emitted to the provided
+/// context.
+struct PyMlirContext::ErrorCapture {
+  ErrorCapture(PyMlirContextRef ctx)
+      : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
+                      ctx->get(), handler, /*userData=*/this,
+                      /*deleteUserData=*/nullptr)) {}
+  ~ErrorCapture() {
+    mlirContextDetachDiagnosticHandler(ctx->get(), handlerID);
+    assert(errors.empty() && "unhandled captured errors");
+  }
+
+  std::vector<PyDiagnostic::DiagnosticInfo> take() {
+    return std::move(errors);
+  };
+
+private:
+  PyMlirContextRef ctx;
+  MlirDiagnosticHandlerID handlerID;
+  std::vector<PyDiagnostic::DiagnosticInfo> errors;
+
+  static MlirLogicalResult handler(MlirDiagnostic diag, void *userData);
+};
+
 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
 /// order to differentiate it from the `Dialect` base class which is extended by
 /// plugins which extend dialect functionality through extension python code.
@@ -416,34 +485,6 @@ private:
   MlirDialectRegistry registry;
 };
 
-/// Wrapper around an MlirLocation.
-class PyLocation : public BaseContextObject {
-public:
-  PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
-      : BaseContextObject(std::move(contextRef)), loc(loc) {}
-
-  operator MlirLocation() const { return loc; }
-  MlirLocation get() const { return loc; }
-
-  /// Enter and exit the context manager.
-  pybind11::object contextEnter();
-  void contextExit(const pybind11::object &excType,
-                   const pybind11::object &excVal,
-                   const pybind11::object &excTb);
-
-  /// Gets a capsule wrapping the void* within the MlirLocation.
-  pybind11::object getCapsule();
-
-  /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
-  /// Note that PyLocation instances are uniqued, so the returned object
-  /// may be a pre-existing object. Ownership of the underlying MlirLocation
-  /// is taken by calling this function.
-  static PyLocation createFromCapsule(pybind11::object capsule);
-
-private:
-  MlirLocation loc;
-};
-
 /// Used in function arguments when None should resolve to the current context
 /// manager set instance.
 class DefaultingPyLocation
@@ -519,6 +560,10 @@ public:
   void moveAfter(PyOperationBase &other);
   void moveBefore(PyOperationBase &other);
 
+  /// Verify the operation. Throws `MLIRError` if verification fails, and
+  /// returns `true` otherwise.
+  bool verify();
+
   /// Each must provide access to the raw Operation.
   virtual PyOperation &getOperation() = 0;
 };
@@ -1073,6 +1118,16 @@ private:
   MlirSymbolTable symbolTable;
 };
 
+/// Custom exception that allows access to error diagnostic information. This is
+/// converted to the `ir.MLIRError` python exception when thrown.
+struct MLIRError {
+  MLIRError(llvm::Twine message,
+            std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
+      : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
+  std::string message;
+  std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
+};
+
 void populateIRAffine(pybind11::module &m);
 void populateIRAttributes(pybind11::module &m);
 void populateIRCore(pybind11::module &m);
index 87ffe59..2166bab 100644 (file)
@@ -407,17 +407,11 @@ public:
         "get",
         [](std::vector<int64_t> shape, PyType &elementType,
            DefaultingPyLocation loc) {
+          PyMlirContext::ErrorCapture errors(loc->getContext());
           MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
                                                 elementType);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirTypeIsNull(t)) {
-            throw SetPyError(
-                PyExc_ValueError,
-                Twine("invalid '") +
-                    py::repr(py::cast(elementType)).cast<std::string>() +
-                    "' and expected floating point or integer type.");
-          }
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
           return PyVectorType(elementType.getContext(), t);
         },
         py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
@@ -438,20 +432,12 @@ public:
         "get",
         [](std::vector<int64_t> shape, PyType &elementType,
            std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+          PyMlirContext::ErrorCapture errors(loc->getContext());
           MlirType t = mlirRankedTensorTypeGetChecked(
               loc, shape.size(), shape.data(), elementType,
               encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirTypeIsNull(t)) {
-            throw SetPyError(
-                PyExc_ValueError,
-                Twine("invalid '") +
-                    py::repr(py::cast(elementType)).cast<std::string>() +
-                    "' and expected floating point, integer, vector or "
-                    "complex "
-                    "type.");
-          }
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
           return PyRankedTensorType(elementType.getContext(), t);
         },
         py::arg("shape"), py::arg("element_type"),
@@ -479,18 +465,10 @@ public:
     c.def_static(
         "get",
         [](PyType &elementType, DefaultingPyLocation loc) {
+          PyMlirContext::ErrorCapture errors(loc->getContext());
           MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
-          // TODO: Rework error reporting once diagnostic engine is exposed
-          // in C API.
-          if (mlirTypeIsNull(t)) {
-            throw SetPyError(
-                PyExc_ValueError,
-                Twine("invalid '") +
-                    py::repr(py::cast(elementType)).cast<std::string>() +
-                    "' and expected floating point, integer, vector or "
-                    "complex "
-                    "type.");
-          }
+          if (mlirTypeIsNull(t))
+            throw MLIRError("Invalid type", errors.take());
           return PyUnrankedTensorType(elementType.getContext(), t);
         },
         py::arg("element_type"), py::arg("loc") = py::none(),
@@ -511,23 +489,15 @@ public:
          [](std::vector<int64_t> shape, PyType &elementType,
             PyAttribute *layout, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
+           PyMlirContext::ErrorCapture errors(loc->getContext());
            MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
            MlirAttribute memSpaceAttr =
                memorySpace ? *memorySpace : mlirAttributeGetNull();
            MlirType t =
                mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
                                         shape.data(), layoutAttr, memSpaceAttr);
-           // TODO: Rework error reporting once diagnostic engine is exposed
-           // in C API.
-           if (mlirTypeIsNull(t)) {
-             throw SetPyError(
-                 PyExc_ValueError,
-                 Twine("invalid '") +
-                     py::repr(py::cast(elementType)).cast<std::string>() +
-                     "' and expected floating point, integer, vector or "
-                     "complex "
-                     "type.");
-           }
+           if (mlirTypeIsNull(t))
+             throw MLIRError("Invalid type", errors.take());
            return PyMemRefType(elementType.getContext(), t);
          },
          py::arg("shape"), py::arg("element_type"),
@@ -570,23 +540,15 @@ public:
          "get",
          [](PyType &elementType, PyAttribute *memorySpace,
             DefaultingPyLocation loc) {
+           PyMlirContext::ErrorCapture errors(loc->getContext());
            MlirAttribute memSpaceAttr = {};
            if (memorySpace)
              memSpaceAttr = *memorySpace;
 
            MlirType t =
                mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
-           // TODO: Rework error reporting once diagnostic engine is exposed
-           // in C API.
-           if (mlirTypeIsNull(t)) {
-             throw SetPyError(
-                 PyExc_ValueError,
-                 Twine("invalid '") +
-                     py::repr(py::cast(elementType)).cast<std::string>() +
-                     "' and expected floating point, integer, vector or "
-                     "complex "
-                     "type.");
-           }
+           if (mlirTypeIsNull(t))
+             throw MLIRError("Invalid type", errors.take());
            return PyUnrankedMemRefType(elementType.getContext(), t);
          },
          py::arg("element_type"), py::arg("memory_space"),
index 7e90d8b..79c5308 100644 (file)
@@ -117,15 +117,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
       .def(
           "run",
           [](PyPassManager &passManager, PyOperationBase &op) {
+            PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
                 passManager.get(), op.getOperation().get());
             if (mlirLogicalResultIsFailure(status))
-              throw SetPyError(PyExc_RuntimeError,
-                               "Failure while executing pass pipeline.");
+              throw MLIRError("Failure while executing pass pipeline",
+                              errors.take());
           },
           py::arg("operation"),
-          "Run the pass manager on the provided operation, throw a "
-          "RuntimeError on failure.")
+          "Run the pass manager on the provided operation, raising an "
+          "MLIRError on failure.")
       .def(
           "__str__",
           [](PyPassManager &self) {
index 9ceeef8..7d3d1f6 100644 (file)
@@ -100,8 +100,29 @@ def _site_initialize():
       # all dialects. It is being done here in order to preserve existing
       # behavior. See: https://github.com/llvm/llvm-project/issues/56037
       self.load_all_available_dialects()
-
   ir.Context = Context
 
+  class MLIRError(Exception):
+    """
+    An exception with diagnostic information. Has the following fields:
+      message: str
+      error_diagnostics: List[ir.DiagnosticInfo]
+    """
+    def __init__(self, message, error_diagnostics):
+      self.message = message
+      self.error_diagnostics = error_diagnostics
+      super().__init__(message, error_diagnostics)
+
+    def __str__(self):
+      s = self.message
+      if self.error_diagnostics:
+        s += ':'
+      for diag in self.error_diagnostics:
+        s += "\nerror: "  + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n  ')
+        for note in diag.notes:
+          s += "\n note: "  + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n  ')
+      return s
+  ir.MLIRError = MLIRError
+
 
 _site_initialize()
index 684d52c..1e1589d 100644 (file)
@@ -28,16 +28,17 @@ def testParsePrint():
 
 
 # CHECK-LABEL: TEST: testParseError
-# TODO: Hook the diagnostic manager to capture a more meaningful error
-# message.
 @run
 def testParseError():
   with Context():
     try:
       t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
-    except ValueError as e:
-      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
-      print("testParseError:", e)
+    except MLIRError as e:
+      # CHECK: testParseError: <
+      # CHECK:   Unable to parse attribute:
+      # CHECK:   error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
+      # CHECK: >
+      print(f"testParseError: <{e}>")
     else:
       print("Exception not produced")
 
@@ -180,8 +181,9 @@ def testFloatAttr():
     try:
       fattr_invalid = FloatAttr.get(
           IntegerType.get_signless(32), 42)
-    except ValueError as e:
-      # CHECK: invalid 'Type(i32)' and expected floating point type.
+    except MLIRError as e:
+      # CHECK: Invalid attribute:
+      # CHECK: error: unknown: expected floating point type
       print(e)
     else:
       print("Exception not produced")
index 7af8185..594cc66 100644 (file)
@@ -26,16 +26,17 @@ def testParsePrint():
 
 
 # CHECK-LABEL: TEST: testParseError
-# TODO: Hook the diagnostic manager to capture a more meaningful error
-# message.
 @run
 def testParseError():
   ctx = Context()
   try:
     t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
-  except ValueError as e:
-    # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
-    print("testParseError:", e)
+  except MLIRError as e:
+    # CHECK: testParseError: <
+    # CHECK:   Unable to parse type:
+    # CHECK:   error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
+    # CHECK: >
+    print(f"testParseError: <{e}>")
   else:
     print("Exception not produced")
 
@@ -292,8 +293,9 @@ def testVectorType():
     none = NoneType.get()
     try:
       vector_invalid = VectorType.get(shape, none)
-    except ValueError as e:
-      # CHECK: invalid 'Type(none)' and expected floating point or integer type.
+    except MLIRError as e:
+      # CHECK: Invalid type:
+      # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
       print(e)
     else:
       print("Exception not produced")
@@ -313,9 +315,9 @@ def testRankedTensorType():
     none = NoneType.get()
     try:
       tensor_invalid = RankedTensorType.get(shape, none)
-    except ValueError as e:
-      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-      # CHECK: or complex type.
+    except MLIRError as e:
+      # CHECK: Invalid type:
+      # CHECK: error: unknown: invalid tensor element type: 'none'
       print(e)
     else:
       print("Exception not produced")
@@ -361,9 +363,9 @@ def testUnrankedTensorType():
     none = NoneType.get()
     try:
       tensor_invalid = UnrankedTensorType.get(none)
-    except ValueError as e:
-      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-      # CHECK: or complex type.
+    except MLIRError as e:
+      # CHECK: Invalid type:
+      # CHECK: error: unknown: invalid tensor element type: 'none'
       print(e)
     else:
       print("Exception not produced")
@@ -400,9 +402,9 @@ def testMemRefType():
     none = NoneType.get()
     try:
       memref_invalid = MemRefType.get(shape, none)
-    except ValueError as e:
-      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-      # CHECK: or complex type.
+    except MLIRError as e:
+      # CHECK: Invalid type:
+      # CHECK: error: unknown: invalid memref element type
       print(e)
     else:
       print("Exception not produced")
@@ -444,9 +446,9 @@ def testUnrankedMemRefType():
     none = NoneType.get()
     try:
       memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
-    except ValueError as e:
-      # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
-      # CHECK: or complex type.
+    except MLIRError as e:
+      # CHECK: Invalid type:
+      # CHECK: error: unknown: invalid memref element type
       print(e)
     else:
       print("Exception not produced")
index d973db2..cc07f6e 100644 (file)
@@ -89,6 +89,7 @@ def testDiagnosticEmptyNotes():
 @run
 def testDiagnosticNonEmptyNotes():
   ctx = Context()
+  ctx.emit_error_diagnostics = True
   def callback(d):
     # CHECK: DIAGNOSTIC:
     # CHECK:   message='arith.addi' op requires one result
@@ -99,7 +100,10 @@ def testDiagnosticNonEmptyNotes():
     return True
   handler = ctx.attach_diagnostic_handler(callback)
   loc = Location.unknown(ctx)
-  Operation.create('arith.addi', loc=loc).verify()
+  try:
+    Operation.create('arith.addi', loc=loc).verify()
+  except MLIRError:
+    pass
   assert not handler.had_error
 
 # CHECK-LABEL: TEST: testDiagnosticCallbackException
diff --git a/mlir/test/python/ir/exception.py b/mlir/test/python/ir/exception.py
new file mode 100644 (file)
index 0000000..6cb2375
--- /dev/null
@@ -0,0 +1,77 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+  return f
+
+
+# CHECK-LABEL: TEST: test_exception
+@run
+def test_exception():
+  ctx =  Context()
+  ctx.allow_unregistered_dialects = True
+  try:
+    Operation.parse("""
+      func.func @foo() {
+          "test.use"(%0) : (i64) -> ()  loc("use")
+          %0 = "test.def"() : () -> i64 loc("def")
+          return
+      }
+    """, context=ctx)
+  except MLIRError as e:
+    # CHECK: Exception: <
+    # CHECK:   Unable to parse operation assembly:
+    # CHECK:   error: "use": operand #0 does not dominate this use
+    # CHECK:    note: "use": see current operation: "test.use"(%0) : (i64) -> ()
+    # CHECK:    note: "def": operand defined here (op in the same block)
+    # CHECK: >
+    print(f"Exception: <{e}>")
+
+    # CHECK: message: Unable to parse operation assembly
+    print(f"message: {e.message}")
+
+    # CHECK: error_diagnostics[0]:           loc("use") operand #0 does not dominate this use
+    # CHECK: error_diagnostics[0].notes[0]:  loc("use") see current operation: "test.use"(%0) : (i64) -> ()
+    # CHECK: error_diagnostics[0].notes[1]:  loc("def") operand defined here (op in the same block)
+    print("error_diagnostics[0]:          ", e.error_diagnostics[0].location, e.error_diagnostics[0].message)
+    print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message)
+    print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message)
+
+
+# CHECK-LABEL: test_emit_error_diagnostics
+@run
+def test_emit_error_diagnostics():
+  ctx = Context()
+  loc = Location.unknown(ctx)
+  handler_diags = []
+  def handler(d):
+    handler_diags.append(str(d))
+    return True
+  ctx.attach_diagnostic_handler(handler)
+
+  try:
+    Attribute.parse("not an attr", ctx)
+  except MLIRError as e:
+    # CHECK: emit_error_diagnostics=False:
+    # CHECK: e.error_diagnostics: ['expected attribute value']
+    # CHECK: handler_diags: []
+    print(f"emit_error_diagnostics=False:")
+    print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+    print(f"handler_diags: {handler_diags}")
+
+  ctx.emit_error_diagnostics = True
+  try:
+    Attribute.parse("not an attr", ctx)
+  except MLIRError as e:
+    # CHECK: emit_error_diagnostics=True:
+    # CHECK: e.error_diagnostics: []
+    # CHECK: handler_diags: ['expected attribute value']
+    print(f"emit_error_diagnostics=True:")
+    print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+    print(f"handler_diags: {handler_diags}")
index f0b6243..2d00923 100644 (file)
@@ -28,14 +28,17 @@ def testParseSuccess():
 
 # Verify parse error.
 # CHECK-LABEL: TEST: testParseError
-# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+# CHECK: testParseError: <
+# CHECK:   Unable to parse module assembly:
+# CHECK:   error: "-":1:1: expected operation name in quotes
+# CHECK: >
 @run
 def testParseError():
   ctx = Context()
   try:
     module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
-  except ValueError as e:
-    print("testParseError:", e)
+  except MLIRError as e:
+    print(f"testParseError: <{e}>")
   else:
     print("Exception not produced")
 
index be7467d..941420e 100644 (file)
@@ -685,8 +685,19 @@ def testInvalidOperationStrSoftFails():
     # CHECK: "builtin.module"() ({
     # CHECK: }) : () -> ()
     print(invalid_op)
-    # CHECK: .verify = False
-    print(f".verify = {invalid_op.operation.verify()}")
+    try:
+      invalid_op.verify()
+    except MLIRError as e:
+      # CHECK: Exception: <
+      # CHECK:   Verification failed:
+      # CHECK:   error: unknown: 'builtin.module' op requires one region
+      # CHECK:    note: unknown: see current operation:
+      # CHECK:     "builtin.module"() ({
+      # CHECK:     ^bb0:
+      # CHECK:     }, {
+      # CHECK:     }) : () -> ()
+      # CHECK: >
+      print(f"Exception: <{e}>")
 
 
 # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
@@ -920,7 +931,7 @@ def testOperationParse():
     assert isinstance(m, ModuleOp)
     try:
       ModuleOp.parse('"test.foo"() : () -> ()')
-    except ValueError as e:
+    except MLIRError as e:
       # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
       print(f"error: {e}")
     else:
index b3acd35..8b27653 100644 (file)
@@ -118,7 +118,7 @@ run(testInvalidNesting)
 
 
 # Verify that a pass manager can execute on IR
-# CHECK-LABEL: TEST: testRun
+# CHECK-LABEL: TEST: testRunPipeline
 def testRunPipeline():
   with Context():
     pm = PassManager.parse("any(print-op-stats{json=false})")
@@ -128,3 +128,20 @@ def testRunPipeline():
 # CHECK: func.func      , 1
 # CHECK: func.return        , 1
 run(testRunPipeline)
+
+# CHECK-LABEL: TEST: testRunPipelineError
+@run
+def testRunPipelineError():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    op = Operation.parse('"test.op"() : () -> ()')
+    pm = PassManager.parse("any(cse)")
+    try:
+      pm.run(op)
+    except MLIRError as e:
+      # CHECK: Exception: <
+      # CHECK:   Failure while executing pass pipeline:
+      # CHECK:   error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
+      # CHECK:    note: "-":1:1: see current operation: "test.op"() : () -> ()
+      # CHECK: >
+      print(f"Exception: <{e}>")