From f2696e469a5ca1fa3efeebef56e77507e73b5047 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 29 Mar 2023 18:36:24 -0700 Subject: [PATCH] [mlir][sparse] Cleaning up some usage of SparseTensorType This is a followup to D147192. Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D147196 --- .../Transforms/SparseTensorRewriting.cpp | 28 ++++++++++------------ 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index dc5755b..52281bf 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -356,16 +356,10 @@ public: PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value srcTensor = op.getSrc(); - auto srcTp = getRankedTensorType(srcTensor); - auto dstTp = getRankedTensorType(op.getResult()); - - SparseTensorType srcStt(srcTp); - SparseTensorType dstStt(dstTp); - - const auto encSrc = srcStt.getEncoding(); - if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) { + const auto srcTp = getSparseTensorType(srcTensor); + const auto dstTp = getSparseTensorType(op.getResult()); + if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) return failure(); - } // Generate code to represent the static dimension constants or compute // the dynamic dimension values. @@ -373,11 +367,11 @@ public: sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor); SmallVector dstSizes; SmallVector dstDynSizes; - if (dstTp.hasStaticShape()) { - for (auto d : dstTp.getShape()) + if (dstTp.hasStaticDimShape()) { + for (Dimension d : dstTp.getDimShape()) dstSizes.push_back(constantIndex(rewriter, loc, d)); } else { - ArrayRef dstShape = dstTp.getShape(); + ArrayRef dstShape = dstTp.getDimShape(); genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape, op.getReassociationIndices()); for (auto [idx, shape] : llvm::enumerate(dstShape)) { @@ -389,8 +383,8 @@ public: // Only need a unordered COO buffer if input and output are not sorted // in the same way. Type bufferTp = - srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity() - ? dstTp + srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity() + ? dstTp.getRankedTensorType() : getUnorderedCOOFromType(dstTp); Value buffer = @@ -406,11 +400,12 @@ public: // followed by an optional // %t = sparse_tensor.cast %tmp // depending on whether the input/output are sorted in the same way. + const auto encSrc = srcTp.getEncoding(); ForeachOp foreachOp = rewriter.create( loc, srcTensor, buffer, [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v, ValueRange reduc) { - const Dimension dimRank = srcTp.getRank(); + const Dimension dimRank = srcTp.getDimRank(); SmallVector srcDcvs; srcDcvs.reserve(dimRank); for (Dimension d = 0; d < dimRank; d++) { @@ -427,7 +422,8 @@ public: Value t = rewriter.create(loc, foreachOp.getResult(0), true); if (bufferTp != dstTp) { - Value converted = rewriter.create(loc, dstTp, t).getResult(); + auto dstRTT = dstTp.getRankedTensorType(); + Value converted = rewriter.create(loc, dstRTT, t).getResult(); rewriter.create(loc, t); t = converted; } -- 2.7.4