From e56f398dd3740d97ac3b7ec1c69a12b951efd9a3 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 5 Dec 2020 02:08:38 +0000 Subject: [PATCH] Add Python binding for MLIR Type Attribute Differential Revision: https://reviews.llvm.org/D92711 --- mlir/lib/Bindings/Python/IRModules.cpp | 23 +++++++++++++++++++++++ mlir/test/Bindings/Python/ir_attributes.py | 14 ++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 39a17d0..cffebf6 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1922,6 +1922,28 @@ public: } }; +class PyTypeAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; + static constexpr const char *pyClassName = "TypeAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirTypeAttrGet(value.get()); + return PyTypeAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued Type attribute"); + c.def_property_readonly("value", [](PyTypeAttribute &self) { + return PyType(self.getContext()->getRef(), + mlirTypeAttrGetValue(self.get())); + }); + } +}; + /// Unit Attribute subclass. Unit attributes don't have values. class PyUnitAttribute : public PyConcreteAttribute { public: @@ -3073,6 +3095,7 @@ void mlir::python::populateIRSubmodule(py::module &m) { PyDenseElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); + PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); //---------------------------------------------------------------------------- diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py index 0572220..4ad180b 100644 --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -255,3 +255,17 @@ def testDenseFPAttr(): run(testDenseFPAttr) + + +# CHECK-LABEL: TEST: testTypeAttr +def testTypeAttr(): + with Context(): + raw = Attribute.parse("vector<4xf32>") + # CHECK: attr: vector<4xf32> + print("attr:", raw) + type_attr = TypeAttr(raw) + # CHECK: f32 + print(ShapedType(type_attr.value).element_type) + + +run(testTypeAttr) -- 2.7.4