From 96fef4dc3c9313d476fafe96d4c380988cf0ecda Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Thu, 15 Dec 2022 14:32:58 -0800 Subject: [PATCH] [mlir][sparse] Added new SparseTensorEncodingAttr::withoutOrdering factory Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D140171 --- .../mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td | 4 ++++ mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp | 6 ++++++ .../SparseTensor/Transforms/SparseTensorConversion.cpp | 12 +++--------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index 47993b0..70c74bf 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -169,6 +169,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", /// Returns the type for index storage based on indexBitWidth Type getIndexType() const; + + /// Constructs a new encoding with the dimOrdering and higherOrdering + /// reset to the default/identity. + SparseTensorEncodingAttr withoutOrdering() const; }]; let genVerifyDecl = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 06a451f..aecde7d 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -57,6 +57,12 @@ Type SparseTensorEncodingAttr::getIndexType() const { return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType; } +SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const { + return SparseTensorEncodingAttr::get( + getContext(), getDimLevelType(), AffineMap(), AffineMap(), + getPointerBitWidth(), getIndexBitWidth()); +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 4f4dbe4..b4d1491 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -529,9 +529,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, assert(elemTp == dstTp.getElementType() && "reshape should not change element type"); // Start an iterator over the source tensor (in original index order). - auto noPerm = SparseTensorEncodingAttr::get( - op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(), - encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + const auto noPerm = encSrc.withoutOrdering(); SmallVector srcDimSizes = getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc()); NewCallParams params(rewriter, loc); @@ -596,9 +594,7 @@ static void genSparseCOOIterationLoop( Type elemTp = tensorTp.getElementType(); // Start an iterator over the tensor (in original index order). - auto noPerm = SparseTensorEncodingAttr::get( - rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(), - enc.getPointerBitWidth(), enc.getIndexBitWidth()); + const auto noPerm = enc.withoutOrdering(); SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t); Value iter = NewCallParams(rewriter, loc) .genBuffers(noPerm, dimSizes, tensorTp) @@ -1485,9 +1481,7 @@ public: auto encSrc = getSparseTensorEncoding(srcType); SmallVector dimSizes = getDimSizes(rewriter, loc, encSrc, srcType, src); - auto enc = SparseTensorEncodingAttr::get( - op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(), - encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + const auto enc = encSrc.withoutOrdering(); Value coo = NewCallParams(rewriter, loc) .genBuffers(enc, dimSizes, srcType) .genNewCall(Action::kToCOO, src); -- 2.7.4