static void bindDerived(ClassTy &m) {}
};
+/// Float Point Attribute subclass - FloatAttr.
+class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+ static constexpr const char *pyClassName = "FloatAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ // TODO: Make the location optional and create a default location.
+ [](PyType &type, double value, PyLocation &loc) {
+ MlirAttribute attr =
+ mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc);
+ // TODO: Rework error reporting once diagnostic engine is exposed
+ // in C API.
+ if (mlirAttributeIsNull(attr)) {
+ throw SetPyError(PyExc_ValueError,
+ llvm::Twine("invalid '") +
+ py::repr(py::cast(type)).cast<std::string>() +
+ "' and expected floating point type.");
+ }
+ return PyFloatAttribute(type.getContext(), attr);
+ },
+ py::arg("type"), py::arg("value"), py::arg("loc"),
+ "Gets an uniqued float point attribute associated to a type");
+ c.def_static(
+ "get_f32",
+ [](PyMlirContext &context, double value) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context.get(), mlirF32TypeGet(context.get()), value);
+ return PyFloatAttribute(context.getRef(), attr);
+ },
+ py::arg("context"), py::arg("value"),
+ "Gets an uniqued float point attribute associated to a f32 type");
+ c.def_static(
+ "get_f64",
+ [](PyMlirContext &context, double value) {
+ MlirAttribute attr = mlirFloatAttrDoubleGet(
+ context.get(), mlirF64TypeGet(context.get()), value);
+ return PyFloatAttribute(context.getRef(), attr);
+ },
+ py::arg("context"), py::arg("value"),
+ "Gets an uniqued float point attribute associated to a f64 type");
+ c.def_property_readonly(
+ "value",
+ [](PyFloatAttribute &self) {
+ return mlirFloatAttrGetValueDouble(self.attr);
+ },
+ "Returns the value of the float point attribute");
+ }
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+ static constexpr const char *pyClassName = "IntegerAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &type, int64_t value) {
+ MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
+ return PyIntegerAttribute(type.getContext(), attr);
+ },
+ py::arg("type"), py::arg("value"),
+ "Gets an uniqued integer attribute associated to a type");
+ c.def_property_readonly(
+ "value",
+ [](PyIntegerAttribute &self) {
+ return mlirIntegerAttrGetValueInt(self.attr);
+ },
+ "Returns the value of the integer attribute");
+ }
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+ static constexpr const char *pyClassName = "BoolAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyMlirContext &context, bool value) {
+ MlirAttribute attr = mlirBoolAttrGet(context.get(), value);
+ return PyBoolAttribute(context.getRef(), attr);
+ },
+ py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute");
+ c.def_property_readonly(
+ "value",
+ [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
+ "Returns the value of the bool attribute");
+ }
+};
+
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
"The underlying generic attribute of the NamedAttribute binding");
// Standard attribute bindings.
+ PyFloatAttribute::bind(m);
+ PyIntegerAttribute::bind(m);
+ PyBoolAttribute::bind(m);
PyStringAttribute::bind(m);
// Mapping of Type.
run(testStandardAttrCasts)
+# CHECK-LABEL: TEST: testFloatAttr
+def testFloatAttr():
+ ctx = mlir.ir.Context()
+ fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32"))
+ # CHECK: fattr value: 42.0
+ print("fattr value:", fattr.value)
+
+ # Test factory methods.
+ loc = ctx.get_unknown_location()
+ # CHECK: default_get: 4.200000e+01 : f32
+ print("default_get:", mlir.ir.FloatAttr.get(
+ mlir.ir.F32Type(ctx), 42.0, loc))
+ # CHECK: f32_get: 4.200000e+01 : f32
+ print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
+ # CHECK: f64_get: 4.200000e+01 : f64
+ print("f64_get:", mlir.ir.FloatAttr.get_f64(ctx, 42.0))
+ try:
+ fattr_invalid = mlir.ir.FloatAttr.get(
+ mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc)
+ except ValueError as e:
+ # CHECK: invalid 'Type(i32)' and expected floating point type.
+ print(e)
+ else:
+ print("Exception not produced")
+
+run(testFloatAttr)
+
+
+# CHECK-LABEL: TEST: testIntegerAttr
+def testIntegerAttr():
+ ctx = mlir.ir.Context()
+ iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
+ # CHECK: iattr value: 42
+ print("iattr value:", iattr.value)
+
+ # Test factory methods.
+ # CHECK: default_get: 42 : i32
+ print("default_get:", mlir.ir.IntegerAttr.get(
+ mlir.ir.IntegerType.get_signless(ctx, 32), 42))
+
+run(testIntegerAttr)
+
+
+# CHECK-LABEL: TEST: testBoolAttr
+def testBoolAttr():
+ ctx = mlir.ir.Context()
+ battr = mlir.ir.BoolAttr(ctx.parse_attr("true"))
+ # CHECK: iattr value: 1
+ print("iattr value:", battr.value)
+
+ # Test factory methods.
+ # CHECK: default_get: true
+ print("default_get:", mlir.ir.BoolAttr.get(ctx, True))
+
+run(testBoolAttr)
+
+
# CHECK-LABEL: TEST: testStringAttr
def testStringAttr():
ctx = mlir.ir.Context()