#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
-/// New tensor storage action. Keep these values consistent with
-/// the sparse runtime support library.
-enum Action : uint32_t {
- kEmpty = 0,
- kFromFile = 1,
- kFromCOO = 2,
- kEmptyCOO = 3,
- kToCOO = 4,
- kToIter = 5
-};
-
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Returns internal type encoding for primary storage. Keep these
-/// values consistent with the sparse runtime support library.
-static uint32_t getPrimaryTypeEncoding(Type tp) {
- if (tp.isF64())
- return 1;
- if (tp.isF32())
- return 2;
- if (tp.isInteger(64))
- return 3;
- if (tp.isInteger(32))
- return 4;
- if (tp.isInteger(16))
- return 5;
- if (tp.isInteger(8))
- return 6;
- return 0;
-}
-
-/// Returns internal type encoding for overhead storage. Keep these
-/// values consistent with the sparse runtime support library.
-static uint32_t getOverheadTypeEncoding(unsigned width) {
- switch (width) {
- default:
- return 1;
- case 32:
- return 2;
- case 16:
- return 3;
- case 8:
- return 4;
- }
-}
-
-/// Returns internal dimension level type encoding. Keep these
-/// values consistent with the sparse runtime support library.
-static uint32_t
-getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
- switch (dlt) {
- case SparseTensorEncodingAttr::DimLevelType::Dense:
- return 0;
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
- return 1;
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
- return 2;
- }
- llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
-}
-
/// Generates a constant zero of the given type.
inline static Value constantZero(ConversionPatternRewriter &rewriter,
Location loc, Type t) {
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
}
+/// Generates a constant of the given `Action`.
+static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
+ Action action) {
+ return constantI32(rewriter, loc, static_cast<uint32_t>(action));
+}
+
+/// Generates a constant of the internal type encoding for overhead storage.
+static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
+ Location loc, unsigned width) {
+ OverheadType sec;
+ switch (width) {
+ default:
+ sec = OverheadType::kU64;
+ break;
+ case 32:
+ sec = OverheadType::kU32;
+ break;
+ case 16:
+ sec = OverheadType::kU16;
+ break;
+ case 8:
+ sec = OverheadType::kU8;
+ break;
+ }
+ return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
+}
+
+/// Generates a constant of the internal type encoding for primary storage.
+static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
+ Location loc, Type tp) {
+ PrimaryType primary;
+ if (tp.isF64())
+ primary = PrimaryType::kF64;
+ else if (tp.isF32())
+ primary = PrimaryType::kF32;
+ else if (tp.isInteger(64))
+ primary = PrimaryType::kI64;
+ else if (tp.isInteger(32))
+ primary = PrimaryType::kI32;
+ else if (tp.isInteger(16))
+ primary = PrimaryType::kI16;
+ else if (tp.isInteger(8))
+ primary = PrimaryType::kI8;
+ else
+ llvm_unreachable("Unknown element type");
+ return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
+}
+
+/// Generates a constant of the internal dimension level type encoding.
+static Value
+constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
+ SparseTensorEncodingAttr::DimLevelType dlt) {
+ DimLevelType dlt2;
+ switch (dlt) {
+ case SparseTensorEncodingAttr::DimLevelType::Dense:
+ dlt2 = DimLevelType::kDense;
+ break;
+ case SparseTensorEncodingAttr::DimLevelType::Compressed:
+ dlt2 = DimLevelType::kCompressed;
+ break;
+ case SparseTensorEncodingAttr::DimLevelType::Singleton:
+ dlt2 = DimLevelType::kSingleton;
+ break;
+ default:
+ llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
+ }
+ return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
+}
+
/// Returns a function reference (first hit also inserts into module). Sets
/// the "_emit_c_interface" on the function declaration when requested,
/// so that LLVM lowering generates a wrapper function that takes care
/// computation.
static void newParams(ConversionPatternRewriter &rewriter,
SmallVector<Value, 8> ¶ms, Operation *op,
- SparseTensorEncodingAttr &enc, uint32_t action,
+ SparseTensorEncodingAttr &enc, Action action,
ValueRange szs, Value ptr = Value()) {
Location loc = op->getLoc();
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
// Sparsity annotations.
SmallVector<Value, 4> attrs;
for (unsigned i = 0; i < sz; i++)
- attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
+ attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i]));
params.push_back(genBuffer(rewriter, loc, attrs));
// Dimension sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
params.push_back(genBuffer(rewriter, loc, rev));
// Secondary and primary types encoding.
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
- uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
- uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
- uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
- assert(primary);
- params.push_back(constantI32(rewriter, loc, secPtr));
- params.push_back(constantI32(rewriter, loc, secInd));
- params.push_back(constantI32(rewriter, loc, primary));
+ params.push_back(
+ constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
+ params.push_back(
+ constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
+ params.push_back(
+ constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
// User action and pointer.
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
if (!ptr)
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
- params.push_back(constantI32(rewriter, loc, action));
+ params.push_back(constantAction(rewriter, loc, action));
params.push_back(ptr);
}
SmallVector<Value, 8> params;
sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
Value ptr = adaptor.getOperands()[0];
- newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
+ newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr);
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
// Generate the call to construct empty tensor. The sizes are
// explicitly defined by the arguments to the init operator.
SmallVector<Value, 8> params;
- newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
+ newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands());
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
- newParams(rewriter, params, op, enc, kToCOO, sizes, src);
+ newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, op, params);
- params[3] = constantI32(
- rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth()));
- params[4] = constantI32(
- rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth()));
- params[6] = constantI32(rewriter, loc, kFromCOO);
+ params[3] = constantOverheadTypeEncoding(rewriter, loc,
+ encDst.getPointerBitWidth());
+ params[4] = constantOverheadTypeEncoding(rewriter, loc,
+ encDst.getIndexBitWidth());
+ params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = coo;
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
Type elemTp = dstTensorTp.getElementType();
// Fabricate a no-permutation encoding for newParams().
// The pointer/index types must be those of `src`.
- // The dimLevelTypes aren't actually used by kToIter.
+ // The dimLevelTypes aren't actually used by Action::kToIterator.
encDst = SparseTensorEncodingAttr::get(
op->getContext(),
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
- newParams(rewriter, params, op, encDst, kToIter, sizes, src);
+ newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src);
Value iter = genNewCall(rewriter, op, params);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromSrc(rewriter, sizes, loc, src);
- newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
+ newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes);
Value ptr = genNewCall(rewriter, op, params);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value perm = params[2];
return {};
});
// Final call to construct sparse tensor storage.
- params[6] = constantI32(rewriter, loc, kFromCOO);
+ params[6] = constantAction(rewriter, loc, Action::kFromCOO);
params[7] = ptr;
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
//
//===----------------------------------------------------------------------===//
+#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
/// function overloading to implement "partial" method specialization.
class SparseTensorStorageBase {
public:
- enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 };
-
virtual uint64_t getDimSize(uint64_t) = 0;
// Overhead storage.
/// permutation, and per-dimension dense/sparse annotations, using
/// the coordinate scheme tensor for the initial contents if provided.
SparseTensorStorage(const std::vector<uint64_t> &szs, const uint64_t *perm,
- const uint8_t *sparsity, SparseTensorCOO<V> *tensor)
+ const DimLevelType *sparsity, SparseTensorCOO<V> *tensor)
: sizes(szs), rev(getRank()), pointers(getRank()), indices(getRank()) {
uint64_t rank = getRank();
// Store "reverse" permutation.
// TODO: needs fine-tuning based on sparsity
for (uint64_t r = 0, s = 1; r < rank; r++) {
s *= sizes[r];
- if (sparsity[r] == kCompressed) {
+ if (sparsity[r] == DimLevelType::kCompressed) {
pointers[r].reserve(s + 1);
indices[r].reserve(s);
s = 1;
} else {
- assert(sparsity[r] == kDense && "singleton not yet supported");
+ assert(sparsity[r] == DimLevelType::kDense &&
+ "singleton not yet supported");
}
}
// Prepare sparse pointer structures for all dimensions.
for (uint64_t r = 0; r < rank; r++)
- if (sparsity[r] == kCompressed)
+ if (sparsity[r] == DimLevelType::kCompressed)
pointers[r].push_back(0);
// Then assign contents from coordinate scheme tensor if provided.
if (tensor) {
/// permutation as is desired for the new sparse tensor storage.
static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank, const uint64_t *sizes, const uint64_t *perm,
- const uint8_t *sparsity, SparseTensorCOO<V> *tensor) {
+ const DimLevelType *sparsity, SparseTensorCOO<V> *tensor) {
SparseTensorStorage<P, I, V> *n = nullptr;
if (tensor) {
assert(tensor->getRank() == rank);
/// Initializes sparse tensor storage scheme from a memory-resident sparse
/// tensor in coordinate scheme. This method prepares the pointers and
/// indices arrays under the given per-dimension dense/sparse annotations.
- void fromCOO(SparseTensorCOO<V> *tensor, const uint8_t *sparsity, uint64_t lo,
- uint64_t hi, uint64_t d) {
+ void fromCOO(SparseTensorCOO<V> *tensor, const DimLevelType *sparsity,
+ uint64_t lo, uint64_t hi, uint64_t d) {
const std::vector<Element<V>> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
if (d == getRank()) {
while (seg < hi && elements[seg].indices[d] == idx)
seg++;
// Handle segment in interval for sparse or dense dimension.
- if (sparsity[d] == kCompressed) {
+ if (sparsity[d] == DimLevelType::kCompressed) {
indices[d].push_back(idx);
} else {
// For dense storage we must fill in all the zero values between
lo = seg;
}
// Finalize the sparse pointer structure at this dimension.
- if (sparsity[d] == kCompressed) {
+ if (sparsity[d] == DimLevelType::kCompressed) {
pointers[d].push_back(indices[d].size());
} else {
// For dense storage we must fill in all the zero values after
//
//===----------------------------------------------------------------------===//
-enum OverheadTypeEnum : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
-
-enum PrimaryTypeEnum : uint32_t {
- kF64 = 1,
- kF32 = 2,
- kI64 = 3,
- kI32 = 4,
- kI16 = 5,
- kI8 = 6
-};
-
-enum Action : uint32_t {
- kEmpty = 0,
- kFromFile = 1,
- kFromCOO = 2,
- kEmptyCOO = 3,
- kToCOO = 4,
- kToIter = 5
-};
-
#define CASE(p, i, v, P, I, V) \
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
SparseTensorCOO<V> *tensor = nullptr; \
- if (action <= kFromCOO) { \
- if (action == kFromFile) { \
+ if (action <= Action::kFromCOO) { \
+ if (action == Action::kFromFile) { \
char *filename = static_cast<char *>(ptr); \
tensor = openSparseTensorCOO<V>(filename, rank, sizes, perm); \
- } else if (action == kFromCOO) { \
+ } else if (action == Action::kFromCOO) { \
tensor = static_cast<SparseTensorCOO<V> *>(ptr); \
} else { \
- assert(action == kEmpty); \
+ assert(action == Action::kEmpty); \
} \
return SparseTensorStorage<P, I, V>::newSparseTensor(rank, sizes, perm, \
sparsity, tensor); \
- } else if (action == kEmptyCOO) { \
+ } else if (action == Action::kEmptyCOO) { \
return SparseTensorCOO<V>::newSparseTensorCOO(rank, sizes, perm); \
} else { \
tensor = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \
- if (action == kToIter) { \
+ if (action == Action::kToIterator) { \
tensor->startIterator(); \
} else { \
- assert(action == kToCOO); \
+ assert(action == Action::kToCOO); \
} \
return tensor; \
} \
}
+#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
+
#define IMPL_SPARSEVALUES(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
assert(ref); \
/// Constructs a new sparse tensor. This is the "swiss army knife"
/// method for materializing sparse tensors into the computation.
///
-/// action:
+/// Action:
/// kEmpty = returns empty storage to fill later
/// kFromFile = returns storage, where ptr contains filename to read
/// kFromCOO = returns storage, where ptr contains coordinate scheme to assign
/// kEmptyCOO = returns empty coordinate scheme to fill and use with kFromCOO
/// kToCOO = returns coordinate scheme from storage in ptr to use with kFromCOO
-/// kToIter = returns iterator from storage in ptr (call getNext() to use)
+/// kToIterator = returns iterator from storage in ptr (call getNext() to use)
void *
-_mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
+_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
StridedMemRefType<index_t, 1> *sref,
StridedMemRefType<index_t, 1> *pref,
- uint32_t ptrTp, uint32_t indTp, uint32_t valTp,
- uint32_t action, void *ptr) {
+ OverheadType ptrTp, OverheadType indTp,
+ PrimaryType valTp, Action action, void *ptr) {
assert(aref && sref && pref);
assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
pref->strides[0] == 1);
assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
- const uint8_t *sparsity = aref->data + aref->offset;
+ const DimLevelType *sparsity = aref->data + aref->offset;
const index_t *sizes = sref->data + sref->offset;
const index_t *perm = pref->data + pref->offset;
uint64_t rank = aref->sizes[0];
// Double matrices with all combinations of overhead storage.
- CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
- CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
- CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
- CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
- CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
- CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
- CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
- CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
- CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
- CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
- CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
- CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
- CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
- CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
- CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
- CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
+ CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
+ uint64_t, double);
+ CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
+ uint32_t, double);
+ CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
+ uint16_t, double);
+ CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
+ uint8_t, double);
+ CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
+ uint64_t, double);
+ CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
+ uint32_t, double);
+ CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
+ uint16_t, double);
+ CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
+ uint8_t, double);
+ CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
+ uint64_t, double);
+ CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
+ uint32_t, double);
+ CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
+ uint16_t, double);
+ CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
+ uint8_t, double);
+ CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
+ uint64_t, double);
+ CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
+ uint32_t, double);
+ CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
+ uint16_t, double);
+ CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
+ uint8_t, double);
// Float matrices with all combinations of overhead storage.
- CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
- CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
- CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
- CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
- CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
- CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
- CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
- CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
- CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
- CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
- CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
- CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
- CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
- CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
- CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
- CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
-
- // Integral matrices with same overhead storage.
- CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
- CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
- CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
- CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
- CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
- CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
- CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
- CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
- CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
- CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
- CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
- CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
- CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
+ CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
+ uint64_t, float);
+ CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
+ uint32_t, float);
+ CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
+ uint16_t, float);
+ CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
+ uint8_t, float);
+ CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
+ uint64_t, float);
+ CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
+ uint32_t, float);
+ CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
+ uint16_t, float);
+ CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
+ uint8_t, float);
+ CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
+ uint64_t, float);
+ CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
+ uint32_t, float);
+ CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
+ uint16_t, float);
+ CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
+ uint8_t, float);
+ CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
+ uint64_t, float);
+ CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
+ uint32_t, float);
+ CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
+ uint16_t, float);
+ CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
+ uint8_t, float);
+
+ // Integral matrices with both overheads of the same type.
+ CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
+ CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
+ CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
+ CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
+ CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
+ CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
+ CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
+ CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
+ CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
+ CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
+ CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
+ CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
+ CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
// Unsupported case (add above if needed).
fputs("unsupported combination of types\n", stderr);
void *convertToMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape,
double *values, uint64_t *indices) {
// Setup all-dims compressed and default ordering.
- std::vector<uint8_t> sparse(rank, SparseTensorStorageBase::kCompressed);
+ std::vector<DimLevelType> sparse(rank, DimLevelType::kCompressed);
std::vector<uint64_t> perm(rank);
std::iota(perm.begin(), perm.end(), 0);
// Convert external format to internal COO.