[mlir][sparse] replace ad-hoc MemRef struct with CRunnerUtils definition
authorAart Bik <ajcbik@google.com>
Wed, 22 Sep 2021 05:56:00 +0000 (22:56 -0700)
committerAart Bik <ajcbik@google.com>
Wed, 22 Sep 2021 16:23:26 +0000 (09:23 -0700)
This revision removes the ad-hoc MemRefs that were needed using the old
ABI (when we still passed by value) and replaces them with the shared
StridedMemRef definitions of CRunnerUtils (possible now that we pass by
pointer). This avoids code duplication and makes sure we have a consistent
view of strided memory references in all our support libraries.

Reviewed By: jsetoain

Differential Revision: https://reviews.llvm.org/D110221

mlir/lib/ExecutionEngine/SparseUtils.cpp

index 7168946..a642a92 100644 (file)
@@ -502,32 +502,11 @@ char *getTensorFilename(uint64_t id) {
 // with sparse tensors, which are only visible as opaque pointers externally.
 // These methods should be used exclusively by MLIR compiler-generated code.
 //
-// Because we cannot use C++ templates with C linkage, some macro magic is used
-// to generate implementations for all required type combinations that can be
-// called from MLIR compiler-generated code.
+// Some macro magic is used to generate implementations for all required type
+// combinations that can be called from MLIR compiler-generated code.
 //
 //===----------------------------------------------------------------------===//
 
-#define TEMPLATE(NAME, TYPE)                                                   \
-  struct NAME {                                                                \
-    const TYPE *base;                                                          \
-    const TYPE *data;                                                          \
-    uint64_t off;                                                              \
-    uint64_t sizes[1];                                                         \
-    uint64_t strides[1];                                                       \
-  }
-
-TEMPLATE(MemRef1DU64, uint64_t);
-TEMPLATE(MemRef1DU32, uint32_t);
-TEMPLATE(MemRef1DU16, uint16_t);
-TEMPLATE(MemRef1DU8, uint8_t);
-TEMPLATE(MemRef1DI64, int64_t);
-TEMPLATE(MemRef1DI32, int32_t);
-TEMPLATE(MemRef1DI16, int16_t);
-TEMPLATE(MemRef1DI8, int8_t);
-TEMPLATE(MemRef1DF64, double);
-TEMPLATE(MemRef1DF32, float);
-
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
     SparseTensorCOO<V> *tensor = nullptr;                                      \
@@ -544,35 +523,37 @@ TEMPLATE(MemRef1DF32, float);
                                                          perm);                \
   }
 
-#define IMPL1(REF, NAME, TYPE, LIB)                                            \
-  void _mlir_ciface_##NAME(REF *ref, void *tensor) {                           \
+#define IMPL1(NAME, TYPE, LIB)                                                 \
+  void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
-    ref->base = ref->data = v->data();                                         \
-    ref->off = 0;                                                              \
+    ref->basePtr = ref->data = v->data();                                      \
+    ref->offset = 0;                                                           \
     ref->sizes[0] = v->size();                                                 \
     ref->strides[0] = 1;                                                       \
   }
 
-#define IMPL2(REF, NAME, TYPE, LIB)                                            \
-  void _mlir_ciface_##NAME(REF *ref, void *tensor, uint64_t d) {               \
+#define IMPL2(NAME, TYPE, LIB)                                                 \
+  void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
+                           uint64_t d) {                                       \
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
-    ref->base = ref->data = v->data();                                         \
-    ref->off = 0;                                                              \
+    ref->basePtr = ref->data = v->data();                                      \
+    ref->offset = 0;                                                           \
     ref->sizes[0] = v->size();                                                 \
     ref->strides[0] = 1;                                                       \
   }
 
 #define IMPL3(NAME, TYPE)                                                      \
-  void *_mlir_ciface_##NAME(void *tensor, TYPE value, MemRef1DU64 *iref,       \
-                            MemRef1DU64 *pref) {                               \
+  void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
+                            StridedMemRefType<uint64_t, 1> *iref,              \
+                            StridedMemRefType<uint64_t, 1> *pref) {            \
     if (!value)                                                                \
       return tensor;                                                           \
     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
     assert(iref->sizes[0] == pref->sizes[0]);                                  \
-    const uint64_t *indx = iref->data + iref->off;                             \
-    const uint64_t *perm = pref->data + pref->off;                             \
+    const uint64_t *indx = iref->data + iref->offset;                          \
+    const uint64_t *perm = pref->data + pref->offset;                          \
     uint64_t isize = iref->sizes[0];                                           \
     std::vector<uint64_t> indices(isize);                                      \
     for (uint64_t r = 0; r < isize; r++)                                       \
@@ -599,16 +580,18 @@ enum PrimaryTypeEnum : uint64_t {
 ///  1 : ptr contains coordinate scheme to assign to new storage
 ///  2 : returns empty coordinate scheme to fill (call back 1 to setup)
 ///  3 : returns coordinate scheme from storage in ptr (call back 1 to convert)
-void *_mlir_ciface_newSparseTensor(MemRef1DU8 *aref, // NOLINT
-                                   MemRef1DU64 *sref, MemRef1DU64 *pref,
-                                   uint64_t ptrTp, uint64_t indTp,
-                                   uint64_t valTp, uint32_t action, void *ptr) {
+void *
+_mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
+                             StridedMemRefType<uint64_t, 1> *sref,
+                             StridedMemRefType<uint64_t, 1> *pref,
+                             uint64_t ptrTp, uint64_t indTp, uint64_t valTp,
+                             uint32_t action, void *ptr) {
   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
          pref->strides[0] == 1);
   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
-  const uint8_t *sparsity = aref->data + aref->off;
-  const uint64_t *sizes = sref->data + sref->off;
-  const uint64_t *perm = pref->data + pref->off;
+  const uint8_t *sparsity = aref->data + aref->offset;
+  const uint64_t *sizes = sref->data + sref->offset;
+  const uint64_t *perm = pref->data + pref->offset;
   uint64_t size = aref->sizes[0];
 
   // Double matrices with all combinations of overhead storage.
@@ -668,22 +651,22 @@ void *_mlir_ciface_newSparseTensor(MemRef1DU8 *aref, // NOLINT
 }
 
 /// Methods that provide direct access to pointers, indices, and values.
-IMPL2(MemRef1DU64, sparsePointers, uint64_t, getPointers)
-IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers)
-IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers)
-IMPL2(MemRef1DU16, sparsePointers16, uint16_t, getPointers)
-IMPL2(MemRef1DU8, sparsePointers8, uint8_t, getPointers)
-IMPL2(MemRef1DU64, sparseIndices, uint64_t, getIndices)
-IMPL2(MemRef1DU64, sparseIndices64, uint64_t, getIndices)
-IMPL2(MemRef1DU32, sparseIndices32, uint32_t, getIndices)
-IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)
-IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices)
-IMPL1(MemRef1DF64, sparseValuesF64, double, getValues)
-IMPL1(MemRef1DF32, sparseValuesF32, float, getValues)
-IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues)
-IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
-IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
-IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)
+IMPL2(sparsePointers, uint64_t, getPointers)
+IMPL2(sparsePointers64, uint64_t, getPointers)
+IMPL2(sparsePointers32, uint32_t, getPointers)
+IMPL2(sparsePointers16, uint16_t, getPointers)
+IMPL2(sparsePointers8, uint8_t, getPointers)
+IMPL2(sparseIndices, uint64_t, getIndices)
+IMPL2(sparseIndices64, uint64_t, getIndices)
+IMPL2(sparseIndices32, uint32_t, getIndices)
+IMPL2(sparseIndices16, uint16_t, getIndices)
+IMPL2(sparseIndices8, uint8_t, getIndices)
+IMPL1(sparseValuesF64, double, getValues)
+IMPL1(sparseValuesF32, float, getValues)
+IMPL1(sparseValuesI64, int64_t, getValues)
+IMPL1(sparseValuesI32, int32_t, getValues)
+IMPL1(sparseValuesI16, int16_t, getValues)
+IMPL1(sparseValuesI8, int8_t, getValues)
 
 /// Helper to add value to coordinate scheme, one per value type.
 IMPL3(addEltF64, double)