[mlir][sparse] Adding {pointer,index}OverheadTypeEncoding
authorwren romano <2998727+wrengr@users.noreply.github.com>
Mon, 21 Mar 2022 23:49:54 +0000 (16:49 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 23 Mar 2022 19:04:47 +0000 (12:04 -0700)
Work towards: https://github.com/llvm/llvm-project/issues/51652

The new functions fill the gap between `overheadTypeEncoding` and `get{Pointer,Index}OverheadType`.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122056

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

index ea9be3bddb5479e45de8033ca304dba46c450346..38f15a1ee7c0111770f34a086d6e30f851f894df 100644 (file)
@@ -58,15 +58,24 @@ Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
   llvm_unreachable("Unknown OverheadType");
 }
 
+OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding(
+    const SparseTensorEncodingAttr &enc) {
+  return overheadTypeEncoding(enc.getPointerBitWidth());
+}
+
+OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding(
+    const SparseTensorEncodingAttr &enc) {
+  return overheadTypeEncoding(enc.getIndexBitWidth());
+}
+
 Type mlir::sparse_tensor::getPointerOverheadType(
     Builder &builder, const SparseTensorEncodingAttr &enc) {
-  return getOverheadType(builder,
-                         overheadTypeEncoding(enc.getPointerBitWidth()));
+  return getOverheadType(builder, pointerOverheadTypeEncoding(enc));
 }
 
 Type mlir::sparse_tensor::getIndexOverheadType(
     Builder &builder, const SparseTensorEncodingAttr &enc) {
-  return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
+  return getOverheadType(builder, indexOverheadTypeEncoding(enc));
 }
 
 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
index 9286cca808aa8daafbc4879bcd75f4d2e3d454e9..944605bab4b9fadb0e5158480eb5f7207bd57a46 100644 (file)
@@ -38,6 +38,12 @@ OverheadType overheadTypeEncoding(Type tp);
 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
 Type getOverheadType(Builder &builder, OverheadType ot);
 
+/// Returns the OverheadType for pointer overhead storage.
+OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
+
+/// Returns the OverheadType for index overhead storage.
+OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
+
 /// Returns the mlir::Type for pointer overhead storage.
 Type getPointerOverheadType(Builder &builder,
                             const SparseTensorEncodingAttr &enc);