[mlir] Add Python bindings for StridedLayoutAttr
authorDenys Shabalin <shabalin@google.com>
Thu, 29 Sep 2022 09:41:42 +0000 (09:41 +0000)
committerDenys Shabalin <shabalin@google.com>
Thu, 29 Sep 2022 11:03:30 +0000 (11:03 +0000)
Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D134869

mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py

index b2e32f6..79f2237 100644 (file)
@@ -543,10 +543,9 @@ mlirSparseElementsAttrGetValues(MlirAttribute attr);
 MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr);
 
 // Creates a strided layout attribute from given strides and offset.
-MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx,
-                                                          int64_t offset,
-                                                          intptr_t numStrides,
-                                                          int64_t *strides);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides,
+                         const int64_t *strides);
 
 // Returns the offset in the given strided layout layout attribute.
 MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr);
index 8d8cea3..e62f155 100644 (file)
@@ -1031,6 +1031,45 @@ public:
   }
 };
 
+/// Strided layout attribute subclass.
+class PyStridedLayoutAttribute
+    : public PyConcreteAttribute<PyStridedLayoutAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
+  static constexpr const char *pyClassName = "StridedLayoutAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](int64_t offset, const std::vector<int64_t> strides,
+           DefaultingPyMlirContext ctx) {
+          MlirAttribute attr = mlirStridedLayoutAttrGet(
+              ctx->get(), offset, strides.size(), strides.data());
+          return PyStridedLayoutAttribute(ctx->getRef(), attr);
+        },
+        py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
+        "Gets a strided layout attribute.");
+    c.def_property_readonly(
+        "offset",
+        [](PyStridedLayoutAttribute &self) {
+          return mlirStridedLayoutAttrGetOffset(self);
+        },
+        "Returns the value of the float point attribute");
+    c.def_property_readonly(
+        "strides",
+        [](PyStridedLayoutAttribute &self) {
+          intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
+          std::vector<int64_t> strides(size);
+          for (intptr_t i = 0; i < size; i++) {
+            strides[i] = mlirStridedLayoutAttrGetStride(self, i);
+          }
+          return strides;
+        },
+        "Returns the value of the float point attribute");
+  }
+};
+
 } // namespace
 
 void mlir::python::populateIRAttributes(py::module &m) {
@@ -1065,4 +1104,6 @@ void mlir::python::populateIRAttributes(py::module &m) {
   PyStringAttribute::bind(m);
   PyTypeAttribute::bind(m);
   PyUnitAttribute::bind(m);
+
+  PyStridedLayoutAttribute::bind(m);
 }
index 153664d..379510c 100644 (file)
@@ -302,11 +302,11 @@ public:
         },
         "Returns the shape of the ranked shaped type as a list of integers.");
     c.def_static(
-        "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
+        "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
         "Returns the value used to indicate dynamic dimensions in shaped "
         "types.");
     c.def_static(
-        "_get_dynamic_stride_or_offset",
+        "get_dynamic_stride_or_offset",
         []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
         "Returns the value used to indicate dynamic strides or offsets in "
         "shaped types.");
index 1ae2a2b..05ecb0f 100644 (file)
@@ -732,7 +732,8 @@ bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
 }
 
 MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
-                                       intptr_t numStrides, int64_t *strides) {
+                                       intptr_t numStrides,
+                                       const int64_t *strides) {
   return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
                                      ArrayRef<int64_t>(strides, numStrides)));
 }
index eddc384..527a865 100644 (file)
@@ -211,7 +211,7 @@ class SplitOp:
       static_split_point = split_point
       dynamic_split_point = None
     else:
-      static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
+      static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
       dynamic_split_point = _get_op_result_or_value(split_point)
 
     pdl_operation_type = pdl.OperationType.get()
@@ -255,7 +255,7 @@ class TileOp:
           static_sizes.append(size)
         else:
           static_sizes.append(
-              IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
+              IntegerAttr.get(i64_type, ShapedType.get_dynamic_size()))
           dynamic_sizes.append(_get_op_result_or_value(size))
       sizes_attr = ArrayAttr.get(static_sizes)
 
index a958abf..e0960e3 100644 (file)
@@ -523,3 +523,22 @@ def testArrayAttr():
     array = array + [StringAttr.get("c")]
     # CHECK: concat: ["a", "b", "c"]
     print("concat: ", array)
+
+
+# CHECK-LABEL: TEST: testStridedLayoutAttr
+@run
+def testStridedLayoutAttr():
+  with Context():
+    attr = StridedLayoutAttr.get(42, [5, 7, 13])
+    # CHECK: strided<[5, 7, 13], offset: 42>
+    print(attr)
+    # CHECK: 42
+    print(attr.offset)
+    # CHECK: 3
+    print(len(attr.strides))
+    # CHECK: 5
+    print(attr.strides[0])
+    # CHECK: 7
+    print(attr.strides[1])
+    # CHECK: 13
+    print(attr.strides[2])
index 945ed7e..91c820f 100644 (file)
@@ -487,3 +487,13 @@ def testOpaqueType():
     print("dialect namespace:", opaque.dialect_namespace)
     # CHECK: data: type
     print("data:", opaque.data)
+
+
+# CHECK-LABEL: TEST: testShapedTypeConstants
+# Tests that ShapedType exposes magic value constants.
+@run
+def testShapedTypeConstants():
+  # CHECK: <class 'int'>
+  print(type(ShapedType.get_dynamic_size()))
+  # CHECK: <class 'int'>
+  print(type(ShapedType.get_dynamic_stride_or_offset()))