[mlir][sparse] Adding wrappers for constantOverheadTypeEncoding
authorwren romano <2998727+wrengr@users.noreply.github.com>
Mon, 22 Nov 2021 21:14:17 +0000 (13:14 -0800)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 24 Nov 2021 02:30:06 +0000 (18:30 -0800)
Minor code cleanup

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D114392

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index 3633ff0..88c0561 100644 (file)
@@ -85,6 +85,22 @@ static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
   return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
 }
 
+/// Generates a constant of the internal type encoding for pointer
+/// overhead storage.
+static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter,
+                                         Location loc,
+                                         SparseTensorEncodingAttr &enc) {
+  return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth());
+}
+
+/// Generates a constant of the internal type encoding for index overhead
+/// storage.
+static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter,
+                                       Location loc,
+                                       SparseTensorEncodingAttr &enc) {
+  return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth());
+}
+
 /// Generates a constant of the internal type encoding for primary storage.
 static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
                                          Location loc, Type tp) {
@@ -277,10 +293,8 @@ static void newParams(ConversionPatternRewriter &rewriter,
   params.push_back(genBuffer(rewriter, loc, rev));
   // Secondary and primary types encoding.
   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
-  params.push_back(
-      constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
-  params.push_back(
-      constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
+  params.push_back(constantPointerTypeEncoding(rewriter, loc, enc));
+  params.push_back(constantIndexTypeEncoding(rewriter, loc, enc));
   params.push_back(
       constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
   // User action and pointer.
@@ -598,10 +612,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
       newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
       Value coo = genNewCall(rewriter, op, params);
-      params[3] = constantOverheadTypeEncoding(rewriter, loc,
-                                               encDst.getPointerBitWidth());
-      params[4] = constantOverheadTypeEncoding(rewriter, loc,
-                                               encDst.getIndexBitWidth());
+      params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
+      params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
       params[6] = constantAction(rewriter, loc, Action::kFromCOO);
       params[7] = coo;
       rewriter.replaceOp(op, genNewCall(rewriter, op, params));