/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
/// Python bindings) and extends its lifetime.
-class PyBlockArgumentList {
+class PyBlockArgumentList
+ : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
- PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
- : operation(std::move(operation)), block(block) {}
+ static constexpr const char *pyClassName = "BlockArgumentList";
- /// Returns the length of the block argument list.
- intptr_t dunderLen() {
+ PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
+ intptr_t startIndex = 0, intptr_t length = -1,
+ intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirBlockGetNumArguments(block) : length,
+ step),
+ operation(std::move(operation)), block(block) {}
+
+ /// Returns the number of arguments in the list.
+ intptr_t getNumElements() {
operation->checkValid();
return mlirBlockGetNumArguments(block);
}
- /// Returns `index`-th element of the block argument list.
- PyBlockArgument dunderGetItem(intptr_t index) {
- if (index < 0 || index >= dunderLen()) {
- throw SetPyError(PyExc_IndexError,
- "attempt to access out of bounds region");
- }
- PyValue value(operation, mlirBlockGetArgument(block, index));
- return PyBlockArgument(value);
+ /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
+ PyBlockArgument getElement(intptr_t pos) {
+ MlirValue argument = mlirBlockGetArgument(block, pos);
+ return PyBlockArgument(operation, argument);
}
- /// Defines a Python class in the bindings.
- static void bind(py::module &m) {
- py::class_<PyBlockArgumentList>(m, "BlockArgumentList", py::module_local())
- .def("__len__", &PyBlockArgumentList::dunderLen)
- .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
+ /// Returns a sublist of this list.
+ PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
+ intptr_t step) {
+ return PyBlockArgumentList(operation, block, startIndex, length, step);
}
private:
sliceLength, step * extraStep);
}
+ /// Returns a new vector (mapped to Python list) containing elements from two
+ /// slices. The new vector is necessary because slices may not be contiguous
+ /// or even come from the same original sequence.
+ std::vector<ElementTy> dunderAdd(Derived &other) {
+ std::vector<ElementTy> elements;
+ elements.reserve(length + other.length);
+ for (intptr_t i = 0; i < length; ++i) {
+ elements.push_back(dunderGetItem(i));
+ }
+ for (intptr_t i = 0; i < other.length; ++i) {
+ elements.push_back(other.dunderGetItem(i));
+ }
+ return elements;
+ }
+
/// Binds the indexing and length methods in the Python class.
static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
pybind11::module_local())
.def("__len__", &Sliceable::dunderLen)
.def("__getitem__", &Sliceable::dunderGetItem)
- .def("__getitem__", &Sliceable::dunderGetItemSlice);
+ .def("__getitem__", &Sliceable::dunderGetItemSlice)
+ .def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);
}
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
class ModuleOp:
"""Specialization for the module op class."""
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
+ @property
+ def arg_attrs(self):
+ return self.attributes[ARGUMENT_ATTRIBUTE_NAME]
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
@classmethod
def from_py_func(FuncOp,
*inputs: Type,
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
print(m)
+
+
+# CHECK-LABEL: TEST: testFuncArgumentAccess
+@run
+def testFuncArgumentAccess():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ with InsertionPoint(module.body):
+ func = builtin.FuncOp("some_func", ([f32, f32], [f64, f64]))
+ with InsertionPoint(func.add_entry_block()):
+ std.ReturnOp(func.arguments)
+ func.arg_attrs = ArrayAttr.get([
+ DictAttr.get({
+ "foo": StringAttr.get("bar"),
+ "baz": UnitAttr.get()
+ }),
+ DictAttr.get({"qux": ArrayAttr.get([])})
+ ])
+ func.result_attrs = ArrayAttr.get([
+ DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}),
+ DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
+ ])
+
+ # CHECK: [{baz, foo = "bar"}, {qux = []}]
+ print(func.arg_attrs)
+
+ # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}]
+ print(func.result_attrs)
+
+ # CHECK: func @some_func(
+ # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
+ # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
+ # CHECK: f64 {res1 = 4.200000e+01 : f32},
+ # CHECK: f64 {res2 = 2.560000e+02 : f64})
+ # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
+ print(module)
for arg in entry_block.arguments:
print(f"Argument {arg.arg_number}, type {arg.type}")
+ # Check that slicing works for block argument lists.
+ # CHECK: Argument 1, type i16
+ # CHECK: Argument 2, type i24
+ for arg in entry_block.arguments[1:]:
+ print(f"Argument {arg.arg_number}, type {arg.type}")
+
+ # Check that we can concatenate slices of argument lists.
+ # CHECK: Length: 4
+ print("Length: ",
+ len(entry_block.arguments[:2] + entry_block.arguments[1:]))
+
run(testBlockArgumentList)
ctx = Context()
with Location.unknown(ctx):
try:
- Operation.create("builtin.module", attributes={None:StringAttr.get("name")})
+ Operation.create(
+ "builtin.module", attributes={None: StringAttr.get("name")})
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
- Operation.create("builtin.module", attributes={42:StringAttr.get("name")})
+ Operation.create(
+ "builtin.module", attributes={42: StringAttr.get("name")})
except Exception as e:
# CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
print(e)
try:
- Operation.create("builtin.module", attributes={"some_key":ctx})
+ Operation.create("builtin.module", attributes={"some_key": ctx})
except Exception as e:
# CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)
try:
- Operation.create("builtin.module", attributes={"some_key":None})
+ Operation.create("builtin.module", attributes={"some_key": None})
except Exception as e:
# CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
print(e)