Add Python binding for MLIR Type Attribute
authorMehdi Amini <joker.eph@gmail.com>
Sat, 5 Dec 2020 02:08:38 +0000 (02:08 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 7 Dec 2020 23:06:58 +0000 (23:06 +0000)
Differential Revision: https://reviews.llvm.org/D92711

mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py

index 39a17d0..cffebf6 100644 (file)
@@ -1922,6 +1922,28 @@ public:
   }
 };
 
+class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
+  static constexpr const char *pyClassName = "TypeAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType value, DefaultingPyMlirContext context) {
+          MlirAttribute attr = mlirTypeAttrGet(value.get());
+          return PyTypeAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets a uniqued Type attribute");
+    c.def_property_readonly("value", [](PyTypeAttribute &self) {
+      return PyType(self.getContext()->getRef(),
+                    mlirTypeAttrGetValue(self.get()));
+    });
+  }
+};
+
 /// Unit Attribute subclass. Unit attributes don't have values.
 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
 public:
@@ -3073,6 +3095,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyDenseElementsAttribute::bind(m);
   PyDenseIntElementsAttribute::bind(m);
   PyDenseFPElementsAttribute::bind(m);
+  PyTypeAttribute::bind(m);
   PyUnitAttribute::bind(m);
 
   //----------------------------------------------------------------------------
index 0572220..4ad180b 100644 (file)
@@ -255,3 +255,17 @@ def testDenseFPAttr():
 
 
 run(testDenseFPAttr)
+
+
+# CHECK-LABEL: TEST: testTypeAttr
+def testTypeAttr():
+  with Context():
+    raw = Attribute.parse("vector<4xf32>")
+    # CHECK: attr: vector<4xf32>
+    print("attr:", raw)
+    type_attr = TypeAttr(raw)
+    # CHECK: f32
+    print(ShapedType(type_attr.value).element_type)
+
+
+run(testTypeAttr)