[mlir][sparse] Factoring out SparseTensorFile class for readSparseTensorShape
authorwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 31 May 2022 19:55:01 +0000 (12:55 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 31 May 2022 20:24:28 +0000 (13:24 -0700)
The primary goal of this change is to define readSparseTensorShape.  Whereas the SparseTensorFile class is merely introduced as a way to reduce code duplication along the way.

Depends On D126106

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126233

mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp

index e733544..51e78c3 100644 (file)
@@ -274,6 +274,11 @@ FOREVERY_V(DECL_DELCOO)
 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
 MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
 
+/// Helper function to read the header of a file and return the
+/// shape/sizes, without parsing the elements of the file.
+MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
+                                                    std::vector<uint64_t> *out);
+
 /// Initializes sparse tensor from a COO-flavored format expressed using
 /// C-style data structures.  The expected parameters are:
 ///
index 430ba85..da3b705 100644 (file)
@@ -88,9 +88,11 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
     exit(1);                                                                   \
   }
 
-// TODO: adjust this so it can be used by `openSparseTensorCOO` too.
-// That version doesn't have the permutation, and the `dimSizes` are
-// a pointer/C-array rather than `std::vector`.
+// TODO: try to unify this with `SparseTensorFile::assertMatchesShape`
+// which is used by `openSparseTensorCOO`.  It's easy enough to resolve
+// the `std::vector` vs pointer mismatch for `dimSizes`; but it's trickier
+// to resolve the presence/absence of `perm` (without introducing extra
+// overhead), so perhaps the code duplication is unavoidable.
 //
 /// Asserts that the `dimSizes` (in target-order) under the `perm` (mapping
 /// semantic-order to target-order) are a refinement of the desired `shape`
@@ -1099,9 +1101,128 @@ static char *toLower(char *token) {
   return token;
 }
 
+/// This class abstracts over the information stored in file headers,
+/// as well as providing the buffers and methods for parsing those headers.
+class SparseTensorFile final {
+public:
+  explicit SparseTensorFile(char *filename) : filename(filename) {
+    assert(filename && "Received nullptr for filename");
+  }
+
+  // Disallows copying, to avoid duplicating the `file` pointer.
+  SparseTensorFile(const SparseTensorFile &) = delete;
+  SparseTensorFile &operator=(const SparseTensorFile &) = delete;
+
+  // This dtor tries to avoid leaking the `file`.  (Though it's better
+  // to call `closeFile` explicitly when possible, since there are
+  // circumstances where dtors are not called reliably.)
+  ~SparseTensorFile() { closeFile(); }
+
+  /// Opens the file for reading.
+  void openFile() {
+    if (file)
+      FATAL("Already opened file %s\n", filename);
+    file = fopen(filename, "r");
+    if (!file)
+      FATAL("Cannot find file %s\n", filename);
+  }
+
+  /// Closes the file.
+  void closeFile() {
+    if (file) {
+      fclose(file);
+      file = nullptr;
+    }
+  }
+
+  // TODO(wrengr/bixia): figure out how to reorganize the element-parsing
+  // loop of `openSparseTensorCOO` into methods of this class, so we can
+  // avoid leaking access to the `line` pointer (both for general hygiene
+  // and because we can't mark it const due to the second argument of
+  // `strtoul`/`strtoud` being `char * *restrict` rather than
+  // `char const* *restrict`).
+  //
+  /// Attempts to read a line from the file.
+  char *readLine() {
+    if (fgets(line, kColWidth, file))
+      return line;
+    FATAL("Cannot read next line of %s\n", filename);
+  }
+
+  /// Reads and parses the file's header.
+  void readHeader() {
+    assert(file && "Attempt to readHeader() before openFile()");
+    if (strstr(filename, ".mtx"))
+      readMMEHeader();
+    else if (strstr(filename, ".tns"))
+      readExtFROSTTHeader();
+    else
+      FATAL("Unknown format %s\n", filename);
+    assert(isValid && "Failed to read the header");
+  }
+
+  /// Gets the MME "pattern" property setting.  Is only valid after
+  /// parsing the header.
+  bool isPattern() const {
+    assert(isValid && "Attempt to isPattern() before readHeader()");
+    return isPattern_;
+  }
+
+  /// Gets the MME "symmetric" property setting.  Is only valid after
+  /// parsing the header.
+  bool isSymmetric() const {
+    assert(isValid && "Attempt to isSymmetric() before readHeader()");
+    return isSymmetric_;
+  }
+
+  /// Gets the rank of the tensor.  Is only valid after parsing the header.
+  uint64_t getRank() const {
+    assert(isValid && "Attempt to getRank() before readHeader()");
+    return idata[0];
+  }
+
+  /// Gets the number of non-zeros.  Is only valid after parsing the header.
+  uint64_t getNNZ() const {
+    assert(isValid && "Attempt to getNNZ() before readHeader()");
+    return idata[1];
+  }
+
+  /// Gets the dimension-sizes array.  The pointer itself is always
+  /// valid; however, the values stored therein are only valid after
+  /// parsing the header.
+  const uint64_t *getDimSizes() const { return idata + 2; }
+
+  /// Safely gets the size of the given dimension.  Is only valid
+  /// after parsing the header.
+  uint64_t getDimSize(uint64_t d) const {
+    assert(d < getRank());
+    return idata[2 + d];
+  }
+
+  /// Asserts the shape subsumes the actual dimension sizes.  Is only
+  /// valid after parsing the header.
+  void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
+    assert(rank == getRank() && "Rank mismatch");
+    for (uint64_t r = 0; r < rank; r++)
+      assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
+             "Dimension size mismatch");
+  }
+
+private:
+  void readMMEHeader();
+  void readExtFROSTTHeader();
+
+  const char *filename;
+  FILE *file = nullptr;
+  bool isValid = false;
+  bool isPattern_ = false;
+  bool isSymmetric_ = false;
+  uint64_t idata[512];
+  char line[kColWidth];
+};
+
 /// Read the MME header of a general sparse matrix of type real.
-static void readMMEHeader(FILE *file, char *filename, char *line,
-                          uint64_t *idata, bool *isPattern, bool *isSymmetric) {
+void SparseTensorFile::readMMEHeader() {
   char header[64];
   char object[64];
   char format[64];
@@ -1112,19 +1233,18 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
              symmetry) != 5)
     FATAL("Corrupt header in %s\n", filename);
   // Set properties
-  *isPattern = (strcmp(toLower(field), "pattern") == 0);
-  *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
+  isPattern_ = (strcmp(toLower(field), "pattern") == 0);
+  isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
   // Make sure this is a general sparse matrix.
   if (strcmp(toLower(header), "%%matrixmarket") ||
       strcmp(toLower(object), "matrix") ||
       strcmp(toLower(format), "coordinate") ||
-      (strcmp(toLower(field), "real") && !(*isPattern)) ||
-      (strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
+      (strcmp(toLower(field), "real") && !isPattern_) ||
+      (strcmp(toLower(symmetry), "general") && !isSymmetric_))
     FATAL("Cannot find a general sparse matrix in %s\n", filename);
   // Skip comments.
   while (true) {
-    if (!fgets(line, kColWidth, file))
-      FATAL("Cannot find data in %s\n", filename);
+    readLine();
     if (line[0] != '%')
       break;
   }
@@ -1133,18 +1253,17 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
              idata + 1) != 3)
     FATAL("Cannot find size in %s\n", filename);
+  isValid = true;
 }
 
 /// Read the "extended" FROSTT header. Although not part of the documented
 /// format, we assume that the file starts with optional comments followed
 /// by two lines that define the rank, the number of nonzeros, and the
 /// dimensions sizes (one per rank) of the sparse tensor.
-static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
-                                uint64_t *idata) {
+void SparseTensorFile::readExtFROSTTHeader() {
   // Skip comments.
   while (true) {
-    if (!fgets(line, kColWidth, file))
-      FATAL("Cannot find data in %s\n", filename);
+    readLine();
     if (line[0] != '#')
       break;
   }
@@ -1155,7 +1274,8 @@ static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
   for (uint64_t r = 0; r < idata[0]; r++)
     if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
       FATAL("Cannot find dimension size %s\n", filename);
-  fgets(line, kColWidth, file); // end of line
+  readLine(); // end of line
+  isValid = true;
 }
 
 /// Reads a sparse tensor with the given filename into a memory-resident
@@ -1164,38 +1284,19 @@ template <typename V>
 static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
                                                const uint64_t *shape,
                                                const uint64_t *perm) {
-  // Open the file.
-  assert(filename && "Received nullptr for filename");
-  FILE *file = fopen(filename, "r");
-  if (!file)
-    FATAL("Cannot find file %s\n", filename);
-  // Perform some file format dependent set up.
-  char line[kColWidth];
-  uint64_t idata[512];
-  bool isPattern = false;
-  bool isSymmetric = false;
-  if (strstr(filename, ".mtx")) {
-    readMMEHeader(file, filename, line, idata, &isPattern, &isSymmetric);
-  } else if (strstr(filename, ".tns")) {
-    readExtFROSTTHeader(file, filename, line, idata);
-  } else {
-    FATAL("Unknown format %s\n", filename);
-  }
+  SparseTensorFile stfile(filename);
+  stfile.openFile();
+  stfile.readHeader();
+  stfile.assertMatchesShape(rank, shape);
   // Prepare sparse tensor object with per-dimension sizes
   // and the number of nonzeros as initial capacity.
-  assert(rank == idata[0] && "rank mismatch");
-  uint64_t nnz = idata[1];
-  for (uint64_t r = 0; r < rank; r++)
-    assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
-           "dimension size mismatch");
-  SparseTensorCOO<V> *tensor =
-      SparseTensorCOO<V>::newSparseTensorCOO(rank, idata + 2, perm, nnz);
+  uint64_t nnz = stfile.getNNZ();
+  auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
+                                                     perm, nnz);
   // Read all nonzero elements.
   std::vector<uint64_t> indices(rank);
   for (uint64_t k = 0; k < nnz; k++) {
-    if (!fgets(line, kColWidth, file))
-      FATAL("Cannot find next line of data in %s\n", filename);
-    char *linePtr = line;
+    char *linePtr = stfile.readLine();
     for (uint64_t r = 0; r < rank; r++) {
       uint64_t idx = strtoul(linePtr, &linePtr, 10);
       // Add 0-based index.
@@ -1204,17 +1305,18 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
     // The external formats always store the numerical values with the type
     // double, but we cast these values to the sparse tensor object type.
     // For a pattern tensor, we arbitrarily pick the value 1 for all entries.
-    double value = isPattern ? 1.0 : strtod(linePtr, &linePtr);
-    tensor->add(indices, value);
+    double value = stfile.isPattern() ? 1.0 : strtod(linePtr, &linePtr);
+    // TODO: <https://github.com/llvm/llvm-project/issues/54179>
+    coo->add(indices, value);
     // We currently chose to deal with symmetric matrices by fully constructing
     // them. In the future, we may want to make symmetry implicit for storage
     // reasons.
-    if (isSymmetric && indices[0] != indices[1])
-      tensor->add({indices[1], indices[0]}, value);
+    if (stfile.isSymmetric() && indices[0] != indices[1])
+      coo->add({indices[1], indices[0]}, value);
   }
   // Close the file and return tensor.
-  fclose(file);
-  return tensor;
+  stfile.closeFile();
+  return coo;
 }
 
 /// Writes the sparse tensor to `dest` in extended FROSTT format.
@@ -1670,6 +1772,18 @@ char *getTensorFilename(index_type id) {
   return env;
 }
 
+void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
+  assert(out && "Received nullptr for out-parameter");
+  SparseTensorFile stfile(filename);
+  stfile.openFile();
+  stfile.readHeader();
+  stfile.closeFile();
+  const uint64_t rank = stfile.getRank();
+  const uint64_t *dimSizes = stfile.getDimSizes();
+  out->reserve(rank);
+  out->assign(dimSizes, dimSizes + rank);
+}
+
 // TODO: generalize beyond 64-bit indices.
 #define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V)                               \
   void *convertToMLIRSparseTensor##VNAME(                                      \