From 9fc35c306cfd61d069633a676fe2a6c56df11fe1 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 11 Oct 2021 17:11:13 -0700 Subject: [PATCH] [mlir][sparse] add action enum, group enums together, few asserts Some random changes that were hanging around in my workspace. Also, a tiny step towards creating a header file for the sparse utils lib. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D111589 --- mlir/lib/ExecutionEngine/SparseUtils.cpp | 53 ++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp index 77cff10..b30804e 100644 --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -507,24 +507,41 @@ char *getTensorFilename(uint64_t id) { // //===----------------------------------------------------------------------===// +enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; + +enum PrimaryTypeEnum : uint64_t { + kF64 = 1, + kF32 = 2, + kI64 = 3, + kI32 = 4, + kI16 = 5, + kI8 = 6 +}; + +enum Action : uint32_t { kFromFile = 0, kFromCOO = 1, kNewCOO = 2, kToCOO = 3 }; + #define CASE(p, i, v, P, I, V) \ if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ SparseTensorCOO *tensor = nullptr; \ - if (action == 0) \ + if (action == kFromFile) \ tensor = \ openSparseTensorCOO(static_cast(ptr), size, sizes, perm); \ - else if (action == 1) \ + else if (action == kFromCOO) \ tensor = static_cast *>(ptr); \ - else if (action == 2) \ + else if (action == kNewCOO) \ return SparseTensorCOO::newSparseTensorCOO(size, sizes, perm); \ - else \ + else if (action == kToCOO) \ return static_cast *>(ptr)->toCOO(perm); \ + else \ + assert(0); \ return SparseTensorStorage::newSparseTensor(tensor, sparsity, \ perm); \ } #define IMPL1(NAME, TYPE, LIB) \ void _mlir_ciface_##NAME(StridedMemRefType *ref, void *tensor) { \ + assert(ref); \ + assert(tensor); \ std::vector *v; \ static_cast(tensor)->LIB(&v); \ ref->basePtr = ref->data = v->data(); \ @@ -536,6 +553,8 @@ char *getTensorFilename(uint64_t id) { #define IMPL2(NAME, TYPE, LIB) \ void _mlir_ciface_##NAME(StridedMemRefType *ref, void *tensor, \ uint64_t d) { \ + assert(ref); \ + assert(tensor); \ std::vector *v; \ static_cast(tensor)->LIB(&v, d); \ ref->basePtr = ref->data = v->data(); \ @@ -548,6 +567,9 @@ char *getTensorFilename(uint64_t id) { void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ StridedMemRefType *iref, \ StridedMemRefType *pref) { \ + assert(tensor); \ + assert(iref); \ + assert(pref); \ assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ assert(iref->sizes[0] == pref->sizes[0]); \ const uint64_t *indx = iref->data + iref->offset; \ @@ -560,30 +582,21 @@ char *getTensorFilename(uint64_t id) { return tensor; \ } -enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; - -enum PrimaryTypeEnum : uint64_t { - kF64 = 1, - kF32 = 2, - kI64 = 3, - kI32 = 4, - kI16 = 5, - kI8 = 6 -}; - /// Constructs a new sparse tensor. This is the "swiss army knife" /// method for materializing sparse tensors into the computation. -/// action -/// 0 : ptr contains filename to read into storage -/// 1 : ptr contains coordinate scheme to assign to new storage -/// 2 : returns empty coordinate scheme to fill (call back 1 to setup) -/// 3 : returns coordinate scheme from storage in ptr (call back 1 to convert) +/// +/// action: +/// kFromFile = ptr contains filename to read into storage +/// kFromCOO = ptr contains coordinate scheme to assign to new storage +/// kNewCOO = returns empty coordinate scheme to fill and use with kFromCOO +/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO void * _mlir_ciface_newSparseTensor(StridedMemRefType *aref, // NOLINT StridedMemRefType *sref, StridedMemRefType *pref, uint64_t ptrTp, uint64_t indTp, uint64_t valTp, uint32_t action, void *ptr) { + assert(aref && sref && pref); assert(aref->strides[0] == 1 && sref->strides[0] == 1 && pref->strides[0] == 1); assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]); -- 2.7.4