[mlir][sparse] assert fail on mismatch between rank and annotations array
authorAart Bik <ajcbik@google.com>
Fri, 19 Feb 2021 06:01:39 +0000 (22:01 -0800)
committerAart Bik <ajcbik@google.com>
Fri, 19 Feb 2021 07:22:14 +0000 (23:22 -0800)
Rationale:
Providing the wrong number of sparse/dense annotations was silently
ignored or caused unrelated crashes. This minor change verifies that
the provided number matches the rank.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D97034

mlir/lib/ExecutionEngine/SparseUtils.cpp

index 0ff6f7d..903b9f1 100644 (file)
@@ -76,8 +76,8 @@ public:
   }
   /// Adds element as indices and value.
   void add(const std::vector<uint64_t> &ind, double val) {
-    assert(sizes.size() == ind.size());
-    for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
+    assert(getRank() == ind.size());
+    for (int64_t r = 0, rank = getRank(); r < rank; r++)
       assert(ind[r] < sizes[r]); // within bounds
     elements.emplace_back(Element(ind, val));
   }
@@ -85,6 +85,8 @@ public:
   void sort() { std::sort(elements.begin(), elements.end(), lexOrder); }
   /// Primitive one-time iteration.
   const Element &next() { return elements[pos++]; }
+  /// Returns rank.
+  uint64_t getRank() const { return sizes.size(); }
   /// Getter for sizes array.
   const std::vector<uint64_t> &getSizes() const { return sizes; }
   /// Getter for elements array.
@@ -139,13 +141,13 @@ public:
   /// Constructs sparse tensor storage scheme following the given
   /// per-rank dimension dense/sparse annotations.
   SparseTensorStorage(SparseTensor *tensor, bool *sparsity)
-      : sizes(tensor->getSizes()), pointers(sizes.size()),
-        indices(sizes.size()) {
+      : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
     // Provide hints on capacity.
     // TODO: needs fine-tuning based on sparsity
-    values.reserve(tensor->getElements().size());
-    for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) {
-      s *= tensor->getSizes()[d];
+    uint64_t nnz = tensor->getElements().size();
+    values.reserve(nnz);
+    for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
+      s *= sizes[d];
       if (sparsity[d]) {
         pointers[d].reserve(s + 1);
         indices[d].reserve(s);
@@ -153,12 +155,16 @@ public:
       }
     }
     // Then setup the tensor.
-    traverse(tensor, sparsity, 0, tensor->getElements().size(), 0);
+    traverse(tensor, sparsity, 0, nnz, 0);
   }
 
   virtual ~SparseTensorStorage() {}
 
+  uint64_t getRank() const { return sizes.size(); }
+
   uint64_t getDimSize(uint64_t d) override { return sizes[d]; }
+
+  // Partially specialize these three methods based on template types.
   void getPointers(std::vector<P> **out, uint64_t d) override {
     *out = &pointers[d];
   }
@@ -176,7 +182,7 @@ private:
                 uint64_t d) {
     const std::vector<Element> &elements = tensor->getElements();
     // Once dimensions are exhausted, insert the numerical values.
-    if (d == sizes.size()) {
+    if (d == getRank()) {
       values.push_back(lo < hi ? elements[lo].value : 0.0);
       return;
     }
@@ -221,9 +227,10 @@ private:
 
 /// Templated reader.
 template <typename P, typename I, typename V>
-void *newSparseTensor(char *filename, bool *sparsity) {
+void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) {
   uint64_t idata[64];
   SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
+  assert(size == t->getRank()); // sparsity array must match rank
   SparseTensorStorageBase *tensor =
       new SparseTensorStorage<P, I, V>(t, sparsity);
   delete t;
@@ -481,21 +488,29 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
   assert(astride == 1);
   bool *sparsity = abase + aoff;
   if (ptrTp == kU64 && indTp == kU64 && valTp == kF64)
-    return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity);
+    return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity,
+                                                       asize);
   if (ptrTp == kU64 && indTp == kU64 && valTp == kF32)
-    return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity);
+    return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity,
+                                                      asize);
   if (ptrTp == kU64 && indTp == kU32 && valTp == kF64)
-    return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity);
+    return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity,
+                                                       asize);
   if (ptrTp == kU64 && indTp == kU32 && valTp == kF32)
-    return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity);
+    return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity,
+                                                      asize);
   if (ptrTp == kU32 && indTp == kU64 && valTp == kF64)
-    return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity);
+    return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity,
+                                                       asize);
   if (ptrTp == kU32 && indTp == kU64 && valTp == kF32)
-    return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity);
+    return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity,
+                                                      asize);
   if (ptrTp == kU32 && indTp == kU32 && valTp == kF64)
-    return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity);
+    return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity,
+                                                       asize);
   if (ptrTp == kU32 && indTp == kU32 && valTp == kF32)
-    return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity);
+    return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity,
+                                                      asize);
   fputs("unsupported combination of types\n", stderr);
   exit(1);
 }