[mlir][sparse] Factoring magic numbers into a header
authorwren romano <2998727+wrengr@users.noreply.github.com>
Fri, 5 Nov 2021 22:15:39 +0000 (15:15 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Fri, 5 Nov 2021 22:59:16 +0000 (15:59 -0700)
Addresses https://bugs.llvm.org/show_bug.cgi?id=52303

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D112962

mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h [new file with mode: 0644]
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h
new file mode 100644 (file)
index 0000000..4361fc7
--- /dev/null
@@ -0,0 +1,55 @@
+//===- SparseTensorUtils.h - Enums shared with the runtime ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines several enums shared between
+// Transforms/SparseTensorConversion.cpp and ExecutionEngine/SparseUtils.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
+#define MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
+
+#include <cinttypes>
+
+extern "C" {
+
+/// Encoding of the elemental type, for "overloading" @newSparseTensor.
+enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
+
+/// Encoding of the elemental type, for "overloading" @newSparseTensor.
+enum class PrimaryType : uint32_t {
+  kF64 = 1,
+  kF32 = 2,
+  kI64 = 3,
+  kI32 = 4,
+  kI16 = 5,
+  kI8 = 6
+};
+
+/// The actions performed by @newSparseTensor.
+enum class Action : uint32_t {
+  kEmpty = 0,
+  kFromFile = 1,
+  kFromCOO = 2,
+  kEmptyCOO = 3,
+  kToCOO = 4,
+  kToIterator = 5
+};
+
+/// This enum mimics `SparseTensorEncodingAttr::DimLevelType` for
+/// breaking dependency cycles.  `SparseTensorEncodingAttr::DimLevelType`
+/// is the source of truth and this enum should be kept consistent with it.
+enum class DimLevelType : uint8_t {
+  kDense = 0,
+  kCompressed = 1,
+  kSingleton = 2
+};
+
+} // extern "C"
+
+#endif // MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
index ed2c540..1ef4ea8 100644 (file)
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/ExecutionEngine/SparseTensorUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -29,69 +30,10 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
-/// New tensor storage action. Keep these values consistent with
-/// the sparse runtime support library.
-enum Action : uint32_t {
-  kEmpty = 0,
-  kFromFile = 1,
-  kFromCOO = 2,
-  kEmptyCOO = 3,
-  kToCOO = 4,
-  kToIter = 5
-};
-
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Returns internal type encoding for primary storage. Keep these
-/// values consistent with the sparse runtime support library.
-static uint32_t getPrimaryTypeEncoding(Type tp) {
-  if (tp.isF64())
-    return 1;
-  if (tp.isF32())
-    return 2;
-  if (tp.isInteger(64))
-    return 3;
-  if (tp.isInteger(32))
-    return 4;
-  if (tp.isInteger(16))
-    return 5;
-  if (tp.isInteger(8))
-    return 6;
-  return 0;
-}
-
-/// Returns internal type encoding for overhead storage. Keep these
-/// values consistent with the sparse runtime support library.
-static uint32_t getOverheadTypeEncoding(unsigned width) {
-  switch (width) {
-  default:
-    return 1;
-  case 32:
-    return 2;
-  case 16:
-    return 3;
-  case 8:
-    return 4;
-  }
-}
-
-/// Returns internal dimension level type encoding. Keep these
-/// values consistent with the sparse runtime support library.
-static uint32_t
-getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
-  switch (dlt) {
-  case SparseTensorEncodingAttr::DimLevelType::Dense:
-    return 0;
-  case SparseTensorEncodingAttr::DimLevelType::Compressed:
-    return 1;
-  case SparseTensorEncodingAttr::DimLevelType::Singleton:
-    return 2;
-  }
-  llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
-}
-
 /// Generates a constant zero of the given type.
 inline static Value constantZero(ConversionPatternRewriter &rewriter,
                                  Location loc, Type t) {
@@ -116,6 +58,75 @@ inline static Value constantI8(ConversionPatternRewriter &rewriter,
   return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
 }
 
+/// Generates a constant of the given `Action`.
+static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
+                            Action action) {
+  return constantI32(rewriter, loc, static_cast<uint32_t>(action));
+}
+
+/// Generates a constant of the internal type encoding for overhead storage.
+static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
+                                          Location loc, unsigned width) {
+  OverheadType sec;
+  switch (width) {
+  default:
+    sec = OverheadType::kU64;
+    break;
+  case 32:
+    sec = OverheadType::kU32;
+    break;
+  case 16:
+    sec = OverheadType::kU16;
+    break;
+  case 8:
+    sec = OverheadType::kU8;
+    break;
+  }
+  return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
+}
+
+/// Generates a constant of the internal type encoding for primary storage.
+static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
+                                         Location loc, Type tp) {
+  PrimaryType primary;
+  if (tp.isF64())
+    primary = PrimaryType::kF64;
+  else if (tp.isF32())
+    primary = PrimaryType::kF32;
+  else if (tp.isInteger(64))
+    primary = PrimaryType::kI64;
+  else if (tp.isInteger(32))
+    primary = PrimaryType::kI32;
+  else if (tp.isInteger(16))
+    primary = PrimaryType::kI16;
+  else if (tp.isInteger(8))
+    primary = PrimaryType::kI8;
+  else
+    llvm_unreachable("Unknown element type");
+  return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
+}
+
+/// Generates a constant of the internal dimension level type encoding.
+static Value
+constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
+                             SparseTensorEncodingAttr::DimLevelType dlt) {
+  DimLevelType dlt2;
+  switch (dlt) {
+  case SparseTensorEncodingAttr::DimLevelType::Dense:
+    dlt2 = DimLevelType::kDense;
+    break;
+  case SparseTensorEncodingAttr::DimLevelType::Compressed:
+    dlt2 = DimLevelType::kCompressed;
+    break;
+  case SparseTensorEncodingAttr::DimLevelType::Singleton:
+    dlt2 = DimLevelType::kSingleton;
+    break;
+  default:
+    llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
+  }
+  return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
+}
+
 /// Returns a function reference (first hit also inserts into module). Sets
 /// the "_emit_c_interface" on the function declaration when requested,
 /// so that LLVM lowering generates a wrapper function that takes care
@@ -238,7 +249,7 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
 /// computation.
 static void newParams(ConversionPatternRewriter &rewriter,
                       SmallVector<Value, 8> &params, Operation *op,
-                      SparseTensorEncodingAttr &enc, uint32_t action,
+                      SparseTensorEncodingAttr &enc, Action action,
                       ValueRange szs, Value ptr = Value()) {
   Location loc = op->getLoc();
   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
@@ -246,7 +257,7 @@ static void newParams(ConversionPatternRewriter &rewriter,
   // Sparsity annotations.
   SmallVector<Value, 4> attrs;
   for (unsigned i = 0; i < sz; i++)
-    attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
+    attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i]));
   params.push_back(genBuffer(rewriter, loc, attrs));
   // Dimension sizes array of the enveloping tensor. Useful for either
   // verification of external data, or for construction of internal data.
@@ -268,18 +279,17 @@ static void newParams(ConversionPatternRewriter &rewriter,
   params.push_back(genBuffer(rewriter, loc, rev));
   // Secondary and primary types encoding.
   ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
-  uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
-  uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
-  uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
-  assert(primary);
-  params.push_back(constantI32(rewriter, loc, secPtr));
-  params.push_back(constantI32(rewriter, loc, secInd));
-  params.push_back(constantI32(rewriter, loc, primary));
+  params.push_back(
+      constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
+  params.push_back(
+      constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
+  params.push_back(
+      constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
   // User action and pointer.
   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
   if (!ptr)
     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
-  params.push_back(constantI32(rewriter, loc, action));
+  params.push_back(constantAction(rewriter, loc, action));
   params.push_back(ptr);
 }
 
@@ -530,7 +540,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     SmallVector<Value, 8> params;
     sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
     Value ptr = adaptor.getOperands()[0];
-    newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
+    newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr);
     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
@@ -549,7 +559,7 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
     // Generate the call to construct empty tensor. The sizes are
     // explicitly defined by the arguments to the init operator.
     SmallVector<Value, 8> params;
-    newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
+    newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands());
     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
@@ -588,13 +598,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       auto enc = SparseTensorEncodingAttr::get(
           op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
           encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
-      newParams(rewriter, params, op, enc, kToCOO, sizes, src);
+      newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
       Value coo = genNewCall(rewriter, op, params);
-      params[3] = constantI32(
-          rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth()));
-      params[4] = constantI32(
-          rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth()));
-      params[6] = constantI32(rewriter, loc, kFromCOO);
+      params[3] = constantOverheadTypeEncoding(rewriter, loc,
+                                               encDst.getPointerBitWidth());
+      params[4] = constantOverheadTypeEncoding(rewriter, loc,
+                                               encDst.getIndexBitWidth());
+      params[6] = constantAction(rewriter, loc, Action::kFromCOO);
       params[7] = coo;
       rewriter.replaceOp(op, genNewCall(rewriter, op, params));
       return success();
@@ -613,7 +623,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       Type elemTp = dstTensorTp.getElementType();
       // Fabricate a no-permutation encoding for newParams().
       // The pointer/index types must be those of `src`.
-      // The dimLevelTypes aren't actually used by kToIter.
+      // The dimLevelTypes aren't actually used by Action::kToIterator.
       encDst = SparseTensorEncodingAttr::get(
           op->getContext(),
           SmallVector<SparseTensorEncodingAttr::DimLevelType>(
@@ -622,7 +632,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       SmallVector<Value, 4> sizes;
       SmallVector<Value, 8> params;
       sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
-      newParams(rewriter, params, op, encDst, kToIter, sizes, src);
+      newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src);
       Value iter = genNewCall(rewriter, op, params);
       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
@@ -677,7 +687,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     SmallVector<Value, 4> sizes;
     SmallVector<Value, 8> params;
     sizesFromSrc(rewriter, sizes, loc, src);
-    newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
+    newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes);
     Value ptr = genNewCall(rewriter, op, params);
     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
     Value perm = params[2];
@@ -718,7 +728,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           return {};
         });
     // Final call to construct sparse tensor storage.
-    params[6] = constantI32(rewriter, loc, kFromCOO);
+    params[6] = constantAction(rewriter, loc, Action::kFromCOO);
     params[7] = ptr;
     rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
index fcdc231..52396d4 100644 (file)
@@ -14,6 +14,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/ExecutionEngine/SparseTensorUtils.h"
 #include "mlir/ExecutionEngine/CRunnerUtils.h"
 
 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
@@ -162,8 +163,6 @@ private:
 /// function overloading to implement "partial" method specialization.
 class SparseTensorStorageBase {
 public:
-  enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 };
-
   virtual uint64_t getDimSize(uint64_t) = 0;
 
   // Overhead storage.
@@ -206,7 +205,7 @@ public:
   /// permutation, and per-dimension dense/sparse annotations, using
   /// the coordinate scheme tensor for the initial contents if provided.
   SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
-                      const uint8_t *sparsity, SparseTensorCOO<V> *tensor)
+                      const DimLevelType *sparsity, SparseTensorCOO<V> *tensor)
       : sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
     uint64_t rank = getRank();
     // Store "reverse" permutation.
@@ -216,17 +215,18 @@ public:
     // TODO: needs fine-tuning based on sparsity
     for (uint64_t r = 0, s = 1; r < rank; r++) {
       s *= sizes[r];
-      if (sparsity[r] == kCompressed) {
+      if (sparsity[r] == DimLevelType::kCompressed) {
         pointers[r].reserve(s + 1);
         indices[r].reserve(s);
         s = 1;
       } else {
-        assert(sparsity[r] == kDense && "singleton not yet supported");
+        assert(sparsity[r] == DimLevelType::kDense &&
+               "singleton not yet supported");
       }
     }
     // Prepare sparse pointer structures for all dimensions.
     for (uint64_t r = 0; r < rank; r++)
-      if (sparsity[r] == kCompressed)
+      if (sparsity[r] == DimLevelType::kCompressed)
         pointers[r].push_back(0);
     // Then assign contents from coordinate scheme tensor if provided.
     if (tensor) {
@@ -288,7 +288,7 @@ public:
   /// permutation as is desired for the new sparse tensor storage.
   static SparseTensorStorage<P, I, V> *
   newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
-                  const uint8_t *sparsity, SparseTensorCOO<V> *tensor) {
+                  const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
     SparseTensorStorage<P, I, V> *n = nullptr;
     if (tensor) {
       assert(tensor->getRank() == rank);
@@ -311,8 +311,8 @@ private:
   /// Initializes sparse tensor storage scheme from a memory-resident sparse
   /// tensor in coordinate scheme. This method prepares the pointers and
   /// indices arrays under the given per-dimension dense/sparse annotations.
-  void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo,
-               uint64_t hi, uint64_t d) {
+  void fromCOO(SparseTensorCOO<V> *tensor, const DimLevelType *sparsity,
+               uint64_t lo, uint64_t hi, uint64_t d) {
     const std::vector<Element<V>> &elements = tensor->getElements();
     // Once dimensions are exhausted, insert the numerical values.
     if (d == getRank()) {
@@ -331,7 +331,7 @@ private:
       while (seg < hi && elements[seg].indices[d] == idx)
         seg++;
       // Handle segment in interval for sparse or dense dimension.
-      if (sparsity[d] == kCompressed) {
+      if (sparsity[d] == DimLevelType::kCompressed) {
         indices[d].push_back(idx);
       } else {
         // For dense storage we must fill in all the zero values between
@@ -346,7 +346,7 @@ private:
       lo = seg;
     }
     // Finalize the sparse pointer structure at this dimension.
-    if (sparsity[d] == kCompressed) {
+    if (sparsity[d] == DimLevelType::kCompressed) {
       pointers[d].push_back(indices[d].size());
     } else {
       // For dense storage we must fill in all the zero values after
@@ -543,53 +543,35 @@ typedef uint64_t index_t;
 //
 //===----------------------------------------------------------------------===//
 
-enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
-
-enum PrimaryTypeEnum : uint32_t {
-  kF64 = 1,
-  kF32 = 2,
-  kI64 = 3,
-  kI32 = 4,
-  kI16 = 5,
-  kI8 = 6
-};
-
-enum Action : uint32_t {
-  kEmpty = 0,
-  kFromFile = 1,
-  kFromCOO = 2,
-  kEmptyCOO = 3,
-  kToCOO = 4,
-  kToIter = 5
-};
-
 #define CASE(p, i, v, P, I, V)                                                 \
   if (ptrTp == (p) && indTp == (i) && valTp == (v)) {                          \
     SparseTensorCOO<V> *tensor = nullptr;                                      \
-    if (action <= kFromCOO) {                                                  \
-      if (action == kFromFile) {                                               \
+    if (action <= Action::kFromCOO) {                                          \
+      if (action == Action::kFromFile) {                                       \
         char *filename = static_cast<char *>(ptr);                             \
         tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm);          \
-      } else if (action == kFromCOO) {                                         \
+      } else if (action == Action::kFromCOO) {                                 \
         tensor = static_cast<SparseTensorCOO<V> *>(ptr);                       \
       } else {                                                                 \
-        assert(action == kEmpty);                                              \
+        assert(action == Action::kEmpty);                                      \
       }                                                                        \
       return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm,  \
                                                            sparsity, tensor);  \
-    } else if (action == kEmptyCOO) {                                          \
+    } else if (action == Action::kEmptyCOO) {                                  \
       return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm);        \
     } else {                                                                   \
       tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm);  \
-      if (action == kToIter) {                                                 \
+      if (action == Action::kToIterator) {                                     \
         tensor->startIterator();                                               \
       } else {                                                                 \
-        assert(action == kToCOO);                                              \
+        assert(action == Action::kToCOO);                                      \
       }                                                                        \
       return tensor;                                                           \
     }                                                                          \
   }
 
+#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
+
 #define IMPL_SPARSEVALUES(NAME, TYPE, LIB)                                     \
   void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) {    \
     assert(ref);                                                               \
@@ -656,78 +638,110 @@ enum Action : uint32_t {
 /// Constructs a new sparse tensor. This is the "swiss army knife"
 /// method for materializing sparse tensors into the computation.
 ///
-/// action:
+/// Action:
 /// kEmpty = returns empty storage to fill later
 /// kFromFile = returns storage, where ptr contains filename to read
 /// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
 /// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
 /// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
-/// kToIter = returns iterator from storage in ptr (call getNext() to use)
+/// kToIterator = returns iterator from storage in ptr (call getNext() to use)
 void *
-_mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
+_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
                              StridedMemRefType<index_t, 1> *sref,
                              StridedMemRefType<index_t, 1> *pref,
-                             uint32_t ptrTp, uint32_t indTp, uint32_t valTp,
-                             uint32_t action, void *ptr) {
+                             OverheadType ptrTp, OverheadType indTp,
+                             PrimaryType valTp, Action action, void *ptr) {
   assert(aref && sref && pref);
   assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
          pref->strides[0] == 1);
   assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
-  const uint8_t *sparsity = aref->data + aref->offset;
+  const DimLevelType *sparsity = aref->data + aref->offset;
   const index_t *sizes = sref->data + sref->offset;
   const index_t *perm = pref->data + pref->offset;
   uint64_t rank = aref->sizes[0];
 
   // Double matrices with all combinations of overhead storage.
-  CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
-  CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
-  CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
-  CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
-  CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
-  CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
-  CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
-  CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
-  CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
-  CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
-  CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
-  CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
-  CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
-  CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
-  CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
-  CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
+       uint64_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
+       uint32_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
+       uint16_t, double);
+  CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
+       uint8_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
+       uint64_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
+       uint32_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
+       uint16_t, double);
+  CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
+       uint8_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
+       uint64_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
+       uint32_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
+       uint16_t, double);
+  CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
+       uint8_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
+       uint64_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
+       uint32_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
+       uint16_t, double);
+  CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
+       uint8_t, double);
 
   // Float matrices with all combinations of overhead storage.
-  CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
-  CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
-  CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
-  CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
-  CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
-  CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
-  CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
-  CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
-  CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
-  CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
-  CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
-  CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
-  CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
-  CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
-  CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
-  CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
-
-  // Integral matrices with same overhead storage.
-  CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
-  CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
-  CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
-  CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
-  CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
-  CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
-  CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
-  CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
-  CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
-  CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
-  CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
-  CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
-  CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
+  CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
+       uint64_t, float);
+  CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
+       uint32_t, float);
+  CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
+       uint16_t, float);
+  CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
+       uint8_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
+       uint64_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
+       uint32_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
+       uint16_t, float);
+  CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
+       uint8_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
+       uint64_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
+       uint32_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
+       uint16_t, float);
+  CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
+       uint8_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
+       uint64_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
+       uint32_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
+       uint16_t, float);
+  CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
+       uint8_t, float);
+
+  // Integral matrices with both overheads of the same type.
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
+  CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
+  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
+  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
+  CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
+  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
+  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
+  CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
+  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
+  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
+  CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
 
   // Unsupported case (add above if needed).
   fputs("unsupported combination of types\n", stderr);
@@ -830,7 +844,7 @@ void delSparseTensor(void *tensor) {
 void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
                                 double *values, uint64_t *indices) {
   // Setup all-dims compressed and default ordering.
-  std::vector<uint8_t> sparse(rank, SparseTensorStorageBase::kCompressed);
+  std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
   std::vector<uint64_t> perm(rank);
   std::iota(perm.begin(), perm.end(), 0);
   // Convert external format to internal COO.
index 13a6fed..6aa22bf 100644 (file)
@@ -1707,7 +1707,10 @@ cc_library(
 cc_library(
     name = "SparseTensorTransforms",
     srcs = glob(["lib/Dialect/SparseTensor/Transforms/*.cpp"]),
-    hdrs = ["include/mlir/Dialect/SparseTensor/Transforms/Passes.h"],
+    hdrs = [
+        "include/mlir/Dialect/SparseTensor/Transforms/Passes.h",
+        "include/mlir/ExecutionEngine/SparseTensorUtils.h",
+    ],
     includes = ["include"],
     deps = [
         ":Affine",
@@ -5391,7 +5394,10 @@ cc_library(
         "lib/ExecutionEngine/CRunnerUtils.cpp",
         "lib/ExecutionEngine/SparseTensorUtils.cpp",
     ],
-    hdrs = ["include/mlir/ExecutionEngine/CRunnerUtils.h"],
+    hdrs = [
+        "include/mlir/ExecutionEngine/CRunnerUtils.h",
+        "include/mlir/ExecutionEngine/SparseTensorUtils.h",
+    ],
     includes = ["include"],
 )