From bb4fc6b6d6b41da9985db0f9b294189e25da4a72 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 15 Feb 2023 12:03:52 -0800 Subject: [PATCH] [mlir][sparse] Adding `SparseTensorType::{operator==, hasSameDimToLvlMap}` Depends On D143800 Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D144052 --- .../Dialect/SparseTensor/IR/SparseTensorType.h | 23 ++++++++++++++++++++++ .../Transforms/SparseTensorRewriting.cpp | 13 ++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index ba31fa7..4eeaa39 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -105,6 +105,16 @@ public: /// implicit conversion. RankedTensorType getRankedTensorType() const { return rtp; } + bool operator==(const SparseTensorType &other) const { + // All other fields are derived from `rtp` and therefore don't need + // to be checked. + return rtp == other.rtp; + } + + bool operator!=(const SparseTensorType &other) const { + return !(*this == other); + } + MLIRContext *getContext() const { return rtp.getContext(); } Type getElementType() const { return rtp.getElementType(); } @@ -130,6 +140,8 @@ public: bool isIdentity() const { return !dim2lvl; } /// Returns the dimToLvl mapping (or the null-map for the identity). + /// If you intend to compare the results of this method for equality, + /// see `hasSameDimToLvlMap` instead. AffineMap getDimToLvlMap() const { return dim2lvl; } /// Returns the dimToLvl mapping, where the identity map is expanded out @@ -142,6 +154,17 @@ public: : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext()); } + /// Returns true iff the two types have the same mapping. This method + /// takes care to handle identity maps properly, so it should be preferred + /// over using `getDimToLvlMap` followed by `AffineMap::operator==`. + bool hasSameDimToLvlMap(const SparseTensorType &other) const { + // If the maps are the identity, then we need to check the rank + // to be sure they're the same size identity. (And since identity + // means dimRank==lvlRank, we use lvlRank as a minor optimization.) + return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank) + : (dim2lvl == other.dim2lvl); + } + /// Returns the dimension-rank. Dimension getDimRank() const { return rtp.getRank(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 9ec1d7c..7046306 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -48,12 +48,6 @@ static bool isSparseTensor(OpOperand *op) { llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed); } -static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) { - assert(rtp1.getRank() == rtp2.getRank()); - return SparseTensorType(rtp1).getDimToLvlMap() == - SparseTensorType(rtp2).getDimToLvlMap(); -} - // Helper method to find zero/uninitialized allocation. static bool isAlloc(OpOperand *op, bool isZero) { Value val = op->get(); @@ -796,8 +790,9 @@ private: // 2. the src tensor is not ordered in the same way as the target // tensor (e.g., src tensor is not ordered or src tensor haves a different // dimOrdering). - if (!isUniqueCOOType(srcRTT) && !(SparseTensorType(srcRTT).isAllOrdered() && - hasSameDimOrdering(srcRTT, dstTp))) { + if (const SparseTensorType srcTp(srcRTT); + !isUniqueCOOType(srcRTT) && + !(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) { // Construct a COO tensor from the src tensor. // TODO: there may be cases for which more efficiently without // going through an intermediate COO, such as cases that only change @@ -841,7 +836,7 @@ private: // Sort the COO tensor so that its elements are ordered via increasing // indices for the storage ordering of the dst tensor. Use SortCoo if the // COO tensor has the same dim ordering as the dst tensor. - if (dimRank > 1 && hasSameDimOrdering(srcTp, dstTp)) { + if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) { MemRefType indTp = get1DMemRefType(getIndexOverheadType(rewriter, encSrc), /*withLayout=*/false); -- 2.7.4