[mlir][sparse] Adding new `Merger::addLat` overload
authorwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 21 Mar 2023 20:13:42 +0000 (13:13 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 21 Mar 2023 23:22:04 +0000 (16:22 -0700)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146559

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

index 6e39404..991c920 100644 (file)
@@ -280,6 +280,7 @@ public:
 
   /// Constructs a new iteration lattice point, and returns its identifier.
   LatPointId addLat(TensorId t, LoopId i, ExprId e);
+  LatPointId addLat(const BitVector &bits, ExprId e);
 
   /// Constructs a new (initially empty) set, and returns its identifier.
   LatSetId addSet();
index 4a8c3cb..0691d25 100644 (file)
@@ -247,6 +247,13 @@ LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
   return p;
 }
 
+LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
+  assert(bits.size() == numLoops * numTensors);
+  const LatPointId p = latPoints.size();
+  latPoints.emplace_back(bits, e);
+  return p;
+}
+
 LatSetId Merger::addSet() {
   const LatSetId s = latSets.size();
   latSets.emplace_back();
@@ -322,8 +329,7 @@ LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
   const LatSetId s = addSet();
   for (const LatPointId p : latSets[s0]) {
     const ExprId e = addExp(kind, latPoints[p].exp, v, op);
-    latPoints.emplace_back(latPoints[p].bits, e);
-    latSets[s].push_back(latPoints.size() - 1);
+    latSets[s].push_back(addLat(latPoints[p].bits, e));
   }
   return s;
 }