using namespace mlir::python;
using llvm::SmallVector;
-using llvm::StringRef;
using llvm::Twine;
namespace {
}
};
+template <typename T>
+static T pyTryCast(py::handle object) {
+ try {
+ return object.cast<T>();
+ } catch (py::cast_error &err) {
+ std::string msg =
+ std::string(
+ "Invalid attribute when attempting to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ } catch (py::reference_cast_error &err) {
+ std::string msg = std::string("Invalid attribute (None?) when attempting "
+ "to create an ArrayAttribute (") +
+ err.what() + ")";
+ throw py::cast_error(msg);
+ }
+}
+
class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
int nextIndex = 0;
};
+ PyAttribute getItem(intptr_t i) {
+ return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
+ }
+
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
SmallVector<MlirAttribute> mlirAttributes;
mlirAttributes.reserve(py::len(attributes));
for (auto attribute : attributes) {
- try {
- mlirAttributes.push_back(attribute.cast<PyAttribute>());
- } catch (py::cast_error &err) {
- std::string msg = std::string("Invalid attribute when attempting "
- "to create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
- } catch (py::reference_cast_error &err) {
- // This exception seems thrown when the value is "None".
- std::string msg =
- std::string("Invalid attribute (None?) when attempting to "
- "create an ArrayAttribute (") +
- err.what() + ")";
- throw py::cast_error(msg);
- }
+ mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
}
MlirAttribute attr = mlirArrayAttrGet(
context->get(), mlirAttributes.size(), mlirAttributes.data());
[](PyArrayAttribute &arr, intptr_t i) {
if (i >= mlirArrayAttrGetNumElements(arr))
throw py::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(),
- mlirArrayAttrGetElement(arr, i));
+ return arr.getItem(i);
})
.def("__len__",
[](const PyArrayAttribute &arr) {
.def("__iter__", [](const PyArrayAttribute &arr) {
return PyArrayAttributeIterator(arr);
});
+ c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
+ std::vector<MlirAttribute> attributes;
+ intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
+ attributes.reserve(numOldElements + py::len(extras));
+ for (intptr_t i = 0; i < numOldElements; ++i)
+ attributes.push_back(arr.getItem(i));
+ for (py::handle attr : extras)
+ attributes.push_back(pyTryCast<PyAttribute>(attr));
+ MlirAttribute arrayAttr = mlirArrayAttrGet(
+ arr.getContext()->get(), attributes.size(), attributes.data());
+ return PyArrayAttribute(arr.getContext(), arrayAttr);
+ });
}
};
mlirNamedAttributes.data());
return PyDictAttribute(context->getRef(), attr);
},
- py::arg("value"), py::arg("context") = py::none(),
+ py::arg("value") = py::dict(), py::arg("context") = py::none(),
"Gets an uniqued dict attribute");
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
MlirAttribute attr =
}
};
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<PyType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<PyType> result;
+ result.reserve(container.getNumElements());
+ for (int i = 0, e = container.getNumElements(); i < e; ++i) {
+ result.push_back(
+ PyType(context, mlirValueGetType(container.getElement(i).get())));
+ }
+ return result;
+}
+
/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
return PyBlockArgumentList(operation, block, startIndex, length, step);
}
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ }
+
private:
PyOperationRef operation;
MlirBlock block;
return PyOpResultList(operation, startIndex, length, step);
}
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("types", [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ }
+
private:
PyOperationRef operation;
};