[mlir] Add fully dynamic constructor to StridedLayoutAttr bindings
authorDenys Shabalin <shabalin@google.com>
Tue, 4 Oct 2022 10:58:38 +0000 (10:58 +0000)
committerDenys Shabalin <shabalin@google.com>
Tue, 4 Oct 2022 13:02:55 +0000 (13:02 +0000)
Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135139

mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/test/python/ir/attributes.py

index e62f155..0c8c9b8 100644 (file)
@@ -1050,6 +1050,19 @@ public:
         },
         py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
         "Gets a strided layout attribute.");
+    c.def_static(
+        "get_fully_dynamic",
+        [](int64_t rank, DefaultingPyMlirContext ctx) {
+          auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
+          std::vector<int64_t> strides(rank);
+          std::fill(strides.begin(), strides.end(), dynamic);
+          MlirAttribute attr = mlirStridedLayoutAttrGet(
+              ctx->get(), dynamic, strides.size(), strides.data());
+          return PyStridedLayoutAttribute(ctx->getRef(), attr);
+        },
+        py::arg("rank"), py::arg("context") = py::none(),
+        "Gets a strided layout attribute with dynamic offset and strides of a "
+        "given rank.");
     c.def_property_readonly(
         "offset",
         [](PyStridedLayoutAttribute &self) {
index e0960e3..684d52c 100644 (file)
@@ -542,3 +542,14 @@ def testStridedLayoutAttr():
     print(attr.strides[1])
     # CHECK: 13
     print(attr.strides[2])
+
+    attr = StridedLayoutAttr.get_fully_dynamic(3)
+    dynamic = ShapedType.get_dynamic_stride_or_offset()
+    # CHECK: strided<[?, ?, ?], offset: ?>
+    print(attr)
+    # CHECK: offset is dynamic: True
+    print(f"offset is dynamic: {attr.offset == dynamic}")
+    # CHECK: rank: 3
+    print(f"rank: {len(attr.strides)}")
+    # CHECK: strides are dynamic: [True, True, True]
+    print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")