From d9affadd0070d4ec67f9d7a7b07ddb4387f07a66 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Fri, 14 Oct 2022 16:36:14 -0700 Subject: [PATCH] [mlir][sparse] rename the values of the runtime DimLevelType This change is to make way for reusing the DimLevelType enum in lieu of the SparseTensorEncodingAttr::DimLevelType enum, but broken out to make it quick and easy to review Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135995 --- .../mlir/ExecutionEngine/SparseTensor/Enums.h | 96 +++++++++++----------- .../SparseTensor/Transforms/CodegenUtils.cpp | 18 ++-- mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp | 4 +- 3 files changed, 59 insertions(+), 59 deletions(-) diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h index 1ead713..8dd0a2e 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h @@ -131,32 +131,32 @@ enum class Action : uint32_t { /// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType` /// is the source of truth and this enum should be kept consistent with it. enum class DimLevelType : uint8_t { - kDense = 4, // 0b001_00 - kCompressed = 8, // 0b010_00 - kCompressedNu = 9, // 0b010_01 - kCompressedNo = 10, // 0b010_10 - kCompressedNuNo = 11, // 0b010_11 - kSingleton = 16, // 0b100_00 - kSingletonNu = 17, // 0b100_01 - kSingletonNo = 18, // 0b100_10 - kSingletonNuNo = 19, // 0b100_11 + Dense = 4, // 0b001_00 + Compressed = 8, // 0b010_00 + CompressedNu = 9, // 0b010_01 + CompressedNo = 10, // 0b010_10 + CompressedNuNo = 11, // 0b010_11 + Singleton = 16, // 0b100_00 + SingletonNu = 17, // 0b100_01 + SingletonNo = 18, // 0b100_10 + SingletonNuNo = 19, // 0b100_11 }; /// Check if the `DimLevelType` is dense. constexpr bool isDenseDLT(DimLevelType dlt) { - return dlt == DimLevelType::kDense; + return dlt == DimLevelType::Dense; } /// Check if the `DimLevelType` is compressed (regardless of properties). constexpr bool isCompressedDLT(DimLevelType dlt) { return static_cast(dlt) & - static_cast(DimLevelType::kCompressed); + static_cast(DimLevelType::Compressed); } /// Check if the `DimLevelType` is singleton (regardless of properties). constexpr bool isSingletonDLT(DimLevelType dlt) { return static_cast(dlt) & - static_cast(DimLevelType::kSingleton); + static_cast(DimLevelType::Singleton); } /// Check if the `DimLevelType` is ordered (regardless of storage format). @@ -170,48 +170,48 @@ constexpr bool isUniqueDLT(DimLevelType dlt) { } // Ensure the above predicates work as intended. -static_assert((!isCompressedDLT(DimLevelType::kDense) && - isCompressedDLT(DimLevelType::kCompressed) && - isCompressedDLT(DimLevelType::kCompressedNu) && - isCompressedDLT(DimLevelType::kCompressedNo) && - isCompressedDLT(DimLevelType::kCompressedNuNo) && - !isCompressedDLT(DimLevelType::kSingleton) && - !isCompressedDLT(DimLevelType::kSingletonNu) && - !isCompressedDLT(DimLevelType::kSingletonNo) && - !isCompressedDLT(DimLevelType::kSingletonNuNo)), +static_assert((!isCompressedDLT(DimLevelType::Dense) && + isCompressedDLT(DimLevelType::Compressed) && + isCompressedDLT(DimLevelType::CompressedNu) && + isCompressedDLT(DimLevelType::CompressedNo) && + isCompressedDLT(DimLevelType::CompressedNuNo) && + !isCompressedDLT(DimLevelType::Singleton) && + !isCompressedDLT(DimLevelType::SingletonNu) && + !isCompressedDLT(DimLevelType::SingletonNo) && + !isCompressedDLT(DimLevelType::SingletonNuNo)), "isCompressedDLT definition is broken"); -static_assert((!isSingletonDLT(DimLevelType::kDense) && - !isSingletonDLT(DimLevelType::kCompressed) && - !isSingletonDLT(DimLevelType::kCompressedNu) && - !isSingletonDLT(DimLevelType::kCompressedNo) && - !isSingletonDLT(DimLevelType::kCompressedNuNo) && - isSingletonDLT(DimLevelType::kSingleton) && - isSingletonDLT(DimLevelType::kSingletonNu) && - isSingletonDLT(DimLevelType::kSingletonNo) && - isSingletonDLT(DimLevelType::kSingletonNuNo)), +static_assert((!isSingletonDLT(DimLevelType::Dense) && + !isSingletonDLT(DimLevelType::Compressed) && + !isSingletonDLT(DimLevelType::CompressedNu) && + !isSingletonDLT(DimLevelType::CompressedNo) && + !isSingletonDLT(DimLevelType::CompressedNuNo) && + isSingletonDLT(DimLevelType::Singleton) && + isSingletonDLT(DimLevelType::SingletonNu) && + isSingletonDLT(DimLevelType::SingletonNo) && + isSingletonDLT(DimLevelType::SingletonNuNo)), "isSingletonDLT definition is broken"); -static_assert((isOrderedDLT(DimLevelType::kDense) && - isOrderedDLT(DimLevelType::kCompressed) && - isOrderedDLT(DimLevelType::kCompressedNu) && - !isOrderedDLT(DimLevelType::kCompressedNo) && - !isOrderedDLT(DimLevelType::kCompressedNuNo) && - isOrderedDLT(DimLevelType::kSingleton) && - isOrderedDLT(DimLevelType::kSingletonNu) && - !isOrderedDLT(DimLevelType::kSingletonNo) && - !isOrderedDLT(DimLevelType::kSingletonNuNo)), +static_assert((isOrderedDLT(DimLevelType::Dense) && + isOrderedDLT(DimLevelType::Compressed) && + isOrderedDLT(DimLevelType::CompressedNu) && + !isOrderedDLT(DimLevelType::CompressedNo) && + !isOrderedDLT(DimLevelType::CompressedNuNo) && + isOrderedDLT(DimLevelType::Singleton) && + isOrderedDLT(DimLevelType::SingletonNu) && + !isOrderedDLT(DimLevelType::SingletonNo) && + !isOrderedDLT(DimLevelType::SingletonNuNo)), "isOrderedDLT definition is broken"); -static_assert((isUniqueDLT(DimLevelType::kDense) && - isUniqueDLT(DimLevelType::kCompressed) && - !isUniqueDLT(DimLevelType::kCompressedNu) && - isUniqueDLT(DimLevelType::kCompressedNo) && - !isUniqueDLT(DimLevelType::kCompressedNuNo) && - isUniqueDLT(DimLevelType::kSingleton) && - !isUniqueDLT(DimLevelType::kSingletonNu) && - isUniqueDLT(DimLevelType::kSingletonNo) && - !isUniqueDLT(DimLevelType::kSingletonNuNo)), +static_assert((isUniqueDLT(DimLevelType::Dense) && + isUniqueDLT(DimLevelType::Compressed) && + !isUniqueDLT(DimLevelType::CompressedNu) && + isUniqueDLT(DimLevelType::CompressedNo) && + !isUniqueDLT(DimLevelType::CompressedNuNo) && + isUniqueDLT(DimLevelType::Singleton) && + !isUniqueDLT(DimLevelType::SingletonNu) && + isUniqueDLT(DimLevelType::SingletonNo) && + !isUniqueDLT(DimLevelType::SingletonNuNo)), "isUniqueDLT definition is broken"); } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 54e206e..e7f0043 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -392,23 +392,23 @@ DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( SparseTensorEncodingAttr::DimLevelType dlt) { switch (dlt) { case SparseTensorEncodingAttr::DimLevelType::Dense: - return DimLevelType::kDense; + return DimLevelType::Dense; case SparseTensorEncodingAttr::DimLevelType::Compressed: - return DimLevelType::kCompressed; + return DimLevelType::Compressed; case SparseTensorEncodingAttr::DimLevelType::CompressedNu: - return DimLevelType::kCompressedNu; + return DimLevelType::CompressedNu; case SparseTensorEncodingAttr::DimLevelType::CompressedNo: - return DimLevelType::kCompressedNo; + return DimLevelType::CompressedNo; case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: - return DimLevelType::kCompressedNuNo; + return DimLevelType::CompressedNuNo; case SparseTensorEncodingAttr::DimLevelType::Singleton: - return DimLevelType::kSingleton; + return DimLevelType::Singleton; case SparseTensorEncodingAttr::DimLevelType::SingletonNu: - return DimLevelType::kSingletonNu; + return DimLevelType::SingletonNu; case SparseTensorEncodingAttr::DimLevelType::SingletonNo: - return DimLevelType::kSingletonNo; + return DimLevelType::SingletonNo; case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: - return DimLevelType::kSingletonNuNo; + return DimLevelType::SingletonNuNo; } llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); } diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp index 721241e..0307f2f 100644 --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -127,8 +127,8 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, const uint64_t *shape, // Verify that the sparsity values are supported. // TODO: update this check to match what we actually support. for (uint64_t i = 0; i < rank; ++i) - if (sparsity[i] != DimLevelType::kDense && - sparsity[i] != DimLevelType::kCompressed) + if (sparsity[i] != DimLevelType::Dense && + sparsity[i] != DimLevelType::Compressed) MLIR_SPARSETENSOR_FATAL("unsupported dimension level type: %d\n", static_cast(sparsity[i])); #endif -- 2.7.4