[mlir] Add Float Attribute, Integer Attribute and Bool Attribute subclasses to python...
authorzhanghb97 <zhanghb97@126.com>
Wed, 30 Sep 2020 06:11:46 +0000 (14:11 +0800)
committerzhanghb97 <zhanghb97@126.com>
Fri, 2 Oct 2020 16:32:51 +0000 (00:32 +0800)
Based on PyAttribute and PyConcreteAttribute classes, this patch implements the bindings of Float Attribute, Integer Attribute and Bool Attribute subclasses.
This patch also defines the `mlirFloatAttrDoubleGetChecked` C API which is bound with the `FloatAttr.get_typed` python method.

Differential Revision: https://reviews.llvm.org/D88531

mlir/include/mlir-c/StandardAttributes.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/StandardAttributes.cpp
mlir/test/Bindings/Python/ir_attributes.py

index e5d5aea..2fc2ecc 100644 (file)
@@ -93,6 +93,11 @@ int mlirAttributeIsAFloat(MlirAttribute attr);
 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
                                      double value);
 
+/** Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a
+ * construction of a FloatAttr, returns a null MlirAttribute. */
+MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
+                                            MlirLocation loc);
+
 /** Returns the value stored in the given floating point attribute, interpreting
  * the value as double. */
 double mlirFloatAttrGetValueDouble(MlirAttribute attr);
index 8d64b2d..36e25ee 100644 (file)
@@ -742,6 +742,106 @@ public:
   static void bindDerived(ClassTy &m) {}
 };
 
+/// Float Point Attribute subclass - FloatAttr.
+class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+  static constexpr const char *pyClassName = "FloatAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        // TODO: Make the location optional and create a default location.
+        [](PyType &type, double value, PyLocation &loc) {
+          MlirAttribute attr =
+              mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirAttributeIsNull(attr)) {
+            throw SetPyError(PyExc_ValueError,
+                             llvm::Twine("invalid '") +
+                                 py::repr(py::cast(type)).cast<std::string>() +
+                                 "' and expected floating point type.");
+          }
+          return PyFloatAttribute(type.getContext(), attr);
+        },
+        py::arg("type"), py::arg("value"), py::arg("loc"),
+        "Gets an uniqued float point attribute associated to a type");
+    c.def_static(
+        "get_f32",
+        [](PyMlirContext &context, double value) {
+          MlirAttribute attr = mlirFloatAttrDoubleGet(
+              context.get(), mlirF32TypeGet(context.get()), value);
+          return PyFloatAttribute(context.getRef(), attr);
+        },
+        py::arg("context"), py::arg("value"),
+        "Gets an uniqued float point attribute associated to a f32 type");
+    c.def_static(
+        "get_f64",
+        [](PyMlirContext &context, double value) {
+          MlirAttribute attr = mlirFloatAttrDoubleGet(
+              context.get(), mlirF64TypeGet(context.get()), value);
+          return PyFloatAttribute(context.getRef(), attr);
+        },
+        py::arg("context"), py::arg("value"),
+        "Gets an uniqued float point attribute associated to a f64 type");
+    c.def_property_readonly(
+        "value",
+        [](PyFloatAttribute &self) {
+          return mlirFloatAttrGetValueDouble(self.attr);
+        },
+        "Returns the value of the float point attribute");
+  }
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+  static constexpr const char *pyClassName = "IntegerAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &type, int64_t value) {
+          MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
+          return PyIntegerAttribute(type.getContext(), attr);
+        },
+        py::arg("type"), py::arg("value"),
+        "Gets an uniqued integer attribute associated to a type");
+    c.def_property_readonly(
+        "value",
+        [](PyIntegerAttribute &self) {
+          return mlirIntegerAttrGetValueInt(self.attr);
+        },
+        "Returns the value of the integer attribute");
+  }
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+  static constexpr const char *pyClassName = "BoolAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyMlirContext &context, bool value) {
+          MlirAttribute attr = mlirBoolAttrGet(context.get(), value);
+          return PyBoolAttribute(context.getRef(), attr);
+        },
+        py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute");
+    c.def_property_readonly(
+        "value",
+        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
+        "Returns the value of the bool attribute");
+  }
+};
+
 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
@@ -1630,6 +1730,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           "The underlying generic attribute of the NamedAttribute binding");
 
   // Standard attribute bindings.
+  PyFloatAttribute::bind(m);
+  PyIntegerAttribute::bind(m);
+  PyBoolAttribute::bind(m);
   PyStringAttribute::bind(m);
 
   // Mapping of Type.
index 77d5fcb..1277d2b 100644 (file)
@@ -102,6 +102,11 @@ MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
   return wrap(FloatAttr::get(unwrap(type), value));
 }
 
+MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
+                                            MlirLocation loc) {
+  return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc)));
+}
+
 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
   return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
 }
index a2fd500..dfdc819 100644 (file)
@@ -92,6 +92,63 @@ def testStandardAttrCasts():
 run(testStandardAttrCasts)
 
 
+# CHECK-LABEL: TEST: testFloatAttr
+def testFloatAttr():
+  ctx = mlir.ir.Context()
+  fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32"))
+  # CHECK: fattr value: 42.0
+  print("fattr value:", fattr.value)
+
+  # Test factory methods.
+  loc = ctx.get_unknown_location()
+  # CHECK: default_get: 4.200000e+01 : f32
+  print("default_get:", mlir.ir.FloatAttr.get(
+      mlir.ir.F32Type(ctx), 42.0, loc))
+  # CHECK: f32_get: 4.200000e+01 : f32
+  print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
+  # CHECK: f64_get: 4.200000e+01 : f64
+  print("f64_get:", mlir.ir.FloatAttr.get_f64(ctx, 42.0))
+  try:
+    fattr_invalid = mlir.ir.FloatAttr.get(
+        mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc)
+  except ValueError as e:
+    # CHECK: invalid 'Type(i32)' and expected floating point type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testFloatAttr)
+
+
+# CHECK-LABEL: TEST: testIntegerAttr
+def testIntegerAttr():
+  ctx = mlir.ir.Context()
+  iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
+  # CHECK: iattr value: 42
+  print("iattr value:", iattr.value)
+
+  # Test factory methods.
+  # CHECK: default_get: 42 : i32
+  print("default_get:", mlir.ir.IntegerAttr.get(
+      mlir.ir.IntegerType.get_signless(ctx, 32), 42))
+
+run(testIntegerAttr)
+
+
+# CHECK-LABEL: TEST: testBoolAttr
+def testBoolAttr():
+  ctx = mlir.ir.Context()
+  battr = mlir.ir.BoolAttr(ctx.parse_attr("true"))
+  # CHECK: iattr value: 1
+  print("iattr value:", battr.value)
+
+  # Test factory methods.
+  # CHECK: default_get: true
+  print("default_get:", mlir.ir.BoolAttr.get(ctx, True))
+
+run(testBoolAttr)
+
+
 # CHECK-LABEL: TEST: testStringAttr
 def testStringAttr():
   ctx = mlir.ir.Context()