[mlir][sparse] minor merger API simplification
authorAart Bik <ajcbik@google.com>
Wed, 14 Sep 2022 00:10:42 +0000 (17:10 -0700)
committerAart Bik <ajcbik@google.com>
Wed, 14 Sep 2022 01:07:24 +0000 (18:07 -0700)
Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133821

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

index 584df84..f9be476 100644 (file)
@@ -261,8 +261,8 @@ public:
     return getDimLevelFormat(t, i).levelType == tp;
   }
 
-  /// Returns true if any set bit corresponds to given dimension level type.
-  bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const;
+  /// Returns true if any set bit corresponds to sparse dimension level type.
+  bool hasAnySparse(const BitVector &bits) const;
 
   /// Dimension level format getter.
   DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {
index 642c666..918018a 100644 (file)
@@ -1641,8 +1641,7 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     unsigned lsize = merger.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       unsigned li = merger.set(lts)[i];
-      if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) &&
-          !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton))
+      if (!merger.hasAnySparse(merger.lat(li).simple))
         return true;
     }
   }
index cd5d134..bd8c8f2 100644 (file)
@@ -262,13 +262,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
     }
   }
   // Now apply the two basic rules.
-  //
-  // TODO: improve for singleton and properties
-  //
   BitVector simple = latPoints[p0].bits;
-  bool reset = isSingleton &&
-      (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) ||
-       hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton));
+  bool reset = isSingleton && hasAnySparse(simple);
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
     if (simple[b] &&
         (!isDimLevelType(b, DimLvlType::kCompressed) &&
@@ -297,8 +292,7 @@ bool Merger::latGT(unsigned i, unsigned j) const {
 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
   BitVector tmp = latPoints[j].bits;
   tmp ^= latPoints[i].bits;
-  return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) &&
-         !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton);
+  return !hasAnySparse(tmp);
 }
 
 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
@@ -384,9 +378,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   llvm_unreachable("unexpected kind");
 }
 
-bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const {
+bool Merger::hasAnySparse(const BitVector &bits) const {
   for (unsigned b = 0, be = bits.size(); b < be; b++)
-    if (bits[b] && isDimLevelType(b, tp))
+    if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) ||
+                    isDimLevelType(b, DimLvlType::kSingleton)))
       return true;
   return false;
 }