[mlir][sparse] Make getNumTensors() consistent between LoopEmitter and Merger.
authorPeiming Liu <peiming@google.com>
Mon, 5 Jun 2023 17:26:40 +0000 (17:26 +0000)
committerPeiming Liu <peiming@google.com>
Mon, 5 Jun 2023 17:56:08 +0000 (17:56 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D152178

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

index 3186889..dea9e74 100644 (file)
@@ -89,9 +89,11 @@ public:
 
   TensorLevel makeTensorLevel(TensorId t, Level l) const {
     // Make sure LoopEmitter, GenericOp, and Merger agree on the number of
-    // tensors. Merger has one more synthetic tensor for loop invariants.
-    assert(loopEmitter.getNumTensors() == linalgOp->getNumOperands() &&
-           loopEmitter.getNumTensors() == latticeMerger.getNumTensors() - 1);
+    // tensors.
+    assert(loopEmitter.getNumManifestTensors() == linalgOp->getNumOperands() &&
+           loopEmitter.getNumTensors() == latticeMerger.getNumTensors() &&
+           loopEmitter.getOutTensorId() == latticeMerger.getOutTensorID() &&
+           loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
     return loopEmitter.makeTensorLevel(t, l);
   }
   std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
index 67f3c30..6a639ef 100644 (file)
@@ -235,8 +235,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   const unsigned numManifestTensors = ts.size();
   const unsigned synTensorId = numManifestTensors;
   const unsigned numTensors = numManifestTensors + 1;
-
+  // tensors array (len == numManifestTensor).
   this->tensors.assign(ts.begin(), ts.end());
+  // Arrays with len == numTensor.
   this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
   this->lvlSizes.assign(numTensors, std::vector<Value>());
   this->highs.assign(numTensors, std::vector<Value>());
@@ -355,13 +356,14 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
 
 void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
                                      LoopEmitter::OutputUpdater updater) {
-  // For every tensor:
+  // For every manifest tensor:
   // * get the values buffer.
   // * For every level:
   //   * get the positions and coordinates buffers
   //   * get/compute the level-size, which is also used as the upper-bound
   //     on positions.
-  for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) {
+  for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
+       t++) {
     const Value tensor = tensors[t];
     const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
     if (!rtp)
index ded58f2..8fa7912 100644 (file)
@@ -191,21 +191,33 @@ public:
     return n < getCurrentDepth() ? loopStack[n].iv : Value();
   }
 
+  /// Gets the total number of manifest tensors (excluding the synthetic
+  /// tensor).
+  unsigned getNumManifestTensors() const { return tensors.size(); }
+
   /// Gets the total number of tensors that loopEmitter is operating on.
-  unsigned getNumTensors() const { return tensors.size(); }
+  unsigned getNumTensors() const {
+    // Manifest tensors with one synthetic tensor at the end.
+    return getNumManifestTensors() + 1;
+  }
 
   /// Gets the TensorId for synthetic tensor.
   TensorId getSynTensorId() const { return tensors.size(); }
 
+  /// Gets the TensorId for output tensor.
+  TensorId getOutTensorId() const {
+    assert(hasOutput);
+    return getNumManifestTensors() - 1;
+  }
+
   /// Compresses a TensorId and Level into a TensorLevel.
   TensorLevel makeTensorLevel(TensorId t, Level l) const {
-    // TODO: getNumTensor() should include synthetic tensor.
-    return l * (getNumTensors() + 1) + t;
+    return l * getNumTensors() + t;
   }
 
   /// De-compresses a TensorLevel back to a pair of TensorId and Level.
   std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
-    unsigned nt = getNumTensors() + 1;
+    unsigned nt = getNumTensors();
     return std::make_pair(tidLvl % nt, tidLvl / nt);
   }
 
@@ -323,10 +335,10 @@ private:
                                                  Location loc, Value crd,
                                                  TensorId tid, Level lvl);
 
-  bool isSynTensor(TensorId tid) const { return tid == getNumTensors(); }
+  bool isSynTensor(TensorId tid) const { return tid == getSynTensorId(); }
 
   bool isOutputTensor(TensorId tid) const {
-    return hasOutput && tid == getNumTensors() - 1;
+    return hasOutput && tid == getOutTensorId();
   }
 
   bool isSparseOutput(TensorId tid) const {
@@ -414,8 +426,8 @@ private:
   /// TODO: why not do this computation when we first store the reassoc,
   /// instead of doing it every time we look it up?
   SmallVector<Level, 2> getCollapseReassociation(TensorId tid, Level dstLvl) {
-    assert(tid < getNumTensors() + 1 && "Invalid TensorId");
-    assert(collapseReassoc.size() == getNumTensors() + 1);
+    assert(tid < getNumTensors() && "Invalid TensorId");
+    assert(collapseReassoc.size() == getNumTensors());
     if (const auto reassoc = collapseReassoc[tid]) {
       assert(!isSynTensor(tid) && !isOutputTensor(tid) &&
              "Output/Synthetic tensor should not have reassociation");