[mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId`
authorwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 14 Mar 2023 20:00:29 +0000 (13:00 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 15 Mar 2023 19:27:47 +0000 (12:27 -0700)
Since all callsites of `foreachTensorLoopId` would simply look up the `LatPointId` to extract its `BitVector`, it's cleaner to let the `Merger` handle that instead.  This seems to better capture the intent of the `foreachTensorLoopId` method, and improves decoupling (since it removes a place that leaks the implementation detail that we use `BitVector`).

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146082

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

index 3c5d2d3..59c5b78 100644 (file)
@@ -437,11 +437,11 @@ public:
   /// for each `TensorLoopId` and passing it the corresponding tensor
   /// identifier, level, and level-type.
   void
-  foreachTensorLoopId(const BitVector &bits,
+  foreachTensorLoopId(LatPointId p,
                       function_ref<void(TensorLoopId, TensorId,
                                         std::optional<Level>, DimLevelType)>
                           callback) const {
-    for (const TensorLoopId b : bits.set_bits())
+    for (const TensorLoopId b : latPoints[p].bits.set_bits())
       callback(b, tensor(b), getLvl(b), getDimLevelType(b));
   }
 
index 9fedd5a..2779e2d 100644 (file)
@@ -1273,18 +1273,18 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
 
   SmallVector<TensorId> tids;
   SmallVector<Level> lvls;
-  env.merger().foreachTensorLoopId(
-      env.lat(l0).bits, [&](TensorLoopId b, TensorId tid,
-                            std::optional<Level> lvl, DimLevelType dlt) {
-        assert(env.merger().loop(b) == idx);
-        if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
-          needsUniv = true;
-        } else {
-          // sparse/singleton levels.
-          tids.push_back(tid);
-          lvls.push_back(*lvl);
-        }
-      });
+  env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
+                                           std::optional<Level> lvl,
+                                           DimLevelType dlt) {
+    assert(env.merger().loop(b) == idx);
+    if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+      needsUniv = true;
+    } else {
+      // sparse/singleton levels.
+      tids.push_back(tid);
+      lvls.push_back(*lvl);
+    }
+  });
 
   env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls);
 
@@ -1342,7 +1342,6 @@ static bool translateBitsToTidLvlPairs(
     CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl<TensorId> &tids,
     SmallVectorImpl<Level> &lvls, SmallVectorImpl<TensorId> &affineTids,
     SmallVectorImpl<Level> &affineLvls, SmallVectorImpl<AffineExpr> &exps) {
-  const BitVector &all = env.lat(li).bits;
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
   const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
@@ -1350,8 +1349,8 @@ static bool translateBitsToTidLvlPairs(
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
   env.merger().foreachTensorLoopId(
-      all, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
-                    DimLevelType dlt) {
+      li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
+                   DimLevelType dlt) {
         if (simple.test(b)) {
           if (isUndefDLT(dlt)) {
             // An undefined dlt in the lattices, we probably mean to