From ebc8466481f938d799d0170e0caf1c7e64f52c53 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Mon, 21 Mar 2022 16:49:54 -0700 Subject: [PATCH] [mlir][sparse] Adding {pointer,index}OverheadTypeEncoding Work towards: https://github.com/llvm/llvm-project/issues/51652 The new functions fill the gap between `overheadTypeEncoding` and `get{Pointer,Index}OverheadType`. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D122056 --- .../SparseTensor/Transforms/CodegenUtils.cpp | 15 ++++++++++++--- .../SparseTensor/Transforms/CodegenUtils.h | 6 ++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index ea9be3bddb54..38f15a1ee7c0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -58,15 +58,24 @@ Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { llvm_unreachable("Unknown OverheadType"); } +OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( + const SparseTensorEncodingAttr &enc) { + return overheadTypeEncoding(enc.getPointerBitWidth()); +} + +OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( + const SparseTensorEncodingAttr &enc) { + return overheadTypeEncoding(enc.getIndexBitWidth()); +} + Type mlir::sparse_tensor::getPointerOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { - return getOverheadType(builder, - overheadTypeEncoding(enc.getPointerBitWidth())); + return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); } Type mlir::sparse_tensor::getIndexOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { - return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); + return getOverheadType(builder, indexOverheadTypeEncoding(enc)); } StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 9286cca808aa..944605bab4b9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -38,6 +38,12 @@ OverheadType overheadTypeEncoding(Type tp); /// Converts the internal type-encoding for overhead storage to an mlir::Type. Type getOverheadType(Builder &builder, OverheadType ot); +/// Returns the OverheadType for pointer overhead storage. +OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); + +/// Returns the OverheadType for index overhead storage. +OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc); + /// Returns the mlir::Type for pointer overhead storage. Type getPointerOverheadType(Builder &builder, const SparseTensorEncodingAttr &enc); -- 2.34.1