From 10033a179f0c73f28f051ac70b058a0c61882e3a Mon Sep 17 00:00:00 2001 From: Stella Stamenova Date: Mon, 5 Dec 2022 17:20:01 -0800 Subject: [PATCH] Revert "[mlir][sparse] Refactoring: abstract sparse tensor memory scheme into a SparseTensorDescriptor class." This reverts commit 8a7e69d145ff72e7e4fc10ce6b81c3aa4794201c. This broke the windows mlir buildbot: https://lab.llvm.org/buildbot/#/builders/13/builds/29257 --- .../mlir/Dialect/SparseTensor/IR/SparseTensor.h | 15 - .../SparseTensor/Transforms/CodegenUtils.cpp | 109 ----- .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 247 +---------- .../Transforms/SparseTensorCodegen.cpp | 481 ++++++++++++--------- 4 files changed, 288 insertions(+), 564 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index e9b63a6..52f9fef 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -75,21 +75,6 @@ inline bool isSingletonDim(RankedTensorType type, uint64_t d) { return isSingletonDLT(getDimLevelType(type, d)); } -/// Convenience function to test for dense dimension (0 <= d < rank). -inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isDenseDLT(getDimLevelType(enc, d)); -} - -/// Convenience function to test for compressed dimension (0 <= d < rank). -inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isCompressedDLT(getDimLevelType(enc, d)); -} - -/// Convenience function to test for singleton dimension (0 <= d < rank). -inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isSingletonDLT(getDimLevelType(enc, d)); -} - // // Dimension level properties. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index cd2fb58..2ac3f3b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,115 +90,6 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, return val; } -void sparse_tensor::foreachFieldInSparseTensor( - const SparseTensorEncodingAttr enc, - llvm::function_ref - callback) { - assert(enc); - -#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ - if (!(callback(idx, kind, dim, dlt))) \ - return; - - RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, - DimLevelType::Undef); - RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, - DimLevelType::Undef); - - static_assert(dataFieldIdx == memSizesIdx + 1); - unsigned fieldIdx = dataFieldIdx; - // Per-dimension storage. - for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. - auto dlt = getDimLevelType(enc, r); - if (isCompressedDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else { - assert(isDenseDLT(dlt)); // no fields - } - } - // The values array. - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, - DimLevelType::Undef); - -#undef RETURN_ON_FALSE -} - -void sparse_tensor::foreachFieldAndTypeInSparseTensor( - RankedTensorType rType, - llvm::function_ref - callback) { - auto enc = getSparseTensorEncoding(rType); - assert(enc); - // Construct the basic types. - Type indexType = IndexType::get(enc.getContext()); - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); - unsigned rank = rType.getShape().size(); - // memref dimSizes - Type dimSizeType = MemRefType::get({rank}, indexType); - // memref memSizes - Type memSizeType = - MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); - // memref pointers - Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); - // memref indices - Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); - // memref values - Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); - - foreachFieldInSparseTensor( - enc, - [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, - callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, - unsigned dim, DimLevelType dlt) -> bool { - switch (fieldKind) { - case SparseTensorFieldKind::DimSizes: - return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::MemSizes: - return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::PtrMemRef: - return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::IdxMemRef: - return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, dim, dlt); - }; - }); -} - -unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; - foreachFieldInSparseTensor(enc, - [&numFields](unsigned, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - numFields++; - return true; - }); - return numFields; -} - -unsigned -sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; // one value memref - foreachFieldInSparseTensor(enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - if (fidx >= dataFieldIdx) - numFields++; - return true; - }); - assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); - return numFields; -} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 4c86135..bafe752 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -311,222 +311,8 @@ inline bool isZeroRankedTensorOrScalar(Type type) { } //===----------------------------------------------------------------------===// -// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout -// scheme. -// -// Sparse tensor storage scheme for rank-dimensional tensor is organized -// as a single compound type with the following fields. Note that every -// memref with ? size actually behaves as a "vector", i.e. the stored -// size is the capacity and the used size resides in the memSizes array. -// -// struct { -// memref dimSizes ; size in each dimension -// memref memSizes ; sizes of ptrs/inds/values -// ; per-dimension d: -// ; if dense: -// -// ; if compresed: -// memref pointers-d ; pointers for sparse dim d -// memref indices-d ; indices for sparse dim d -// ; if singleton: -// memref indices-d ; indices for singleton dim d -// memref values ; values -// }; -// -//===----------------------------------------------------------------------===// -enum class SparseTensorFieldKind { - DimSizes, - MemSizes, - PtrMemRef, - IdxMemRef, - ValMemRef -}; - -constexpr uint64_t dimSizesIdx = 0; -constexpr uint64_t memSizesIdx = dimSizesIdx + 1; -constexpr uint64_t dataFieldIdx = memSizesIdx + 1; - -/// For each field that will be allocated for the given sparse tensor encoding, -/// calls the callback with the corresponding field index, field kind, dimension -/// (for sparse tensor level memrefs) and dimlevelType. -/// The field index always starts with zero and increments by one between two -/// callback invocations. -/// Ideally, all other methods should rely on this function to query a sparse -/// tensor fields instead of relying on ad-hoc index computation. -void foreachFieldInSparseTensor( - SparseTensorEncodingAttr, - llvm::function_ref); - -/// Same as above, except that it also builds the Type for the corresponding -/// field. -void foreachFieldAndTypeInSparseTensor( - RankedTensorType, - llvm::function_ref); - -/// Gets the total number of fields for the given sparse tensor encoding. -unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Gets the total number of data fields (index arrays, pointer arrays, and a -/// value array) for the given sparse tensor encoding. -unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Get the index of the field in memSizes (only valid for data fields). -inline unsigned getFieldMemSizesIndex(unsigned fid) { - assert(fid >= dataFieldIdx); - return fid - dataFieldIdx; -} - -/// A helper class around an array of values that corresponding to a sparse -/// tensor, provides a set of meaningful APIs to query and update a particular -/// field in a consistent way. -/// Users should not make assumption on how a sparse tensor is laid out but -/// instead relies on this class to access the right value for the right field. -template -class SparseTensorDescriptorImpl { -private: - template - struct ArrayStorage; - - template <> - struct ArrayStorage { - using ValueArray = ValueRange; - }; - - template <> - struct ArrayStorage { - using ValueArray = SmallVectorImpl &; - }; - - // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & - // for mutable descriptors. - // Using SmallVector for mutable descriptor allows users to reuse it as a tmp - // buffers to append value for some special cases, though users should be - // responsible to restore the buffer to legal states after their use. It is - // probably not a clean way, but it is the most efficient way to avoid copying - // the fields into another SmallVector. If a more clear way is wanted, we - // should change it to MutableArrayRef instead. - using Storage = typename ArrayStorage::ValueArray; - -public: - SparseTensorDescriptorImpl(Type tp, Storage fields) - : rType(tp.cast()), fields(fields) { - assert(getSparseTensorEncoding(tp) && - getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == - fields.size()); - // We should make sure the class is trivially copyable (and should be small - // enough) such that we can pass it by value. - static_assert( - std::is_trivially_copyable_v>); - } - - // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to - // SparseTensorDescriptor. - template > - /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) - : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} - - /// - /// Getters: get the field index for required field. - /// - - unsigned getPtrMemRefIndex(unsigned ptrDim) const { - return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); - } - - unsigned getIdxMemRefIndex(unsigned idxDim) const { - return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); - } - - unsigned getValMemRefIndex() const { return fields.size() - 1; } - - unsigned getPtrMemSizesIndex(unsigned dim) const { - return getPtrMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getIdxMemSizesIndex(unsigned dim) const { - return getIdxMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getValMemSizesIndex() const { - return getValMemRefIndex() - dataFieldIdx; - } - - unsigned getNumFields() const { return fields.size(); } - - /// - /// Getters: get the value for required field. - /// - - Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } - Value getMemSizesMemRef() const { return fields[memSizesIdx]; } - - Value getPtrMemRef(unsigned ptrDim) const { - return fields[getPtrMemRefIndex(ptrDim)]; - } - - Value getIdxMemRef(unsigned idxDim) const { - return fields[getIdxMemRefIndex(idxDim)]; - } - - Value getValMemRef() const { return fields[getValMemRefIndex()]; } - - Value getField(unsigned fid) const { - assert(fid < fields.size()); - return fields[fid]; - } - - /// - /// Setters: update the value for required field (only enabled for - /// MutSparseTensorDescriptor). - /// - - template - void setField(unsigned fid, std::enable_if_t v) { - assert(fid < fields.size()); - fields[fid] = v; - } - - RankedTensorType getTensorType() const { return rType; } - Storage getFields() const { return fields; } - - Type getElementType(unsigned fidx) const { - return fields[fidx].getType().template cast().getElementType(); - } - -private: - unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { - unsigned fieldIdx = -1u; - foreachFieldInSparseTensor( - getSparseTensorEncoding(rType), - [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, - unsigned fDim, DimLevelType dlt) -> bool { - if (fDim == dim && kind == fKind) { - fieldIdx = fIdx; - // Returns false to break the iteration. - return false; - } - return true; - }); - assert(fieldIdx != -1u); - return fieldIdx; - } - - RankedTensorType rType; - Storage fields; -}; - -using SparseTensorDescriptor = SparseTensorDescriptorImpl; -using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; - -//===----------------------------------------------------------------------===// -// SparseTensorLoopEmiter class, manages sparse tensors and helps to -// generate loop structure to (co)-iterate sparse tensors. +// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate +// loop structure to (co)-iterate sparse tensors. // // An example usage: // To generate the following loops over T1 and T2 @@ -559,15 +345,15 @@ public: using OutputUpdater = function_ref; - /// Constructor: take an array of tensors inputs, on which the generated - /// loops will iterate on. The index of the tensor in the array is also the + /// Constructor: take an array of tensors inputs, on which the generated loops + /// will iterate on. The index of the tensor in the array is also the /// tensor id (tid) used in related functions. /// If isSparseOut is set, loop emitter assume that the sparse output tensor /// is empty, and will always generate loops on it based on the dim sizes. /// An optional array could be provided (by sparsification) to indicate the /// loop id sequence that will be generated. It is used to establish the - /// mapping between affineDimExpr to the corresponding loop index in the - /// loop stack that are maintained by the loop emitter. + /// mapping between affineDimExpr to the corresponding loop index in the loop + /// stack that are maintained by the loop emitter. explicit SparseTensorLoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, @@ -582,8 +368,8 @@ public: /// Generates a list of operations to compute the affine expression. Value genAffine(OpBuilder &builder, AffineExpr a, Location loc); - /// Enters a new loop sequence, the loops within the same sequence starts - /// from the break points of previous loop instead of starting over from 0. + /// Enters a new loop sequence, the loops within the same sequence starts from + /// the break points of previous loop instead of starting over from 0. /// e.g., /// { /// // loop sequence start. @@ -738,10 +524,10 @@ private: /// scf.reduce.return %val /// } /// } - /// NOTE: only one instruction will be moved into reduce block, - /// transformation will fail if multiple instructions are used to compute - /// the reduction value. Return %ret to user, while %val is provided by - /// users (`reduc`). + /// NOTE: only one instruction will be moved into reduce block, transformation + /// will fail if multiple instructions are used to compute the reduction + /// value. + /// Return %ret to user, while %val is provided by users (`reduc`). void exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc); @@ -749,9 +535,9 @@ private: void exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc); - /// A optional string attribute that should be attached to the loop - /// generated by loop emitter, it might help following passes to identify - /// loops that operates on sparse tensors more easily. + /// A optional string attribute that should be attached to the loop generated + /// by loop emitter, it might help following passes to identify loops that + /// operates on sparse tensors more easily. StringAttr loopTag; /// Whether the loop emitter needs to treat the last tensor as the output /// tensor. @@ -770,8 +556,7 @@ private: std::vector> idxBuffer; // to_indices std::vector valBuffer; // to_value - // Loop Stack, stores the information of all the nested loops that are - // alive. + // Loop Stack, stores the information of all the nested loops that are alive. std::vector loopStack; // Loop Sequence Stack, stores the unversial index for the current loop diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index e059bd3..113347c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -36,6 +36,10 @@ using FuncGeneratorType = static constexpr const char kInsertFuncNamePrefix[] = "_insert_"; +static constexpr uint64_t dimSizesIdx = 0; +static constexpr uint64_t memSizesIdx = 1; +static constexpr uint64_t fieldsIdx = 2; + //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -45,18 +49,6 @@ static UnrealizedConversionCastOp getTuple(Value tensor) { return llvm::cast(tensor.getDefiningOp()); } -static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { - auto tuple = getTuple(tensor); - return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); -} - -static MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { - auto tuple = getTuple(tensor); - fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); -} - /// Packs the given values as a "tuple" value. static Value genTuple(OpBuilder &builder, Location loc, Type tp, ValueRange values) { @@ -64,14 +56,6 @@ static Value genTuple(OpBuilder &builder, Location loc, Type tp, .getResult(0); } -static Value genTuple(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc) { - return builder - .create(loc, desc.getTensorType(), - desc.getFields()) - .getResult(0); -} - /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -117,7 +101,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, /// Creates a straightforward counting for-loop. static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, - MutableArrayRef fields, + SmallVectorImpl &fields, Value lower = Value()) { Type indexType = builder.getIndexType(); if (!lower) @@ -134,46 +118,81 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, /// original dimension 'dim'. Returns std::nullopt if no sparse encoding is /// attached to the given tensor type. static Optional sizeFromTensorAtDim(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc, - unsigned dim) { - RankedTensorType rtp = desc.getTensorType(); + RankedTensorType tensorTp, + Value adaptedValue, unsigned dim) { + auto enc = getSparseTensorEncoding(tensorTp); + if (!enc) + return std::nullopt; + // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. - auto shape = rtp.getShape(); + auto shape = tensorTp.getShape(); if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim)); - return builder.create(loc, desc.getDimSizesMemRef(), idx) + auto tuple = getTuple(adaptedValue); + Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim)); + return builder + .create(loc, tuple.getInputs()[dimSizesIdx], idx) .getResult(); } // Gets the dimension size at the given stored dimension 'd', either as a // constant for a static size, or otherwise dynamically through memSizes. -Value sizeAtStoredDim(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc, unsigned d) { - RankedTensorType rtp = desc.getTensorType(); +Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp, + SmallVectorImpl &fields, unsigned d) { unsigned dim = toOrigDim(rtp, d); auto shape = rtp.getShape(); if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - - return genLoad(builder, loc, desc.getDimSizesMemRef(), + return genLoad(builder, loc, fields[dimSizesIdx], constantIndex(builder, loc, d)); } +/// Translates field index to memSizes index. +static unsigned getMemSizesIndex(unsigned field) { + assert(fieldsIdx <= field); + return field - fieldsIdx; +} + +/// Creates a pushback op for given field and updates the fields array +/// accordingly. This operation also updates the memSizes contents. static void createPushback(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned fidx, + SmallVectorImpl &fields, unsigned field, Value value, Value repeat = Value()) { - Type etp = desc.getElementType(fidx); - Value field = desc.getField(fidx); - Value newField = builder.create( - loc, field.getType(), desc.getMemSizesMemRef(), field, - toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)), + assert(fieldsIdx <= field && field < fields.size()); + Type etp = fields[field].getType().cast().getElementType(); + fields[field] = builder.create( + loc, fields[field].getType(), fields[memSizesIdx], fields[field], + toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)), repeat); - desc.setField(fidx, newField); +} + +/// Returns field index of sparse tensor type for pointers/indices, when set. +static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { + assert(getSparseTensorEncoding(type)); + RankedTensorType rType = type.cast(); + unsigned field = fieldsIdx; // start past header + for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { + if (isCompressedDim(rType, r)) { + if (r == ptrDim) + return field; + field++; + if (r == idxDim) + return field; + field++; + } else if (isSingletonDim(rType, r)) { + if (r == idxDim) + return field; + field++; + } else { + assert(isDenseDim(rType, r)); // no fields + } + } + assert(ptrDim == -1u && idxDim == -1u); + return field + 1; // return values field index } /// Maps a sparse tensor type to the appropriate compounded buffers. @@ -182,24 +201,64 @@ convertSparseTensorType(Type type, SmallVectorImpl &fields) { auto enc = getSparseTensorEncoding(type); if (!enc) return std::nullopt; - + // Construct the basic types. + auto *context = type.getContext(); RankedTensorType rType = type.cast(); - foreachFieldAndTypeInSparseTensor( - rType, - [&fields](Type fieldType, unsigned fieldIdx, - SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, - DimLevelType /*dlt*/) -> bool { - assert(fieldIdx == fields.size()); - fields.push_back(fieldType); - return true; - }); + Type indexType = IndexType::get(context); + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); + Type eltType = rType.getElementType(); + // + // Sparse tensor storage scheme for rank-dimensional tensor is organized + // as a single compound type with the following fields. Note that every + // memref with ? size actually behaves as a "vector", i.e. the stored + // size is the capacity and the used size resides in the memSizes array. + // + // struct { + // memref dimSizes ; size in each dimension + // memref memSizes ; sizes of ptrs/inds/values + // ; per-dimension d: + // ; if dense: + // + // ; if compresed: + // memref pointers-d ; pointers for sparse dim d + // memref indices-d ; indices for sparse dim d + // ; if singleton: + // memref indices-d ; indices for singleton dim d + // memref values ; values + // }; + // + unsigned rank = rType.getShape().size(); + unsigned lastField = getFieldIndex(type, -1u, -1u); + // The dimSizes array and memSizes array. + fields.push_back(MemRefType::get({rank}, indexType)); + fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType)); + // Per-dimension storage. + for (unsigned r = 0; r < rank; r++) { + // Dimension level types apply in order to the reordered dimension. + // As a result, the compound type can be constructed directly in the given + // order. Clients of this type know what field is what from the sparse + // tensor type. + if (isCompressedDim(rType, r)) { + fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType)); + fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); + } else if (isSingletonDim(rType, r)) { + fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType)); + } else { + assert(isDenseDim(rType, r)); // no fields + } + } + // The values array. + fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType)); + assert(fields.size() == lastField); return success(); } /// Generates code that allocates a sparse storage scheme for given rank. static void allocSchemeForRank(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned r0) { - RankedTensorType rtp = desc.getTensorType(); + RankedTensorType rtp, + SmallVectorImpl &fields, unsigned field, + unsigned r0) { unsigned rank = rtp.getShape().size(); Value linear = constantIndex(builder, loc, 1); for (unsigned r = r0; r < rank; r++) { @@ -209,8 +268,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, // the desired "linear + 1" length property at all times. Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero, - linear); + createPushback(builder, loc, fields, field, ptrZero, linear); return; } if (isSingletonDim(rtp, r)) { @@ -220,23 +278,23 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, // at this level. We will eventually reach a compressed level or // otherwise the values array for the from-here "all-dense" case. assert(isDenseDim(rtp, r)); - Value size = sizeAtStoredDim(builder, loc, desc, r); + Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); - createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear); + createPushback(builder, loc, fields, field, valZero, linear); + assert(fields.size() == ++field); } /// Creates allocation operation. -static Value createAllocation(OpBuilder &builder, Location loc, - MemRefType memRefType, Value sz, - bool enableInit) { - Value buffer = builder.create(loc, memRefType, sz); - Type elemType = memRefType.getElementType(); +static Value createAllocation(OpBuilder &builder, Location loc, Type type, + Value sz, bool enableInit) { + auto memType = MemRefType::get({ShapedType::kDynamic}, type); + Value buffer = builder.create(loc, memType, sz); if (enableInit) { - Value fillValue = builder.create( - loc, elemType, builder.getZeroAttr(elemType)); + Value fillValue = + builder.create(loc, type, builder.getZeroAttr(type)); builder.create(loc, fillValue, buffer); } return buffer; @@ -252,68 +310,69 @@ static Value createAllocation(OpBuilder &builder, Location loc, static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { + auto enc = getSparseTensorEncoding(type); + assert(enc); RankedTensorType rtp = type.cast(); + Type indexType = builder.getIndexType(); + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); + Type eltType = rtp.getElementType(); + auto shape = rtp.getShape(); + unsigned rank = shape.size(); Value heuristic = constantIndex(builder, loc, 16); - - foreachFieldAndTypeInSparseTensor( - rtp, - [&builder, &fields, loc, heuristic, - enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, - unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { - assert(fields.size() == fIdx); - auto memRefTp = fType.cast(); - Value field; - switch (fKind) { - case SparseTensorFieldKind::DimSizes: - case SparseTensorFieldKind::MemSizes: - field = builder.create(loc, memRefTp); - break; - case SparseTensorFieldKind::PtrMemRef: - case SparseTensorFieldKind::IdxMemRef: - case SparseTensorFieldKind::ValMemRef: - field = - createAllocation(builder, loc, memRefTp, heuristic, enableInit); - break; - } - assert(field); - fields.push_back(field); - // Returns true to continue the iteration. - return true; - }); - - MutSparseTensorDescriptor desc(rtp, fields); - // Build original sizes. SmallVector sizes; - auto shape = rtp.getShape(); - unsigned rank = shape.size(); for (unsigned r = 0, o = 0; r < rank; r++) { if (ShapedType::isDynamic(shape[r])) sizes.push_back(dynSizes[o++]); else sizes.push_back(constantIndex(builder, loc, shape[r])); } + // The dimSizes array and memSizes array. + unsigned lastField = getFieldIndex(type, -1u, -1u); + Value dimSizes = + builder.create(loc, MemRefType::get({rank}, indexType)); + Value memSizes = builder.create( + loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType)); + fields.push_back(dimSizes); + fields.push_back(memSizes); + // Per-dimension storage. + for (unsigned r = 0; r < rank; r++) { + if (isCompressedDim(rtp, r)) { + fields.push_back( + createAllocation(builder, loc, ptrType, heuristic, enableInit)); + fields.push_back( + createAllocation(builder, loc, idxType, heuristic, enableInit)); + } else if (isSingletonDim(rtp, r)) { + fields.push_back( + createAllocation(builder, loc, idxType, heuristic, enableInit)); + } else { + assert(isDenseDim(rtp, r)); // no fields + } + } + // The values array. + fields.push_back( + createAllocation(builder, loc, eltType, heuristic, enableInit)); + assert(fields.size() == lastField); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. builder.create( - loc, constantZero(builder, loc, builder.getIndexType()), - desc.getMemSizesMemRef()); // zero memSizes - - Value ptrZero = - constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); - for (unsigned r = 0; r < rank; r++) { + loc, ValueRange{constantZero(builder, loc, indexType)}, + ValueRange{memSizes}); // zero memSizes + Value ptrZero = constantZero(builder, loc, ptrType); + for (unsigned r = 0, field = fieldsIdx; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); - // Fills dim sizes array. - genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), - constantIndex(builder, loc, r)); - - // Pushes a leading zero to pointers memref. - if (isCompressedDim(rtp, r)) - createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero); + genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r)); + if (isCompressedDim(rtp, r)) { + createPushback(builder, loc, fields, field, ptrZero); + field += 2; + } else if (isSingletonDim(rtp, r)) { + field += 1; + } } - allocSchemeForRank(builder, loc, desc, /*rank=*/0); + allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0); } /// Helper method that generates block specific to compressed case: @@ -337,22 +396,19 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, /// } /// pos[d] = next static Value genCompressed(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, + RankedTensorType rtp, SmallVectorImpl &fields, SmallVectorImpl &indices, Value value, - Value pos, unsigned d) { - RankedTensorType rtp = desc.getTensorType(); + Value pos, unsigned field, unsigned d) { unsigned rank = rtp.getShape().size(); SmallVector types; Type indexType = builder.getIndexType(); Type boolType = builder.getIntegerType(1); - unsigned idxIndex = desc.getIdxMemRefIndex(d); - unsigned ptrIndex = desc.getPtrMemRefIndex(d); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos); - Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1); - Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex)); - Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz); + Value plo = genLoad(builder, loc, fields[field], pos); + Value phi = genLoad(builder, loc, fields[field], pp1); + Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1)); + Value msz = genLoad(builder, loc, fields[memSizesIdx], psz); Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); // Conditional expression. @@ -362,55 +418,49 @@ static Value genCompressed(OpBuilder &builder, Location loc, scf::IfOp ifOp1 = builder.create(loc, types, lt, /*else*/ true); types.pop_back(); builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); - Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1); + Value crd = genLoad(builder, loc, fields[field + 1], phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), indices[d]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (d > 0) - genStore(builder, loc, msz, desc.getField(ptrIndex), pos); + genStore(builder, loc, msz, fields[field], pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); - // If present construct. Note that for a non-unique dimension level, we - // simply set the condition to false and rely on CSE/DCE to clean up the IR. + // If present construct. Note that for a non-unique dimension level, we simply + // set the condition to false and rely on CSE/DCE to clean up the IR. // // TODO: generate less temporary IR? // - for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) - types.push_back(desc.getField(i).getType()); + for (unsigned i = 0, e = fields.size(); i < e; i++) + types.push_back(fields[i].getType()); types.push_back(indexType); if (!isUniqueDim(rtp, d)) p = constantI1(builder, loc, false); scf::IfOp ifOp2 = builder.create(loc, types, p, /*else*/ true); // If present (fields unaffected, update next to phim1). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - - // FIXME: This does not looks like a clean way, but probably the most - // efficient way. - desc.getFields().push_back(phim1); - builder.create(loc, desc.getFields()); - desc.getFields().pop_back(); - + fields.push_back(phim1); + builder.create(loc, fields); + fields.pop_back(); // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1); - createPushback(builder, loc, desc, idxIndex, indices[d]); + genStore(builder, loc, mszp1, fields[field], pp1); + createPushback(builder, loc, fields, field + 1, indices[d]); // Prepare the next dimension "as needed". if ((d + 1) < rank) - allocSchemeForRank(builder, loc, desc, d + 1); - - desc.getFields().push_back(msz); - builder.create(loc, desc.getFields()); - desc.getFields().pop_back(); - + allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1); + fields.push_back(msz); + builder.create(loc, fields); + fields.pop_back(); // Update fields and return next pos. builder.setInsertionPointAfter(ifOp2); unsigned o = 0; - for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) - desc.setField(i, ifOp2.getResult(o++)); + for (unsigned i = 0, e = fields.size(); i < e; i++) + fields[i] = ifOp2.getResult(o++); return ifOp2.getResult(o); } @@ -438,10 +488,11 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module, // Construct fields and indices arrays from parameters. ValueRange tmp = args.drop_back(rank + 1); SmallVector fields(tmp.begin(), tmp.end()); - MutSparseTensorDescriptor desc(rtp, fields); tmp = args.take_back(rank + 1).drop_back(); SmallVector indices(tmp.begin(), tmp.end()); Value value = args.back(); + + unsigned field = fieldsIdx; // Start past header. Value pos = constantZero(builder, loc, builder.getIndexType()); // Generate code for every dimension. for (unsigned d = 0; d < rank; d++) { @@ -453,35 +504,39 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module, // } // pos[d] = indices.size() - 1 // - pos = genCompressed(builder, loc, desc, indices, value, pos, d); + pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field, + d); + field += 2; } else if (isSingletonDim(rtp, d)) { // Create: // indices[d].push_back(i[d]) // pos[d] = pos[d-1] // - createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]); + createPushback(builder, loc, fields, field, indices[d]); + field += 1; } else { assert(isDenseDim(rtp, d)); // Construct the new position as: // pos[d] = size * pos[d-1] + i[d] // - Value size = sizeAtStoredDim(builder, loc, desc, d); + Value size = sizeAtStoredDim(builder, loc, rtp, fields, d); Value mult = builder.create(loc, size, pos); pos = builder.create(loc, mult, indices[d]); } } // Reached the actual value append/insert. if (!isDenseDim(rtp, rank - 1)) - createPushback(builder, loc, desc, desc.getValMemRefIndex(), value); + createPushback(builder, loc, fields, field++, value); else - genStore(builder, loc, value, desc.getValMemRef(), pos); + genStore(builder, loc, value, fields[field++], pos); + assert(fields.size() == field); builder.create(loc, fields); } /// Generates a call to a function to perform an insertion operation. If the /// function doesn't exist yet, call `createFunc` to generate the function. -static void genInsertionCallHelper(OpBuilder &builder, - MutSparseTensorDescriptor desc, +static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp, + SmallVectorImpl &fields, SmallVectorImpl &indices, Value value, func::FuncOp insertPoint, StringRef namePrefix, @@ -489,7 +544,6 @@ static void genInsertionCallHelper(OpBuilder &builder, // The mangled name of the function has this format: // _[C|S|D]___ // __ - RankedTensorType rtp = desc.getTensorType(); SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); nameOstream << namePrefix; @@ -523,7 +577,7 @@ static void genInsertionCallHelper(OpBuilder &builder, auto func = module.lookupSymbol(result.getAttr()); // Construct parameters for fields and indices. - SmallVector operands(desc.getFields().begin(), desc.getFields().end()); + SmallVector operands(fields.begin(), fields.end()); operands.append(indices.begin(), indices.end()); operands.push_back(value); Location loc = insertPoint.getLoc(); @@ -536,7 +590,7 @@ static void genInsertionCallHelper(OpBuilder &builder, func = builder.create( loc, nameOstream.str(), FunctionType::get(context, ValueRange(operands).getTypes(), - ValueRange(desc.getFields()).getTypes())); + ValueRange(fields).getTypes())); func.setPrivate(); createFunc(builder, module, func, rtp); } @@ -544,44 +598,42 @@ static void genInsertionCallHelper(OpBuilder &builder, // Generate a call to perform the insertion and update `fields` with values // returned from the call. func::CallOp call = builder.create(loc, func, operands); - for (size_t i = 0, e = desc.getNumFields(); i < e; i++) { - desc.getFields()[i] = call.getResult(i); + for (size_t i = 0; i < fields.size(); i++) { + fields[i] = call.getResult(i); } } /// Generations insertion finalization code. -static void genEndInsert(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc) { - RankedTensorType rtp = desc.getTensorType(); +static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, + SmallVectorImpl &fields) { unsigned rank = rtp.getShape().size(); + unsigned field = fieldsIdx; // start past header for (unsigned d = 0; d < rank; d++) { if (isCompressedDim(rtp, d)) { // Compressed dimensions need a pointer cleanup for all entries // that were not visited during the insertion pass. // - // TODO: avoid cleanup and keep compressed scheme consistent at all - // times? + // TODO: avoid cleanup and keep compressed scheme consistent at all times? // if (d > 0) { Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); - Value ptrMemRef = desc.getPtrMemRef(d); - Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); - Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz); + Value mz = constantIndex(builder, loc, getMemSizesIndex(field)); + Value hi = genLoad(builder, loc, fields[memSizesIdx], mz); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. - SmallVector inits{genLoad(builder, loc, ptrMemRef, zero)}; + SmallVector inits{genLoad(builder, loc, fields[field], zero)}; scf::ForOp loop = createFor(builder, loc, hi, inits, one); Value i = loop.getInductionVar(); Value oldv = loop.getRegionIterArg(0); - Value newv = genLoad(builder, loc, ptrMemRef, i); + Value newv = genLoad(builder, loc, fields[field], i); Value ptrZero = constantZero(builder, loc, ptrType); Value cond = builder.create( loc, arith::CmpIPredicate::eq, newv, ptrZero); scf::IfOp ifOp = builder.create(loc, TypeRange(ptrType), cond, /*else*/ true); builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - genStore(builder, loc, oldv, ptrMemRef, i); + genStore(builder, loc, oldv, fields[field], i); builder.create(loc, oldv); builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); builder.create(loc, newv); @@ -589,10 +641,14 @@ static void genEndInsert(OpBuilder &builder, Location loc, builder.create(loc, ifOp.getResult(0)); builder.setInsertionPointAfter(loop); } + field += 2; + } else if (isSingletonDim(rtp, d)) { + field++; } else { - assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d)); + assert(isDenseDim(rtp, d)); } } + assert(fields.size() == ++field); } //===----------------------------------------------------------------------===// @@ -683,12 +739,12 @@ public: matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Optional index = op.getConstantIndex(); - if (!index || !getSparseTensorEncoding(adaptor.getSource().getType())) + if (!index) return failure(); - - auto desc = getDescriptorFromTensorTuple(adaptor.getSource()); - auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index); - + auto sz = + sizeFromTensorAtDim(rewriter, op.getLoc(), + op.getSource().getType().cast(), + adaptor.getSource(), *index); if (!sz) return failure(); @@ -778,14 +834,16 @@ public: LogicalResult matchAndRewrite(LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Prepare descriptor. - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + RankedTensorType srcType = + op.getTensor().getType().cast(); + auto tuple = getTuple(adaptor.getTensor()); + // Prepare fields. + SmallVector fields(tuple.getInputs()); // Generate optional insertion finalization code. if (op.getHasInserts()) - genEndInsert(rewriter, op.getLoc(), desc); + genEndInsert(rewriter, op.getLoc(), srcType, fields); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields)); return success(); } }; @@ -797,10 +855,7 @@ public: LogicalResult matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!getSparseTensorEncoding(op.getTensor().getType())) - return failure(); Location loc = op->getLoc(); - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); RankedTensorType srcType = op.getTensor().getType().cast(); Type eltType = srcType.getElementType(); @@ -812,7 +867,8 @@ public: // dimension size, translated back to original dimension). Note that we // recursively rewrite the new DimOp on the **original** tensor. unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); - auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); + auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(), + innerDim); assert(sz); // This for sure is a sparse tensor // Generate a memref for `sz` elements of type `t`. auto genAlloc = [&](Type t) { @@ -852,15 +908,16 @@ public: matchAndRewrite(CompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + RankedTensorType dstType = + op.getTensor().getType().cast(); + Type eltType = dstType.getElementType(); + auto tuple = getTuple(adaptor.getTensor()); Value values = adaptor.getValues(); Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); - RankedTensorType dstType = desc.getTensorType(); - Type eltType = dstType.getElementType(); - // Prepare indices. + // Prepare fields and indices. + SmallVector fields(tuple.getInputs()); SmallVector indices(adaptor.getIndices()); // If the innermost dimension is ordered, we need to sort the indices // in the "added" array prior to applying the compression. @@ -882,19 +939,19 @@ public: // filled[index] = false; // yield new_memrefs // } - scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields()); + scf::ForOp loop = createFor(rewriter, loc, count, fields); Value i = loop.getInductionVar(); Value index = genLoad(rewriter, loc, added, i); Value value = genLoad(rewriter, loc, values, index); indices.push_back(index); // TODO: faster for subsequent insertions? auto insertPoint = op->template getParentOfType(); - genInsertionCallHelper(rewriter, desc, indices, value, insertPoint, - kInsertFuncNamePrefix, genInsertBody); + genInsertionCallHelper(rewriter, dstType, fields, indices, value, + insertPoint, kInsertFuncNamePrefix, genInsertBody); genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, index); genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index); - rewriter.create(loc, desc.getFields()); + rewriter.create(loc, fields); rewriter.setInsertionPointAfter(loop); Value result = genTuple(rewriter, loc, dstType, loop->getResults()); // Deallocate the buffers on exit of the full loop nest. @@ -916,18 +973,20 @@ public: LogicalResult matchAndRewrite(InsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); - // Prepare and indices. + RankedTensorType dstType = + op.getTensor().getType().cast(); + auto tuple = getTuple(adaptor.getTensor()); + // Prepare fields and indices. + SmallVector fields(tuple.getInputs()); SmallVector indices(adaptor.getIndices()); // Generate insertion. Value value = adaptor.getValue(); auto insertPoint = op->template getParentOfType(); - genInsertionCallHelper(rewriter, desc, indices, value, insertPoint, - kInsertFuncNamePrefix, genInsertBody); + genInsertionCallHelper(rewriter, dstType, fields, indices, value, + insertPoint, kInsertFuncNamePrefix, genInsertBody); // Replace operation with resulting memrefs. - rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc)); + rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields)); return success(); } }; @@ -944,9 +1003,11 @@ public: // Replace the requested pointer access with corresponding field. // The cast_op is inserted by type converter to intermix 1:N type // conversion. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - Value field = Base::getFieldForOp(desc, op); - rewriter.replaceOp(op, field); + auto tuple = getTuple(adaptor.getTensor()); + unsigned idx = Base::getIndexForOp(tuple, op); + auto fields = tuple.getInputs(); + assert(idx < fields.size()); + rewriter.replaceOp(op, fields[idx]); return success(); } }; @@ -957,10 +1018,10 @@ class SparseToPointersConverter public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToPointersOp op) { + static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, + ToPointersOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return desc.getPtrMemRef(dim); + return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u); } }; @@ -970,10 +1031,10 @@ class SparseToIndicesConverter public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToIndicesOp op) { + static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/, + ToIndicesOp op) { uint64_t dim = op.getDimension().getZExtValue(); - return desc.getIdxMemRef(dim); + return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim); } }; @@ -983,9 +1044,10 @@ class SparseToValuesConverter public: using SparseGetterOpConverter::SparseGetterOpConverter; // Callback for SparseGetterOpConverter. - static Value getFieldForOp(const SparseTensorDescriptor &desc, - ToValuesOp /*op*/) { - return desc.getValMemRef(); + static unsigned getIndexForOp(UnrealizedConversionCastOp tuple, + ToValuesOp /*op*/) { + // The last field holds the value buffer. + return tuple.getInputs().size() - 1; } }; @@ -1017,11 +1079,12 @@ public: matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. - auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + auto tuple = getTuple(adaptor.getTensor()); + auto fields = tuple.getInputs(); + unsigned lastField = fields.size() - 1; Value field = - constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); - rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), - field); + constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); + rewriter.replaceOpWithNewOp(op, fields[memSizesIdx], field); return success(); } }; -- 2.7.4