return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
}
+ intptr_t dunderLen() { return mlirElementsAttrGetNumElements(attr); }
+
static void bindDerived(ClassTy &c) {
- c.def_static("get", PyDenseElementsAttribute::getFromBuffer,
- py::arg("array"), py::arg("signless") = true,
- py::arg("context") = py::none(),
- "Gets from a buffer or ndarray")
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen)
+ .def_static("get", PyDenseElementsAttribute::getFromBuffer,
+ py::arg("array"), py::arg("signless") = true,
+ py::arg("context") = py::none(),
+ "Gets from a buffer or ndarray")
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
py::arg("shaped_type"), py::arg("element_attr"),
"Gets a DenseElementsAttr where all values are the same")
}
};
+/// Refinement of the PyDenseElementsAttribute for attributes containing integer
+/// (and boolean) values. Supports element access.
+class PyDenseIntElementsAttribute
+ : public PyConcreteAttribute<PyDenseIntElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
+ static constexpr const char *pyClassName = "DenseIntElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ /// Returns the element at the given linear position. Asserts if the index is
+ /// out of range.
+ py::int_ dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds element");
+ }
+
+ MlirType type = mlirAttributeGetType(attr);
+ type = mlirShapedTypeGetElementType(type);
+ assert(mlirTypeIsAInteger(type) &&
+ "expected integer element type in dense int elements attribute");
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. py::int_ is implicitly constructible
+ // from any C++ integral type and handles bitwidth correctly.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ unsigned width = mlirIntegerTypeGetWidth(type);
+ bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
+ if (isUnsigned) {
+ if (width == 1) {
+ return mlirDenseElementsAttrGetBoolValue(attr, pos);
+ }
+ if (width == 32) {
+ return mlirDenseElementsAttrGetUInt32Value(attr, pos);
+ }
+ if (width == 64) {
+ return mlirDenseElementsAttrGetUInt64Value(attr, pos);
+ }
+ } else {
+ if (width == 1) {
+ return mlirDenseElementsAttrGetBoolValue(attr, pos);
+ }
+ if (width == 32) {
+ return mlirDenseElementsAttrGetInt32Value(attr, pos);
+ }
+ if (width == 64) {
+ return mlirDenseElementsAttrGetInt64Value(attr, pos);
+ }
+ }
+ throw SetPyError(PyExc_TypeError, "Unsupported integer type");
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
+ }
+};
+
+/// Refinement of PyDenseElementsAttribute for attributes containing
+/// floating-point values. Supports element access.
+class PyDenseFPElementsAttribute
+ : public PyConcreteAttribute<PyDenseFPElementsAttribute,
+ PyDenseElementsAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
+ static constexpr const char *pyClassName = "DenseFPElementsAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ py::float_ dunderGetItem(intptr_t pos) {
+ if (pos < 0 || pos >= dunderLen()) {
+ throw SetPyError(PyExc_IndexError,
+ "attempt to access out of bounds element");
+ }
+
+ MlirType type = mlirAttributeGetType(attr);
+ type = mlirShapedTypeGetElementType(type);
+ // Dispatch element extraction to an appropriate C function based on the
+ // elemental type of the attribute. py::float_ is implicitly constructible
+ // from float and double.
+ // TODO: consider caching the type properties in the constructor to avoid
+ // querying them on each element access.
+ if (mlirTypeIsAF32(type)) {
+ return mlirDenseElementsAttrGetFloatValue(attr, pos);
+ }
+ if (mlirTypeIsAF64(type)) {
+ return mlirDenseElementsAttrGetDoubleValue(attr, pos);
+ }
+ throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
+ }
+};
+
} // namespace
//------------------------------------------------------------------------------
PyBoolAttribute::bind(m);
PyStringAttribute::bind(m);
PyDenseElementsAttribute::bind(m);
+ PyDenseIntElementsAttribute::bind(m);
+ PyDenseFPElementsAttribute::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyType.
print("named:", named)
run(testNamedAttr)
+
+
+# CHECK-LABEL: TEST: testDenseIntAttr
+def testDenseIntAttr():
+ with Context():
+ raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
+ # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
+ print("attr:", raw)
+
+ a = DenseIntElementsAttr(raw)
+ assert len(a) == 6
+
+ # CHECK: 0 1 2 3 4 5
+ for value in a:
+ print(value, end=" ")
+ print()
+
+ # CHECK: i32
+ print(ShapedType(a.type).element_type)
+
+ raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
+ # CHECK: attr: dense<[true, false, true, false]>
+ print("attr:", raw)
+
+ a = DenseIntElementsAttr(raw)
+ assert len(a) == 4
+
+ # CHECK: 1 0 1 0
+ for value in a:
+ print(value, end=" ")
+ print()
+
+ # CHECK: i1
+ print(ShapedType(a.type).element_type)
+
+
+run(testDenseIntAttr)
+
+
+# CHECK-LABEL: TEST: testDenseFPAttr
+def testDenseFPAttr():
+ with Context():
+ raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
+ # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
+
+ print("attr:", raw)
+
+ a = DenseFPElementsAttr(raw)
+ assert len(a) == 4
+
+ # CHECK: 0.0 1.0 2.0 3.0
+ for value in a:
+ print(value, end=" ")
+ print()
+
+ # CHECK: f32
+ print(ShapedType(a.type).element_type)
+
+
+run(testDenseFPAttr)