From e568d0016eb8096d01eb30c485a054df786a69bf Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 3 Jan 2023 18:06:54 -0800 Subject: [PATCH] [mlir][sparse] minor code layout edits Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D140934 --- .../Transforms/SparseStorageSpecifierToLLVM.cpp | 32 +++++++++++++++++++--- .../Transforms/SparseTensorStorageLayout.cpp | 22 +++++++++++++++ .../Transforms/SparseTensorStorageLayout.h | 2 ++ 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp index 5843aef..eaa4b42 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -14,6 +14,12 @@ using namespace mlir; using namespace sparse_tensor; +namespace { + +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); @@ -34,10 +40,9 @@ static Type convertSpecifier(StorageSpecifierType tp) { getSpecifierFields(tp)); } -StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); }); -} +//===----------------------------------------------------------------------===// +// Specifier struct builder. +//===----------------------------------------------------------------------===// constexpr uint64_t kDimSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; @@ -102,6 +107,21 @@ void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc, loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); } +} // namespace + +//===----------------------------------------------------------------------===// +// The sparse storage specifier type converter (defined in Passes.h). +//===----------------------------------------------------------------------===// + +StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); }); +} + +//===----------------------------------------------------------------------===// +// Storage specifier conversion rules. +//===----------------------------------------------------------------------===// + template class SpecifierGetterSetterOpConverter : public OpConversionPattern { public: @@ -176,6 +196,10 @@ public: } }; +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter, RewritePatternSet &patterns) { patterns.add &fields) { return success(); } +//===----------------------------------------------------------------------===// +// The sparse tensor type converter (defined in Passes.h). +//===----------------------------------------------------------------------===// + SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); addConversion([&](RankedTensorType rtp, SmallVectorImpl &fields) { @@ -65,6 +75,10 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { }); } +//===----------------------------------------------------------------------===// +// StorageLayout methods. +//===----------------------------------------------------------------------===// + unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, std::optional dim) const { unsigned fieldIdx = -1u; @@ -89,6 +103,10 @@ unsigned StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind, return getMemRefFieldIndex(toFieldKind(kind), dim); } +//===----------------------------------------------------------------------===// +// StorageTensorSpecifier methods. +//===----------------------------------------------------------------------===// + Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, RankedTensorType rtp) { return builder.create( @@ -114,6 +132,10 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, createIndexCast(builder, loc, v, getFieldType(kind, dim))); } +//===----------------------------------------------------------------------===// +// Public methods. +//===----------------------------------------------------------------------===// + constexpr uint64_t kDataFieldStartingIdx = 0; void sparse_tensor::foreachFieldInSparseTensor( diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h index d94aa1f..45aac6d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -50,6 +50,7 @@ namespace sparse_tensor { // }; // //===----------------------------------------------------------------------===// + enum class SparseTensorFieldKind : uint32_t { StorageSpec = 0, PtrMemRef = 1, @@ -355,4 +356,5 @@ getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { } // namespace sparse_tensor } // namespace mlir + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ -- 2.7.4