[mlir][sparse] Removing the DimLvlType and DimLevelFormat types
authorwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 18 Oct 2022 02:11:20 +0000 (19:11 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 18 Oct 2022 22:47:40 +0000 (15:47 -0700)
This removes another massive source of redundancy, and instead has the Merger.{h,cpp} reuse the SparseTensorEnums library.

Depends On D136005

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136123

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 7697f0b..2f475b4 100644 (file)
@@ -138,7 +138,13 @@ enum class Action : uint32_t {
 /// versions; consequently, client code should use the predicate functions
 /// defined below, rather than relying on knowledge about the particular
 /// binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `DimLevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
 enum class DimLevelType : uint8_t {
+  Undef = 0,           // 0b000_00
   Dense = 4,           // 0b001_00
   Compressed = 8,      // 0b010_00
   CompressedNu = 9,    // 0b010_01
@@ -150,20 +156,39 @@ enum class DimLevelType : uint8_t {
   SingletonNuNo = 19,  // 0b100_11
 };
 
+/// Check that the `DimLevelType` contains a valid (possibly undefined) value.
+constexpr bool isValidDLT(DimLevelType dlt) {
+  const uint8_t formatBits = static_cast<uint8_t>(dlt) >> 2;
+  const uint8_t propertyBits = static_cast<uint8_t>(dlt) & 3;
+  // If undefined or dense, then must be unique and ordered.
+  // Otherwise, the format must be one of the known ones.
+  return (formatBits <= 1) ? (propertyBits == 0)
+                           : (formatBits == 2 || formatBits == 4);
+}
+
+/// Check if the `DimLevelType` is the special undefined value.
+constexpr bool isUndefDLT(DimLevelType dlt) {
+  return dlt == DimLevelType::Undef;
+}
+
 /// Check if the `DimLevelType` is dense.
 constexpr bool isDenseDLT(DimLevelType dlt) {
   return dlt == DimLevelType::Dense;
 }
 
+// We use the idiom `(dlt & ~3) == format` in order to only return true
+// for valid DLTs.  Whereas the `dlt & format` idiom is a bit faster but
+// can return false-positives on invalid DLTs.
+
 /// Check if the `DimLevelType` is compressed (regardless of properties).
 constexpr bool isCompressedDLT(DimLevelType dlt) {
-  return static_cast<uint8_t>(dlt) &
+  return (static_cast<uint8_t>(dlt) & ~3) ==
          static_cast<uint8_t>(DimLevelType::Compressed);
 }
 
 /// Check if the `DimLevelType` is singleton (regardless of properties).
 constexpr bool isSingletonDLT(DimLevelType dlt) {
-  return static_cast<uint8_t>(dlt) &
+  return (static_cast<uint8_t>(dlt) & ~3) ==
          static_cast<uint8_t>(DimLevelType::Singleton);
 }
 
@@ -178,6 +203,18 @@ constexpr bool isUniqueDLT(DimLevelType dlt) {
 }
 
 // Ensure the above predicates work as intended.
+static_assert((isValidDLT(DimLevelType::Undef) &&
+               isValidDLT(DimLevelType::Dense) &&
+               isValidDLT(DimLevelType::Compressed) &&
+               isValidDLT(DimLevelType::CompressedNu) &&
+               isValidDLT(DimLevelType::CompressedNo) &&
+               isValidDLT(DimLevelType::CompressedNuNo) &&
+               isValidDLT(DimLevelType::Singleton) &&
+               isValidDLT(DimLevelType::SingletonNu) &&
+               isValidDLT(DimLevelType::SingletonNo) &&
+               isValidDLT(DimLevelType::SingletonNuNo)),
+              "isValidDLT definition is broken");
+
 static_assert((!isCompressedDLT(DimLevelType::Dense) &&
                isCompressedDLT(DimLevelType::Compressed) &&
                isCompressedDLT(DimLevelType::CompressedNu) &&
index 67aaa3e..7c615b4 100644 (file)
@@ -37,29 +37,37 @@ SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 // Dimension level types.
 //
 
-// Cannot be constexpr, because `getRank` isn't constexpr.  However,
-// for some strange reason, the wrapper functions below don't trigger
-// the same [-Winvalid-constexpr] warning (despite this function not
-// being constexpr).
-inline DimLevelType getDimLevelType(RankedTensorType type, uint64_t d) {
-  assert(d < static_cast<uint64_t>(type.getRank()));
-  if (auto enc = getSparseTensorEncoding(type))
-    return enc.getDimLevelType()[d];
+// MSVC does not allow this function to be constexpr, because
+// `SparseTensorEncodingAttr::operator bool` isn't declared constexpr.
+// And therefore all functions calling it cannot be constexpr either.
+// TODO: since Clang does allow these to be constexpr, perhaps we should
+// define a macro to abstract over `inline` vs `constexpr` annotations.
+inline DimLevelType getDimLevelType(const SparseTensorEncodingAttr &enc,
+                                    uint64_t d) {
+  if (enc) {
+    auto types = enc.getDimLevelType();
+    assert(d < types.size() && "Dimension out of bounds");
+    return types[d];
+  }
   return DimLevelType::Dense; // unannotated tensor is dense
 }
 
+inline DimLevelType getDimLevelType(RankedTensorType type, uint64_t d) {
+  return getDimLevelType(getSparseTensorEncoding(type), d);
+}
+
 /// Convenience function to test for dense dimension (0 <= d < rank).
-constexpr bool isDenseDim(RankedTensorType type, uint64_t d) {
+inline bool isDenseDim(RankedTensorType type, uint64_t d) {
   return isDenseDLT(getDimLevelType(type, d));
 }
 
 /// Convenience function to test for compressed dimension (0 <= d < rank).
-constexpr bool isCompressedDim(RankedTensorType type, uint64_t d) {
+inline bool isCompressedDim(RankedTensorType type, uint64_t d) {
   return isCompressedDLT(getDimLevelType(type, d));
 }
 
 /// Convenience function to test for singleton dimension (0 <= d < rank).
-constexpr bool isSingletonDim(RankedTensorType type, uint64_t d) {
+inline bool isSingletonDim(RankedTensorType type, uint64_t d) {
   return isSingletonDLT(getDimLevelType(type, d));
 }
 
@@ -69,13 +77,13 @@ constexpr bool isSingletonDim(RankedTensorType type, uint64_t d) {
 
 /// Convenience function to test for ordered property in the
 /// given dimension (0 <= d < rank).
-constexpr bool isOrderedDim(RankedTensorType type, uint64_t d) {
+inline bool isOrderedDim(RankedTensorType type, uint64_t d) {
   return isOrderedDLT(getDimLevelType(type, d));
 }
 
 /// Convenience function to test for unique property in the
 /// given dimension (0 <= d < rank).
-constexpr bool isUniqueDim(RankedTensorType type, uint64_t d) {
+inline bool isUniqueDim(RankedTensorType type, uint64_t d) {
   return isUniqueDLT(getDimLevelType(type, d));
 }
 
index a376d9a..52456a2 100644 (file)
 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/BitVector.h"
 
 namespace mlir {
 namespace sparse_tensor {
 
-/// Dimension level type for a tensor (undef means index does not appear).
-enum class DimLvlType { kDense, kCompressed, kSingleton, kUndef };
-
-/// Per-dimension level format (type and properties). Dense and undefined
-/// level types should always be marked ordered and unique.
-struct DimLevelFormat {
-  DimLevelFormat(DimLvlType tp, bool o = true, bool u = true)
-      : levelType(tp), isOrdered(o), isUnique(u) {
-    assert((tp == DimLvlType::kCompressed || tp == DimLvlType::kSingleton) ||
-           (o && u));
-  }
-  DimLvlType levelType;
-  bool isOrdered;
-  bool isUnique;
-};
-
 /// Tensor expression kind.
 enum Kind {
   // Leaf.
@@ -171,8 +156,7 @@ public:
   Merger(unsigned t, unsigned l)
       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
         hasSparseOut(false),
-        dims(t + 1, std::vector<DimLevelFormat>(
-                        l, DimLevelFormat(DimLvlType::kUndef))) {}
+        dimTypes(t + 1, std::vector<DimLevelType>(l, DimLevelType::Undef)) {}
 
   /// Adds a tensor expression. Returns its index.
   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
@@ -252,28 +236,24 @@ public:
   /// sparse vector a.
   bool isSingleCondition(unsigned t, unsigned e) const;
 
-  /// Returns true if bit corresponds to given dimension level type.
-  bool isDimLevelType(unsigned b, DimLvlType tp) const {
-    return isDimLevelType(tensor(b), index(b), tp);
-  }
-
-  /// Returns true if tensor access at index has given dimension level type.
-  bool isDimLevelType(unsigned t, unsigned i, DimLvlType tp) const {
-    return getDimLevelFormat(t, i).levelType == tp;
-  }
-
   /// Returns true if any set bit corresponds to sparse dimension level type.
   bool hasAnySparse(const BitVector &bits) const;
 
-  /// Dimension level format getter.
-  DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {
+  /// Gets the dimension level type of the `i`th loop of the `t`th tensor.
+  DimLevelType getDimLevelType(unsigned t, unsigned i) const {
     assert(t < numTensors && i < numLoops);
-    return dims[t][i];
+    return dimTypes[t][i];
+  }
+
+  /// Gets the dimension level type of `b`.
+  DimLevelType getDimLevelType(unsigned b) const {
+    return getDimLevelType(tensor(b), index(b));
   }
 
-  /// Dimension level format setter.
-  void setDimLevelFormat(unsigned t, unsigned i, DimLevelFormat d) {
-    dims[t][i] = d;
+  /// Sets the dimension level type of the `i`th loop of the `t`th tensor.
+  void setDimLevelType(unsigned t, unsigned i, DimLevelType d) {
+    assert(isValidDLT(d));
+    dimTypes[t][i] = d;
   }
 
   // Has sparse output tensor setter.
@@ -323,7 +303,7 @@ private:
   const unsigned numTensors;
   const unsigned numLoops;
   bool hasSparseOut;
-  std::vector<std::vector<DimLevelFormat>> dims;
+  std::vector<std::vector<DimLevelType>> dimTypes;
   llvm::SmallVector<TensorExp, 32> tensorExps;
   llvm::SmallVector<LatPoint, 16> latPoints;
   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
index c2c9901..7822b40 100644 (file)
@@ -143,6 +143,10 @@ void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
   printer << "<{ dimLevelType = [ ";
   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
     switch (getDimLevelType()[i]) {
+    case DimLevelType::Undef:
+      // TODO: should probably raise an error instead of printing it...
+      printer << "\"undef\"";
+      break;
     case DimLevelType::Dense:
       printer << "\"dense\"";
       break;
index dc4f3cc..45013c0 100644 (file)
@@ -128,60 +128,29 @@ static AffineMap permute(MLIRContext *context, AffineMap m,
   return AffineMap::getPermutationMap(perm, context);
 }
 
-/// Helper method to obtain the dimension level format from the encoding.
-//
-//  TODO: note that we store, but currently completely *ignore* the properties
-//
-static DimLevelFormat toDimLevelFormat(const SparseTensorEncodingAttr &enc,
-                                       unsigned d) {
-  if (enc) {
-    switch (enc.getDimLevelType()[d]) {
-    case DimLevelType::Dense:
-      return DimLevelFormat(DimLvlType::kDense);
-    case DimLevelType::Compressed:
-      return DimLevelFormat(DimLvlType::kCompressed);
-    case DimLevelType::CompressedNu:
-      return DimLevelFormat(DimLvlType::kCompressed, true, false);
-    case DimLevelType::CompressedNo:
-      return DimLevelFormat(DimLvlType::kCompressed, false, true);
-    case DimLevelType::CompressedNuNo:
-      return DimLevelFormat(DimLvlType::kCompressed, false, false);
-    case DimLevelType::Singleton:
-      return DimLevelFormat(DimLvlType::kSingleton);
-    case DimLevelType::SingletonNu:
-      return DimLevelFormat(DimLvlType::kSingleton, true, false);
-    case DimLevelType::SingletonNo:
-      return DimLevelFormat(DimLvlType::kSingleton, false, true);
-    case DimLevelType::SingletonNuNo:
-      return DimLevelFormat(DimLvlType::kSingleton, false, false);
-    }
-  }
-  return DimLevelFormat(DimLvlType::kDense);
-}
-
 /// Helper method to inspect affine expressions. Rejects cases where the
 /// same index is used more than once. Also rejects compound affine
 /// expressions in sparse dimensions.
 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
-                       DimLevelFormat dim) {
+                       DimLevelType dim) {
   switch (a.getKind()) {
   case AffineExprKind::DimId: {
     unsigned idx = a.cast<AffineDimExpr>().getPosition();
-    if (!merger.isDimLevelType(tensor, idx, DimLvlType::kUndef))
+    if (!isUndefDLT(merger.getDimLevelType(tensor, idx)))
       return false; // used more than once
-    merger.setDimLevelFormat(tensor, idx, dim);
+    merger.setDimLevelType(tensor, idx, dim);
     return true;
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
-    if (dim.levelType != DimLvlType::kDense)
+    if (!isDenseDLT(dim))
       return false; // compound only in dense dim
     auto binOp = a.cast<AffineBinaryOpExpr>();
     return findAffine(merger, tensor, binOp.getLHS(), dim) &&
            findAffine(merger, tensor, binOp.getRHS(), dim);
   }
   case AffineExprKind::Constant:
-    return dim.levelType == DimLvlType::kDense; // const only in dense dim
+    return isDenseDLT(dim); // const only in dense dim
   default:
     return false;
   }
@@ -203,7 +172,7 @@ static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       unsigned tensor = t.getOperandNumber();
       AffineExpr a = map.getResult(toOrigDim(enc, d));
-      if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
+      if (!findAffine(merger, tensor, a, getDimLevelType(enc, d)))
         return false; // inadmissible affine expression
     }
   }
@@ -316,16 +285,16 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     if (mask & SortMask::kIncludeUndef) {
       unsigned tensor = t.getOperandNumber();
       for (unsigned i = 0; i < n; i++)
-        if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
-            merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
+        if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
+            isSingletonDLT(merger.getDimLevelType(tensor, i))) {
           for (unsigned j = 0; j < n; j++)
-            if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) {
+            if (isUndefDLT(merger.getDimLevelType(tensor, j))) {
               adjM[i][j] = true;
               inDegree[j]++;
             }
         } else {
-          assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) ||
-                 merger.isDimLevelType(tensor, i, DimLvlType::kUndef));
+          assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
+                 isUndefDLT(merger.getDimLevelType(tensor, i)));
         }
     }
   }
@@ -364,13 +333,13 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op,
   auto iteratorTypes = op.getIteratorTypesArray();
   unsigned numLoops = iteratorTypes.size();
   for (unsigned i = 0; i < numLoops; i++)
-    if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
-        merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
+    if (isCompressedDLT(merger.getDimLevelType(tensor, i)) ||
+        isSingletonDLT(merger.getDimLevelType(tensor, i))) {
       allDense = false;
       break;
     } else {
-      assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) ||
-             merger.isDimLevelType(tensor, i, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(tensor, i)) ||
+             isUndefDLT(merger.getDimLevelType(tensor, i)));
     }
   if (allDense)
     return true;
@@ -552,7 +521,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         continue; // compound
       unsigned idx = a.cast<AffineDimExpr>().getPosition();
       // Handle the different storage schemes.
-      if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) {
+      if (isCompressedDLT(merger.getDimLevelType(tensor, idx))) {
         // Compressed dimension, fetch pointer and indices.
         auto ptrTp =
             MemRefType::get(dynShape, getPointerOverheadType(builder, enc));
@@ -563,7 +532,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
             builder.create<ToPointersOp>(loc, ptrTp, t.get(), dim);
         codegen.indices[tensor][idx] =
             builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
-      } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
+      } else if (isSingletonDLT(merger.getDimLevelType(tensor, idx))) {
         // Singleton dimension, fetch indices.
         auto indTp =
             MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
@@ -572,7 +541,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
             builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
       } else {
         // Dense dimension, nothing to fetch.
-        assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense));
+        assert(isDenseDLT(merger.getDimLevelType(tensor, idx)));
       }
       // Find upper bound in current dimension.
       unsigned p = toOrigDim(enc, d);
@@ -1195,7 +1164,7 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       continue;
     unsigned tensor = merger.tensor(b);
     assert(idx == merger.index(b));
-    if (merger.isDimLevelType(b, DimLvlType::kCompressed)) {
+    if (isCompressedDLT(merger.getDimLevelType(b))) {
       // Initialize sparse index that will implement the iteration:
       //   for pidx_idx = pointers(pidx_idx-1), pointers(1+pidx_idx-1)
       unsigned pat = at;
@@ -1210,7 +1179,7 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0);
       Value p1 = builder.create<arith::AddIOp>(loc, p0, one);
       codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1);
-    } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+    } else if (isSingletonDLT(merger.getDimLevelType(b))) {
       // Initialize sparse index that will implement the "iteration":
       //   for pidx_idx = pidx_idx-1, 1+pidx_idx-1
       // We rely on subsequent loop unrolling to get rid of the loop
@@ -1226,8 +1195,8 @@ static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       codegen.pidxs[tensor][idx] = p0;
       codegen.highs[tensor][idx] = builder.create<arith::AddIOp>(loc, p0, one);
     } else {
-      assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
-             merger.isDimLevelType(b, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(b)) ||
+             isUndefDLT(merger.getDimLevelType(b)));
       // Dense index still in play.
       needsUniv = true;
     }
@@ -1316,8 +1285,8 @@ static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   assert(idx == merger.index(fb));
   auto iteratorTypes = op.getIteratorTypesArray();
   bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]);
-  bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) ||
-                  merger.isDimLevelType(fb, DimLvlType::kSingleton);
+  bool isSparse = isCompressedDLT(merger.getDimLevelType(fb)) ||
+                  isSingletonDLT(merger.getDimLevelType(fb));
   bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) &&
                   denseUnitStrides(merger, op, idx);
   bool isParallel =
@@ -1392,15 +1361,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
     if (!indices[b])
       continue;
-    if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
-        merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+    if (isCompressedDLT(merger.getDimLevelType(b)) ||
+        isSingletonDLT(merger.getDimLevelType(b))) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       types.push_back(indexType);
       operands.push_back(codegen.pidxs[tensor][idx]);
     } else {
-      assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
-             merger.isDimLevelType(b, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(b)) ||
+             isUndefDLT(merger.getDimLevelType(b)));
     }
   }
   if (codegen.redVal) {
@@ -1431,8 +1400,8 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   for (unsigned b = 0, be = indices.size(); b < be; b++) {
     if (!indices[b])
       continue;
-    if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
-        merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+    if (isCompressedDLT(merger.getDimLevelType(b)) ||
+        isSingletonDLT(merger.getDimLevelType(b))) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = before->getArgument(o);
@@ -1442,8 +1411,8 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc;
       codegen.pidxs[tensor][idx] = after->getArgument(o++);
     } else {
-      assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
-             merger.isDimLevelType(b, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(b)) ||
+             isUndefDLT(merger.getDimLevelType(b)));
     }
   }
   if (codegen.redVal)
@@ -1486,8 +1455,8 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
     if (!locals[b])
       continue;
-    if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
-        merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+    if (isCompressedDLT(merger.getDimLevelType(b)) ||
+        isSingletonDLT(merger.getDimLevelType(b))) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value ptr = codegen.indices[tensor][idx];
@@ -1504,8 +1473,8 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
         }
       }
     } else {
-      assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
-             merger.isDimLevelType(b, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(b)) ||
+             isUndefDLT(merger.getDimLevelType(b)));
     }
   }
 
@@ -1520,7 +1489,7 @@ static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   // but may be needed for linearized codegen.
   for (unsigned b = 0, be = locals.size(); b < be; b++) {
     if ((locals[b] || merger.isOutTensor(b, idx)) &&
-        merger.isDimLevelType(b, DimLvlType::kDense)) {
+        isDenseDLT(merger.getDimLevelType(b))) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       unsigned pat = at;
@@ -1572,8 +1541,8 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
   for (unsigned b = 0, be = induction.size(); b < be; b++) {
     if (!induction[b])
       continue;
-    if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
-        merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+    if (isCompressedDLT(merger.getDimLevelType(b)) ||
+        isSingletonDLT(merger.getDimLevelType(b))) {
       unsigned tensor = merger.tensor(b);
       assert(idx == merger.index(b));
       Value op1 = codegen.idxs[tensor][idx];
@@ -1585,8 +1554,8 @@ static void genWhileInduction(Merger &merger, CodeGen &codegen,
       operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3));
       codegen.pidxs[tensor][idx] = whileOp->getResult(o++);
     } else {
-      assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
-             merger.isDimLevelType(b, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(b)) ||
+             isUndefDLT(merger.getDimLevelType(b)));
     }
   }
   if (codegen.redVal) {
@@ -1641,15 +1610,15 @@ static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     unsigned tensor = merger.tensor(b);
     assert(idx == merger.index(b));
     Value clause;
-    if (merger.isDimLevelType(b, DimLvlType::kCompressed) ||
-        merger.isDimLevelType(b, DimLvlType::kSingleton)) {
+    if (isCompressedDLT(merger.getDimLevelType(b)) ||
+        isSingletonDLT(merger.getDimLevelType(b))) {
       Value op1 = codegen.idxs[tensor][idx];
       Value op2 = codegen.loops[idx];
       clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1,
                                              op2);
     } else {
-      assert(merger.isDimLevelType(b, DimLvlType::kDense) ||
-             merger.isDimLevelType(b, DimLvlType::kUndef));
+      assert(isDenseDLT(merger.getDimLevelType(b)) ||
+             isUndefDLT(merger.getDimLevelType(b)));
       clause = constantI1(builder, loc, true);
     }
     cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
index 8d614ed..ebcf5d7 100644 (file)
@@ -9,4 +9,5 @@ add_mlir_dialect_library(MLIRSparseTensorUtils
   MLIRComplexDialect
   MLIRIR
   MLIRLinalgDialect
+  MLIRSparseTensorEnums
 )
index b8f6a93..5b6cddf 100644 (file)
@@ -271,7 +271,7 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
     // Starts resetting from a dense dimension, so that the first bit (if kept)
     // is not undefined dimension type.
     for (unsigned b = 0; b < be; b++) {
-      if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) {
+      if (simple[b] && isDenseDLT(getDimLevelType(b))) {
         offset = be - b - 1; // relative to the end
         break;
       }
@@ -281,8 +281,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
   // keep the rightmost bit (which could possibly be a synthetic tensor).
   for (unsigned b = be - 1 - offset, i = 0; i < be;
        b = b == 0 ? be - 1 : b - 1, i++) {
-    if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) &&
-                      !isDimLevelType(b, DimLvlType::kSingleton))) {
+    if (simple[b] && (!isCompressedDLT(getDimLevelType(b)) &&
+                      !isSingletonDLT(getDimLevelType(b)))) {
       if (reset)
         simple.reset(b);
       reset = true;
@@ -396,8 +396,8 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
 
 bool Merger::hasAnySparse(const BitVector &bits) const {
   for (unsigned b = 0, be = bits.size(); b < be; b++)
-    if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) ||
-                    isDimLevelType(b, DimLvlType::kSingleton)))
+    if (bits[b] && (isCompressedDLT(getDimLevelType(b)) ||
+                    isSingletonDLT(getDimLevelType(b))))
       return true;
   return false;
 }
@@ -613,23 +613,18 @@ void Merger::dumpBits(const BitVector &bits) const {
     if (bits[b]) {
       unsigned t = tensor(b);
       unsigned i = index(b);
-      DimLevelFormat f = dims[t][i];
+      DimLevelType dlt = dimTypes[t][i];
       llvm::dbgs() << " i_" << t << "_" << i << "_";
-      switch (f.levelType) {
-      case DimLvlType::kDense:
+      if (isDenseDLT(dlt))
         llvm::dbgs() << "D";
-        break;
-      case DimLvlType::kCompressed:
+      else if (isCompressedDLT(dlt))
         llvm::dbgs() << "C";
-        break;
-      case DimLvlType::kSingleton:
+      else if (isSingletonDLT(dlt))
         llvm::dbgs() << "S";
-        break;
-      case DimLvlType::kUndef:
+      else if (isUndefDLT(dlt))
         llvm::dbgs() << "U";
-        break;
-      }
-      llvm::dbgs() << "[O=" << f.isOrdered << ",U=" << f.isUnique << "]";
+      llvm::dbgs() << "[O=" << isOrderedDLT(dlt) << ",U=" << isUniqueDLT(dlt)
+                   << "]";
     }
   }
 }
index 4e95b8f..0e0882f 100644 (file)
@@ -311,15 +311,15 @@ protected:
   MergerTest3T1L() : MergerTestBase(3, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelType(t2, l0, DimLevelType::Dense);
   }
 };
 
@@ -334,19 +334,19 @@ protected:
   MergerTest4T1L() : MergerTestBase(4, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
 
     // Tensor 1: sparse input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t1, l0, DimLevelType::Compressed);
 
     // Tensor 2: sparse input vector
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
 
     // Tensor 3: dense output vector
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelType(t3, l0, DimLevelType::Dense);
   }
 };
 
@@ -365,15 +365,15 @@ protected:
   MergerTest3T1LD() : MergerTestBase(3, 1) {
     // Tensor 0: sparse input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t0, l0, DimLevelType::Compressed);
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelType(t1, l0, DimLevelType::Dense);
 
     // Tensor 2: dense output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelType(t2, l0, DimLevelType::Dense);
   }
 };
 
@@ -392,19 +392,19 @@ protected:
   MergerTest4T1LU() : MergerTestBase(4, 1) {
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
+    merger.setDimLevelType(t0, l0, DimLevelType::Undef);
 
     // Tensor 1: dense input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelType(t1, l0, DimLevelType::Dense);
 
     // Tensor 2: undef input vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kUndef));
+    merger.setDimLevelType(t2, l0, DimLevelType::Undef);
 
     // Tensor 3: dense output vector.
     merger.addExp(Kind::kTensor, t3, -1u);
-    merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense));
+    merger.setDimLevelType(t3, l0, DimLevelType::Dense);
   }
 };
 
@@ -425,15 +425,15 @@ protected:
 
     // Tensor 0: undef input vector.
     merger.addExp(Kind::kTensor, t0, -1u);
-    merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef));
+    merger.setDimLevelType(t0, l0, DimLevelType::Undef);
 
     // Tensor 1: undef input vector.
     merger.addExp(Kind::kTensor, t1, -1u);
-    merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kUndef));
+    merger.setDimLevelType(t1, l0, DimLevelType::Undef);
 
     // Tensor 2: sparse output vector.
     merger.addExp(Kind::kTensor, t2, -1u);
-    merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed));
+    merger.setDimLevelType(t2, l0, DimLevelType::Compressed);
   }
 };
 
index 4401526..af3805f 100644 (file)
@@ -2112,6 +2112,7 @@ cc_library(
         ":LinalgDialect",
         ":MathDialect",
         ":SparseTensorDialect",
+        ":SparseTensorEnums",
         "//llvm:Support",
     ],
 )