#include "IRModules.h"
#include "PybindUtils.h"
+#include "mlir-c/StandardTypes.h"
+
namespace py = pybind11;
+using namespace mlir;
using namespace mlir::python;
//------------------------------------------------------------------------------
R"(Parses a module's assembly format from a string.
Returns a new MlirModule or raises a ValueError if the parsing fails.
+
+See also: https://mlir.llvm.org/docs/LangRef/
+)";
+
+static const char kContextParseType[] = R"(Parses the assembly form of a type.
+
+Returns a Type object or raises a ValueError if the type cannot be parsed.
+
+See also: https://mlir.llvm.org/docs/LangRef/#type-system
)";
static const char kOperationStrDunderDocstring[] =
behavior.
)";
+static const char kTypeStrDunderDocstring[] =
+ R"(Prints the assembly form of the type.)";
+
static const char kDumpDocstring[] =
R"(Dumps a debug representation of the object to stderr.)";
} // namespace
//------------------------------------------------------------------------------
-// Context Wrapper Class.
+// PyType.
//------------------------------------------------------------------------------
-PyMlirModule PyMlirContext::parse(const std::string &module) {
- auto moduleRef = mlirModuleCreateParse(context, module.c_str());
- if (!moduleRef.ptr) {
- throw SetPyError(PyExc_ValueError,
- "Unable to parse module assembly (see diagnostics)");
- }
- return PyMlirModule(moduleRef);
+bool PyType::operator==(const PyType &other) {
+ return mlirTypeEqual(type, other.type);
}
//------------------------------------------------------------------------------
-// Module Wrapper Class.
+// Standard type subclasses.
//------------------------------------------------------------------------------
-void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
+namespace {
+
+/// CRTP base classes for Python types that subclass Type and should be
+/// castable from it (i.e. via something like IntegerType(t)).
+template <typename T>
+class PyConcreteType : public PyType {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ using ClassTy = py::class_<T, PyType>;
+ using IsAFunctionTy = int (*)(MlirType);
+
+ PyConcreteType() = default;
+ PyConcreteType(MlirType t) : PyType(t) {}
+ PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
+
+ static MlirType castFrom(PyType &orig) {
+ if (!T::isaFunction(orig.type)) {
+ auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
+ throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
+ T::pyClassName + " (from " +
+ origRepr + ")");
+ }
+ return orig.type;
+ }
+
+ static void bind(py::module &m) {
+ auto class_ = ClassTy(m, T::pyClassName);
+ class_.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+ T::bindDerived(class_);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+class PyIntegerType : public PyConcreteType<PyIntegerType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "signless",
+ [](PyMlirContext &context, unsigned width) {
+ MlirType t = mlirIntegerTypeGet(context.context, width);
+ return PyIntegerType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a signless integer type");
+ c.def_static(
+ "signed",
+ [](PyMlirContext &context, unsigned width) {
+ MlirType t = mlirIntegerTypeSignedGet(context.context, width);
+ return PyIntegerType(t);
+ },
+ py::keep_alive<0, 1>(), "Create a signed integer type");
+ c.def_static(
+ "unsigned",
+ [](PyMlirContext &context, unsigned width) {
+ MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
+ return PyIntegerType(t);
+ },
+ py::keep_alive<0, 1>(), "Create an unsigned integer type");
+ c.def_property_readonly(
+ "width",
+ [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
+ "Returns the width of the integer type");
+ c.def_property_readonly(
+ "is_signless",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSignless(self.type);
+ },
+ "Returns whether this is a signless integer");
+ c.def_property_readonly(
+ "is_signed",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsSigned(self.type);
+ },
+ "Returns whether this is a signed integer");
+ c.def_property_readonly(
+ "is_unsigned",
+ [](PyIntegerType &self) -> bool {
+ return mlirIntegerTypeIsUnsigned(self.type);
+ },
+ "Returns whether this is an unsigned integer");
+ }
+};
+
+} // namespace
//------------------------------------------------------------------------------
// Populates the pybind11 IR submodule.
//------------------------------------------------------------------------------
void mlir::python::populateIRSubmodule(py::module &m) {
- py::class_<PyMlirContext>(m, "MlirContext")
+ // Mapping of MlirContext
+ py::class_<PyMlirContext>(m, "Context")
.def(py::init<>())
- .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(),
- kContextParseDocstring);
+ .def(
+ "parse_module",
+ [](PyMlirContext &self, const std::string module) {
+ auto moduleRef =
+ mlirModuleCreateParse(self.context, module.c_str());
+ if (mlirModuleIsNull(moduleRef)) {
+ throw SetPyError(
+ PyExc_ValueError,
+ "Unable to parse module assembly (see diagnostics)");
+ }
+ return PyModule(moduleRef);
+ },
+ py::keep_alive<0, 1>(), kContextParseDocstring)
+ .def(
+ "parse_type",
+ [](PyMlirContext &self, std::string typeSpec) {
+ MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+ if (mlirTypeIsNull(type)) {
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("Unable to parse type: '") +
+ typeSpec + "'");
+ }
+ return PyType(type);
+ },
+ py::keep_alive<0, 1>(), kContextParseType);
- py::class_<PyMlirModule>(m, "MlirModule")
- .def("dump", &PyMlirModule::dump, kDumpDocstring)
+ // Mapping of Module
+ py::class_<PyModule>(m, "Module")
+ .def(
+ "dump",
+ [](PyModule &self) {
+ mlirOperationDump(mlirModuleGetOperation(self.module));
+ },
+ kDumpDocstring)
.def(
"__str__",
- [](PyMlirModule &self) {
+ [](PyModule &self) {
auto operation = mlirModuleGetOperation(self.module);
PyPrintAccumulator printAccum;
mlirOperationPrint(operation, printAccum.getCallback(),
return printAccum.join();
},
kOperationStrDunderDocstring);
+
+ // Mapping of Type.
+ py::class_<PyType>(m, "Type")
+ .def("__eq__",
+ [](PyType &self, py::object &other) {
+ try {
+ PyType otherType = other.cast<PyType>();
+ return self == otherType;
+ } catch (std::exception &e) {
+ return false;
+ }
+ })
+ .def(
+ "dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyType &self) {
+ PyPrintAccumulator printAccum;
+ mlirTypePrint(self.type, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ kTypeStrDunderDocstring)
+ .def("__repr__", [](PyType &self) {
+ // Generally, assembly formats are not printed for __repr__ because
+ // this can cause exceptionally long debug output and exceptions.
+ // However, types are an exception as they typically have compact
+ // assembly forms and printing them is useful.
+ PyPrintAccumulator printAccum;
+ printAccum.parts.append("Type(");
+ mlirTypePrint(self.type, printAccum.getCallback(),
+ printAccum.getUserData());
+ printAccum.parts.append(")");
+ return printAccum.join();
+ });
+
+ // Standard type bindings.
+ PyIntegerType::bind(m);
}
--- /dev/null
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testParsePrint
+def testParsePrint():
+ ctx = mlir.ir.Context()
+ t = ctx.parse_type("i32")
+ # CHECK: i32
+ print(str(t))
+ # CHECK: Type(i32)
+ print(repr(t))
+
+run(testParsePrint)
+
+
+# CHECK-LABEL: TEST: testParseError
+# TODO: Hook the diagnostic manager to capture a more meaningful error
+# message.
+def testParseError():
+ ctx = mlir.ir.Context()
+ try:
+ t = ctx.parse_type("BAD_TYPE_DOES_NOT_EXIST")
+ except ValueError as e:
+ # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
+ print("testParseError:", e)
+ else:
+ print("Exception not produced")
+
+run(testParseError)
+
+
+# CHECK-LABEL: TEST: testTypeEq
+def testTypeEq():
+ ctx = mlir.ir.Context()
+ t1 = ctx.parse_type("i32")
+ t2 = ctx.parse_type("f32")
+ t3 = ctx.parse_type("i32")
+ # CHECK: t1 == t1: True
+ print("t1 == t1:", t1 == t1)
+ # CHECK: t1 == t2: False
+ print("t1 == t2:", t1 == t2)
+ # CHECK: t1 == t3: True
+ print("t1 == t3:", t1 == t3)
+ # CHECK: t1 == None: False
+ print("t1 == None:", t1 == None)
+
+run(testTypeEq)
+
+
+# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
+def testTypeEqDoesNotRaise():
+ ctx = mlir.ir.Context()
+ t1 = ctx.parse_type("i32")
+ not_a_type = "foo"
+ # CHECK: False
+ print(t1 == not_a_type)
+ # CHECK: False
+ print(t1 == None)
+ # CHECK: True
+ print(t1 != None)
+
+run(testTypeEqDoesNotRaise)
+
+
+# CHECK-LABEL: TEST: testStandardTypeCasts
+def testStandardTypeCasts():
+ ctx = mlir.ir.Context()
+ t1 = ctx.parse_type("i32")
+ tint = mlir.ir.IntegerType(t1)
+ tself = mlir.ir.IntegerType(tint)
+ # CHECK: Type(i32)
+ print(repr(tint))
+ try:
+ tillegal = mlir.ir.IntegerType(ctx.parse_type("f32"))
+ except ValueError as e:
+ # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32))
+ print("ValueError:", e)
+ else:
+ print("Exception not produced")
+
+run(testStandardTypeCasts)
+
+
+# CHECK-LABEL: TEST: testIntegerType
+def testIntegerType():
+ ctx = mlir.ir.Context()
+ i32 = mlir.ir.IntegerType(ctx.parse_type("i32"))
+ # CHECK: i32 width: 32
+ print("i32 width:", i32.width)
+ # CHECK: i32 signless: True
+ print("i32 signless:", i32.is_signless)
+ # CHECK: i32 signed: False
+ print("i32 signed:", i32.is_signed)
+ # CHECK: i32 unsigned: False
+ print("i32 unsigned:", i32.is_unsigned)
+
+ s32 = mlir.ir.IntegerType(ctx.parse_type("si32"))
+ # CHECK: s32 signless: False
+ print("s32 signless:", s32.is_signless)
+ # CHECK: s32 signed: True
+ print("s32 signed:", s32.is_signed)
+ # CHECK: s32 unsigned: False
+ print("s32 unsigned:", s32.is_unsigned)
+
+ u32 = mlir.ir.IntegerType(ctx.parse_type("ui32"))
+ # CHECK: u32 signless: False
+ print("u32 signless:", u32.is_signless)
+ # CHECK: u32 signed: False
+ print("u32 signed:", u32.is_signed)
+ # CHECK: u32 unsigned: True
+ print("u32 unsigned:", u32.is_unsigned)
+
+ # CHECK: signless: i16
+ print("signless:", mlir.ir.IntegerType.signless(ctx, 16))
+ # CHECK: signed: si8
+ print("signed:", mlir.ir.IntegerType.signed(ctx, 8))
+ # CHECK: unsigned: ui64
+ print("unsigned:", mlir.ir.IntegerType.unsigned(ctx, 64))
+
+run(testIntegerType)