[mlir][sparse] minor merger code cleanup
authorAart Bik <ajcbik@google.com>
Mon, 19 Dec 2022 20:55:22 +0000 (12:55 -0800)
committerAart Bik <ajcbik@google.com>
Mon, 19 Dec 2022 21:31:25 +0000 (13:31 -0800)
Moved larger constructor from header to CPP file.
Used toMLIRString() for DimLvlType debug.
Minor layout changes.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140342

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

index 1ff9f69..9b3b695 100644 (file)
@@ -150,33 +150,28 @@ class Merger {
 public:
   /// Constructs a merger for the given number of tensors, native loops, and
   /// filter loops. The user supplies the number of tensors involved in the
-  /// kernel, with the last tensor in this set denoting the output tensor. The
-  /// merger adds an additional synthetic tensor at the end of this set to
-  /// represent all invariant expressions in the kernel.
-  /// In addition to natives
-  /// loops (which are specified by the GenericOp), extra filter loops are
-  /// needed in order to handle affine expressions on sparse dimensions.
-  /// E.g., (d0, d1, d2) => (d0 + d1, d2), a naive implementation of the filter
-  /// loop could be generated as:
+  /// kernel, with the last tensor in this set denoting the output tensor.
+  /// The merger adds an additional synthetic tensor at the end of this set
+  /// to represent all invariant expressions in the kernel.
+  ///
+  /// In addition to natives loops (which are specified by the GenericOp),
+  /// extra filter loops are needed in order to handle affine expressions on
+  /// sparse dimensions. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive
+  /// implementation of the filter loop could be generated as
+  ///
   /// for (coord : sparse_dim[0])
   ///   if (coord == d0 + d1) {
   ///      generated_code;
   ///   }
   /// }
-  /// to filter out coordinates that are not equal to the affine expression
-  /// result.
+  ///
+  /// to filter out coordinates that are not equal to the affine expression.
+  ///
   /// TODO: we want to make the filter loop more efficient in the future, e.g.,
   /// by avoiding scanning the full stored index sparse (keeping the last
   /// position in ordered list) or even apply binary search to find the index.
-  Merger(unsigned t, unsigned l, unsigned fl)
-      : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1),
-        numNativeLoops(l), numLoops(l + fl), hasSparseOut(false),
-        dimTypes(numTensors,
-                 std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
-        loopIdxToDim(numTensors,
-                     std::vector<Optional<unsigned>>(numLoops, std::nullopt)),
-        dimToLoopIdx(numTensors,
-                     std::vector<Optional<unsigned>>(numLoops, std::nullopt)) {}
+  ///
+  Merger(unsigned t, unsigned l, unsigned fl);
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@@ -386,14 +381,18 @@ private:
   const unsigned numNativeLoops;
   const unsigned numLoops;
   bool hasSparseOut;
+
   // Map that converts pair<tensor id, loop id> to the corresponding dimension
   // level type.
   std::vector<std::vector<DimLevelType>> dimTypes;
+
   // Map that converts pair<tensor id, loop id> to the corresponding
   // dimension.
   std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
+
   // Map that converts pair<tensor id, dim> to the corresponding loop id.
   std::vector<std::vector<Optional<unsigned>>> dimToLoopIdx;
+
   llvm::SmallVector<TensorExp> tensorExps;
   llvm::SmallVector<LatPoint> latPoints;
   llvm::SmallVector<SmallVector<unsigned>> latSets;
index b9f8d26..f82ff44 100644 (file)
@@ -205,6 +205,16 @@ LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
 
 LatPoint::LatPoint(const BitVector &b, unsigned e) : bits(b), exp(e) {}
 
+Merger::Merger(unsigned t, unsigned l, unsigned fl)
+    : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1),
+      numNativeLoops(l), numLoops(l + fl), hasSparseOut(false),
+      dimTypes(numTensors,
+               std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
+      loopIdxToDim(numTensors,
+                   std::vector<Optional<unsigned>>(numLoops, std::nullopt)),
+      dimToLoopIdx(numTensors,
+                   std::vector<Optional<unsigned>>(numLoops, std::nullopt)) {}
+
 //===----------------------------------------------------------------------===//
 // Lattice methods.
 //===----------------------------------------------------------------------===//
@@ -740,17 +750,7 @@ void Merger::dumpBits(const BitVector &bits) const {
       unsigned t = tensor(b);
       unsigned i = index(b);
       DimLevelType dlt = dimTypes[t][i];
-      llvm::dbgs() << " i_" << t << "_" << i << "_";
-      if (isDenseDLT(dlt))
-        llvm::dbgs() << "D";
-      else if (isCompressedDLT(dlt))
-        llvm::dbgs() << "C";
-      else if (isSingletonDLT(dlt))
-        llvm::dbgs() << "S";
-      else if (isUndefDLT(dlt))
-        llvm::dbgs() << "U";
-      llvm::dbgs() << "[O=" << isOrderedDLT(dlt) << ",U=" << isUniqueDLT(dlt)
-                   << "]";
+      llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(dlt);
     }
   }
 }