From d950bdc73eb23a79cd4cf35fd4c8cb198e00b2d0 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 15 Feb 2023 13:28:11 -0800 Subject: [PATCH] [mlir][sparse] misc code cleanup * Flattening/simplifying some nested conditionals * const-ifying some local variables Depends On D143800 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143949 --- .../Transforms/SparseTensorConversion.cpp | 22 +++++++-------- .../Transforms/SparseTensorRewriting.cpp | 33 ++++++++++------------ 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index cfd7bca..7622554 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -1414,20 +1414,18 @@ public: createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim); offset = rewriter.create(loc, offset, curDim); } - if (dstTp.hasEncoding()) { - if (!allDense) { - // In sparse output case, the destination holds the COO. - Value coo = dst; - dst = params.genNewCall(Action::kFromCOO, coo); - // Release resources. - genDelCOOCall(rewriter, loc, elemTp, coo); - } else { - dst = dstTensor; - } - rewriter.replaceOp(op, dst); - } else { + if (!dstTp.hasEncoding()) { rewriter.replaceOpWithNewOp( op, dstTp.getRankedTensorType(), dst); + } else if (allDense) { + rewriter.replaceOp(op, dstTensor); + } else { + // In sparse output case, the destination holds the COO. + Value coo = dst; + dst = params.genNewCall(Action::kFromCOO, coo); + // Release resources. + genDelCOOCall(rewriter, loc, elemTp, coo); + rewriter.replaceOp(op, dst); } return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 7046306..3506494 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -513,13 +513,11 @@ struct ConcatenateRewriter : public OpRewritePattern { } needTmpCOO = !allDense && !allOrdered; + const RankedTensorType tp = needTmpCOO ? getUnorderedCOOFromType(dstTp) + : dstTp.getRankedTensorType(); + encDst = needTmpCOO ? getSparseTensorEncoding(tp) : encDst; SmallVector dynSizes; getDynamicSizes(dstTp, sizes, dynSizes); - RankedTensorType tp = dstTp; - if (needTmpCOO) { - tp = getUnorderedCOOFromType(dstTp); - encDst = getSparseTensorEncoding(tp); - } dst = rewriter.create(loc, tp, dynSizes).getResult(); if (allDense) { // Create a view of the values buffer to match the unannotated dense @@ -592,21 +590,20 @@ struct ConcatenateRewriter : public OpRewritePattern { // Temp variable to avoid needing to call `getRankedTensorType` // in the three use-sites below. const RankedTensorType dstRTT = dstTp; - if (encDst) { - if (!allDense) { - dst = rewriter.create(loc, dst, true); - if (needTmpCOO) { - Value tmpCoo = dst; - dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); - rewriter.create(loc, tmpCoo); - } - } else { - dst = rewriter.create(loc, dstRTT, annotatedDenseDst) - .getResult(); + if (!encDst) { + rewriter.replaceOpWithNewOp(op, dstRTT, dst); + } else if (allDense) { + rewriter.replaceOp( + op, rewriter.create(loc, dstRTT, annotatedDenseDst) + .getResult()); + } else { + dst = rewriter.create(loc, dst, true); + if (needTmpCOO) { + Value tmpCoo = dst; + dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); + rewriter.create(loc, tmpCoo); } rewriter.replaceOp(op, dst); - } else { - rewriter.replaceOpWithNewOp(op, dstRTT, dst); } return success(); } -- 2.7.4