* delineated). */
#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
+/** Attribute on MLIR Python objects that expose a function for downcasting the
+ * corresponding Python object to a subclass if the object is in fact a subclass
+ * (Concrete or mlir_type_subclass) of ir.Type. The signature of the function
+ * is: def maybe_downcast(self) -> object where the resulting object will
+ * (possibly) be an instance of the subclass.
+ */
+#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast"
+
+/** Attribute on main C extension module (_mlir) that corresponds to the
+ * type caster registration binding. The signature of the function is:
+ * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
+ * bool replace)
+ * where replace indicates the typeCaster should replace any existing registered
+ * type casters (such as those for upstream ConcreteTypes).
+ */
+#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
+
/// Gets a void* from a wrapped struct. Needed because const cast is different
/// between C/C++.
#ifdef __cplusplus
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED MlirType
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
/// Gets the type ID of the type.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
+/// Gets the dialect a type belongs to.
+MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type);
+
/// Checks whether a type is null.
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
#include "llvm/ADT/Twine.h"
namespace py = pybind11;
+using namespace py::literals;
// Raw CAPI type casters need to be declared before use, so always include them
// first.
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
+ .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
}
};
class mlir_type_subclass : public pure_subclass {
public:
using IsAFunctionTy = bool (*)(MlirType);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
/// Subclasses by looking up the super-class dynamically.
mlir_type_subclass(py::handle scope, const char *typeClassName,
- IsAFunctionTy isaFunction)
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_type_subclass(
scope, typeClassName, isaFunction,
- py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
+ getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
/// as the mlir.ir class (otherwise, it will trigger a recursive
/// initialization).
mlir_type_subclass(py::handle scope, const char *typeClassName,
- IsAFunctionTy isaFunction, const py::object &superCls)
+ IsAFunctionTy isaFunction, const py::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it hard, if not impossible, to properly
// call chain to parent `__init__` in pybind11 due to its special handling
"isinstance",
[isaFunction](MlirType other) { return isaFunction(other); },
py::arg("other_type"));
+ def("__repr__", [superCls, captureTypeName](py::object self) {
+ return py::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction(),
+ pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirType) {
+ return thisClass(mlirType);
+ }));
+ }
}
};
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
+namespace llvm {
+
+template <>
+struct DenseMapInfo<MlirTypeID> {
+ static inline MlirTypeID getEmptyKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlirTypeIDCreate(pointer);
+ }
+ static inline MlirTypeID getTombstoneKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlirTypeIDCreate(pointer);
+ }
+ static inline unsigned getHashValue(const MlirTypeID &val) {
+ return mlirTypeIDHashValue(val);
+ }
+ static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
+ return mlirTypeIDEqual(lhs, rhs);
+ }
+};
+} // namespace llvm
+
#endif // MLIR_CAPI_SUPPORT_H
//===-------------------------------------------------------------------===//
auto operationType =
- mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
+ mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
+ mlirTransformOperationTypeGetTypeID);
operationType.def_classmethod(
"get",
[](py::object cls, const std::string &operationName, MlirContext ctx) {
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
+#include <optional>
#include <string>
#include <vector>
-#include <optional>
#include "PybindUtils.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
/// entities.
void loadDialectModule(llvm::StringRef dialectNamespace);
- /// Decorator for registering a custom Dialect class. The class object must
- /// have a DIALECT_NAMESPACE attribute.
- pybind11::object registerDialectDecorator(pybind11::object pyClass);
-
/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc);
+ /// Adds a user-friendly type caster. Raises an exception if the mapping
+ /// already exists and replace == false. This is intended to be called by
+ /// implementation code.
+ void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
+ bool replace = false);
+
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
std::optional<pybind11::function>
lookupAttributeBuilder(const std::string &attributeKind);
+ /// Returns the custom type caster for MlirTypeID mlirTypeID.
+ std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect);
+
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
llvm::StringMap<pybind11::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
+ /// Map of MlirTypeID to custom type caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
+ /// Cache for map of MlirTypeID to custom type caster.
+ llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir;
py::arg("value"), py::arg("context") = py::none(),
"Gets a uniqued Type attribute");
c.def_property_readonly("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext()->getRef(),
- mlirTypeAttrGetValue(self.get()));
+ return mlirTypeAttrGetValue(self.get());
});
}
};
#include <utility>
namespace py = pybind11;
+using namespace py::literals;
using namespace mlir;
using namespace mlir::python;
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<PyType> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<PyType> result;
+static std::vector<MlirType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<MlirType> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(
- PyType(context, mlirValueGetType(container.getElement(i).get())));
+ result.push_back(mlirValueGetType(container.getElement(i).get()));
}
return result;
}
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
"Context that owns the Attribute")
- .def_property_readonly("type",
- [](PyAttribute &self) {
- return PyType(self.getContext()->getRef(),
- mlirAttributeGetType(self));
- })
+ .def_property_readonly(
+ "type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
.def(
"get_named",
[](PyAttribute &self, std::string name) {
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
if (mlirTypeIsNull(type))
throw MLIRError("Unable to parse type", errors.take());
- return PyType(context->getRef(), type);
+ return type;
},
py::arg("asm"), py::arg("context") = py::none(),
kContextParseTypeDocstring)
printAccum.parts.append(")");
return printAccum.join();
})
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) {
+ MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
+ assert(!mlirTypeIDIsNull(mlirTypeID) &&
+ "mlirTypeID was expected to be non-null.");
+ std::optional<pybind11::function> typeCaster =
+ PyGlobals::get().lookupTypeCaster(mlirTypeID,
+ mlirTypeGetDialect(self));
+ if (!typeCaster)
+ return py::cast(self);
+ return typeCaster.value()(self);
+ })
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
- .def_property_readonly("type",
- [](PyValue &self) {
- return PyType(
- self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()));
- })
+ .def_property_readonly(
+ "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
py::module_local())
.def_property_readonly(
"element_type",
- [](PyShapedTypeComponents &self) {
- return PyType(PyMlirContext::forContext(
- mlirTypeGetContext(self.elementType)),
- self.elementType);
- },
+ [](PyShapedTypeComponents &self) { return self.elementType; },
"Returns the element type of the shaped type components.")
.def_static(
"get",
#include <vector>
#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Support.h"
namespace py = pybind11;
using namespace mlir;
found = std::move(pyFunc);
}
+void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
+ pybind11::function typeCaster,
+ bool replace) {
+ pybind11::object &found = typeCasterMap[mlirTypeID];
+ if (found && !found.is_none() && !replace)
+ throw std::runtime_error("Type caster is already registered");
+ found = std::move(typeCaster);
+}
+
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
return std::nullopt;
}
+std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
+ MlirDialect dialect) {
+ {
+ // Fast match against the class map first (common case).
+ const auto foundIt = typeCasterMapCache.find(mlirTypeID);
+ if (foundIt != typeCasterMapCache.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::function is defined");
+ return foundIt->second;
+ }
+ }
+
+ // Not found. Load the dialect namespace.
+ loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+
+ // Attempt to find from the canonical map and cache.
+ {
+ const auto foundIt = typeCasterMap.find(mlirTypeID);
+ if (foundIt != typeCasterMap.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::object is defined");
+ // Positive cache.
+ typeCasterMapCache[mlirTypeID] = foundIt->second;
+ return foundIt->second;
+ }
+ // Negative cache.
+ typeCasterMap[mlirTypeID] = py::none();
+ return std::nullopt;
+ }
+}
+
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
operationClassMapCache.clear();
+ typeCasterMapCache.clear();
}
#include <utility>
#include <vector>
+#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/AffineExpr.h"
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
- : BaseTy(std::move(contextRef), t) {
- pybind11::implicitly_convertible<PyType, DerivedTy>();
- }
+ : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
return printAccum.join();
});
+ if (DerivedTy::getTypeIdFunction) {
+ PyGlobals::get().registerTypeCaster(
+ DerivedTy::getTypeIdFunction(),
+ pybind11::cpp_function(
+ [](PyType pyType) -> DerivedTy { return pyType; }));
+ }
+
DerivedTy::bindDerived(cls);
}
return DerivedTy::isaFunction(otherAttr);
},
pybind11::arg("other"));
- cls.def_property_readonly("type", [](PyAttribute &attr) {
- return PyType(attr.getContext(), mlirAttributeGetType(attr));
- });
+ cls.def_property_readonly(
+ "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
DerivedTy::bindDerived(cls);
}
"Create a complex type");
c.def_property_readonly(
"element_type",
- [](PyComplexType &self) -> PyType {
- MlirType t = mlirComplexTypeGetElementType(self);
- return PyType(self.getContext(), t);
- },
+ [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
"Returns element type.");
}
};
static void bindDerived(ClassTy &c) {
c.def_property_readonly(
"element_type",
- [](PyShapedType &self) {
- MlirType t = mlirShapedTypeGetElementType(self);
- return PyType(self.getContext(), t);
- },
+ [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
"Create a tuple type");
c.def(
"get_type",
- [](PyTupleType &self, intptr_t pos) -> PyType {
- MlirType t = mlirTupleTypeGetType(self, pos);
- return PyType(self.getContext(), t);
+ [](PyTupleType &self, intptr_t pos) {
+ return mlirTupleTypeGetType(self, pos);
},
py::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_property_readonly(
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
- types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
+ types.append(mlirFunctionTypeGetInput(t, i));
}
return types;
},
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
- types.append(
- PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
+ types.append(mlirFunctionTypeGetResult(self, i));
}
return types;
},
namespace py = pybind11;
using namespace mlir;
+using namespace py::literals;
using namespace mlir::python;
// -----------------------------------------------------------------------------
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
self.clearImportCache();
},
- py::arg("module_name"))
+ "module_name"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
- py::arg("dialect_namespace"), py::arg("dialect_class"),
+ "dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- py::arg("operation_name"), py::arg("operation_class"),
+ "operation_name"_a, "operation_class"_a,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
- py::arg("dialect_class"),
+ "dialect_class"_a,
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
- [](py::object dialectClass) -> py::cpp_function {
+ [](const py::object &dialectClass) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
return opClass;
});
},
- py::arg("dialect_class"),
+ "dialect_class"_a,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
+ m.def(
+ MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
+ [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
+ PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
+ replace);
+ },
+ "typeid"_a, "type_caster"_a, "replace"_a = false,
+ "Register a type caster for casting MLIR types to custom user types.");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
return isa<transform::OperationType>(unwrap(type));
}
+MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
+ return wrap(transform::OperationType::getTypeID());
+}
+
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
+MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
+ return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
+}
+
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
return wrap(unwrap(type).getTypeID());
}
+MlirDialect mlirTypeGetDialect(MlirType type) {
+ return wrap(&unwrap(type).getDialect());
+}
+
bool mlirTypeEqual(MlirType t1, MlirType t2) {
return unwrap(t1) == unwrap(t2);
}
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
-
MlirTypeID mlirTypeIDCreate(const void *ptr) {
assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
"ptr must be 8 byte aligned");
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
+from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
def register_python_test_dialect(context, load=True):
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
+from ._mlir_libs._mlir import register_type_caster
# Convenience decorator for registering user-friendly Attribute builders.
# Classes of custom types that inherit from concrete types should have
# static_typeid
- assert isinstance(test.TestTensorType.static_typeid, TypeID)
+ assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
- assert test.TestTensorType.static_typeid == t.type.typeid
+ assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
# CHECK-LABEL: TEST: inferReturnTypeComponents
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
+
+
+# CHECK-LABEL: TEST: testCustomTypeTypeCaster
+@run
+def testCustomTypeTypeCaster():
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
+
+ a = test.TestType.get()
+ assert a.typeid is not None
+
+ b = Type.parse("!python_test.test_type")
+ # CHECK: !python_test.test_type
+ print(b)
+ # CHECK: TestType(!python_test.test_type)
+ print(repr(b))
+
+ c = test.TestIntegerRankedTensorType.get([10, 10], 5)
+ # CHECK: tensor<10x10xi5>
+ print(c)
+ # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+ print(repr(c))
+
+ # CHECK: Type caster is already registered
+ try:
+
+ def type_caster(pytype):
+ return test.TestIntegerRankedTensorType(pytype)
+
+ register_type_caster(c.typeid, type_caster)
+ except RuntimeError as e:
+ print(e)
+
+ def type_caster(pytype):
+ return test.TestIntegerRankedTensorType(pytype)
+
+ register_type_caster(c.typeid, type_caster, replace=True)
+
+ d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
+ # CHECK: tensor<10x10xi5>
+ print(d.type)
+ # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
+ print(repr(d.type))
print(f"rank: {len(attr.strides)}")
# CHECK: strides are dynamic: [True, True, True]
print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
+
+
+# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
+@run
+def testConcreteTypesRoundTrip():
+ with Context(), Location.unknown():
+
+ def print_item(attr):
+ print(repr(attr.type))
+
+ # CHECK: F32Type(f32)
+ print_item(Attribute.parse("42.0 : f32"))
+ # CHECK: F32Type(f32)
+ print_item(FloatAttr.get_f32(42.0))
+ # CHECK: IntegerType(i64)
+ print_item(IntegerAttr.get(IntegerType.get_signless(64), 42))
+
+ def print_container_item(attr_asm):
+ attr = DenseElementsAttr(Attribute.parse(attr_asm))
+ print(repr(attr.type))
+ print(repr(attr.type.element_type))
+
+ # CHECK: RankedTensorType(tensor<i16>)
+ # CHECK: IntegerType(i16)
+ print_container_item("dense<123> : tensor<i16>")
+
+ # CHECK: RankedTensorType(tensor<f64>)
+ # CHECK: F64Type(f64)
+ print_container_item("dense<1.0> : tensor<f64>")
+
+ raw = Attribute.parse("vector<4xf32>")
+ # CHECK: attr: vector<4xf32>
+ print("attr:", raw)
+ type_attr = TypeAttr(raw)
+
+ # CHECK: VectorType(vector<4xf32>)
+ print(repr(type_attr.value))
+ # CHECK: F32Type(f32)
+ print(repr(type_attr.value.element_type))
import gc
from mlir.ir import *
+from mlir.dialects import arith, tensor, func, memref
def run(f):
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
- memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
+ memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
# CHECK: memref type: memref<2x3xf32, 2>
- print("memref type:", memref)
+ print("memref type:", memref_f32)
# CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
- print("memref layout:", memref.layout)
+ print("memref layout:", memref_f32.layout)
# CHECK: memref affine map: (d0, d1) -> (d0, d1)
- print("memref affine map:", memref.affine_map)
+ print("memref affine map:", memref_f32.affine_map)
# CHECK: memory space: 2
- print("memory space:", memref.memory_space)
+ print("memory space:", memref_f32.memory_space)
layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
memref_layout = MemRefType.get(shape, f32, layout=layout)
else:
print("Exception not produced")
- assert memref.shape == shape
+ assert memref_f32.shape == shape
# CHECK-LABEL: TEST: testUnrankedMemRefType
input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
result_types = [IndexType.get()]
func = FunctionType.get(input_types, result_types)
- # CHECK: INPUTS: [Type(i32), Type(i16)]
+ # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)]
print("INPUTS:", func.inputs)
- # CHECK: RESULTS: [Type(index)]
+ # CHECK: RESULTS: [IndexType(index)]
print("RESULTS:", func.results)
vector_type = Type.parse("vector<2x3xf32>")
# CHECK: True
print(ShapedType(vector_type).typeid == vector_type.typeid)
+
+
+# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
+@run
+def testConcreteTypesRoundTrip():
+ with Context() as ctx, Location.unknown():
+ ctx.allow_unregistered_dialects = True
+
+ def print_downcasted(typ):
+ downcasted = Type(typ).maybe_downcast()
+ print(type(downcasted).__name__)
+ print(repr(downcasted))
+
+ # CHECK: F16Type
+ # CHECK: F16Type(f16)
+ print_downcasted(F16Type.get())
+ # CHECK: F32Type
+ # CHECK: F32Type(f32)
+ print_downcasted(F32Type.get())
+ # CHECK: F64Type
+ # CHECK: F64Type(f64)
+ print_downcasted(F64Type.get())
+ # CHECK: Float8E4M3B11FNUZType
+ # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
+ print_downcasted(Float8E4M3B11FNUZType.get())
+ # CHECK: Float8E4M3FNType
+ # CHECK: Float8E4M3FNType(f8E4M3FN)
+ print_downcasted(Float8E4M3FNType.get())
+ # CHECK: Float8E4M3FNUZType
+ # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
+ print_downcasted(Float8E4M3FNUZType.get())
+ # CHECK: Float8E5M2Type
+ # CHECK: Float8E5M2Type(f8E5M2)
+ print_downcasted(Float8E5M2Type.get())
+ # CHECK: Float8E5M2FNUZType
+ # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
+ print_downcasted(Float8E5M2FNUZType.get())
+ # CHECK: BF16Type
+ # CHECK: BF16Type(bf16)
+ print_downcasted(BF16Type.get())
+ # CHECK: IndexType
+ # CHECK: IndexType(index)
+ print_downcasted(IndexType.get())
+ # CHECK: IntegerType
+ # CHECK: IntegerType(i32)
+ print_downcasted(IntegerType.get_signless(32))
+
+ f32 = F32Type.get()
+ ranked_tensor = tensor.EmptyOp([10, 10], f32).result
+ # CHECK: RankedTensorType
+ print(type(ranked_tensor.type).__name__)
+ # CHECK: RankedTensorType(tensor<10x10xf32>)
+ print(repr(ranked_tensor.type))
+
+ cf32 = ComplexType.get(f32)
+ # CHECK: ComplexType
+ print(type(cf32).__name__)
+ # CHECK: ComplexType(complex<f32>)
+ print(repr(cf32))
+
+ ranked_tensor = tensor.EmptyOp([10, 10], f32).result
+ # CHECK: RankedTensorType
+ print(type(ranked_tensor.type).__name__)
+ # CHECK: RankedTensorType(tensor<10x10xf32>)
+ print(repr(ranked_tensor.type))
+
+ vector = VectorType.get([10, 10], f32)
+ tuple_type = TupleType.get_tuple([f32, vector])
+ # CHECK: TupleType
+ print(type(tuple_type).__name__)
+ # CHECK: TupleType(tuple<f32, vector<10x10xf32>>)
+ print(repr(tuple_type))
+ # CHECK: F32Type(f32)
+ print(repr(tuple_type.get_type(0)))
+ # CHECK: VectorType(vector<10x10xf32>)
+ print(repr(tuple_type.get_type(1)))
+
+ index_type = IndexType.get()
+
+ @func.FuncOp.from_py_func()
+ def default_builder():
+ c0 = arith.ConstantOp(f32, 0.0)
+ unranked_tensor_type = UnrankedTensorType.get(f32)
+ unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
+ # CHECK: UnrankedTensorType
+ print(type(unranked_tensor.type).__name__)
+ # CHECK: UnrankedTensorType(tensor<*xf32>)
+ print(repr(unranked_tensor.type))
+
+ c10 = arith.ConstantOp(index_type, 10)
+ memref_f32_t = MemRefType.get([10, 10], f32)
+ memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result
+ # CHECK: MemRefType
+ print(type(memref_f32.type).__name__)
+ # CHECK: MemRefType(memref<10x10xf32>)
+ print(repr(memref_f32.type))
+
+ unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2"))
+ memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result
+ # CHECK: UnrankedMemRefType
+ print(type(memref_f32.type).__name__)
+ # CHECK: UnrankedMemRefType(memref<*xf32, 2>)
+ print(repr(memref_f32.type))
+
+ tuple_type = Operation.parse(
+ f'"test.make_tuple"() : () -> tuple<i32, f32>'
+ ).result
+ # CHECK: TupleType
+ print(type(tuple_type.type).__name__)
+ # CHECK: TupleType(tuple<i32, f32>)
+ print(repr(tuple_type.type))
+
+ return c0, c10
+
+
+# CHECK-LABEL: TEST: testCustomTypeTypeCaster
+# This tests being able to materialize a type from a dialect *and* have
+# the implemented type caster called without explicitly importing the dialect.
+# I.e., we get a transform.OperationType without explicitly importing the transform dialect.
+@run
+def testCustomTypeTypeCaster():
+ with Context() as ctx, Location.unknown():
+ t = Type.parse('!transform.op<"foo.bar">', Context())
+ # CHECK: !transform.op<"foo.bar">
+ print(t)
+ # CHECK: OperationType(!transform.op<"foo.bar">)
+ print(repr(t))
return wrap(python_test::TestTypeType::get(unwrap(context)));
}
+MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
+ return wrap(python_test::TestTypeType::getTypeID());
+}
+
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
}
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
+
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
#ifdef __cplusplus
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir::python::adaptors;
+using namespace pybind11::literals;
+
+static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
+ return mlirTypeIsARankedTensor(t) &&
+ mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
+}
PYBIND11_MODULE(_mlirPythonTest, m) {
m.def(
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
- mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
+ mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
+ mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
- mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
- py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("RankedTensorType"));
+ auto cls =
+ mlir_type_subclass(m, "TestIntegerRankedTensorType",
+ mlirTypeIsARankedIntegerTensor,
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("RankedTensorType"))
+ .def_classmethod(
+ "get",
+ [](const py::object &cls, std::vector<int64_t> shape,
+ unsigned width, MlirContext ctx) {
+ MlirAttribute encoding = mlirAttributeGetNull();
+ return cls(mlirRankedTensorTypeGet(
+ shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
+ encoding));
+ },
+ "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
+ assert(py::hasattr(cls.get_class(), "static_typeid") &&
+ "TestIntegerRankedTensorType has no static_typeid");
+ MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
+ return cls.get_class()(mlirType);
+ }),
+ /*replace=*/true);
mlir_value_subclass(m, "TestTensorValue",
mlirTypeIsAPythonTestTestTensorValue)
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });