[mlir][sparse] add some APIs for merger to query the tensor id for output tensor...
authorPeiming Liu <peiming@google.com>
Mon, 24 Oct 2022 18:41:19 +0000 (18:41 +0000)
committerPeiming Liu <peiming@google.com>
Mon, 24 Oct 2022 18:49:53 +0000 (18:49 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136630

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

index 52456a2b06d5bd6aaa3117781fcbc501e44ef67b..7d1b770ca55f71c5d1cf109617f5594d9feb32b3 100644 (file)
@@ -221,8 +221,9 @@ public:
   /// Returns true if Li and Lj only differ in dense.
   bool onlyDenseDiff(unsigned i, unsigned j);
 
-  /// Bit translation.
+  /// Bit translation (get tensor ID).
   unsigned tensor(unsigned b) const { return b % numTensors; }
+  /// Bit translation (get loop index).
   unsigned index(unsigned b) const { return b / numTensors; }
 
   /// Returns true if bit corresponds to index of output tensor.
@@ -230,6 +231,12 @@ public:
     return tensor(b) == outTensor && index(b) == i;
   }
 
+  /// Gets tensor ID for the output tensor.
+  unsigned getOutTensorID() const { return outTensor; }
+  /// Gets tensor ID for the synthetic tensor (used for all invariant tensor
+  /// expressions).
+  unsigned getSynTensorID() const { return syntheticTensor; }
+
   /// Returns true if given tensor iterates *only* in the given tensor
   /// expression. For the output tensor, this defines a "simply dynamic"
   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
index 0e0882fa9fc1857eb2bbd35000ebe83f59637631..0ce2740485956ccb2dfaa606b762b1f8b45dee4e 100644 (file)
@@ -309,6 +309,8 @@ protected:
   const unsigned l0 = 0;
 
   MergerTest3T1L() : MergerTestBase(3, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t2);
+
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -332,6 +334,8 @@ protected:
   const unsigned l0 = 0;
 
   MergerTest4T1L() : MergerTestBase(4, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t3);
+
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -363,6 +367,8 @@ protected:
   const unsigned l0 = 0;
 
   MergerTest3T1LD() : MergerTestBase(3, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t2);
+
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
@@ -383,13 +389,15 @@ protected:
 
 class MergerTest4T1LU : public MergerTestBase {
 protected:
-  // Our three tensors (two inputs, one output).
+  // Our three tensors (three inputs, one output).
   const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
 
   // Our single loop.
   const unsigned l0 = 0;
 
   MergerTest4T1LU() : MergerTestBase(4, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t3);
+
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
     merger.setDimLevelType(t0, l0, DimLevelType::Undef);
@@ -421,6 +429,9 @@ protected:
   const unsigned l0 = 0;
 
   MergerTest3T1L_SO() : MergerTestBase(3, 1) {
+    EXPECT_TRUE(merger.getOutTensorID() == t2);
+    EXPECT_TRUE(merger.getSynTensorID() == t3);
+
     merger.setHasSparseOut(true);
 
     // Tensor 0: undef input vector.