/// instead relies on this class to access the right value for the right field.
template <bool mut>
class SparseTensorDescriptorImpl {
-private:
+protected:
// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
// for mutable descriptors.
// Using SmallVector for mutable descriptor allows users to reuse it as a tmp
return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
}
- Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
- dim);
- }
-
- Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
- dim);
- }
-
- Value getValMemSize(OpBuilder &builder, Location loc) const {
- return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
- std::nullopt);
- }
-
Value getPtrMemRef(unsigned ptrDim) const {
return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
}
return fields[fidx];
}
+ ValueRange getMemRefFields() const {
+ ValueRange ret = fields;
+ // Drop the last metadata fields.
+ return ret.slice(0, fields.size() - 1);
+ }
+
+ Type getMemRefElementType(SparseTensorFieldKind kind,
+ Optional<unsigned> dim) const {
+ return getMemRefField(kind, dim)
+ .getType()
+ .template cast<MemRefType>()
+ .getElementType();
+ }
+
+ RankedTensorType getTensorType() const { return rType; }
+ ValueArrayRef getFields() const { return fields; }
+
+protected:
+ RankedTensorType rType;
+ ValueArrayRef fields;
+};
+
+class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
+public:
+ MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
+ : SparseTensorDescriptorImpl<true>(tp, buffers) {}
+
+ ///
+ /// Getters: get the value for required field.
+ ///
+
+ Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+ dim);
+ }
+
+ Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+ dim);
+ }
+
+ Value getValMemSize(OpBuilder &builder, Location loc) const {
+ return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+ std::nullopt);
+ }
+
///
/// Setters: update the value for required field (only enabled for
/// MutSparseTensorDescriptor).
///
template <typename T = Value>
- void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
- std::enable_if_t<mut, T> v) {
+ void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim, T v) {
fields[getMemRefFieldIndex(kind, dim)] = v;
}
template <typename T = Value>
- void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
+ void setMemRefField(unsigned fidx, T v) {
assert(fidx < fields.size() - 1);
fields[fidx] = v;
}
template <typename T = Value>
- void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
+ void setField(unsigned fidx, T v) {
assert(fidx < fields.size());
fields[fidx] = v;
}
template <typename T = Value>
void setSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, Optional<unsigned> dim,
- std::enable_if_t<mut, T> v) {
+ T v) {
SparseTensorSpecifier md(fields.back());
md.setSpecifierField(builder, loc, v, kind, dim);
fields.back() = md;
}
template <typename T = Value>
- void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
- std::enable_if_t<mut, T> v) {
+ void setDimSize(OpBuilder &builder, Location loc, unsigned dim, T v) {
setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
}
-
- ValueRange getMemRefFields() const {
- ValueRange ret = fields;
- // drop the last metadata fields
- return ret.slice(0, fields.size() - 1);
- }
-
- Type getMemRefElementType(SparseTensorFieldKind kind,
- Optional<unsigned> dim) const {
- return getMemRefField(kind, dim)
- .getType()
- .template cast<MemRefType>()
- .getElementType();
- }
-
- RankedTensorType getTensorType() const { return rType; }
- ValueArrayRef getFields() const { return fields; }
-
-private:
- RankedTensorType rType;
- ValueArrayRef fields;
};
using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
-using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
/// Returns the "tuple" value of the adapted tensor.
inline UnrealizedConversionCastOp getTuple(Value tensor) {