From c8177f845b4132f2838d169cee04270051235140 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Thu, 29 Sep 2022 13:09:24 -0700 Subject: [PATCH] [mlir][sparse] Factoring out SparseTensorFile::canReadAs predicate This is a followup to the refactoring of D133462, D133830, D133831, and D133833. Depends On D133833 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133835 --- .../mlir/ExecutionEngine/SparseTensor/File.h | 12 ++++------ mlir/lib/ExecutionEngine/SparseTensor/File.cpp | 27 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h index 3d82b49..5dd6c17 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -86,6 +86,10 @@ public: /// Checks if a header has been successfully read. bool isValid() const { return valueKind_ != ValueKind::kInvalid; } + /// Checks if the file's ValueKind can be converted into the given + /// tensor PrimaryType. Is only valid after parsing the header. + bool canReadAs(PrimaryType valTy) const; + /// Gets the MME "pattern" property setting. Is only valid after /// parsing the header. bool isPattern() const { @@ -208,16 +212,10 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape, stfile.openFile(); stfile.readHeader(); // Check tensor element type against the value type in the input file. - SparseTensorFile::ValueKind valueKind = stfile.getValueKind(); - bool tensorIsInteger = - (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8); - bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8); - if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) || - (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) { + if (!stfile.canReadAs(valTp)) MLIR_SPARSETENSOR_FATAL( "Tensor element type %d not compatible with values in file %s\n", static_cast(valTp), filename); - } stfile.assertMatchesShape(rank, shape); // Prepare sparse tensor object with per-dimension sizes // and the number of nonzeros as initial capacity. diff --git a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp index b105f93..5a01b89 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp @@ -83,6 +83,33 @@ void SparseTensorFile::assertMatchesShape(uint64_t rank, "Dimension size mismatch"); } +bool SparseTensorFile::canReadAs(PrimaryType valTy) const { + switch (valueKind_) { + case ValueKind::kInvalid: + assert(false && "Must readHeader() before calling canReadAs()"); + return false; // In case assertions are disabled. + case ValueKind::kPattern: + return true; + case ValueKind::kInteger: + // When the file is specified to store integer values, we still + // allow implicitly converting those to floating primary-types. + return isRealPrimaryType(valTy); + case ValueKind::kReal: + // When the file is specified to store real/floating values, then + // we disallow implicit conversion to integer primary-types. + return isFloatingPrimaryType(valTy); + case ValueKind::kComplex: + // When the file is specified to store complex values, then we + // require a complex primary-type. + return isComplexPrimaryType(valTy); + case ValueKind::kUndefined: + // The "extended" FROSTT format doesn't specify a ValueKind. + // So we allow implicitly converting the stored values to both + // integer and floating primary-types. + return isRealPrimaryType(valTy); + } +} + /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case. static inline char *toLower(char *token) { for (char *c = token; *c; ++c) -- 2.7.4