[mlir][sparse] add print methods to Merger (for debugging)
authorAart Bik <ajcbik@google.com>
Fri, 25 Jun 2021 19:02:08 +0000 (12:02 -0700)
committerAart Bik <ajcbik@google.com>
Fri, 25 Jun 2021 22:10:06 +0000 (15:10 -0700)
Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D104939

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 0ffd001..9457276 100644 (file)
 namespace mlir {
 namespace sparse_tensor {
 
+/// Tensor expression kind.
 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
+
+/// Dimension level type for a tensor (undef means index does not appear).
 enum class Dim { kSparse, kDense, kSingle, kUndef };
 
 /// Tensor expression. Represents a MLIR expression in tensor index notation.
@@ -33,10 +36,14 @@ struct TensorExp {
            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
   }
+
+  /// Tensor expression kind.
   Kind kind;
+
   /// Indices of children expression(s).
   unsigned e0;
   unsigned e1;
+
   /// Direct link to IR for an invariant. During code generation,
   /// field is used to cache "hoisted" loop invariant tensor loads.
   Value val;
@@ -50,13 +57,16 @@ struct LatPoint {
     bits.set(b);
   }
   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
+
   /// Conjunction of tensor loop indices as bitvector. This represents
   /// all indices involved in the tensor expression
   llvm::BitVector bits;
+
   /// Simplified conjunction of tensor loop indices as bitvector. This
   /// represents a simplified condition under which this tensor expression
   /// must execute. Pre-computed during codegen to avoid repeated eval.
   llvm::BitVector simple;
+
   /// Index of the tensor expresssion.
   unsigned exp;
 };
@@ -146,6 +156,14 @@ public:
   LatPoint &lat(unsigned l) { return latPoints[l]; }
   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
 
+#ifndef NDEBUG
+  /// Print methods (for debugging).
+  void dumpExp(unsigned e) const;
+  void dumpLat(unsigned p) const;
+  void dumpSet(unsigned s) const;
+  void dumpBits(const llvm::BitVector &bits) const;
+#endif
+
 private:
   const unsigned outTensor;
   const unsigned numTensors;
index 0d1d345..af864b7 100644 (file)
@@ -8,6 +8,9 @@
 
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
 
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/Debug.h"
+
 namespace mlir {
 namespace sparse_tensor {
 
@@ -134,5 +137,83 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
   return false;
 }
 
+#ifndef NDEBUG
+
+//
+// Print methods (for debugging).
+//
+
+void Merger::dumpExp(unsigned e) const {
+  switch (tensorExps[e].kind) {
+  case Kind::kTensor:
+    llvm::dbgs() << "tensor_" << tensorExps[e].e0;
+    break;
+  case Kind::kInvariant:
+    llvm::dbgs() << "invariant";
+    break;
+  default:
+  case Kind::kMulI:
+    llvm::dbgs() << "(";
+    dumpExp(tensorExps[e].e0);
+    llvm::dbgs() << " * ";
+    dumpExp(tensorExps[e].e1);
+    llvm::dbgs() << ")";
+    break;
+  case Kind::kAddF:
+  case Kind::kAddI:
+    llvm::dbgs() << "(";
+    dumpExp(tensorExps[e].e0);
+    llvm::dbgs() << " + ";
+    dumpExp(tensorExps[e].e1);
+    llvm::dbgs() << ")";
+    break;
+  }
+}
+
+void Merger::dumpLat(unsigned p) const {
+  llvm::dbgs() << "lat(";
+  dumpBits(latPoints[p].bits);
+  llvm::dbgs() << " :";
+  dumpBits(latPoints[p].simple);
+  llvm::dbgs() << " / ";
+  dumpExp(latPoints[p].exp);
+  llvm::dbgs() << " )\n";
+}
+
+void Merger::dumpSet(unsigned s) const {
+  llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
+  for (unsigned p : latSets[s]) {
+    llvm::dbgs() << "  ";
+    dumpLat(p);
+  }
+  llvm::dbgs() << "}\n";
+}
+
+void Merger::dumpBits(const llvm::BitVector &bits) const {
+  for (unsigned b = 0, be = bits.size(); b < be; b++) {
+    if (bits[b]) {
+      unsigned t = tensor(b);
+      unsigned i = index(b);
+      llvm::dbgs() << " i_" << t << "_" << i << "_";
+      switch (dims[t][i]) {
+      case Dim::kSparse:
+        llvm::dbgs() << "S";
+        break;
+      case Dim::kDense:
+        llvm::dbgs() << "D";
+        break;
+      case Dim::kSingle:
+        llvm::dbgs() << "T";
+        break;
+      case Dim::kUndef:
+        llvm::dbgs() << "U";
+        break;
+      }
+    }
+  }
+}
+
+#endif // NDEBUG
+
 } // namespace sparse_tensor
 } // namespace mlir
index d35628c..486f8ad 100644 (file)
@@ -1497,7 +1497,9 @@ gentbl_cc_library(
 
 cc_library(
     name = "SparseTensor",
-    srcs = glob(["lib/Dialect/SparseTensor/IR/*.cpp"]),
+    srcs = glob([
+        "lib/Dialect/SparseTensor/IR/*.cpp",
+    ]),
     hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"],
     includes = ["include"],
     deps = [
@@ -1505,6 +1507,22 @@ cc_library(
         ":SideEffectInterfaces",
         ":SparseTensorAttrDefsIncGen",
         ":SparseTensorOpsIncGen",
+        ":SparseTensorUtils",
+        ":StandardOps",
+        "//llvm:Support",
+    ],
+)
+
+cc_library(
+    name = "SparseTensorUtils",
+    srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]),
+    hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]),
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":SideEffectInterfaces",
+        ":SparseTensorAttrDefsIncGen",
+        ":SparseTensorOpsIncGen",
         ":StandardOps",
         "//llvm:Support",
     ],
@@ -1535,17 +1553,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "SparseTensorUtils",
-    srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]),
-    hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]),
-    includes = ["include"],
-    deps = [
-        ":IR",
-        "//llvm:Support",
-    ],
-)
-
 td_library(
     name = "StdOpsTdFiles",
     srcs = [