step),
affineMap(map) {}
- intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyAffineMapExprList, PyAffineExpr>;
+
+ intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); }
- PyAffineExpr getElement(intptr_t pos) {
+ PyAffineExpr getRawElement(intptr_t pos) {
return PyAffineExpr(affineMap.getContext(),
mlirAffineMapGetResult(affineMap, pos));
}
return PyAffineMapExprList(affineMap, startIndex, length, step);
}
-private:
PyAffineMap affineMap;
};
} // namespace
step),
set(set) {}
- intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint>;
+
+ intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); }
- PyIntegerSetConstraint getElement(intptr_t pos) {
+ PyIntegerSetConstraint getRawElement(intptr_t pos) {
return PyIntegerSetConstraint(set, pos);
}
return PyIntegerSetConstraintList(set, startIndex, length, step);
}
-private:
PyIntegerSet set;
};
} // namespace
static std::vector<PyType> getValueTypes(Container &container,
PyMlirContextRef &context) {
std::vector<PyType> result;
- result.reserve(container.getNumElements());
- for (int i = 0, e = container.getNumElements(); i < e; ++i) {
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(
PyType(context, mlirValueGetType(container.getElement(i).get())));
}
step),
operation(std::move(operation)), block(block) {}
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("types", [](PyBlockArgumentList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
+
/// Returns the number of arguments in the list.
- intptr_t getNumElements() {
+ intptr_t getRawNumElements() {
operation->checkValid();
return mlirBlockGetNumArguments(block);
}
- /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
- PyBlockArgument getElement(intptr_t pos) {
+ /// Returns `pos`-the element in the list.
+ PyBlockArgument getRawElement(intptr_t pos) {
MlirValue argument = mlirBlockGetArgument(block, pos);
return PyBlockArgument(operation, argument);
}
return PyBlockArgumentList(operation, block, startIndex, length, step);
}
- static void bindDerived(ClassTy &c) {
- c.def_property_readonly("types", [](PyBlockArgumentList &self) {
- return getValueTypes(self, self.operation->getContext());
- });
- }
-
-private:
PyOperationRef operation;
MlirBlock block;
};
step),
operation(operation) {}
- intptr_t getNumElements() {
+ void dunderSetItem(intptr_t index, PyValue value) {
+ index = wrapIndex(index);
+ mlirOperationSetOperand(operation->get(), index, value.get());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def("__setitem__", &PyOpOperandList::dunderSetItem);
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpOperandList, PyValue>;
+
+ intptr_t getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumOperands(operation->get());
}
- PyValue getElement(intptr_t pos) {
+ PyValue getRawElement(intptr_t pos) {
MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
MlirOperation owner;
if (mlirValueIsAOpResult(operand))
return PyOpOperandList(operation, startIndex, length, step);
}
- void dunderSetItem(intptr_t index, PyValue value) {
- index = wrapIndex(index);
- mlirOperationSetOperand(operation->get(), index, value.get());
- }
-
- static void bindDerived(ClassTy &c) {
- c.def("__setitem__", &PyOpOperandList::dunderSetItem);
- }
-
-private:
PyOperationRef operation;
};
step),
operation(operation) {}
- intptr_t getNumElements() {
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly("types", [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumResults(operation->get());
}
- PyOpResult getElement(intptr_t index) {
+ PyOpResult getRawElement(intptr_t index) {
PyValue value(operation, mlirOperationGetResult(operation->get(), index));
return PyOpResult(value);
}
return PyOpResultList(operation, startIndex, length, step);
}
- static void bindDerived(ClassTy &c) {
- c.def_property_readonly("types", [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- });
- }
-
-private:
PyOperationRef operation;
};
/// A derived class must provide the following:
/// - a `static const char *pyClassName ` field containing the name of the
/// Python class to bind;
-/// - an instance method `intptr_t getNumElements()` that returns the number
+/// - an instance method `intptr_t getRawNumElements()` that returns the
+/// number
/// of elements in the backing container (NOT that of the slice);
-/// - an instance method `ElementTy getElement(intptr_t)` that returns a
-/// single element at the given index.
+/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a
+/// single element at the given linear index (NOT slice index);
/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that
/// constructs a new instance of the derived pseudo-container with the
/// given slice parameters (to be forwarded to the Sliceable constructor).
///
-/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
+/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not
+/// throw.
///
/// A derived class may additionally define:
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
protected:
using ClassTy = pybind11::class_<Derived>;
- // Transforms `index` into a legal value to access the underlying sequence.
- // Returns <0 on failure.
+ /// Transforms `index` into a legal value to access the underlying sequence.
+ /// Returns <0 on failure.
intptr_t wrapIndex(intptr_t index) {
if (index < 0)
index = length + index;
return index;
}
+ /// Computes the linear index given the current slice properties.
+ intptr_t linearizeIndex(intptr_t index) {
+ intptr_t linearIndex = index * step + startIndex;
+ assert(linearIndex >= 0 &&
+ linearIndex < static_cast<Derived *>(this)->getRawNumElements() &&
+ "linear index out of bounds, the slice is ill-formed");
+ return linearIndex;
+ }
+
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
return {};
}
- // Compute the linear index given the current slice properties.
- int linearIndex = index * step + startIndex;
- assert(linearIndex >= 0 &&
- linearIndex < static_cast<Derived *>(this)->getNumElements() &&
- "linear index out of bounds, the slice is ill-formed");
return pybind11::cast(
- static_cast<Derived *>(this)->getElement(linearIndex));
+ static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}
/// Returns a new instance of the pseudo-container restricted to the given
assert(length >= 0 && "expected non-negative slice length");
}
+ /// Returns the `index`-th element in the slice, supports negative indices.
+ /// Throws if the index is out of bounds.
+ ElementTy getElement(intptr_t index) {
+ // Negative indices mean we count from the end.
+ index = wrapIndex(index);
+ if (index < 0) {
+ throw pybind11::index_error("index out of range");
+ }
+
+ return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
+ }
+
+ /// Returns the size of slice.
+ intptr_t size() { return length; }
+
/// Returns a new vector (mapped to Python list) containing elements from two
/// slices. The new vector is necessary because slices may not be contiguous
/// or even come from the same original sequence.
elements.push_back(static_cast<Derived *>(this)->getElement(i));
}
for (intptr_t i = 0; i < other.length; ++i) {
- elements.push_back(static_cast<Derived *>(this)->getElement(i));
+ elements.push_back(static_cast<Derived *>(&other)->getElement(i));
}
return elements;
}
for t in entry_block.arguments.types:
print("Type: ", t)
+ # Check that slicing and type access compose.
+ # CHECK: Sliced type: i16
+ # CHECK: Sliced type: i24
+ for t in entry_block.arguments[1:].types:
+ print("Sliced type: ", t)
+
+ # Check that slice addition works as expected.
+ # CHECK: Argument 2, type i24
+ # CHECK: Argument 0, type i8
+ restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
+ for arg in restructured:
+ print(f"Argument {arg.arg_number}, type {arg.type}")
+
# CHECK-LABEL: TEST: testOperationOperands
@run