[mlir] Add C bindings for StridedArrayAttr
authorDenys Shabalin <shabalin@google.com>
Wed, 28 Sep 2022 13:40:31 +0000 (13:40 +0000)
committerDenys Shabalin <shabalin@google.com>
Thu, 29 Sep 2022 09:52:57 +0000 (11:52 +0200)
Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D134808

mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/test/CAPI/ir.c

index c75db95..b2e32f6 100644 (file)
@@ -535,6 +535,30 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirSparseElementsAttrGetValues(MlirAttribute attr);
 
+//===----------------------------------------------------------------------===//
+// Strided layout attribute.
+//===----------------------------------------------------------------------===//
+
+// Checks wheather the given attribute is a strided layout attribute.
+MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr);
+
+// Creates a strided layout attribute from given strides and offset.
+MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx,
+                                                          int64_t offset,
+                                                          intptr_t numStrides,
+                                                          int64_t *strides);
+
+// Returns the offset in the given strided layout layout attribute.
+MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr);
+
+// Returns the number of strides in the given strided layout attribute.
+MLIR_CAPI_EXPORTED intptr_t
+mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr);
+
+// Returns the pos-th stride stored in the given strided layout attribute.
+MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr,
+                                                          intptr_t pos);
+
 #ifdef __cplusplus
 }
 #endif
index b02484a..1ae2a2b 100644 (file)
@@ -722,3 +722,30 @@ MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
   return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
 }
+
+//===----------------------------------------------------------------------===//
+// Strided layout attribute.
+//===----------------------------------------------------------------------===//
+
+bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
+  return unwrap(attr).isa<StridedLayoutAttr>();
+}
+
+MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
+                                       intptr_t numStrides, int64_t *strides) {
+  return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
+                                     ArrayRef<int64_t>(strides, numStrides)));
+}
+
+int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) {
+  return unwrap(attr).cast<StridedLayoutAttr>().getOffset();
+}
+
+intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
+  return static_cast<intptr_t>(
+      unwrap(attr).cast<StridedLayoutAttr>().getStrides().size());
+}
+
+int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<StridedLayoutAttr>().getStrides()[pos];
+}
index 6f1764b..e1d6133 100644 (file)
@@ -1220,6 +1220,20 @@ int printBuiltinAttributes(MlirContext ctx) {
       fabs(mlirDenseF64ArrayGetElement(doubleArray, 1) - 1.0) > 1E-6)
     return 21;
 
+  int64_t layoutStrides[3] = {5, 7, 13};
+  MlirAttribute stridedLayoutAttr =
+      mlirStridedLayoutAttrGet(ctx, 42, 3, &layoutStrides[0]);
+
+  // CHECK: strided<[5, 7, 13], offset: 42>
+  mlirAttributeDump(stridedLayoutAttr);
+
+  if (mlirStridedLayoutAttrGetOffset(stridedLayoutAttr) != 42 ||
+      mlirStridedLayoutAttrGetNumStrides(stridedLayoutAttr) != 3 ||
+      mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 0) != 5 ||
+      mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 1) != 7 ||
+      mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 2) != 13)
+    return 22;
+
   return 0;
 }