From d7d7ffe254d53cf0860126ab4c3f5db18c927892 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Mon, 22 Nov 2021 13:14:17 -0800 Subject: [PATCH] [mlir][sparse] Adding wrappers for constantOverheadTypeEncoding Minor code cleanup Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D114392 --- .../Transforms/SparseTensorConversion.cpp | 28 +++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 3633ff0..88c0561 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -85,6 +85,22 @@ static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter, return constantI32(rewriter, loc, static_cast(sec)); } +/// Generates a constant of the internal type encoding for pointer +/// overhead storage. +static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter, + Location loc, + SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()); +} + +/// Generates a constant of the internal type encoding for index overhead +/// storage. +static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter, + Location loc, + SparseTensorEncodingAttr &enc) { + return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()); +} + /// Generates a constant of the internal type encoding for primary storage. static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter, Location loc, Type tp) { @@ -277,10 +293,8 @@ static void newParams(ConversionPatternRewriter &rewriter, params.push_back(genBuffer(rewriter, loc, rev)); // Secondary and primary types encoding. ShapedType resType = op->getResult(0).getType().cast(); - params.push_back( - constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth())); - params.push_back( - constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth())); + params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); + params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); params.push_back( constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); // User action and pointer. @@ -598,10 +612,8 @@ class SparseTensorConvertConverter : public OpConversionPattern { encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src); Value coo = genNewCall(rewriter, op, params); - params[3] = constantOverheadTypeEncoding(rewriter, loc, - encDst.getPointerBitWidth()); - params[4] = constantOverheadTypeEncoding(rewriter, loc, - encDst.getIndexBitWidth()); + params[3] = constantPointerTypeEncoding(rewriter, loc, encDst); + params[4] = constantIndexTypeEncoding(rewriter, loc, encDst); params[6] = constantAction(rewriter, loc, Action::kFromCOO); params[7] = coo; rewriter.replaceOp(op, genNewCall(rewriter, op, params)); -- 2.7.4