[mlir][sparse] code refactoring, move <tid, loop id> -> dim map to Merger.
authorPeiming Liu <peiming@google.com>
Wed, 26 Oct 2022 19:07:25 +0000 (19:07 +0000)
committerPeiming Liu <peiming@google.com>
Thu, 27 Oct 2022 21:01:06 +0000 (21:01 +0000)
To address unresolved comments in D136185

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136780

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

index 7d1b770..ebc33f3 100644 (file)
@@ -156,7 +156,8 @@ public:
   Merger(unsigned t, unsigned l)
       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
         hasSparseOut(false),
-        dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)) {}
+        dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)),
+        loopIdxToDim(t + 1, std::vector<Optional<unsigned>>(l, llvm::None)) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@@ -246,7 +247,7 @@ public:
   /// Returns true if any set bit corresponds to sparse dimension level type.
   bool hasAnySparse(const BitVector &bits) const;
 
-  /// Gets the dimension level type of the `i`th loop of the `t`th tensor.
+  /// Gets the dimension level type of the `t`th tensor on `i`th loop.
   DimLevelType getDimLevelType(unsigned t, unsigned i) const {
     assert(t < numTensors && i < numLoops);
     return dimTypes[t][i];
@@ -257,10 +258,35 @@ public:
     return getDimLevelType(tensor(b), index(b));
   }
 
-  /// Sets the dimension level type of the `i`th loop of the `t`th tensor.
-  void setDimLevelType(unsigned t, unsigned i, DimLevelType d) {
-    assert(isValidDLT(d));
-    dimTypes[t][i] = d;
+  /// Gets the dimension number of the the `t`th tensor on `i`th loop.
+  Optional<unsigned> getDimNum(unsigned t, unsigned i) const {
+    assert(t < numTensors && i < numLoops);
+    return loopIdxToDim[t][i];
+  }
+
+  /// Gets the dimension number of `b`.
+  Optional<unsigned> getDimNum(unsigned b) const {
+    return getDimNum(tensor(b), index(b));
+  }
+
+  /// Sets the dimension and dimension level type of the `t`th tensor on `i`th
+  /// loop.
+  void setDimAndDimLevelType(unsigned t, unsigned i, unsigned dim,
+                             DimLevelType dlt) {
+    assert(isValidDLT(dlt));
+    dimTypes[t][i] = dlt;
+    loopIdxToDim[t][i] = dim;
+  }
+
+  // Iterates the bits of a lattice, for each set bit, converts it into the
+  // corresponding tensor dimension and invokes the callback.
+  void foreachTidDimPairInBits(
+      const BitVector &bits,
+      function_ref<void(unsigned b, unsigned tid, Optional<unsigned> dim,
+                        DimLevelType dlt)>
+          cb) {
+    for (unsigned b : bits.set_bits())
+      cb(b, tensor(b), getDimNum(b), getDimLevelType(b));
   }
 
   // Has sparse output tensor setter.
@@ -310,7 +336,11 @@ private:
   const unsigned numTensors;
   const unsigned numLoops;
   bool hasSparseOut;
+  // Map that converts pair<tensor id, loop id> to the corresponding dimension
+  // level type.
   std::vector<std::vector<DimLevelType>> dimTypes;
+  // Map that converts pair<tensor id, loop id> to the corresponding dimension.
+  std::vector<std::vector<Optional<unsigned>>> loopIdxToDim;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
index a936d98..efef4ff 100644 (file)
@@ -40,8 +40,6 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-constexpr unsigned INVALID_ID = std::numeric_limits<unsigned>::max();
-
 // Iteration graph sorting.
 enum SortMask {
   kSparseOnly = 0x0,
@@ -83,14 +81,6 @@ struct CodeGen {
   // Topsort (reference should remain in scope).
   std::vector<unsigned> &topSort;
 
-  // From tensor id + loop id => dim id.
-  // TODO: This map should probably be maintained by Merger (it can be set up
-  // together with dimLvlType Map).
-  std::vector<std::vector<unsigned>> loopIdxToDim;
-
-  // Initialize the above two mapping.
-  void buildLoopIdxToDimMap(linalg::GenericOp op);
-
   Value getLoopIdxValue(size_t loopIdx) const {
     for (unsigned lv = 0; lv < topSort.size(); lv++)
       if (topSort[lv] == loopIdx)
@@ -100,30 +90,6 @@ struct CodeGen {
   }
 };
 
-void CodeGen::buildLoopIdxToDimMap(linalg::GenericOp op) {
-  size_t numLoops = op.getNumLoops();
-  size_t numTensors = op.getNumOperands();
-  loopIdxToDim.assign(numTensors, std::vector<unsigned>(numLoops, INVALID_ID));
-
-  for (OpOperand &t : op->getOpOperands()) {
-    auto map = op.getMatchingIndexingMap(&t);
-    auto enc = getSparseTensorEncoding(t.get().getType());
-    // Scan all dimensions of current tensor.
-    unsigned tid = t.getOperandNumber();
-    for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
-      auto a = map.getResult(toOrigDim(enc, d)).dyn_cast<AffineDimExpr>();
-      if (a) {
-        unsigned loopId = a.getPosition();
-        // Fills the mapping.
-        loopIdxToDim[tid][loopId] = d;
-      }
-      // Else a compound affine, do nothing. (at least we are good for
-      // now, as we only support compound affine expr on non-annoated dense
-      // tensors).
-    }
-  }
-}
-
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -151,8 +117,9 @@ static AffineMap permute(MLIRContext *context, AffineMap m,
 /// Helper method to inspect affine expressions. Rejects cases where the
 /// same index is used more than once. Also rejects compound affine
 /// expressions in sparse dimensions.
-static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
-                       DimLevelType dim, bool setLvlFormat = true) {
+static bool findAffine(Merger &merger, unsigned tensor, unsigned dim,
+                       AffineExpr a, DimLevelType dlt,
+                       bool setLvlFormat = true) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     unsigned idx = a.cast<AffineDimExpr>().getPosition();
@@ -160,21 +127,21 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
       return false; // used more than once
 
     if (setLvlFormat)
-      merger.setDimLevelType(tensor, idx, dim);
+      merger.setDimAndDimLevelType(tensor, idx, dim, dlt);
     return true;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
-    if (!isDenseDLT(dim))
+    if (!isDenseDLT(dlt))
       return false; // compound only in dense dim
     auto binOp = a.cast<AffineBinaryOpExpr>();
     // We do not set dim level format for affine expresssion like d0 + d1 on
     // both loop index at d0 and d1,
-    return findAffine(merger, tensor, binOp.getLHS(), dim, false) &&
-           findAffine(merger, tensor, binOp.getRHS(), dim, false);
+    return findAffine(merger, tensor, dim, binOp.getLHS(), dlt, false) &&
+           findAffine(merger, tensor, dim, binOp.getRHS(), dlt, false);
   }
   case AffineExprKind::Constant:
-    return isDenseDLT(dim); // const only in dense dim
+    return isDenseDLT(dlt); // const only in dense dim
   default:
     return false;
   }
@@ -196,7 +163,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t.getOperandNumber();
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      if (!findAffine(merger, tensor, a, getDimLevelType(enc, d)))
+      if (!findAffine(merger, tensor, d, a, getDimLevelType(enc, d)))
         return false; // inadmissible affine expression
     }
   }
@@ -1024,8 +991,7 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     Value clause;
     if (isCompressedDLT(merger.getDimLevelType(b)) ||
         isSingletonDLT(merger.getDimLevelType(b))) {
-      auto dim = codegen.loopIdxToDim[tensor][idx];
-      assert(dim != INVALID_ID);
+      auto dim = merger.getDimNum(tensor, idx).value();
       Value op1 = codegen.loopEmitter.getCoord()[tensor][dim];
       Value op2 = codegen.getLoopIdxValue(idx);
       clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
@@ -1082,23 +1048,22 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   unsigned l0 = merger.set(lts)[0];
   bool needsUniv = false;
 
-  SmallVector<size_t, 4> ts;
-  SmallVector<size_t, 4> ds;
-  for (auto b : merger.lat(l0).bits.set_bits()) {
-    if (isDenseDLT(merger.getDimLevelType(b)) ||
-        isUndefDLT(merger.getDimLevelType(b))) {
-      needsUniv = true;
-    } else {
-      unsigned tensor = merger.tensor(b);
-      assert(idx == merger.index(b));
-      size_t dim = codegen.loopIdxToDim[tensor][idx];
-      assert(dim != INVALID_ID);
-      ts.push_back(tensor);
-      ds.push_back(dim);
-    }
-  }
+  SmallVector<size_t> tids;
+  SmallVector<size_t> dims;
+  merger.foreachTidDimPairInBits(
+      merger.lat(l0).bits,
+      [&](unsigned b, unsigned tid, Optional<unsigned> dim, DimLevelType dlt) {
+        assert(merger.index(b) == idx);
+        if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+          needsUniv = true;
+        } else {
+          // sparse/singleton dim levels.
+          tids.push_back(tid);
+          dims.push_back(dim.value());
+        }
+      });
 
-  codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), ts, ds);
+  codegen.loopEmitter.enterNewLoopSeq(builder, op.getLoc(), tids, dims);
 
   // Maintain the universal index only if it is actually
   // consumed by a subsequent lattice point.
@@ -1119,17 +1084,15 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
                                        SmallVectorImpl<size_t> &condDims,
                                        SmallVectorImpl<size_t> &extraTids,
                                        SmallVectorImpl<size_t> &extraDims) {
-  const BitVector &simple = merger.lat(li).simple;
   const BitVector &all = merger.lat(li).bits;
-  assert(simple.size() == all.size());
-  // First converts bits to array + dim pair
-  for (unsigned b = 0, e = simple.size(); b < e; b++) {
-    size_t tid = merger.tensor(b);
+  const BitVector &simple = merger.lat(li).simple;
+
+  // Converts bits to array + dim pair
+  merger.foreachTidDimPairInBits(all, [&, idx](unsigned b, unsigned tid,
+                                               Optional<unsigned> dim,
+                                               DimLevelType dlt) {
     if (simple.test(b)) {
-      // the simplified condition must be a subset of the original condition.
-      assert(all.test(b));
-      assert(merger.index(b) == idx);
-      if (isUndefDLT(merger.getDimLevelType(b))) {
+      if (isUndefDLT(dlt)) {
         // An undefined dlt in the lattices, we probably mean to iterate based
         // on the dim of output tensor.
         // E.g., this could be a synthetic tensor (for invariants and sparse
@@ -1137,26 +1100,28 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
         // out[i][j] = invariant; or a broadcast
         // out[i][j] = in[i] (j is undef for input)
         tid = merger.getOutTensorID();
+        dim = merger.getDimNum(tid, idx);
+        // Skips invalid dim (e.g., when this is a zero ranked tensor).
+        if (!dim)
+          return;
       }
-      auto dim = codegen.loopIdxToDim[tid][idx];
-      if (dim != INVALID_ID) {
-        // dim could be invalid if this is a zero ranked tensor
-        condTids.push_back(tid);
-        condDims.push_back(dim);
-      }
-    } else if ((all.test(b) || merger.isOutTensor(b, idx)) &&
-               isDenseDLT(merger.getDimLevelType(b))) {
-      assert(merger.index(b) == idx);
-      // Note that we generate dense indices of the output tensor
-      // unconditionally, since they may not appear in the lattice, but may be
-      // needed for linearized codegen.
-      // Only dense dimensions should be optimized from conditions.
-      assert(isDenseDLT(merger.getDimLevelType(b)));
-      auto dim = codegen.loopIdxToDim[tid][idx];
-      assert(dim != INVALID_ID);
+      condTids.push_back(tid);
+      condDims.push_back(dim.value());
+    } else if (isDenseDLT(dlt)) {
+      // TODO: get rid of extraTids and extraDims.
       extraTids.push_back(tid);
-      extraDims.push_back(dim);
+      extraDims.push_back(dim.value());
     }
+  });
+
+  if (isDenseDLT(merger.getDimLevelType(merger.getOutTensorID(), idx))) {
+    // Note that we generate dense indices of the output tensor
+    // unconditionally, since they may not appear in the lattice, but may be
+    // needed for linearized codegen.
+    // Only dense dimensions should be optimized from conditions.
+    auto dim = merger.getDimNum(merger.getOutTensorID(), idx).value();
+    extraTids.push_back(merger.getOutTensorID());
+    extraDims.push_back(dim);
   }
 }
 
@@ -1370,8 +1335,6 @@ public:
     // Recursively generates code if admissible.
     CodeGen codegen(options, tensors, numTensors, numLoops, sparseOut,
                     outerParNest, topSort);
-    // TODO: maybe merger should be responsible of maintaining the map.
-    codegen.buildLoopIdxToDimMap(op);
     genBuffers(merger, codegen, rewriter, op);
     genStmt(merger, codegen, rewriter, op, exp, 0);
     genResult(merger, codegen, rewriter, op);
index 0ce2740..a9a3725 100644 (file)
@@ -313,15 +313,15 @@ protected:
 
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Dense);
+    merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -338,19 +338,19 @@ protected:
 
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Compressed);
 
     // Tensor 2: sparse input vector
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
 
     // Tensor 3: dense output vector
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDimLevelType(t3, l0, DimLevelType::Dense);
+    merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -371,15 +371,15 @@ protected:
 
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Compressed);
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Dense);
+    merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Dense);
+    merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -400,19 +400,19 @@ protected:
 
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Undef);
+    merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Dense);
+    merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Dense);
 
     // Tensor 2: undef input vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Undef);
+    merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Undef);
 
     // Tensor 3: dense output vector.
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDimLevelType(t3, l0, DimLevelType::Dense);
+    merger.setDimAndDimLevelType(t3, l0, 0, DimLevelType::Dense);
   }
 };
 
@@ -436,15 +436,15 @@ protected:
 
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelType(t0, l0, DimLevelType::Undef);
+    merger.setDimAndDimLevelType(t0, l0, 0, DimLevelType::Undef);
 
     // Tensor 1: undef input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelType(t1, l0, DimLevelType::Undef);
+    merger.setDimAndDimLevelType(t1, l0, 0, DimLevelType::Undef);
 
     // Tensor 2: sparse output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
+    merger.setDimAndDimLevelType(t2, l0, 0, DimLevelType::Compressed);
   }
 };