[mlir][sparse] add action enum, group enums together, few asserts
authorAart Bik <ajcbik@google.com>
Tue, 12 Oct 2021 00:11:13 +0000 (17:11 -0700)
committerAart Bik <ajcbik@google.com>
Tue, 12 Oct 2021 15:50:51 +0000 (08:50 -0700)
Some random changes that were hanging around in my workspace. Also,
a tiny step towards creating a header file for the sparse utils lib.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D111589

mlir/lib/ExecutionEngine/SparseUtils.cpp

index 77cff10..b30804e 100644 (file)
@@ -507,24 +507,41 @@ char *getTensorFilename(uint64_t id) {
 //
 //===----------------------------------------------------------------------===//
 
+enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
+
+enum PrimaryTypeEnum : uint64_t {
+  kF64 = 1,
+  kF32 = 2,
+  kI64 = 3,
+  kI32 = 4,
+  kI16 = 5,
+  kI8 = 6
+};
+
+enum Action : uint32_t { kFromFile = 0, kFromCOO = 1, kNewCOO = 2, kToCOO = 3 };
+
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
     SparseTensorCOO<V> *tensor = nullptr;                                      \
-    if (action == 0)                                                           \
+    if (action == kFromFile)                                                   \
       tensor =                                                                 \
           openSparseTensorCOO<V>(static_cast<char *>(ptr), size, sizes, perm); \
-    else if (action == 1)                                                      \
+    else if (action == kFromCOO)                                               \
       tensor = static_cast<SparseTensorCOO<V> *>(ptr);                         \
-    else if (action == 2)                                                      \
+    else if (action == kNewCOO)                                                \
       return SparseTensorCOO<V>::newSparseTensorCOO(size, sizes, perm);        \
-    else                                                                       \
+    else if (action == kToCOO)                                                 \
       return static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);    \
+    else                                                                       \
+      assert(0);                                                               \
     return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity,     \
                                                          perm);                \
   }
 
 #define IMPL1(NAME, TYPE, LIB)                                                 \
   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
+    assert(ref);                                                               \
+    assert(tensor);                                                            \
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v);                   \
     ref->basePtr = ref->data = v->data();                                      \
@@ -536,6 +553,8 @@ char *getTensorFilename(uint64_t id) {
 #define IMPL2(NAME, TYPE, LIB)                                                 \
   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor,      \
                            uint64_t d) {                                       \
+    assert(ref);                                                               \
+    assert(tensor);                                                            \
     std::vector<TYPE> *v;                                                      \
     static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d);                \
     ref->basePtr = ref->data = v->data();                                      \
@@ -548,6 +567,9 @@ char *getTensorFilename(uint64_t id) {
   void *_mlir_ciface_##NAME(void *tensor, TYPE value,                          \
                             StridedMemRefType<uint64_t, 1> *iref,              \
                             StridedMemRefType<uint64_t, 1> *pref) {            \
+    assert(tensor);                                                            \
+    assert(iref);                                                              \
+    assert(pref);                                                              \
     assert(iref->strides[0] == 1 && pref->strides[0] == 1);                    \
     assert(iref->sizes[0] == pref->sizes[0]);                                  \
     const uint64_t *indx = iref->data + iref->offset;                          \
@@ -560,30 +582,21 @@ char *getTensorFilename(uint64_t id) {
     return tensor;                                                             \
   }
 
-enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
-
-enum PrimaryTypeEnum : uint64_t {
-  kF64 = 1,
-  kF32 = 2,
-  kI64 = 3,
-  kI32 = 4,
-  kI16 = 5,
-  kI8 = 6
-};
-
 /// Constructs a new sparse tensor. This is the "swiss army knife"
 /// method for materializing sparse tensors into the computation.
-///  action
-///  0 : ptr contains filename to read into storage
-///  1 : ptr contains coordinate scheme to assign to new storage
-///  2 : returns empty coordinate scheme to fill (call back 1 to setup)
-///  3 : returns coordinate scheme from storage in ptr (call back 1 to convert)
+///
+/// action:
+/// kFromFile = ptr contains filename to read into storage
+/// kFromCOO = ptr contains coordinate scheme to assign to new storage
+/// kNewCOO = returns empty coordinate scheme to fill and use with kFromCOO
+/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
 void *
 _mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
                              StridedMemRefType<uint64_t, 1> *sref,
                              StridedMemRefType<uint64_t, 1> *pref,
                              uint64_t ptrTp, uint64_t indTp, uint64_t valTp,
                              uint32_t action, void *ptr) {
+  assert(aref && sref && pref);
   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
          pref->strides[0] == 1);
   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);