c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirAttributeIsNull(attr)) {
- throw SetPyError(PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(type)).cast<std::string>() +
- "' and expected floating point type.");
- }
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Debug.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
//#include "mlir-c/Registration.h"
#include "llvm/ADT/ArrayRef.h"
static const char kContextParseTypeDocstring[] =
R"(Parses the assembly form of a type.
-Returns a Type object or raises a ValueError if the type cannot be parsed.
+Returns a Type object or raises an MLIRError if the type cannot be parsed.
See also: https://mlir.llvm.org/docs/LangRef/#type-system
)";
static const char kModuleParseDocstring[] =
R"(Parses a module's assembly format from a string.
-Returns a new MlirModule or raises a ValueError if the parsing fails.
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";
return pyHandlerObject;
}
+MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
+ void *userData) {
+ auto *self = static_cast<ErrorCapture *>(userData);
+ // Check if the context requested we emit errors instead of capturing them.
+ if (self->ctx->emitErrorDiagnostics)
+ return mlirLogicalResultFailure();
+
+ if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
+ return mlirLogicalResultFailure();
+
+ self->errors.emplace_back(PyDiagnostic(diag).getInfo());
+ return mlirLogicalResultSuccess();
+}
+
PyMlirContext &DefaultingPyMlirContext::resolve() {
PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
if (!context) {
return *materializedNotes;
}
+PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
+ std::vector<DiagnosticInfo> notes;
+ for (py::handle n : getNotes())
+ notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
+ return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
+}
+
//------------------------------------------------------------------------------
// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
//------------------------------------------------------------------------------
PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
const std::string &sourceStr,
const std::string &sourceName) {
+ PyMlirContext::ErrorCapture errors(contextRef);
MlirOperation op =
mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
toMlirStringRef(sourceName));
- // TODO: Include error diagnostic messages in the exception message
if (mlirOperationIsNull(op))
- throw py::value_error(
- "Unable to parse operation assembly (see diagnostics)");
+ throw MLIRError("Unable to parse operation assembly", errors.take());
return PyOperation::createDetached(std::move(contextRef), op);
}
operation.parentKeepAlive = otherOp.parentKeepAlive;
}
+bool PyOperationBase::verify() {
+ PyOperation &op = getOperation();
+ PyMlirContext::ErrorCapture errors(op.getContext());
+ if (!mlirOperationVerify(op.get()))
+ throw MLIRError("Verification failed", errors.take());
+ return true;
+}
+
std::optional<PyOperationRef> PyOperation::getParentOperation() {
checkValid();
if (!isAttached())
return self.getMessage();
});
+ py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
+ py::module_local())
+ .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
+ .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
+ .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
+ .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
+ .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
+ .def("__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
+
py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
.def("detach", &PyDiagnosticHandler::detach)
.def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
mlirContextAppendDialectRegistry(self.get(), registry);
},
py::arg("registry"))
+ .def_property("emit_error_diagnostics", nullptr,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ "Emit error diagnostics to diagnostic handlers. By default "
+ "error diagnostics are captured and reported through "
+ "MLIRError exceptions.")
.def("load_all_available_dialects", [](PyMlirContext &self) {
mlirContextLoadAllAvailableDialects(self.get());
});
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
.def_static(
"parse",
- [](const std::string moduleAsm, DefaultingPyMlirContext context) {
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirModuleIsNull(module)) {
- throw SetPyError(
- PyExc_ValueError,
- "Unable to parse module assembly (see diagnostics)");
- }
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
py::arg("asm"), py::arg("context") = py::none(),
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, kOperationGetAsmDocstring)
- .def(
- "verify",
- [](PyOperationBase &self) {
- return mlirOperationVerify(self.getOperation());
- },
- "Verify the operation and return true if it passes, false if it "
- "fails.")
+ .def("verify", &PyOperationBase::verify,
+ "Verify the operation. Raises MLIRError if verification fails, and "
+ "returns true otherwise.")
.def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
"Puts self immediately after the other operation in its parent "
"block.")
// directly.
std::string clsOpName =
py::cast<std::string>(cls.attr("OPERATION_NAME"));
- MlirStringRef parsedOpName =
+ MlirStringRef identifier =
mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
- if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName)))
- throw py::value_error(
- "Expected a '" + clsOpName + "' op, got: '" +
- std::string(parsedOpName.data, parsedOpName.length) + "'");
+ std::string_view parsedOpName(identifier.data, identifier.length);
+ if (clsOpName != parsedOpName)
+ throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+ parsedOpName + "'");
return PyOpView::constructDerived(cls, *parsed.get());
},
py::arg("cls"), py::arg("source"), py::kw_only(),
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute type = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirAttributeIsNull(type)) {
- throw SetPyError(PyExc_ValueError,
- Twine("Unable to parse attribute: '") +
- attrSpec + "'");
- }
+ if (mlirAttributeIsNull(type))
+ throw MLIRError("Unable to parse attribute", errors.take());
return PyAttribute(context->getRef(), type);
},
py::arg("asm"), py::arg("context") = py::none(),
- "Parses an attribute from an assembly form")
+ "Parses an attribute from an assembly form. Raises an MLIRError on "
+ "failure.")
.def_property_readonly(
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type =
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(type)) {
- throw SetPyError(PyExc_ValueError,
- Twine("Unable to parse type: '") + typeSpec +
- "'");
- }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Unable to parse type", errors.take());
return PyType(context->getRef(), type);
},
py::arg("asm"), py::arg("context") = py::none(),
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
+
+ py::register_local_exception_translator([](std::exception_ptr p) {
+ // We can't define exceptions with custom fields through pybind, so instead
+ // the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
/// registration object (internally a PyDiagnosticHandler).
pybind11::object attachDiagnosticHandler(pybind11::object callback);
+ /// Controls whether error diagnostics should be propagated to diagnostic
+ /// handlers, instead of being captured by `ErrorCapture`.
+ void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; }
+ struct ErrorCapture;
+
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
LiveOperationMap liveOperations;
+ bool emitErrorDiagnostics = false;
+
MlirContext context;
friend class PyModule;
friend class PyOperation;
PyMlirContextRef contextRef;
};
+/// Wrapper around an MlirLocation.
+class PyLocation : public BaseContextObject {
+public:
+ PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
+ : BaseContextObject(std::move(contextRef)), loc(loc) {}
+
+ operator MlirLocation() const { return loc; }
+ MlirLocation get() const { return loc; }
+
+ /// Enter and exit the context manager.
+ pybind11::object contextEnter();
+ void contextExit(const pybind11::object &excType,
+ const pybind11::object &excVal,
+ const pybind11::object &excTb);
+
+ /// Gets a capsule wrapping the void* within the MlirLocation.
+ pybind11::object getCapsule();
+
+ /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
+ /// Note that PyLocation instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirLocation
+ /// is taken by calling this function.
+ static PyLocation createFromCapsule(pybind11::object capsule);
+
+private:
+ MlirLocation loc;
+};
+
/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
pybind11::str getMessage();
pybind11::tuple getNotes();
+ /// Materialized diagnostic information. This is safe to access outside the
+ /// diagnostic callback.
+ struct DiagnosticInfo {
+ MlirDiagnosticSeverity severity;
+ PyLocation location;
+ std::string message;
+ std::vector<DiagnosticInfo> notes;
+ };
+ DiagnosticInfo getInfo();
+
private:
MlirDiagnostic diagnostic;
friend class PyMlirContext;
};
+/// RAII object that captures any error diagnostics emitted to the provided
+/// context.
+struct PyMlirContext::ErrorCapture {
+ ErrorCapture(PyMlirContextRef ctx)
+ : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
+ ctx->get(), handler, /*userData=*/this,
+ /*deleteUserData=*/nullptr)) {}
+ ~ErrorCapture() {
+ mlirContextDetachDiagnosticHandler(ctx->get(), handlerID);
+ assert(errors.empty() && "unhandled captured errors");
+ }
+
+ std::vector<PyDiagnostic::DiagnosticInfo> take() {
+ return std::move(errors);
+ };
+
+private:
+ PyMlirContextRef ctx;
+ MlirDiagnosticHandlerID handlerID;
+ std::vector<PyDiagnostic::DiagnosticInfo> errors;
+
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *userData);
+};
+
/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
/// order to differentiate it from the `Dialect` base class which is extended by
/// plugins which extend dialect functionality through extension python code.
MlirDialectRegistry registry;
};
-/// Wrapper around an MlirLocation.
-class PyLocation : public BaseContextObject {
-public:
- PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
- : BaseContextObject(std::move(contextRef)), loc(loc) {}
-
- operator MlirLocation() const { return loc; }
- MlirLocation get() const { return loc; }
-
- /// Enter and exit the context manager.
- pybind11::object contextEnter();
- void contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb);
-
- /// Gets a capsule wrapping the void* within the MlirLocation.
- pybind11::object getCapsule();
-
- /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
- /// Note that PyLocation instances are uniqued, so the returned object
- /// may be a pre-existing object. Ownership of the underlying MlirLocation
- /// is taken by calling this function.
- static PyLocation createFromCapsule(pybind11::object capsule);
-
-private:
- MlirLocation loc;
-};
-
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
class DefaultingPyLocation
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
+ /// Verify the operation. Throws `MLIRError` if verification fails, and
+ /// returns `true` otherwise.
+ bool verify();
+
/// Each must provide access to the raw Operation.
virtual PyOperation &getOperation() = 0;
};
MlirSymbolTable symbolTable;
};
+/// Custom exception that allows access to error diagnostic information. This is
+/// converted to the `ir.MLIRError` python exception when thrown.
+struct MLIRError {
+ MLIRError(llvm::Twine message,
+ std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
+ : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
+ std::string message;
+ std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
+};
+
void populateIRAffine(pybind11::module &m);
void populateIRAttributes(pybind11::module &m);
void populateIRCore(pybind11::module &m);
"get",
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point or integer type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
"get",
[](std::vector<int64_t> shape, PyType &elementType,
std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(), elementType,
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyRankedTensorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
c.def_static(
"get",
[](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyUnrankedTensorType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("loc") = py::none(),
[](std::vector<int64_t> shape, PyType &elementType,
PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
MlirAttribute memSpaceAttr =
memorySpace ? *memorySpace : mlirAttributeGetNull();
MlirType t =
mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
shape.data(), layoutAttr, memSpaceAttr);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyMemRefType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
"get",
[](PyType &elementType, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute memSpaceAttr = {};
if (memorySpace)
memSpaceAttr = *memorySpace;
MlirType t =
mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyUnrankedMemRefType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("memory_space"),
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
+ PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
- throw SetPyError(PyExc_RuntimeError,
- "Failure while executing pass pipeline.");
+ throw MLIRError("Failure while executing pass pipeline",
+ errors.take());
},
py::arg("operation"),
- "Run the pass manager on the provided operation, throw a "
- "RuntimeError on failure.")
+ "Run the pass manager on the provided operation, raising an "
+ "MLIRError on failure.")
.def(
"__str__",
[](PyPassManager &self) {
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
-
ir.Context = Context
+ class MLIRError(Exception):
+ """
+ An exception with diagnostic information. Has the following fields:
+ message: str
+ error_diagnostics: List[ir.DiagnosticInfo]
+ """
+ def __init__(self, message, error_diagnostics):
+ self.message = message
+ self.error_diagnostics = error_diagnostics
+ super().__init__(message, error_diagnostics)
+
+ def __str__(self):
+ s = self.message
+ if self.error_diagnostics:
+ s += ':'
+ for diag in self.error_diagnostics:
+ s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ')
+ for note in diag.notes:
+ s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ')
+ return s
+ ir.MLIRError = MLIRError
+
_site_initialize()
# CHECK-LABEL: TEST: testParseError
-# TODO: Hook the diagnostic manager to capture a more meaningful error
-# message.
@run
def testParseError():
with Context():
try:
t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
- except ValueError as e:
- # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
- print("testParseError:", e)
+ except MLIRError as e:
+ # CHECK: testParseError: <
+ # CHECK: Unable to parse attribute:
+ # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
+ # CHECK: >
+ print(f"testParseError: <{e}>")
else:
print("Exception not produced")
try:
fattr_invalid = FloatAttr.get(
IntegerType.get_signless(32), 42)
- except ValueError as e:
- # CHECK: invalid 'Type(i32)' and expected floating point type.
+ except MLIRError as e:
+ # CHECK: Invalid attribute:
+ # CHECK: error: unknown: expected floating point type
print(e)
else:
print("Exception not produced")
# CHECK-LABEL: TEST: testParseError
-# TODO: Hook the diagnostic manager to capture a more meaningful error
-# message.
@run
def testParseError():
ctx = Context()
try:
t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
- except ValueError as e:
- # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
- print("testParseError:", e)
+ except MLIRError as e:
+ # CHECK: testParseError: <
+ # CHECK: Unable to parse type:
+ # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
+ # CHECK: >
+ print(f"testParseError: <{e}>")
else:
print("Exception not produced")
none = NoneType.get()
try:
vector_invalid = VectorType.get(shape, none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point or integer type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
print(e)
else:
print("Exception not produced")
none = NoneType.get()
try:
tensor_invalid = RankedTensorType.get(shape, none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid tensor element type: 'none'
print(e)
else:
print("Exception not produced")
none = NoneType.get()
try:
tensor_invalid = UnrankedTensorType.get(none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid tensor element type: 'none'
print(e)
else:
print("Exception not produced")
none = NoneType.get()
try:
memref_invalid = MemRefType.get(shape, none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid memref element type
print(e)
else:
print("Exception not produced")
none = NoneType.get()
try:
memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid memref element type
print(e)
else:
print("Exception not produced")
@run
def testDiagnosticNonEmptyNotes():
ctx = Context()
+ ctx.emit_error_diagnostics = True
def callback(d):
# CHECK: DIAGNOSTIC:
# CHECK: message='arith.addi' op requires one result
return True
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
- Operation.create('arith.addi', loc=loc).verify()
+ try:
+ Operation.create('arith.addi', loc=loc).verify()
+ except MLIRError:
+ pass
assert not handler.had_error
# CHECK-LABEL: TEST: testDiagnosticCallbackException
--- /dev/null
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
+
+
+# CHECK-LABEL: TEST: test_exception
+@run
+def test_exception():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ try:
+ Operation.parse("""
+ func.func @foo() {
+ "test.use"(%0) : (i64) -> () loc("use")
+ %0 = "test.def"() : () -> i64 loc("def")
+ return
+ }
+ """, context=ctx)
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Unable to parse operation assembly:
+ # CHECK: error: "use": operand #0 does not dominate this use
+ # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> ()
+ # CHECK: note: "def": operand defined here (op in the same block)
+ # CHECK: >
+ print(f"Exception: <{e}>")
+
+ # CHECK: message: Unable to parse operation assembly
+ print(f"message: {e.message}")
+
+ # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use
+ # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> ()
+ # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block)
+ print("error_diagnostics[0]: ", e.error_diagnostics[0].location, e.error_diagnostics[0].message)
+ print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message)
+ print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message)
+
+
+# CHECK-LABEL: test_emit_error_diagnostics
+@run
+def test_emit_error_diagnostics():
+ ctx = Context()
+ loc = Location.unknown(ctx)
+ handler_diags = []
+ def handler(d):
+ handler_diags.append(str(d))
+ return True
+ ctx.attach_diagnostic_handler(handler)
+
+ try:
+ Attribute.parse("not an attr", ctx)
+ except MLIRError as e:
+ # CHECK: emit_error_diagnostics=False:
+ # CHECK: e.error_diagnostics: ['expected attribute value']
+ # CHECK: handler_diags: []
+ print(f"emit_error_diagnostics=False:")
+ print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+ print(f"handler_diags: {handler_diags}")
+
+ ctx.emit_error_diagnostics = True
+ try:
+ Attribute.parse("not an attr", ctx)
+ except MLIRError as e:
+ # CHECK: emit_error_diagnostics=True:
+ # CHECK: e.error_diagnostics: []
+ # CHECK: handler_diags: ['expected attribute value']
+ print(f"emit_error_diagnostics=True:")
+ print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+ print(f"handler_diags: {handler_diags}")
# Verify parse error.
# CHECK-LABEL: TEST: testParseError
-# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+# CHECK: testParseError: <
+# CHECK: Unable to parse module assembly:
+# CHECK: error: "-":1:1: expected operation name in quotes
+# CHECK: >
@run
def testParseError():
ctx = Context()
try:
module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
- except ValueError as e:
- print("testParseError:", e)
+ except MLIRError as e:
+ print(f"testParseError: <{e}>")
else:
print("Exception not produced")
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(invalid_op)
- # CHECK: .verify = False
- print(f".verify = {invalid_op.operation.verify()}")
+ try:
+ invalid_op.verify()
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Verification failed:
+ # CHECK: error: unknown: 'builtin.module' op requires one region
+ # CHECK: note: unknown: see current operation:
+ # CHECK: "builtin.module"() ({
+ # CHECK: ^bb0:
+ # CHECK: }, {
+ # CHECK: }) : () -> ()
+ # CHECK: >
+ print(f"Exception: <{e}>")
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
assert isinstance(m, ModuleOp)
try:
ModuleOp.parse('"test.foo"() : () -> ()')
- except ValueError as e:
+ except MLIRError as e:
# CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
print(f"error: {e}")
else:
# Verify that a pass manager can execute on IR
-# CHECK-LABEL: TEST: testRun
+# CHECK-LABEL: TEST: testRunPipeline
def testRunPipeline():
with Context():
pm = PassManager.parse("any(print-op-stats{json=false})")
# CHECK: func.func , 1
# CHECK: func.return , 1
run(testRunPipeline)
+
+# CHECK-LABEL: TEST: testRunPipelineError
+@run
+def testRunPipelineError():
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ op = Operation.parse('"test.op"() : () -> ()')
+ pm = PassManager.parse("any(cse)")
+ try:
+ pm.run(op)
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Failure while executing pass pipeline:
+ # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
+ # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
+ # CHECK: >
+ print(f"Exception: <{e}>")