[mlir][python] Add FlatSymbolRef attribute.
authorStella Laurenzo <stellaraccident@gmail.com>
Tue, 29 Dec 2020 20:07:57 +0000 (12:07 -0800)
committerStella Laurenzo <stellaraccident@gmail.com>
Tue, 29 Dec 2020 20:24:28 +0000 (12:24 -0800)
Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93909

mlir/lib/Bindings/Python/IRModules.cpp
mlir/test/Bindings/Python/ir_attributes.py

index 8a77d60..86d6f82 100644 (file)
@@ -1643,6 +1643,33 @@ public:
   }
 };
 
+class PyFlatSymbolRefAttribute
+    : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
+  static constexpr const char *pyClassName = "FlatSymbolRefAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::string value, DefaultingPyMlirContext context) {
+          MlirAttribute attr =
+              mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
+          return PyFlatSymbolRefAttribute(context->getRef(), attr);
+        },
+        py::arg("value"), py::arg("context") = py::none(),
+        "Gets a uniqued FlatSymbolRef attribute");
+    c.def_property_readonly(
+        "value",
+        [](PyFlatSymbolRefAttribute &self) {
+          MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
+          return py::str(stringRef.data, stringRef.length);
+        },
+        "Returns the value of the FlatSymbolRef attribute as a string");
+  }
+};
+
 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
@@ -3229,6 +3256,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
   PyIntegerAttribute::bind(m);
   PyBoolAttribute::bind(m);
+  PyFlatSymbolRefAttribute::bind(m);
   PyStringAttribute::bind(m);
   PyDenseElementsAttribute::bind(m);
   PyDenseIntElementsAttribute::bind(m);
index 84f3139..ce85dc3 100644 (file)
@@ -165,6 +165,20 @@ def testBoolAttr():
 run(testBoolAttr)
 
 
+# CHECK-LABEL: TEST: testFlatSymbolRefAttr
+def testFlatSymbolRefAttr():
+  with Context() as ctx:
+    sattr = FlatSymbolRefAttr(Attribute.parse('@symbol'))
+    # CHECK: symattr value: symbol
+    print("symattr value:", sattr.value)
+
+    # Test factory methods.
+    # CHECK: default_get: @foobar
+    print("default_get:", FlatSymbolRefAttr.get("foobar"))
+
+run(testFlatSymbolRefAttr)
+
+
 # CHECK-LABEL: TEST: testStringAttr
 def testStringAttr():
   with Context() as ctx: