[mlir] Add Python binding for MLIR Dict Attribute
authorkweisamx <kweisamx0322@gmail.com>
Sun, 13 Dec 2020 03:06:32 +0000 (03:06 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 13 Dec 2020 04:30:35 +0000 (04:30 +0000)
Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93004

mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py

index 66443bf890721e3dda648fa9ca587618f5207a31..8a77d60741b478e9cfa9c41aaf32eb8e8a1eb43e 100644 (file)
@@ -1968,6 +1968,58 @@ public:
   }
 };
 
+class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
+  static constexpr const char *pyClassName = "DictAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
+
+  static void bindDerived(ClassTy &c) {
+    c.def("__len__", &PyDictAttribute::dunderLen);
+    c.def_static(
+        "get",
+        [](py::dict attributes, DefaultingPyMlirContext context) {
+          SmallVector<MlirNamedAttribute> mlirNamedAttributes;
+          mlirNamedAttributes.reserve(attributes.size());
+          for (auto &it : attributes) {
+            auto &mlir_attr = it.second.cast<PyAttribute &>();
+            auto name = it.first.cast<std::string>();
+            mlirNamedAttributes.push_back(mlirNamedAttributeGet(
+                mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
+                                  toMlirStringRef(name)),
+                mlir_attr));
+          }
+          MlirAttribute attr =
+              mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
+                                    mlirNamedAttributes.data());
+          return PyDictAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets an uniqued dict attribute");
+    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
+      MlirAttribute attr =
+          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+      if (mlirAttributeIsNull(attr)) {
+        throw SetPyError(PyExc_KeyError,
+                         "attempt to access a non-existent attribute");
+      }
+      return PyAttribute(self.getContext(), attr);
+    });
+    c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
+      if (index < 0 || index >= self.dunderLen()) {
+        throw SetPyError(PyExc_IndexError,
+                         "attempt to access out of bounds attribute");
+      }
+      MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
+      return PyNamedAttribute(
+          namedAttr.attribute,
+          std::string(mlirIdentifierStr(namedAttr.name).data));
+    });
+  }
+};
+
 /// Refinement of PyDenseElementsAttribute for attributes containing
 /// floating-point values. Supports element access.
 class PyDenseFPElementsAttribute
@@ -3181,6 +3233,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyDenseElementsAttribute::bind(m);
   PyDenseIntElementsAttribute::bind(m);
   PyDenseFPElementsAttribute::bind(m);
+  PyDictAttribute::bind(m);
   PyTypeAttribute::bind(m);
   PyUnitAttribute::bind(m);
 
index 642c1f6a836c9b8961de04adf5fc6e1c34d9a581..84f3139125476e184d807ed676a283956bd7e92d 100644 (file)
@@ -257,6 +257,47 @@ def testDenseFPAttr():
 run(testDenseFPAttr)
 
 
+# CHECK-LABEL: TEST: testDictAttr
+def testDictAttr():
+  with Context():
+    dict_attr = {
+      'stringattr':  StringAttr.get('string'),
+      'integerattr' : IntegerAttr.get(
+        IntegerType.get_signless(32), 42)
+    }
+
+    a = DictAttr.get(dict_attr)
+
+    # CHECK attr: {integerattr = 42 : i32, stringattr = "string"}
+    print("attr:", a)
+
+    assert len(a) == 2
+
+    # CHECK: 42 : i32
+    print(a['integerattr'])
+
+    # CHECK: "string"
+    print(a['stringattr'])
+
+    # Check that exceptions are raised as expected.
+    try:
+      _ = a['does_not_exist']
+    except KeyError:
+      pass
+    else:
+      assert False, "Exception not produced"
+
+    try:
+      _ = a[42]
+    except IndexError:
+      pass
+    else:
+      assert False, "expected IndexError on accessing an out-of-bounds attribute"
+
+
+
+run(testDictAttr)
+
 # CHECK-LABEL: TEST: testTypeAttr
 def testTypeAttr():
   with Context():