[mlir][sparse] properly record dimension level type and properties
authorAart Bik <ajcbik@google.com>
Fri, 9 Sep 2022 22:42:46 +0000 (15:42 -0700)
committerAart Bik <ajcbik@google.com>
Mon, 12 Sep 2022 16:59:53 +0000 (09:59 -0700)
A next step towards supporting the new dimension level types and
properties. This changes properly records the properties in the
Merger, so that subsequent computations (lattice optimizations)
and code generation (during sparsification) can do the right thing.

https://github.com/llvm/llvm-project/issues/51658

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133620

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

index ea00b95..584df84 100644 (file)
@@ -21,7 +21,20 @@ namespace mlir {
 namespace sparse_tensor {
 
 /// Dimension level type for a tensor (undef means index does not appear).
-enum Dim { kSparse, kDense, kUndef };
+enum class DimLvlType { kDense, kCompressed, kSingleton, kUndef };
+
+/// Per-dimension level format (type and properties). Dense and undefined
+/// level types should always be marked ordered and unique.
+struct DimLevelFormat {
+  DimLevelFormat(DimLvlType tp, bool o = true, bool u = true)
+      : levelType(tp), isOrdered(o), isUnique(u) {
+    assert((tp == DimLvlType::kCompressed || tp == DimLvlType::kSingleton) ||
+           (o && u));
+  }
+  DimLvlType levelType;
+  bool isOrdered;
+  bool isUnique;
+};
 
 /// Tensor expression kind.
 enum Kind {
@@ -156,7 +169,9 @@ public:
   /// invariant expressions in the kernel.
   Merger(unsigned t, unsigned l)
       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
-        hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
+        hasSparseOut(false),
+        dims(t + 1, std::vector<DimLevelFormat>(
+                        l, DimLevelFormat(DimLvlType::kUndef))) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@@ -225,31 +240,40 @@ public:
   unsigned tensor(unsigned b) const { return b % numTensors; }
   unsigned index(unsigned b) const { return b / numTensors; }
 
-  /// Returns true if bit corresponds to queried dim.
-  bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
-
   /// Returns true if bit corresponds to index of output tensor.
   bool isOutTensor(unsigned b, unsigned i) const {
     return tensor(b) == outTensor && index(b) == i;
   }
 
-  /// Returns true if tensor access at given index has queried dim.
-  bool isDim(unsigned t, unsigned i, Dim d) const {
-    assert(t < numTensors && i < numLoops);
-    return dims[t][i] == d;
-  }
-
-  /// Returns true if any set bit corresponds to queried dim.
-  bool hasAnyDimOf(const BitVector &bits, Dim d) const;
-
   /// Returns true if given tensor iterates *only* in the given tensor
   /// expression. For the output tensor, this defines a "simply dynamic"
   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
   /// sparse vector a.
   bool isSingleCondition(unsigned t, unsigned e) const;
 
-  /// Dimension setter.
-  void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
+  /// Returns true if bit corresponds to given dimension level type.
+  bool isDimLevelType(unsigned b, DimLvlType tp) const {
+    return isDimLevelType(tensor(b), index(b), tp);
+  }
+
+  /// Returns true if tensor access at index has given dimension level type.
+  bool isDimLevelType(unsigned t, unsigned i, DimLvlType tp) const {
+    return getDimLevelFormat(t, i).levelType == tp;
+  }
+
+  /// Returns true if any set bit corresponds to given dimension level type.
+  bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const;
+
+  /// Dimension level format getter.
+  DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {
+    assert(t < numTensors && i < numLoops);
+    return dims[t][i];
+  }
+
+  /// Dimension level format setter.
+  void setDimLevelFormat(unsigned t, unsigned i, DimLevelFormat d) {
+    dims[t][i] = d;
+  }
 
   // Has sparse output tensor setter.
   void setHasSparseOut(bool s) { hasSparseOut = s; }
@@ -298,7 +322,7 @@ private:
   const unsigned numTensors;
   const unsigned numLoops;
   bool hasSparseOut;
-  std::vector<std::vector<Dim>> dims;
+  std::vector<std::vector<DimLevelFormat>> dims;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
index aa44a7f..642c666 100644 (file)
@@ -141,40 +141,60 @@ static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) {
   return d;
 }
 
-/// Helper method to translate dim level type to internal representation.
-static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) {
+/// Helper method to obtain the dimension level format from the encoding.
+//
+//  TODO: note that we store, but currently completely *ignore* the properties
+//
+static DimLevelFormat toDimLevelFormat(const SparseTensorEncodingAttr &enc,
+                                       unsigned d) {
   if (enc) {
-    SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d];
-    if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed)
-      return Dim::kSparse;
+    switch (enc.getDimLevelType()[d]) {
+    case SparseTensorEncodingAttr::DimLevelType::Dense:
+      return DimLevelFormat(DimLvlType::kDense);
+    case SparseTensorEncodingAttr::DimLevelType::Compressed:
+      return DimLevelFormat(DimLvlType::kCompressed);
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
+      return DimLevelFormat(DimLvlType::kCompressed, true, false);
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
+      return DimLevelFormat(DimLvlType::kCompressed, false, true);
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+      return DimLevelFormat(DimLvlType::kCompressed, false, false);
+    case SparseTensorEncodingAttr::DimLevelType::Singleton:
+      return DimLevelFormat(DimLvlType::kSingleton);
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
+      return DimLevelFormat(DimLvlType::kSingleton, true, false);
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
+      return DimLevelFormat(DimLvlType::kSingleton, false, true);
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+      return DimLevelFormat(DimLvlType::kSingleton, false, false);
+    }
   }
-  return Dim::kDense;
+  return DimLevelFormat(DimLvlType::kDense);
 }
 
 /// Helper method to inspect affine expressions. Rejects cases where the
-/// same index is used more than once. Also rejects affine expressions
-/// that are not a direct index for annotated tensors.
-// TODO: accept more affine cases for sparse tensors
-static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim,
-                       bool isDense) {
+/// same index is used more than once. Also rejects compound affine
+/// expressions in sparse dimensions.
+static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
+                       DimLevelFormat dim) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     unsigned idx = a.cast<AffineDimExpr>().getPosition();
-    if (!merger.isDim(tensor, idx, Dim::kUndef))
+    if (!merger.isDimLevelType(tensor, idx, DimLvlType::kUndef))
       return false; // used more than once
-    merger.setDim(tensor, idx, dim);
+    merger.setDimLevelFormat(tensor, idx, dim);
     return true;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
-    if (!isDense)
-      return false;
+    if (dim.levelType != DimLvlType::kDense)
+      return false; // compound only in dense dim
     auto binOp = a.cast<AffineBinaryOpExpr>();
-    return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) &&
-           findAffine(merger, tensor, binOp.getRHS(), dim, isDense);
+    return findAffine(merger, tensor, binOp.getLHS(), dim) &&
+           findAffine(merger, tensor, binOp.getRHS(), dim);
   }
   case AffineExprKind::Constant:
-    return isDense;
+    return dim.levelType == DimLvlType::kDense; // const only in dense dim
   default:
     return false;
   }
@@ -196,7 +216,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t->getOperandNumber();
       AffineExpr a = map.getResult(perm(enc, d));
-      if (!findAffine(merger, tensor, a, toDim(enc, d), !enc))
+      if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
         return false; // inadmissable affine expression
     }
   }
@@ -286,13 +306,13 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     if (mask & SortMask::kIncludeUndef) {
       unsigned tensor = t->getOperandNumber();
       for (unsigned i = 0; i < n; i++)
-        if (merger.isDim(tensor, i, Dim::kSparse))
+        if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
+            merger.isDimLevelType(tensor, i, DimLvlType::kSingleton))
           for (unsigned j = 0; j < n; j++)
-            if (merger.isDim(tensor, j, Dim::kUndef))
+            if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef))
               adjM[i][j] = true;
     }
   }
-
   // Topologically sort the iteration graph to determine loop order.
   // Report failure for a cyclic iteration graph.
   topSort.clear();
@@ -334,7 +354,8 @@ static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op,
   auto iteratorTypes = op.iterator_types().getValue();
   unsigned numLoops = iteratorTypes.size();
   for (unsigned i = 0; i < numLoops; i++)
-    if (merger.isDim(tensor, i, Dim::kSparse)) {
+    if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
+        merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
       allDense = false;
       break;
     }
@@ -519,7 +540,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         continue; // compound
       unsigned idx = a.cast<AffineDimExpr>().getPosition();
       // Handle sparse storage schemes.
-      if (merger.isDim(tensor, idx, Dim::kSparse)) {
+      if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) {
         auto dynShape = {ShapedType::kDynamicSize};
         auto ptrTp =
             MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
@@ -531,6 +552,8 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
             builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
         codegen.indices[tensor][idx] =
             builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+      } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
+        llvm_unreachable("TODO: not implemented yet");
       }
       // Find upper bound in current dimension.
       unsigned p = perm(enc, d);
@@ -543,7 +566,6 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     // Perform the required bufferization. Dense inputs materialize
     // from the input tensors. Dense outputs need special handling.
     // Sparse inputs use sparse primitives to obtain the values.
-    // We also accept in-place all-dense annotated "sparse" outputs.
     Type elementType = getElementTypeOrSelf(t->get().getType());
     if (!enc) {
       // Non-annotated dense tensors.
@@ -985,11 +1007,13 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
     return genInvariantValue(merger, codegen, rewriter, exp);
   if (merger.exp(exp).kind == Kind::kIndex)
     return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx);
+
   if (merger.exp(exp).kind == Kind::kReduce) {
     // Make custom reduction identity accessible for expanded access pattern.
     assert(codegen.redCustom == -1u);
     codegen.redCustom = exp;
   }
+
   Value v0 =
       genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx);
   Value v1 =
@@ -1000,8 +1024,12 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
              merger.exp(exp).kind == Kind::kBinaryBranch ||
              merger.exp(exp).kind == Kind::kReduce))
     ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
-  if (merger.exp(exp).kind == Kind::kReduce)
+
+  if (merger.exp(exp).kind == Kind::kReduce) {
+    assert(codegen.redCustom != -1u);
     codegen.redCustom = -1u;
+  }
+
   return ee;
 }
 
@@ -1029,7 +1057,7 @@ static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a,
 /// Hoists loop invariant tensor loads for which indices have been exhausted.
 static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
                           linalg::GenericOp op, unsigned exp, unsigned ldx,
-                          bool atStart, unsigned last = 0) {
+                          bool atStart, unsigned last = -1u) {
   if (exp == -1u)
     return;
   if (merger.exp(exp).kind == Kind::kTensor) {
@@ -1131,7 +1159,7 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     if (inits[b]) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
-      if (merger.isDim(b, Dim::kSparse)) {
+      if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
         // Initialize sparse index.
         unsigned pat = at;
         for (; pat != 0; pat--) {
@@ -1145,6 +1173,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0);
         Value p1 = builder.create<arith::AddIOp>(loc, p0, one);
         codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1);
+      } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+        llvm_unreachable("TODO: not implemented yet");
       } else {
         // Dense index still in play.
         needsUniv = true;
@@ -1235,12 +1265,14 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   assert(idx == merger.index(fb));
   auto iteratorTypes = op.iterator_types().getValue();
   bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
-  bool isSparse = merger.isDim(fb, Dim::kSparse);
+  bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed);
   bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
                   denseUnitStrides(merger, op, idx);
   bool isParallel =
       isParallelFor(codegen, isOuter, isReduction, isSparse, isVector);
 
+  assert(!merger.isDimLevelType(fb, DimLvlType::kSingleton) && "TODO: implement");
+
   // Prepare vector length.
   if (isVector)
     codegen.curVecLength = codegen.options.vectorLength;
@@ -1308,7 +1340,7 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // Construct the while-loop with a parameter for each index.
   Type indexType = builder.getIndexType();
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
-    if (indices[b] && merger.isDim(b, Dim::kSparse)) {
+    if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       types.push_back(indexType);
@@ -1341,7 +1373,8 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   Value cond;
   unsigned o = 0;
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
-    if (indices[b] && merger.isDim(b, Dim::kSparse)) {
+    // TODO: singleton
+    if (indices[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = before->getArgument(o);
@@ -1389,7 +1422,8 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // Initialize sparse indices.
   Value min;
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
-    if (locals[b] && merger.isDim(b, Dim::kSparse)) {
+    // TODO: singleton
+    if (locals[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value ptr = codegen.indices[tensor][idx];
@@ -1419,7 +1453,7 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // but may be needed for linearized codegen.
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
     if ((locals[b] || merger.isOutTensor(b, idx)) &&
-        merger.isDim(b, Dim::kDense)) {
+        merger.isDimLevelType(b, DimLvlType::kDense)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       unsigned pat = at;
@@ -1477,7 +1511,8 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
   SmallVector<Value, 4> operands;
   Value one = constantIndex(builder, loc, 1);
   for (unsigned b = 0, be = induction.size(); b < be; b++) {
-    if (induction[b] && merger.isDim(b, Dim::kSparse)) {
+    // TODO: singleton
+    if (induction[b] && merger.isDimLevelType(b, DimLvlType::kCompressed)) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = codegen.idxs[tensor][idx];
@@ -1541,7 +1576,8 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value clause;
-      if (merger.isDim(b, Dim::kSparse)) {
+      // TODO: singleton
+      if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
         Value op1 = codegen.idxs[tensor][idx];
         Value op2 = codegen.loops[idx];
         clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -1605,7 +1641,8 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     unsigned lsize = merger.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       unsigned li = merger.set(lts)[i];
-      if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse))
+      if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) &&
+          !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton))
         return true;
     }
   }
index eeaaa2e..cd5d134 100644 (file)
@@ -262,10 +262,17 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
     }
   }
   // Now apply the two basic rules.
+  //
+  // TODO: improve for singleton and properties
+  //
   BitVector simple = latPoints[p0].bits;
-  bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
+  bool reset = isSingleton &&
+      (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) ||
+       hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton));
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
-    if (simple[b] && !isDim(b, kSparse)) {
+    if (simple[b] &&
+        (!isDimLevelType(b, DimLvlType::kCompressed) &&
+         !isDimLevelType(b, DimLvlType::kSingleton))) {
       if (reset)
         simple.reset(b);
       reset = true;
@@ -290,14 +297,8 @@ bool Merger::latGT(unsigned i, unsigned j) const {
 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
   BitVector tmp = latPoints[j].bits;
   tmp ^= latPoints[i].bits;
-  return !hasAnyDimOf(tmp, kSparse);
-}
-
-bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
-  for (unsigned b = 0, be = bits.size(); b < be; b++)
-    if (bits[b] && isDim(b, d))
-      return true;
-  return false;
+  return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) &&
+         !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton);
 }
 
 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
@@ -383,6 +384,13 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   llvm_unreachable("unexpected kind");
 }
 
+bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const {
+  for (unsigned b = 0, be = bits.size(); b < be; b++)
+    if (bits[b] && isDimLevelType(b, tp))
+      return true;
+  return false;
+}
+
 #ifndef NDEBUG
 
 //===----------------------------------------------------------------------===//
@@ -591,18 +599,23 @@ void Merger::dumpBits(const BitVector &bits) const {
     if (bits[b]) {
       unsigned t = tensor(b);
       unsigned i = index(b);
+      DimLevelFormat f = dims[t][i];
       llvm::dbgs() << " i_" << t << "_" << i << "_";
-      switch (dims[t][i]) {
-      case kSparse:
-        llvm::dbgs() << "S";
-        break;
-      case kDense:
+      switch (f.levelType) {
+      case DimLvlType::kDense:
         llvm::dbgs() << "D";
         break;
-      case kUndef:
+      case DimLvlType::kCompressed:
+        llvm::dbgs() << "C";
+        break;
+      case DimLvlType::kSingleton:
+        llvm::dbgs() << "S";
+        break;
+      case DimLvlType::kUndef:
         llvm::dbgs() << "U";
         break;
       }
+      llvm::dbgs() << "[O=" << f.isOrdered << ",U=" << f.isUnique << "]";
     }
   }
 }
@@ -855,9 +868,8 @@ static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
   if (isa<linalg::IndexOp>(def))
     return true;
   // Operation defined outside branch.
-  if (def->getBlock() != block) {
+  if (def->getBlock() != block)
     return def->getBlock() != op->getBlock(); // invariant?
-  }
   // Operation defined within branch. Anything is accepted,
   // as long as all subexpressions are admissable.
   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
@@ -1038,7 +1050,6 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     if (x.has_value() && y.has_value() && z.has_value()) {
       unsigned e0 = x.value();
       unsigned e1 = y.value();
-      // unsigned e2 = z.getValue();
       if (auto redop = dyn_cast<sparse_tensor::ReduceOp>(def)) {
         if (isAdmissableBranch(redop, redop.getRegion()))
           return addExp(kReduce, e0, e1, Value(), def);
index 3bf424c..8d41558 100644 (file)
@@ -310,15 +310,15 @@ protected:
   MergerTest3T1L() : MergerTestBase(3, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kDense);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
   }
 };
 
@@ -333,19 +333,19 @@ protected:
   MergerTest4T1L() : MergerTestBase(4, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 2: sparse input vector
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 3: dense output vector
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDim(t3, l0, Dim::kDense);
+    merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense));
   }
 };
 
@@ -364,15 +364,15 @@ protected:
   MergerTest3T1LD() : MergerTestBase(3, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDim(t0, l0, Dim::kSparse);
+    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDim(t1, l0, Dim::kDense);
+    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDim(t2, l0, Dim::kDense);
+    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
   }
 };