MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr);
// Creates a strided layout attribute from given strides and offset.
-MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx,
- int64_t offset,
- intptr_t numStrides,
- int64_t *strides);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides,
+ const int64_t *strides);
// Returns the offset in the given strided layout layout attribute.
MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr);
}
};
+/// Strided layout attribute subclass.
+class PyStridedLayoutAttribute
+ : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+ static constexpr const char *pyClassName = "StridedLayoutAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](int64_t offset, const std::vector<int64_t> strides,
+ DefaultingPyMlirContext ctx) {
+ MlirAttribute attr = mlirStridedLayoutAttrGet(
+ ctx->get(), offset, strides.size(), strides.data());
+ return PyStridedLayoutAttribute(ctx->getRef(), attr);
+ },
+ py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
+ "Gets a strided layout attribute.");
+ c.def_property_readonly(
+ "offset",
+ [](PyStridedLayoutAttribute &self) {
+ return mlirStridedLayoutAttrGetOffset(self);
+ },
+ "Returns the value of the float point attribute");
+ c.def_property_readonly(
+ "strides",
+ [](PyStridedLayoutAttribute &self) {
+ intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+ std::vector<int64_t> strides(size);
+ for (intptr_t i = 0; i < size; i++) {
+ strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+ }
+ return strides;
+ },
+ "Returns the value of the float point attribute");
+ }
+};
+
} // namespace
void mlir::python::populateIRAttributes(py::module &m) {
PyStringAttribute::bind(m);
PyTypeAttribute::bind(m);
PyUnitAttribute::bind(m);
+
+ PyStridedLayoutAttribute::bind(m);
}
},
"Returns the shape of the ranked shaped type as a list of integers.");
c.def_static(
- "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
+ "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
"Returns the value used to indicate dynamic dimensions in shaped "
"types.");
c.def_static(
- "_get_dynamic_stride_or_offset",
+ "get_dynamic_stride_or_offset",
[]() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
"Returns the value used to indicate dynamic strides or offsets in "
"shaped types.");
}
MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
- intptr_t numStrides, int64_t *strides) {
+ intptr_t numStrides,
+ const int64_t *strides) {
return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
ArrayRef<int64_t>(strides, numStrides)));
}
static_split_point = split_point
dynamic_split_point = None
else:
- static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
+ static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
dynamic_split_point = _get_op_result_or_value(split_point)
pdl_operation_type = pdl.OperationType.get()
static_sizes.append(size)
else:
static_sizes.append(
- IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
+ IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = ArrayAttr.get(static_sizes)
array = array + [StringAttr.get("c")]
# CHECK: concat: ["a", "b", "c"]
print("concat: ", array)
+
+
+# CHECK-LABEL: TEST: testStridedLayoutAttr
+@run
+def testStridedLayoutAttr():
+ with Context():
+ attr = StridedLayoutAttr.get(42, [5, 7, 13])
+ # CHECK: strided<[5, 7, 13], offset: 42>
+ print(attr)
+ # CHECK: 42
+ print(attr.offset)
+ # CHECK: 3
+ print(len(attr.strides))
+ # CHECK: 5
+ print(attr.strides[0])
+ # CHECK: 7
+ print(attr.strides[1])
+ # CHECK: 13
+ print(attr.strides[2])
print("dialect namespace:", opaque.dialect_namespace)
# CHECK: data: type
print("data:", opaque.data)
+
+
+# CHECK-LABEL: TEST: testShapedTypeConstants
+# Tests that ShapedType exposes magic value constants.
+@run
+def testShapedTypeConstants():
+ # CHECK: <class 'int'>
+ print(type(ShapedType.get_dynamic_size()))
+ # CHECK: <class 'int'>
+ print(type(ShapedType.get_dynamic_stride_or_offset()))