From 938f419cf1910b79388896c9694e58efc5325cba Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 9 Sep 2022 17:29:16 +0000 Subject: [PATCH] [mlir][sparse] Avoid generating DimOp in conversion passes. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133592 --- .../Transforms/SparseTensorCodegen.cpp | 60 ++++++++++++++-------- .../Transforms/SparseTensorConversion.cpp | 8 +-- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 9ad37bf..4ac2d17 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -76,6 +76,31 @@ static void flattenOperands(ValueRange operands, } } +/// Gets the dimension size for the given sparse tensor at the given dim. +/// Returns None if no sparse encoding is attached to the tensor type. +static Optional sizeFromTensorAtDim(OpBuilder &rewriter, Location loc, + ShapedType tensorTp, + Value adaptedValue, unsigned dim) { + auto enc = getSparseTensorEncoding(tensorTp); + if (!enc) + return llvm::None; + + // Access into static dimension can query original type directly. + // Note that this is typically already done by DimOp's folding. + auto shape = tensorTp.getShape(); + if (!ShapedType::isDynamic(shape[dim])) + return constantIndex(rewriter, loc, shape[dim]); + + // Any other query can consult the dimSizes array at field 0 using, + // accounting for the reordering applied to the sparse storage. + auto tuple = + llvm::cast(adaptedValue.getDefiningOp()); + return rewriter + .create(loc, tuple.getInputs().front(), + constantIndex(rewriter, loc, toStored(enc, dim))) + .getResult(); +} + /// Maps a sparse tensor type to the appropriate compounded buffers. static Optional convertSparseTensorType(Type type, SmallVectorImpl &fields) { @@ -344,28 +369,17 @@ public: LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Only rewrite annotated DimOp with constant index. - auto enc = getSparseTensorEncoding(op.getSource().getType()); - if (!enc) - return failure(); Optional index = op.getConstantIndex(); if (!index) return failure(); - // Access into static dimension can query original type directly. - // Note that this is typically already done by DimOp's folding. - Location loc = op->getLoc(); - auto shape = op.getSource().getType().cast().getShape(); - if (!ShapedType::isDynamic(shape[*index])) { - rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index])); - return success(); - } - // Any other query can consult the dimSizes array at field 0 using, - // accounting for the reordering applied to the sparse storage. - auto tuple = llvm::cast( - adaptor.getSource().getDefiningOp()); - rewriter.replaceOpWithNewOp( - op, tuple.getInputs().front(), - constantIndex(rewriter, loc, toStored(enc, *index))); + auto sz = + sizeFromTensorAtDim(rewriter, op.getLoc(), + op.getSource().getType().cast(), + adaptor.getSource(), *index); + if (!sz) + return failure(); + + rewriter.replaceOp(op, *sz); return success(); } }; @@ -496,11 +510,13 @@ public: unsigned innerDim = srcType.getRank() - 1; if (AffineMap p = enc.getDimOrdering()) innerDim = p.getDimPosition(innerDim); - Value sz = rewriter.create(loc, op.getTensor(), innerDim); + auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(), + innerDim); + assert(sz); // This for sure is a sparse tensor // Generate a memref for `sz` elements of type `t`. auto genAlloc = [&](Type t) { auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t); - return rewriter.create(loc, memTp, ValueRange{sz}); + return rewriter.create(loc, memTp, ValueRange{*sz}); }; // Allocate temporary buffers for values, filled-switch, and indices. // We do not use stack buffers for this, since the expanded size may @@ -590,5 +606,5 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseToPointersConverter, SparseToIndicesConverter, SparseToValuesConverter, SparseTensorLoadConverter>( - typeConverter, patterns.getContext()); + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index f2967c3..d6fe145 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -564,7 +564,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); SmallVector sizes; SmallVector params; - sizesFromSrc(rewriter, sizes, loc, op.getSrc()); + sizesFromPtr(rewriter, sizes, loc, encSrc, srcTp, adaptor.getSrc()); newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes, adaptor.getSrc()); Value iter = genNewCall(rewriter, loc, params); @@ -1168,13 +1168,13 @@ public: // All initialization should be done on entry of the loop nest. rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); // Determine the size for access expansion (always the innermost stored - // dimension size, translated back to original dimension). Note that we - // recursively rewrite the new DimOp on the **original** tensor. + // dimension size, translated back to original dimension). auto enc = getSparseTensorEncoding(srcType); unsigned innerDim = srcType.getRank() - 1; if (AffineMap p = enc.getDimOrdering()) innerDim = p.getDimPosition(innerDim); - Value sz = rewriter.create(loc, op.getTensor(), innerDim); + auto sz = sizeFromPtrAtDim(rewriter, loc, enc, srcType, adaptor.getTensor(), + innerDim); // Allocate temporary buffers for values, filled-switch, and indices. // We do not use stack buffers for this, since the expanded size may // be rather large (as it envelops a single expanded dense dimension). -- 2.7.4