From 988733c60037c61ca49233c356c0f928a5ac14bb Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Dec 2022 18:45:07 +0000 Subject: [PATCH] [mlir][sparse] use sparse_tensor::StorageSpecifier to store dim/memSizes Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140130 --- .../Dialect/SparseTensor/IR/SparseTensorOps.td | 41 +- .../SparseTensor/IR/SparseTensorDialect.cpp | 6 +- .../SparseTensor/Transforms/CodegenUtils.cpp | 110 ---- .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 214 ------- .../Transforms/SparseBufferRewriting.cpp | 11 +- .../Transforms/SparseStorageSpecifierToLLVM.cpp | 2 +- .../Transforms/SparseTensorCodegen.cpp | 166 ++---- .../Transforms/SparseTensorStorageLayout.cpp | 80 ++- .../Transforms/SparseTensorStorageLayout.h | 5 +- .../Dialect/SparseTensor/buffer_rewriting.mlir | 47 +- mlir/test/Dialect/SparseTensor/codegen.mlir | 613 ++++++++++----------- .../codegen_buffer_initialization.mlir | 48 +- mlir/test/Dialect/SparseTensor/invalid.mlir | 12 +- mlir/test/Dialect/SparseTensor/roundtrip.mlir | 36 +- .../Dialect/SparseTensor/scf_1_N_conversion.mlir | 114 ++-- .../SparseTensor/sparse_matmul_codegen.mlir | 235 ++++---- .../SparseTensor/CPU/sparse_rewrite_push_back.mlir | 9 +- 17 files changed, 678 insertions(+), 1071 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 13c6e03..771e97f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -343,18 +343,20 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", "inBuffer", "value", "$_self.cast().getElementType()">, AllTypesMatch<["inBuffer", "outBuffer"]>]>, - Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, + Arguments<(ins Index:$curSize, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, - AnyType:$value, IndexAttr:$idx, Optional:$n, + AnyType:$value, Optional:$n, UnitAttr:$inbounds)>, - Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> { + Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer, + Index:$newSize)> { string summary = "Pushes a value to the back of a given buffer"; string description = [{ - Push `value` to the end of the given sparse tensor storage buffer - `inBuffer` and update the size of the buffer in `bufferSizes[idx]`. The - capacity of the buffer is recorded in the memref type of `inBuffer `. If the - current buffer is full, then `inBuffer.realloc` is called before pushing the - data to the buffer. This is similar to std::vector push_back. + Pushes `value` to the end of the given sparse tensor storage buffer + `inBuffer` as indicated by the value of `curSize` and returns the + new size of the buffer in `newSize` (`newSize = curSize + n`). + The capacity of the buffer is recorded in the memref type of `inBuffer`. + If the current buffer is full, then `inBuffer.realloc` is called before + pushing the data to the buffer. This is similar to std::vector push_back. The optional input `n` specifies the number of times to repeately push the value to the back of the tensor. When `n` is a compile-time constant, @@ -376,29 +378,28 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", Example: ```mlir - %r = sparse_tensor.push_back %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back %curSize, %buffer, %val + : index, memref, f64 ``` ```mlir - %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val + : xindex, memref, f64 ``` ```mlir - %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val, %n - {idx = 0 : index} : memref, memref, f64 + %buf, %newSize = sparse_tensor.push_back inbounds %curSize, %buffer, %val, %n + : xindex, memref, f64 ``` }]; - let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" + let assemblyFormat = "(`inbounds` $inbounds^)? $curSize `,` $inBuffer" " `,` $value (`,` $n^ )? attr-dict `:`" - " type($bufferSizes) `,` type($inBuffer) `,`" - " type($value) (`,` type($n)^ )?"; + " type($curSize) `,` type($inBuffer) `,`" + " type($value) (`,` type($n)^ )?"; let builders = [ - //Build an op without input `n`. - OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer, - "Value":$value, "APInt":$idx)> + // Build an op (reusing type from curSize and inBuffer) without input `n` + OpBuilder<(ins "Value":$curSize, "Value":$inBuffer, "Value":$value)> ]; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 1e9aab8..f28abee 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -694,10 +694,8 @@ LogicalResult InsertOp::verify() { } void PushBackOp::build(OpBuilder &builder, OperationState &result, - Type outBuffer, Value bufferSizes, Value inBuffer, - Value value, APInt idx) { - build(builder, result, outBuffer, bufferSizes, inBuffer, value, - std::move(idx), Value()); + Value curSize, Value inBuffer, Value value) { + build(builder, result, curSize, inBuffer, value, Value()); } LogicalResult PushBackOp::verify() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 9fd74f7..e3ab5ce 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -90,116 +90,6 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc, return val; } -void sparse_tensor::foreachFieldInSparseTensor( - const SparseTensorEncodingAttr enc, - llvm::function_ref - callback) { - assert(enc); - -#define RETURN_ON_FALSE(idx, kind, dim, dlt) \ - if (!(callback(idx, kind, dim, dlt))) \ - return; - - RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u, - DimLevelType::Undef); - RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u, - DimLevelType::Undef); - - static_assert(dataFieldIdx == memSizesIdx + 1); - unsigned fieldIdx = dataFieldIdx; - // Per-dimension storage. - for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) { - // Dimension level types apply in order to the reordered dimension. - // As a result, the compound type can be constructed directly in the given - // order. - auto dlt = getDimLevelType(enc, r); - if (isCompressedDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); - } else { - assert(isDenseDLT(dlt)); // no fields - } - } - // The values array. - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u, - DimLevelType::Undef); - -#undef RETURN_ON_FALSE -} - -void sparse_tensor::foreachFieldAndTypeInSparseTensor( - RankedTensorType rType, - llvm::function_ref - callback) { - auto enc = getSparseTensorEncoding(rType); - assert(enc); - // Construct the basic types. - Type indexType = IndexType::get(enc.getContext()); - Type idxType = enc.getIndexType(); - Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); - unsigned rank = rType.getShape().size(); - // memref dimSizes - Type dimSizeType = MemRefType::get({rank}, indexType); - // memref memSizes - Type memSizeType = - MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType); - // memref pointers - Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType); - // memref indices - Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType); - // memref values - Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType); - - foreachFieldInSparseTensor( - enc, - [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType, - callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, - unsigned dim, DimLevelType dlt) -> bool { - switch (fieldKind) { - case SparseTensorFieldKind::DimSizes: - return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::MemSizes: - return callback(memSizeType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::PtrMemRef: - return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::IdxMemRef: - return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); - case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, dim, dlt); - }; - llvm_unreachable("unrecognized field kind"); - }); -} - -unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; - foreachFieldInSparseTensor(enc, - [&numFields](unsigned, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - numFields++; - return true; - }); - return numFields; -} - -unsigned -sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { - unsigned numFields = 0; // one value memref - foreachFieldInSparseTensor(enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { - if (fidx >= dataFieldIdx) - numFields++; - return true; - }); - assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx); - return numFields; -} //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 6e4ea83..46f0214 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -314,220 +314,6 @@ inline bool isZeroRankedTensorOrScalar(Type type) { } //===----------------------------------------------------------------------===// -// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout -// scheme. -// -// Sparse tensor storage scheme for rank-dimensional tensor is organized -// as a single compound type with the following fields. Note that every -// memref with ? size actually behaves as a "vector", i.e. the stored -// size is the capacity and the used size resides in the memSizes array. -// -// struct { -// memref dimSizes ; size in each dimension -// memref memSizes ; sizes of ptrs/inds/values -// ; per-dimension d: -// ; if dense: -// -// ; if compresed: -// memref pointers-d ; pointers for sparse dim d -// memref indices-d ; indices for sparse dim d -// ; if singleton: -// memref indices-d ; indices for singleton dim d -// memref values ; values -// }; -// -//===----------------------------------------------------------------------===// -enum class SparseTensorFieldKind { - DimSizes, - MemSizes, - PtrMemRef, - IdxMemRef, - ValMemRef -}; - -constexpr uint64_t dimSizesIdx = 0; -constexpr uint64_t memSizesIdx = dimSizesIdx + 1; -constexpr uint64_t dataFieldIdx = memSizesIdx + 1; - -/// For each field that will be allocated for the given sparse tensor encoding, -/// calls the callback with the corresponding field index, field kind, dimension -/// (for sparse tensor level memrefs) and dimlevelType. -/// The field index always starts with zero and increments by one between two -/// callback invocations. -/// Ideally, all other methods should rely on this function to query a sparse -/// tensor fields instead of relying on ad-hoc index computation. -void foreachFieldInSparseTensor( - SparseTensorEncodingAttr, - llvm::function_ref); - -/// Same as above, except that it also builds the Type for the corresponding -/// field. -void foreachFieldAndTypeInSparseTensor( - RankedTensorType, - llvm::function_ref); - -/// Gets the total number of fields for the given sparse tensor encoding. -unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Gets the total number of data fields (index arrays, pointer arrays, and a -/// value array) for the given sparse tensor encoding. -unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); - -/// Get the index of the field in memSizes (only valid for data fields). -inline unsigned getFieldMemSizesIndex(unsigned fid) { - assert(fid >= dataFieldIdx); - return fid - dataFieldIdx; -} - -template -struct SparseTensorValueArrayRef; - -// Uses ValueRange for immuatable descriptors; uses SmallVectorImpl & -// for mutable descriptors. -template <> -struct SparseTensorValueArrayRef { - using ValueArray = ValueRange; -}; - -// Using SmallVector for mutable descriptor allows users to reuse it as a tmp -// buffers to append value for some special cases, though users should be -// responsible to restore the buffer to legal states after their use. It is -// probably not a clean way, but it is the most efficient way to avoid copying -// the fields into another SmallVector. If a more clear way is wanted, we -// should change it to MutableArrayRef instead. -template <> -struct SparseTensorValueArrayRef { - using ValueArray = SmallVectorImpl &; -}; - -/// A helper class around an array of values that corresponding to a sparse -/// tensor, provides a set of meaningful APIs to query and update a particular -/// field in a consistent way. -/// Users should not make assumption on how a sparse tensor is laid out but -/// instead relies on this class to access the right value for the right field. -template -class SparseTensorDescriptorImpl { -private: - using Storage = typename SparseTensorValueArrayRef::ValueArray; - -public: - SparseTensorDescriptorImpl(Type tp, Storage fields) - : rType(tp.cast()), fields(fields) { - assert(getSparseTensorEncoding(tp) && - getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == - fields.size()); - // We should make sure the class is trivially copyable (and should be small - // enough) such that we can pass it by value. - static_assert( - std::is_trivially_copyable_v>); - } - - // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to - // SparseTensorDescriptor. - template > - /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t &mDesc) - : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {} - - /// - /// Getters: get the field index for required field. - /// - - unsigned getPtrMemRefIndex(unsigned ptrDim) const { - return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef); - } - - unsigned getIdxMemRefIndex(unsigned idxDim) const { - return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef); - } - - unsigned getValMemRefIndex() const { return fields.size() - 1; } - - unsigned getPtrMemSizesIndex(unsigned dim) const { - return getPtrMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getIdxMemSizesIndex(unsigned dim) const { - return getIdxMemRefIndex(dim) - dataFieldIdx; - } - - unsigned getValMemSizesIndex() const { - return getValMemRefIndex() - dataFieldIdx; - } - - unsigned getNumFields() const { return fields.size(); } - - /// - /// Getters: get the value for required field. - /// - - Value getDimSizesMemRef() const { return fields[dimSizesIdx]; } - Value getMemSizesMemRef() const { return fields[memSizesIdx]; } - - Value getPtrMemRef(unsigned ptrDim) const { - return fields[getPtrMemRefIndex(ptrDim)]; - } - - Value getIdxMemRef(unsigned idxDim) const { - return fields[getIdxMemRefIndex(idxDim)]; - } - - Value getValMemRef() const { return fields[getValMemRefIndex()]; } - - Value getField(unsigned fid) const { - assert(fid < fields.size()); - return fields[fid]; - } - - /// - /// Setters: update the value for required field (only enabled for - /// MutSparseTensorDescriptor). - /// - - template - void setField(unsigned fid, std::enable_if_t v) { - assert(fid < fields.size()); - fields[fid] = v; - } - - RankedTensorType getTensorType() const { return rType; } - Storage getFields() const { return fields; } - - Type getElementType(unsigned fidx) const { - return fields[fidx].getType().template cast().getElementType(); - } - -private: - unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { - unsigned fieldIdx = -1u; - foreachFieldInSparseTensor( - getSparseTensorEncoding(rType), - [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, - unsigned fDim, DimLevelType dlt) -> bool { - if (fDim == dim && kind == fKind) { - fieldIdx = fIdx; - // Returns false to break the iteration. - return false; - } - return true; - }); - assert(fieldIdx != -1u); - return fieldIdx; - } - - RankedTensorType rType; - Storage fields; -}; - -using SparseTensorDescriptor = SparseTensorDescriptorImpl; -using MutSparseTensorDescriptor = SparseTensorDescriptorImpl; - -//===----------------------------------------------------------------------===// // SparseTensorLoopEmiter class, manages sparse tensors and helps to // generate loop structure to (co)-iterate sparse tensors. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index ded1e65..fc9476c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -331,7 +331,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); Value p = args[hiIdx]; - SmallVector types(2, p.getType()); // only two + SmallVector types(2, p.getType()); // only two scf::WhileOp whileOp = builder.create( loc, types, SmallVector{args[loIdx], args[hiIdx]}); @@ -490,7 +490,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module, Value i = lo; Value j = builder.create(loc, hi, c1); - SmallVector operands{i, j, p}; // exactly three + SmallVector operands{i, j, p}; // exactly three SmallVector types{i.getType(), j.getType(), p.getType()}; scf::WhileOp whileOp = builder.create(loc, types, operands); @@ -770,9 +770,7 @@ public: Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); Value capacity = rewriter.create(loc, buffer, c0); - Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); - Value bufferSizes = op.getBufferSizes(); - Value size = rewriter.create(loc, bufferSizes, idx); + Value size = op.getCurSize(); Value value = op.getValue(); Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); @@ -852,8 +850,7 @@ public: } // Update the buffer size. - rewriter.create(loc, newSize, bufferSizes, idx); - rewriter.replaceOp(op, buffer); + rewriter.replaceOp(op, {buffer, newSize}); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp index d6a9007..5843aef 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -118,7 +118,7 @@ public: op.getDim().value().getZExtValue()); } else { auto enc = op.getSpecifier().getType().getEncoding(); - builder::StorageLayout layout(enc); + StorageLayout layout(enc); Optional dim = std::nullopt; if (op.getDim()) dim = op.getDim().value().getZExtValue(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 6b406d8..710d6ce 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -16,6 +16,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" +#include "SparseTensorStorageLayout.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -40,38 +41,6 @@ static constexpr const char kInsertFuncNamePrefix[] = "_insert_"; // Helper methods. //===----------------------------------------------------------------------===// -/// Returns the "tuple" value of the adapted tensor. -static UnrealizedConversionCastOp getTuple(Value tensor) { - return llvm::cast(tensor.getDefiningOp()); -} - -static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { - auto tuple = getTuple(tensor); - return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); -} - -static MutSparseTensorDescriptor -getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { - auto tuple = getTuple(tensor); - fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); -} - -/// Packs the given values as a "tuple" value. -static Value genTuple(OpBuilder &builder, Location loc, Type tp, - ValueRange values) { - return builder.create(loc, TypeRange(tp), values) - .getResult(0); -} - -static Value genTuple(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc) { - return builder - .create(loc, desc.getTensorType(), - desc.getFields()) - .getResult(0); -} - /// Flatten a list of operands that may contain sparse tensors. static void flattenOperands(ValueRange operands, SmallVectorImpl &flattened) { @@ -146,9 +115,7 @@ static std::optional sizeFromTensorAtDim(OpBuilder &builder, // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim)); - return builder.create(loc, desc.getDimSizesMemRef(), idx) - .getResult(); + return desc.getDimSize(builder, loc, toStoredDim(rtp, dim)); } // Gets the dimension size at the given stored dimension 'd', either as a @@ -161,40 +128,24 @@ Value sizeAtStoredDim(OpBuilder &builder, Location loc, if (!ShapedType::isDynamic(shape[dim])) return constantIndex(builder, loc, shape[dim]); - return genLoad(builder, loc, desc.getDimSizesMemRef(), - constantIndex(builder, loc, d)); + return desc.getDimSize(builder, loc, d); } static void createPushback(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned fidx, + MutSparseTensorDescriptor desc, + SparseTensorFieldKind kind, Optional dim, Value value, Value repeat = Value()) { - Type etp = desc.getElementType(fidx); - Value field = desc.getField(fidx); - Value newField = builder.create( - loc, field.getType(), desc.getMemSizesMemRef(), field, - toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)), - repeat); - desc.setField(fidx, newField); -} + Type etp = desc.getMemRefElementType(kind, dim); + Value field = desc.getMemRefField(kind, dim); + StorageSpecifierKind specFieldKind = toSpecifierKind(kind); -/// Maps a sparse tensor type to the appropriate compounded buffers. -static std::optional -convertSparseTensorType(Type type, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(type); - if (!enc) - return std::nullopt; + auto pushBackOp = builder.create( + loc, desc.getSpecifierField(builder, loc, specFieldKind, dim), field, + toType(builder, loc, value, etp), repeat); - RankedTensorType rType = type.cast(); - foreachFieldAndTypeInSparseTensor( - rType, - [&fields](Type fieldType, unsigned fieldIdx, - SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, - DimLevelType /*dlt*/) -> bool { - assert(fieldIdx == fields.size()); - fields.push_back(fieldType); - return true; - }); - return success(); + desc.setMemRefField(kind, dim, pushBackOp.getOutBuffer()); + desc.setSpecifierField(builder, loc, specFieldKind, dim, + pushBackOp.getNewSize()); } /// Generates code that allocates a sparse storage scheme for given rank. @@ -210,8 +161,8 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, // the desired "linear + 1" length property at all times. Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero, - linear); + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + ptrZero, linear); return; } if (isSingletonDim(rtp, r)) { @@ -226,7 +177,8 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); - createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear); + createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, + std::nullopt, valZero, linear); } /// Creates allocation operation. @@ -257,22 +209,20 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, foreachFieldAndTypeInSparseTensor( rtp, - [&builder, &fields, loc, heuristic, + [&builder, &fields, rtp, loc, heuristic, enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); - auto memRefTp = fType.cast(); Value field; switch (fKind) { - case SparseTensorFieldKind::DimSizes: - case SparseTensorFieldKind::MemSizes: - field = builder.create(loc, memRefTp); + case SparseTensorFieldKind::StorageSpec: + field = SparseTensorSpecifier::getInitValue(builder, loc, rtp); break; case SparseTensorFieldKind::PtrMemRef: case SparseTensorFieldKind::IdxMemRef: case SparseTensorFieldKind::ValMemRef: - field = - createAllocation(builder, loc, memRefTp, heuristic, enableInit); + field = createAllocation(builder, loc, fType.cast(), + heuristic, enableInit); break; } assert(field); @@ -297,21 +247,18 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. - builder.create( - loc, constantZero(builder, loc, builder.getIndexType()), - desc.getMemSizesMemRef()); // zero memSizes - Value ptrZero = constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); // Fills dim sizes array. - genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(), - constantIndex(builder, loc, r)); + desc.setDimSize(builder, loc, r, sizes[ro]); // Pushes a leading zero to pointers memref. - if (isCompressedDim(rtp, r)) - createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero); + if (isCompressedDim(rtp, r)) { + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + ptrZero); + } } allocSchemeForRank(builder, loc, desc, /*rank=*/0); } @@ -349,10 +296,11 @@ static Value genCompressed(OpBuilder &builder, Location loc, unsigned ptrIndex = desc.getPtrMemRefIndex(d); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos); - Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1); - Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex)); - Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz); + Value plo = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pos); + Value phi = genLoad(builder, loc, desc.getMemRefField(ptrIndex), pp1); + Value msz = desc.getIdxMemSize(builder, loc, d); + // Value msz = desc.getMemSize(builder, loc, getFieldMemSizesIndex(idxIndex)); + Value phim1 = builder.create( loc, toType(builder, loc, phi, indexType), one); // Conditional expression. @@ -362,14 +310,14 @@ static Value genCompressed(OpBuilder &builder, Location loc, scf::IfOp ifOp1 = builder.create(loc, types, lt, /*else*/ true); types.pop_back(); builder.setInsertionPointToStart(&ifOp1.getThenRegion().front()); - Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1); + Value crd = genLoad(builder, loc, desc.getMemRefField(idxIndex), phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), indices[d]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); if (d > 0) - genStore(builder, loc, msz, desc.getField(ptrIndex), pos); + genStore(builder, loc, msz, desc.getMemRefField(ptrIndex), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); Value p = ifOp1.getResult(0); @@ -396,8 +344,9 @@ static Value genCompressed(OpBuilder &builder, Location loc, // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1); - createPushback(builder, loc, desc, idxIndex, indices[d]); + genStore(builder, loc, mszp1, desc.getMemRefField(ptrIndex), pp1); + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, + indices[d]); // Prepare the next dimension "as needed". if ((d + 1) < rank) allocSchemeForRank(builder, loc, desc, d + 1); @@ -459,7 +408,8 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module, // indices[d].push_back(i[d]) // pos[d] = pos[d-1] // - createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]); + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, + indices[d]); } else { assert(isDenseDim(rtp, d)); // Construct the new position as: @@ -472,7 +422,8 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module, } // Reached the actual value append/insert. if (!isDenseDim(rtp, rank - 1)) - createPushback(builder, loc, desc, desc.getValMemRefIndex(), value); + createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, + std::nullopt, value); else genStore(builder, loc, value, desc.getValMemRef(), pos); builder.create(loc, fields); @@ -565,8 +516,7 @@ static void genEndInsert(OpBuilder &builder, Location loc, if (d > 0) { Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrMemRef = desc.getPtrMemRef(d); - Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); - Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz); + Value hi = desc.getPtrMemSize(builder, loc, d); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. @@ -723,6 +673,7 @@ public: bool enableInit) : OpConversionPattern(typeConverter, context), enableBufferInitialization(enableInit) {} + LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -761,8 +712,8 @@ public: // Replace the sparse tensor deallocation with field deallocations. Location loc = op.getLoc(); - auto tuple = getTuple(adaptor.getTensor()); - for (auto input : tuple.getInputs()) + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + for (auto input : desc.getMemRefFields()) // Deallocate every buffer used to store the sparse tensor handler. rewriter.create(loc, input); @@ -1018,10 +969,7 @@ public: ConversionPatternRewriter &rewriter) const override { // Query memSizes for the actually stored values size. auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - Value field = - constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex()); - rewriter.replaceOpWithNewOp(op, desc.getMemSizesMemRef(), - field); + rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc())); return success(); } }; @@ -1029,26 +977,6 @@ public: } // namespace //===----------------------------------------------------------------------===// -// Sparse tensor type conversion into an actual buffer. -//===----------------------------------------------------------------------===// - -mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { - addConversion([](Type type) { return type; }); - addConversion(convertSparseTensorType); - - // Required by scf.for 1:N type conversion. - addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, - ValueRange inputs, - Location loc) -> std::optional { - if (!getSparseTensorEncoding(tp)) - // Not a sparse tensor. - return std::nullopt; - // Sparse compiler knows how to cancel out these casts. - return genTuple(builder, loc, tp, inputs); - }); -} - -//===----------------------------------------------------------------------===// // Public method for populating conversion rules. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp index 4dcd034..fd0126b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -22,15 +22,51 @@ static Value createIndexCast(OpBuilder &builder, Location loc, Value value, return value; } -static IntegerAttr fromOptionalInt(MLIRContext *ctx, Optional dim) { +static IntegerAttr fromOptionalInt(MLIRContext *ctx, + std::optional dim) { if (!dim) return nullptr; return IntegerAttr::get(IndexType::get(ctx), dim.value()); } -unsigned -builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, - Optional dim) const { +static std::optional +convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl &fields) { + auto enc = getSparseTensorEncoding(rtp); + if (!enc) + return std::nullopt; + + foreachFieldAndTypeInSparseTensor( + rtp, + [&fields](Type fieldType, unsigned fieldIdx, + SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + DimLevelType /*dlt*/) -> bool { + assert(fieldIdx == fields.size()); + fields.push_back(fieldType); + return true; + }); + return success(); +} + +SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { + addConversion([](Type type) { return type; }); + addConversion([&](RankedTensorType rtp, SmallVectorImpl &fields) { + return convertSparseTensorType(rtp, fields); + }); + + // Required by scf.for 1:N type conversion. + addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, + ValueRange inputs, + Location loc) -> std::optional { + if (!getSparseTensorEncoding(tp)) + // Not a sparse tensor. + return std::nullopt; + // Sparse compiler knows how to cancel out these casts. + return genTuple(builder, loc, tp, inputs); + }); +} + +unsigned StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, + std::optional dim) const { unsigned fieldIdx = -1u; foreachFieldInSparseTensor( enc, @@ -48,22 +84,20 @@ builder::StorageLayout::getMemRefFieldIndex(SparseTensorFieldKind kind, return fieldIdx; } -unsigned -builder::StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind, - Optional dim) const { +unsigned StorageLayout::getMemRefFieldIndex(StorageSpecifierKind kind, + std::optional dim) const { return getMemRefFieldIndex(toFieldKind(kind), dim); } -Value builder::SparseTensorSpecifier::getInitValue(OpBuilder &builder, - Location loc, - RankedTensorType rtp) { +Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, + RankedTensorType rtp) { return builder.create( loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp))); } -Value builder::SparseTensorSpecifier::getSpecifierField( - OpBuilder &builder, Location loc, StorageSpecifierKind kind, - Optional dim) { +Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, + StorageSpecifierKind kind, + std::optional dim) { return createIndexCast(builder, loc, builder.create( loc, getFieldType(kind, dim), specifier, kind, @@ -71,9 +105,10 @@ Value builder::SparseTensorSpecifier::getSpecifierField( builder.getIndexType()); } -void builder::SparseTensorSpecifier::setSpecifierField( - OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, - Optional dim) { +void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, + Value v, + StorageSpecifierKind kind, + std::optional dim) { specifier = builder.create( loc, specifier, kind, fromOptionalInt(specifier.getContext(), dim), createIndexCast(builder, loc, v, getFieldType(kind, dim))); @@ -81,7 +116,7 @@ void builder::SparseTensorSpecifier::setSpecifierField( constexpr uint64_t kDataFieldStartingIdx = 0; -void sparse_tensor::builder::foreachFieldInSparseTensor( +void sparse_tensor::foreachFieldInSparseTensor( const SparseTensorEncodingAttr enc, llvm::function_ref @@ -120,7 +155,7 @@ void sparse_tensor::builder::foreachFieldInSparseTensor( #undef RETURN_ON_FALSE } -void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor( +void sparse_tensor::foreachFieldAndTypeInSparseTensor( RankedTensorType rType, llvm::function_ref @@ -159,8 +194,7 @@ void sparse_tensor::builder::foreachFieldAndTypeInSparseTensor( }); } -unsigned -sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { +unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; foreachFieldInSparseTensor(enc, [&numFields](unsigned, SparseTensorFieldKind, @@ -171,8 +205,8 @@ sparse_tensor::builder::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { return numFields; } -unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding( - SparseTensorEncodingAttr enc) { +unsigned +sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; // one value memref foreachFieldInSparseTensor(enc, [&numFields](unsigned fidx, SparseTensorFieldKind, @@ -183,6 +217,6 @@ unsigned sparse_tensor::builder::getNumDataFieldsFromEncoding( }); numFields -= 1; // the last field is MetaData field assert(numFields == - builder::getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); + getNumFieldsFromEncoding(enc) - kDataFieldStartingIdx - 1); return numFields; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h index 9b4e235..d94aa1f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -22,8 +22,6 @@ namespace mlir { namespace sparse_tensor { -// FIXME: this is a tmp namespace -namespace builder { //===----------------------------------------------------------------------===// // SparseTensorDescriptor and helpers, manage the sparse tensor memory layout // scheme. @@ -171,7 +169,7 @@ public: SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) : rType(tp.cast()), fields(fields) { assert(getSparseTensorEncoding(tp) && - builder::getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == + getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == fields.size()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. @@ -355,7 +353,6 @@ getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); } -} // namespace builder } // namespace sparse_tensor } // namespace mlir #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORBUILDER_H_ diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir index 04c850f..7e10ae1 100644 --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -1,15 +1,14 @@ // RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s // CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] -// CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] : index +// CHECK: %[[S2:.*]] = arith.addi %[[A]], %[[C1]] : index // CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { // CHECK: %[[P2:.*]] = arith.muli %[[P1]], %[[C2]] @@ -18,25 +17,23 @@ // CHECK: } else { // CHECK: scf.yield %[[B]] : memref // CHECK: } -// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[S1]]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[A]]] +// CHECK: return %[[M]], %[[S2]] +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_n( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[S1:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64, -// CHECK-SAME: %[[D:.*]]: index) -> memref { +// CHECK-SAME: %[[D:.*]]: index) -> (memref, index) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[P1:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] // CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[D]] : index // CHECK: %[[T:.*]] = arith.cmpi ugt, %[[S2]], %[[P1]] // CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { @@ -55,29 +52,25 @@ func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: // CHECK: } // CHECK: %[[S:.*]] = memref.subview %[[M]]{{\[}}%[[S1]]] {{\[}}%[[D]]] [1] // CHECK: linalg.fill ins(%[[C]] : f64) outs(%[[S]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index - return %0 : memref +// CHECK: return %[[M]], %[[S2]] : memref, index +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_inbound( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[S1:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[S1:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] // CHECK: %[[S2:.*]] = arith.addi %[[S1]], %[[C1]] // CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[S1]]] -// CHECK: memref.store %[[S2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[B]] : memref -func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +// CHECK: return %[[B]], %[[S2]] : memref, index +func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index 0e3eb03..b60935c 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -47,31 +47,28 @@ }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @sparse_nop_multi_ret( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref, -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]] : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_nop_multi_ret(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -79,20 +76,18 @@ func.func @sparse_nop_multi_ret(%arg0: tensor, } // CHECK-LABEL: func @sparse_nop_call( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) -// CHECK: %[[T:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) -// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9 : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref, -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: !sparse_tensor.storage_specifier +// CHECK: %[[T:.*]]:8 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]]) +// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7 : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_nop_call(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -103,68 +98,61 @@ func.func @sparse_nop_call(%arg0: tensor, } // CHECK-LABEL: func @sparse_nop_cast( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_nop_cast_3d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]] : -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A0]], %[[A1]] : +// CHECK-SAME: memref, !sparse_tensor.storage_specifier func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor } // CHECK-LABEL: func @sparse_dense_2d( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier // CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_row( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier // CHECK: return func.func @sparse_row(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier // CHECK: return func.func @sparse_csr(%arg0: tensor) { return } // CHECK-LABEL: func @sparse_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier // CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return @@ -175,9 +163,8 @@ func.func @sparse_dcsr(%arg0: tensor) { // fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -192,12 +179,11 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { // since the latter honors the dimOrdering. // // CHECK-LABEL: func @sparse_dense_3d_dyn( -// CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) -// CHECK: %[[C:.*]] = arith.constant 2 : index -// CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> -// CHECK: return %[[L]] : index +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: !sparse_tensor.storage_specifier +// CHECK: %[[A2:.*]] = sparse_tensor.storage_specifier.get %[[A1]] dim_sz at 2 +// CHECK: %[[A3:.*]] = arith.index_cast %[[A2]] : i64 to index +// CHECK: return %[[A3]] : index func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { %c = arith.constant 1 : index %0 = tensor.dim %arg0, %c : tensor @@ -205,55 +191,51 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { } // CHECK-LABEL: func @sparse_pointers_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A4]] : memref +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A2]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_indices_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A5]] : memref +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A3]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_values_dcsr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, // CHECK-SAME: %[[A2:.*2]]: memref, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A6]] : memref +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A4]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref } // CHECK-LABEL: func @sparse_noe( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier +// CHECK: %[[A4:.*]] = sparse_tensor.storage_specifier.get %[[A3]] val_mem_sz +// CHECK: %[[NOE:.*]] = arith.index_cast %[[A4]] : i64 to index // CHECK: return %[[NOE]] : index func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> @@ -261,70 +243,66 @@ func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { } // CHECK-LABEL: func @sparse_dealloc_csr( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: memref.dealloc %[[A0]] : memref<2xindex> -// CHECK: memref.dealloc %[[A1]] : memref<3xindex> -// CHECK: memref.dealloc %[[A2]] : memref -// CHECK: memref.dealloc %[[A3]] : memref -// CHECK: memref.dealloc %[[A4]] : memref +// CHECK-SAME: %[[A0:.*]]: memref, +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier +// CHECK: memref.dealloc %[[A0]] : memref +// CHECK: memref.dealloc %[[A1]] : memref +// CHECK: memref.dealloc %[[A2]] : memref // CHECK: return func.func @sparse_dealloc_csr(%arg0: tensor) { bufferization.dealloc_tensor %arg0 : tensor return } -// CHECK-LABEL: func @sparse_alloc_csc( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref -// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref -// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> -// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> -// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @sparse_alloc_csc( +// CHECK-SAME: %[[A0:.*]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[A1:.*]] = arith.constant 10 : i64 +// CHECK: %[[A2:.*]] = arith.constant 0 : index +// CHECK: %[[A3:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[A4:.*]] = memref.cast %[[A3]] : memref<16xindex> to memref +// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<16xindex> to memref +// CHECK: %[[A7:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[A8:.*]] = memref.cast %[[A7]] : memref<16xf64> to memref +// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier +// CHECK: %[[A10:.*]] = arith.index_cast %[[A0]] : index to i64 +// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 0 with %[[A10]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A12:.*]] = sparse_tensor.storage_specifier.set %[[A11]] dim_sz at 1 with %[[A1]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A13:.*]] = sparse_tensor.storage_specifier.get %[[A12]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[A14:.*]] = arith.index_cast %[[A13]] : i64 to index +// CHECK: %[[A15:.*]], %[[A16:.*]] = sparse_tensor.push_back %[[A14]], %[[A4]], %[[A2]] : index, memref, index +// CHECK: %[[A17:.*]] = arith.index_cast %[[A16]] : index to i64 +// CHECK: %[[A18:.*]] = sparse_tensor.storage_specifier.set %[[A12]] ptr_mem_sz at 1 with %[[A17]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A23:.*]], %[[A24:.*]] = sparse_tensor.push_back %[[A16]], %[[A15]], %[[A2]], %[[A0]] : index, memref, index, index +// CHECK: %[[A25:.*]] = arith.index_cast %[[A24]] : index to i64 +// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] ptr_mem_sz at 1 with %[[A25]] : i64, !sparse_tensor.storage_specifier +// CHECK: return %[[A23]], %[[A6]], %[[A8]], %[[A26]] : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> return %1 : tensor<10x?xf64, #CSC> } -// CHECK-LABEL: func @sparse_alloc_3d() -> -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index -// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[AV:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[A2:.*]] = memref.cast %[[AV]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[A1]] : memref<1xindex>) -// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> -// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> -// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[F0]], %[[C6000]] -// CHECK: return %[[A0]], %[[A1]], %[[P]] : -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-LABEL: func.func @sparse_alloc_3d() -> (memref, !sparse_tensor.storage_specifier +// CHECK: %[[A0:.*]] = arith.constant 6000 : index +// CHECK: %[[A1:.*]] = arith.constant 20 : i64 +// CHECK: %[[A2:.*]] = arith.constant 10 : i64 +// CHECK: %[[A3:.*]] = arith.constant 30 : i64 +// CHECK: %[[A4:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<16xf64> to memref +// CHECK: %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier +// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A10:.*]] = sparse_tensor.storage_specifier.set %[[A9]] dim_sz at 2 with %[[A1]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[A11:.*]] = sparse_tensor.storage_specifier.get %[[A10]] val_mem_sz : !sparse_tensor.storage_specifier +// CHECK: %[[A12:.*]] = arith.index_cast %[[A11]] : i64 to index +// CHECK: %[[A13:.*]], %[[A14:.*]] = sparse_tensor.push_back %[[A12]], %[[A6]], %[[A4]], %[[A0]] : index, memref, f64, index +// CHECK: %[[A15:.*]] = arith.index_cast %[[A14]] : index to i64 +// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A15]] : i64, !sparse_tensor.storage_specifier +// CHECK: return %[[A13]], %[[A16]] : memref, !sparse_tensor.storage_specifier func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> @@ -364,13 +342,9 @@ func.func @sparse_expansion2() -> memref { // CHECK-LABEL: func.func @sparse_expansion3( // CHECK-SAME: %[[D0:.*]]: index, // CHECK-SAME: %{{.*}}: index) -> memref { -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[S0:.*]] = memref.alloc() : memref<2xindex> -// CHECK: memref.store %[[D0]], %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[D1:.*]] = memref.load %[[S0]]{{\[}}%[[C1]]] : memref<2xindex> -// CHECK: %[[V:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[B:.*]] = memref.alloc(%[[D1]]) : memref -// CHECK: %[[D:.*]] = memref.alloc(%[[D1]]) : memref +// CHECK: %[[V:.*]] = memref.alloc(%[[D0]]) : memref +// CHECK: %[[B:.*]] = memref.alloc(%[[D0]]) : memref +// CHECK: %[[D:.*]] = memref.alloc(%[[D0]]) : memref // CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[V]] : memref) // CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref) // CHECK: return %[[D]] : memref @@ -382,45 +356,39 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { } // CHECK-LABEL: func.func private @_insert_C_100_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) // -// CHECK-LABEL: func @sparse_compression_1d( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-LABEL: func.func @sparse_compression_1d( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier // CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK-DAG: %[[A8:.*]] = arith.constant false +// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index +// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]]) +// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref +// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref +// CHECK: %[[A20:.*]]:4 = func.call @_insert_C_100_f64_0_0(%[[A14]], %[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A19]]) +// CHECK: memref.store %[[A9]], %[[A4]]{{\[}}%[[A18]]] : memref +// CHECK: memref.store %[[A8]], %[[A5]]{{\[}}%[[A18]]] : memref +// CHECK: scf.yield %[[A20]]#0, %[[A20]]#1, %[[A20]]#2, %[[A20]]#3 // CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: return %[[A21:.*]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, %values: memref, %filled: memref, @@ -433,47 +401,54 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: index, -// CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: index, +// CHECK-SAME: %[[A7:.*6]]: f64) // -// CHECK-LABEL: func @sparse_compression( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @sparse_compression( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[A9:.*]] = arith.constant 0 : i32 +// CHECK: %[[A10:.*]] = arith.constant false +// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[A12:.*]] = arith.constant 1 : index +// CHECK: %[[A13:.*]] = arith.constant 0 : index +// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref +// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref +// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref +// CHECK: %[[A22:.*]]:4 = func.call @_insert_D_C_8_8_f64_64_32(%[[A16]], %[[A17]], %[[A18]], %[[A19]], %[[A8]], %[[A20]], %[[A21]]) : (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: memref.store %[[A11]], %[[A4]]{{\[}}%[[A20]]] : memref +// CHECK: memref.store %[[A10]], %[[A5]]{{\[}}%[[A20]]] : memref +// CHECK: scf.yield %[[A22]]#0, %[[A22]]#1, %[[A22]]#2, %[[A22]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: } +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: %[[A23:.*]] = sparse_tensor.storage_specifier.get %[[A24:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[A25:.*]] = arith.index_cast %[[A23]] : i64 to index +// CHECK: %[[A26:.*]] = memref.load %[[A24]]#0{{\[}}%[[A13]]] : memref +// CHECK: %[[A27:.*]] = scf.for %[[A28:.*]] = %[[A12]] to %[[A25]] step %[[A12]] iter_args(%[[A29:.*]] = %[[A26]]) -> (i32) { +// CHECK: %[[A30:.*]] = memref.load %[[A24]]#0{{\[}}%[[A28]]] : memref +// CHECK: %[[A31:.*]] = arith.cmpi eq, %[[A30]], %[[A9]] : i32 +// CHECK: %[[A32:.*]] = arith.select %[[A31]], %[[A29]], %[[A30]] : i32 +// CHECK: scf.if %[[A31]] { +// CHECK: memref.store %[[A29]], %[[A24]]#0{{\[}}%[[A28]]] : memref +// CHECK: } +// CHECK: scf.yield %[[A32]] : i32 +// CHECK: } +// CHECK: return %[[A24]]#0, %[[A24]]#1, %[[A24]]#2, %[[A24]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, %filled: memref, @@ -487,47 +462,52 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, } // CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: index, -// CHECK-SAME: %[[A7:.*7]]: f64) -// CHECK: %[[PV:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[PV]] +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: index, +// CHECK-SAME: %[[A7:.*6]]: f64) // -// CHECK-LABEL: func @sparse_compression_unordered( -// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-NOT: sparse_tensor.sort -// CHECK: %[[R:.*]]:5 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] -// CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) -// CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref -// CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @sparse_compression_unordered( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[A9:.*]] = arith.constant false +// CHECK: %[[A10:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[A11:.*]] = arith.constant 0 : index +// CHECK: %[[A12:.*]] = arith.constant 1 : index +// CHECK: %[[A13:.*]]:4 = scf.for %[[A14:.*]] = %[[A11]] to %[[A7]] step %[[A12]] iter_args(%[[A15:.*]] = %[[A0]], %[[A16:.*]] = %[[A1]], %[[A17:.*]] = %[[A2]], %[[A18:.*]] = %[[A3]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[A19:.*]] = memref.load %[[A6]]{{\[}}%[[A14]]] : memref +// CHECK: %[[A20:.*]] = memref.load %[[A4]]{{\[}}%[[A19]]] : memref +// CHECK: %[[A21:.*]]:4 = func.call @_insert_D_C_8_8_f64_0_0(%[[A15]], %[[A16]], %[[A17]], %[[A18]], %[[A8]], %[[A19]], %[[A20]]) : (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: memref.store %[[A10]], %[[A4]]{{\[}}%[[A19]]] : memref +// CHECK: memref.store %[[A9]], %[[A5]]{{\[}}%[[A19]]] : memref +// CHECK: scf.yield %[[A21]]#0, %[[A21]]#1, %[[A21]]#2, %[[A21]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: } +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: %[[A22:.*]] = sparse_tensor.storage_specifier.get %[[A23:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[A24:.*]] = arith.index_cast %[[A22]] : i64 to index +// CHECK: %[[A25:.*]] = memref.load %[[A23]]#0{{\[}}%[[A11]]] : memref +// CHECK: %[[A26:.*]] = scf.for %[[A27:.*]] = %[[A12]] to %[[A24]] step %[[A12]] iter_args(%[[A28:.*]] = %[[A25]]) -> (index) { +// CHECK: %[[A29:.*]] = memref.load %[[A23]]#0{{\[}}%[[A27]]] : memref +// CHECK: %[[A30:.*]] = arith.cmpi eq, %[[A29]], %[[A11]] : index +// CHECK: %[[A31:.*]] = arith.select %[[A30]], %[[A28]], %[[A29]] : index +// CHECK: scf.if %[[A30]] { +// CHECK: memref.store %[[A28]], %[[A23]]#0{{\[}}%[[A27]]] : memref +// CHECK: } +// CHECK: scf.yield %[[A31]] : index +// CHECK: } +// CHECK: return %[[A23]]#0, %[[A23]]#1, %[[A23]]#2, %[[A23]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, %values: memref, %filled: memref, @@ -541,26 +521,22 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, } // CHECK-LABEL: func.func private @_insert_C_128_f64_0_0( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : -// CHECK: func @sparse_insert( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) +// +// CHECK-LABEL: func @sparse_insert( +// CHECK-SAME: %[[A1:.*0]]: memref, +// CHECK-SAME: %[[A2:.*1]]: memref, +// CHECK-SAME: %[[A3:.*2]]: memref, +// CHECK-SAME: %[[A4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*4]]: index, +// CHECK-SAME: %[[A6:.*5]]: f64) +// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> @@ -568,26 +544,22 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) } // CHECK-LABEL: func.func private @_insert_C_128_f64_64_32( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[P:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %{{.*}}, %[[P]] : -// CHECK: func @sparse_insert_typed( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) -// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*]]: index, +// CHECK-SAME: %[[A6:.*]]: f64) +// +// CHECK-LABEL: func @sparse_insert_typed( +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A5:.*]]: index, +// CHECK-SAME: %[[A6:.*]]: f64) +// CHECK: %[[R:.*]]:4 = call @_insert_C_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3 func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector> @@ -595,14 +567,13 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind } // CHECK-LABEL: func.func @sparse_nop_convert( -// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A1]], %[[A2]], %[[A3]], %[[A4]] : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir index 4599b23..33bbe6a 100644 --- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -2,28 +2,32 @@ #SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> -// CHECK-LABEL: func @sparse_alloc_sparse_vector( -// CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[T0:.*]] = memref.alloc() : memref<1xindex> -// CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[T2:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T3:.*]] = memref.cast %[[T2]] : memref<16xindex> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T2]] : memref<16xindex>) -// CHECK: %[[T4:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[T5:.*]] = memref.cast %[[T4]] : memref<16xindex> to memref -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>) -// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[T6]] : memref<16xf64>) -// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex> -// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[T1]], %[[P0]] -// CHECK: return %[[T0]], %[[T1]], %[[P1]], %[[T5]], %[[T7]] : +// CHECK-LABEL: func.func @sparse_alloc_sparse_vector( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_5:.*]] = memref.cast %[[VAL_4]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_3]] : index) outs(%[[VAL_4]] : memref<16xindex>) +// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<16xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_3]] : index) outs(%[[VAL_6]] : memref<16xindex>) +// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<16xf64> to memref +// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_8]] : memref<16xf64>) +// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_0]] : index to i64 +// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_11]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_12]] ptr_mem_sz at 0 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i64 to index +// CHECK: %[[VAL_15:.*]], %[[VAL_16:.*]] = sparse_tensor.push_back %[[VAL_14]], %[[VAL_5]], %[[VAL_3]] : index, memref, index +// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_16]] : index to i64 +// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] ptr_mem_sz at 0 with %[[VAL_17]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_19:.*]], %[[VAL_20:.*]] = sparse_tensor.push_back %[[VAL_16]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref, index, index +// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_20]] : index to i64 +// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] ptr_mem_sz at 0 with %[[VAL_21]] : i64, !sparse_tensor.storage_specifier +// CHECK: return %[[VAL_19]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref, memref, memref, !sparse_tensor.storage_specifier func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0) : tensor %1 = sparse_tensor.load %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index e94be7e..4482cf2 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -192,19 +192,19 @@ func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: ind // ----- -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f32) -> (memref, index) { // expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}} - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f32 - return %0 : memref + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f32 + return %0#0, %0#1 : memref, index } // ----- -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f32) -> (memref, index) { %c0 = arith.constant 0: index // expected-error@+1 {{'sparse_tensor.push_back' op n must be not less than 1}} - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 {idx = 2 : index} : memref, memref, f32, index - return %0 : memref + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 : index, memref, f32, index + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index 48b4509..67fefa2 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -201,41 +201,41 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %a // ----- // CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { +// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] : index, memref, f64 // CHECK: return %[[D]] -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +func.func @sparse_push_back(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_inbound( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 +// CHECK-SAME: %[[C:.*]]: f64) -> (memref, index) { +// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] : index, memref, f64 // CHECK: return %[[D]] -func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 - return %0 : memref +func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref, %arg2: f64) -> (memref, index) { + %0:2 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 : index, memref, f64 + return %0#0, %0#1 : memref, index } // ----- // CHECK-LABEL: func @sparse_push_back_n( -// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: memref, // CHECK-SAME: %[[C:.*]]: f64, -// CHECK-SAME: %[[D:.*]]: index) -> memref { -// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] {idx = 2 : index} : memref, memref, f64, index +// CHECK-SAME: %[[D:.*]]: index) -> (memref, index) { +// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] : index, memref, f64, index // CHECK: return %[[E]] -func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index - return %0 : memref +func.func @sparse_push_back_n(%arg0: index, %arg1: memref, %arg2: f64, %arg3: index) -> (memref, index) { + %0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref, f64, index + return %0#0, %0#1 : memref, index } // ----- diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index 0c3c9bb..6922201e 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -1,24 +1,23 @@ // RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> -// CHECK-LABEL: func @for( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[LB:.*5]]: index, -// CHECK-SAME: %[[UB:.*6]]: index, -// CHECK-SAME: %[[STEP:.*7]]: index) -// CHECK: %[[OUT:.*]]:5 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args( -// CHECK-SAME: %[[SIZE:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[MEM:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[PTR:.*]] = %[[POINTER]], -// CHECK-SAME: %[[IDX:.*]] = %[[INDICES]], -// CHECK-SAME: %[[VAL:.*]] = %[[VALUE]]) -// CHECK: scf.yield %[[SIZE]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref + +// CHECK-LABEL: func.func @for( +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_5:.*4]]: index, +// CHECK-SAME: %[[VAL_6:.*5]]: index, +// CHECK-SAME: %[[VAL_7:.*6]]: index) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_8:.*]]:4 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args( +// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_1]], +// CHECK-SAME: %[[VAL_12:.*]] = %[[VAL_2]], +// CHECK-SAME: %[[VAL_13:.*]] = %[[VAL_3]], +// CHECK-SAME: %[[VAL_14:.*]] = %[[VAL_4]]) +// CHECK: scf.yield %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : +// CHECK: } +// CHECK: return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3 func.func @for(%in: tensor<1024xf32, #SparseVector>, %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> { %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in) @@ -28,26 +27,23 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>, return %1 : tensor<1024xf32, #SparseVector> } - -// CHECK-LABEL: func @if( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[DIM_SIZE_1:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE_1:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER_1:.*7]]: memref, -// CHECK-SAME: %[[INDICES_1:.*8]]: memref, -// CHECK-SAME: %[[VALUE_1:.*9]]: memref, -// CHECK-SAME: %[[I1:.*10]]: i1) -> -// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[SV:.*]]:5 = scf.if %[[I1]] -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: scf.yield %[[DIM_SIZE]], %[[MEM_SIZE]], %[[POINTER]], %[[INDICES]], %[[VALUE]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } else { -// CHECK: scf.yield %[[DIM_SIZE_1]], %[[MEM_SIZE_1]], %[[POINTER_1]], %[[INDICES_1]], %[[VALUE_1]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @if( +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_6:.*4]]: memref, +// CHECK-SAME: %[[VAL_7:.*5]]: memref, +// CHECK-SAME: %[[VAL_8:.*6]]: memref, +// CHECK-SAME: %[[VAL_9:.*7]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_10:.*]]: i1) +// CHECK: %[[VAL_11:.*]]:4 = scf.if %[[VAL_10]] +// CHECK: scf.yield %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] +// CHECK: } else { +// CHECK: scf.yield %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] +// CHECK: } +// CHECK: return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3 : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier func.func @if(%t: tensor<1024xf32, #SparseVector>, %f: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { @@ -59,26 +55,28 @@ func.func @if(%t: tensor<1024xf32, #SparseVector>, return %1 : tensor<1024xf32, #SparseVector> } -// CHECK-LABEL: func @while( -// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[I1:.*5]]: i1) -> -// CHECK-SAME: (memref<1xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[SV:.*]]:5 = scf.while ( -// CHECK-SAME: %[[TMP_DIM:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[TMP_MEM:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[TMP_PTR:.*]] = %[[POINTER]], -// CHECK-SAME: %[[TMP_IND:.*]] = %[[INDICES]], -// CHECK-SAME: %[[TMP_VAL:.*]] = %[[VALUE]]) -// CHECK: scf.condition(%[[I1]]) %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } do { -// CHECK: ^bb0(%[[TMP_DIM]]: memref<1xindex>, %[[TMP_MEM]]: memref<3xindex>, %[[TMP_PTR]]: memref, %[[TMP_IND]]: memref, %[[TMP_VAL]]: memref): -// CHECK: scf.yield %[[TMP_DIM]], %[[TMP_MEM]], %[[TMP_PTR]], %[[TMP_IND]], %[[TMP_VAL]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[SV]]#0, %[[SV]]#1, %[[SV]]#2, %[[SV]]#3, %[[SV]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref + +// CHECK-LABEL: func.func @while( +// CHECK-SAME: %[[VAL_1:.*0]]: memref, +// CHECK-SAME: %[[VAL_2:.*1]]: memref, +// CHECK-SAME: %[[VAL_3:.*2]]: memref, +// CHECK-SAME: %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_5:.*4]]: i1) +// CHECK: %[[VAL_6:.*]]:4 = scf.while ( +// CHECK-SAME: %[[VAL_8:.*]] = %[[VAL_1]], +// CHECK-SAME: %[[VAL_9:.*]] = %[[VAL_2]], +// CHECK-SAME: %[[VAL_10:.*]] = %[[VAL_3]], +// CHECK-SAME: %[[VAL_11:.*]] = %[[VAL_4]]) +// CHECK: scf.condition(%[[VAL_5]]) %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_13:.*5]]: memref, +// CHECK-SAME: %[[VAL_14:.*6]]: memref, +// CHECK-SAME: %[[VAL_15:.*7]]: memref, +// CHECK-SAME: %[[VAL_16:.*8]]: !sparse_tensor.storage_specifier +// CHECK: scf.yield %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] +// CHECK: } +// CHECK: return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3 : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> { %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> { scf.condition(%c) %in : tensor<1024xf32, #SparseVector> diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir index 3808a5b..90fa3a5 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -13,134 +13,145 @@ // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // // CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( -// CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_2:[^ ]+]]: memref, -// CHECK-SAME: %[[VAL_3:.*]]: memref, -// CHECK-SAME: %[[VAL_4:.*]]: memref, -// CHECK-SAME: %[[VAL_5:[^ ]+]]: index, -// CHECK-SAME: %[[VAL_6:.*]]: index, -// CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index -// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex> -// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index -// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index +// CHECK-SAME: %[[VAL_0:.*0]]: memref, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_4:.*4]]: index, +// CHECK-SAME: %[[VAL_5:.*5]]: index, +// CHECK-SAME: %[[VAL_6:.*6]]: f64) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_7:.*]] = arith.constant false +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_4]], %[[VAL_8]] : index +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] idx_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i64 to index +// CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_11]], %[[VAL_8]] : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_10]], %[[VAL_11]] : index // CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref -// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_5]] : index // CHECK: scf.yield %[[VAL_18]] : i1 // CHECK: } else { -// CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.yield %[[VAL_8]] : i1 +// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.yield %[[VAL_7]] : i1 // CHECK: } -// CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref) { -// CHECK: scf.yield %[[VAL_3]] : memref +// CHECK: %[[VAL_19:.*]]:2 = scf.if %[[VAL_20:.*]] -> (memref, !sparse_tensor.storage_specifier +// CHECK: scf.yield %[[VAL_1]], %[[VAL_3]] : memref, !sparse_tensor.storage_specifier // CHECK: } else { -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index -// CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref -// CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index -// CHECK: scf.yield %[[VAL_22]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index +// CHECK: memref.store %[[VAL_21]], %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_22:.*]], %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_13]], %[[VAL_1]], %[[VAL_5]] : index, memref, index +// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : index to i64 +// CHECK: %[[VAL_25:.*]] = sparse_tensor.storage_specifier.set %[[VAL_3]] idx_mem_sz at 1 with %[[VAL_24]] : i64, !sparse_tensor.storage_specifier +// CHECK: scf.yield %[[VAL_22]], %[[VAL_25]] : memref, !sparse_tensor.storage_specifier // CHECK: } -// CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[VAL_26:.*]] = sparse_tensor.storage_specifier.get %[[VAL_27:.*]]#1 val_mem_sz : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_26]] : i64 to index +// CHECK: %[[VAL_29:.*]], %[[VAL_30:.*]] = sparse_tensor.push_back %[[VAL_28]], %[[VAL_2]], %[[VAL_6]] : index, memref, f64 +// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : index to i64 +// CHECK: %[[VAL_32:.*]] = sparse_tensor.storage_specifier.set %[[VAL_27]]#1 val_mem_sz with %[[VAL_31]] : i64, !sparse_tensor.storage_specifier +// CHECK: return %[[VAL_0]], %[[VAL_27]]#0, %[[VAL_29]], %[[VAL_32]] : memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: } // CHECK-LABEL: func.func @matmul( -// CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_2:.*2]]: memref, -// CHECK-SAME: %[[VAL_3:.*3]]: memref, -// CHECK-SAME: %[[VAL_4:.*4]]: memref, -// CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, -// CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[VAL_7:.*7]]: memref, -// CHECK-SAME: %[[VAL_8:.*8]]: memref, -// CHECK-SAME: %[[VAL_9:.*9]]: memref) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant true -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> -// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> -// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref -// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>) -// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex> -// CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex> -// CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index -// CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64> -// CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1> -// CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex> -// CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref -// CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>) -// CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>) -// CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref -// CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index -// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) { -// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref -// CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) { -// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64 -// CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64 -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> -// CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1 -// CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) { -// CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> -// CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex> -// CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index -// CHECK: scf.yield %[[VAL_59]] : index +// CHECK-SAME: %[[VAL_0:.*0]]: memref, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_4:.*4]]: memref, +// CHECK-SAME: %[[VAL_5:.*5]]: memref, +// CHECK-SAME: %[[VAL_6:.*6]]: memref, +// CHECK-SAME: %[[VAL_7:.*7]]: !sparse_tensor.storage_specifier +// CHECK: %[[VAL_8:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[VAL_11:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_13:.*]] = arith.constant false +// CHECK: %[[VAL_14:.*]] = arith.constant true +// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_16:.*]] = memref.cast %[[VAL_15]] : memref<16xindex> to memref +// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<16xindex> +// CHECK: %[[VAL_18:.*]] = memref.cast %[[VAL_17]] : memref<16xindex> to memref +// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf64> +// CHECK: %[[VAL_20:.*]] = memref.cast %[[VAL_19]] : memref<16xf64> to memref +// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] dim_sz at 0 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_23:.*]] = sparse_tensor.storage_specifier.set %[[VAL_22]] dim_sz at 1 with %[[VAL_9]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_24:.*]] = sparse_tensor.storage_specifier.get %[[VAL_23]] ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_24]] : i64 to index +// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.push_back %[[VAL_25]], %[[VAL_16]], %[[VAL_11]] : index, memref, index +// CHECK: %[[VAL_28:.*]] = arith.index_cast %[[VAL_27]] : index to i64 +// CHECK: %[[VAL_29:.*]] = sparse_tensor.storage_specifier.set %[[VAL_23]] ptr_mem_sz at 1 with %[[VAL_28]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = sparse_tensor.push_back %[[VAL_27]], %[[VAL_26]], %[[VAL_11]], %[[VAL_8]] : index, memref, index, index +// CHECK: %[[VAL_34:.*]] = arith.index_cast %[[VAL_33]] : index to i64 +// CHECK: %[[VAL_35:.*]] = sparse_tensor.storage_specifier.set %[[VAL_29]] ptr_mem_sz at 1 with %[[VAL_34]] : i64, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_36:.*]] = memref.alloc() : memref<4xf64> +// CHECK: %[[VAL_37:.*]] = memref.alloc() : memref<4xi1> +// CHECK: %[[VAL_38:.*]] = memref.alloc() : memref<4xindex> +// CHECK: %[[VAL_39:.*]] = memref.cast %[[VAL_38]] : memref<4xindex> to memref +// CHECK: linalg.fill ins(%[[VAL_10]] : f64) outs(%[[VAL_36]] : memref<4xf64>) +// CHECK: linalg.fill ins(%[[VAL_13]] : i1) outs(%[[VAL_37]] : memref<4xi1>) +// CHECK: %[[VAL_40:.*]]:4 = scf.for %[[VAL_41:.*]] = %[[VAL_11]] to %[[VAL_8]] step %[[VAL_12]] iter_args(%[[VAL_42:.*]] = %[[VAL_32]], %[[VAL_43:.*]] = %[[VAL_18]], %[[VAL_44:.*]] = %[[VAL_20]], %[[VAL_45:.*]] = %[[VAL_35]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_41]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_41]], %[[VAL_12]] : index +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_49:.*]] = scf.for %[[VAL_50:.*]] = %[[VAL_46]] to %[[VAL_48]] step %[[VAL_12]] iter_args(%[[VAL_51:.*]] = %[[VAL_11]]) -> (index) { +// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_52]]] : memref +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_52]], %[[VAL_12]] : index +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_12]] iter_args(%[[VAL_59:.*]] = %[[VAL_51]]) -> (index) { +// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> +// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_53]], %[[VAL_62]] : f64 +// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_61]], %[[VAL_63]] : f64 +// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> +// CHECK: %[[VAL_66:.*]] = arith.cmpi eq, %[[VAL_65]], %[[VAL_13]] : i1 +// CHECK: %[[VAL_67:.*]] = scf.if %[[VAL_66]] -> (index) { +// CHECK: memref.store %[[VAL_14]], %[[VAL_37]]{{\[}}%[[VAL_60]]] : memref<4xi1> +// CHECK: memref.store %[[VAL_60]], %[[VAL_38]]{{\[}}%[[VAL_59]]] : memref<4xindex> +// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_59]], %[[VAL_12]] : index +// CHECK: scf.yield %[[VAL_68]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_50]] : index +// CHECK: scf.yield %[[VAL_59]] : index // CHECK: } -// CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> -// CHECK: scf.yield %[[VAL_60:.*]] : index +// CHECK: memref.store %[[VAL_64]], %[[VAL_36]]{{\[}}%[[VAL_60]]] : memref<4xf64> +// CHECK: scf.yield %[[VAL_69:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref -// CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { -// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex> -// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) -// CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1> -// CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: scf.yield %[[VAL_70:.*]] : index +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref +// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex> +// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> +// CHECK: %[[VAL_80:.*]]:4 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_74]], %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_41]], %[[VAL_78]], %[[VAL_79]]) : (memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: memref.store %[[VAL_10]], %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64> +// CHECK: memref.store %[[VAL_13]], %[[VAL_37]]{{\[}}%[[VAL_78]]] : memref<4xi1> +// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_80]]#1, %[[VAL_80]]#2, %[[VAL_80]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: } -// CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK: scf.yield %[[VAL_81:.*]]#0, %[[VAL_81]]#1, %[[VAL_81]]#2, %[[VAL_81]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64> -// CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1> -// CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex> -// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex> -// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref -// CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) { -// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref -// CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index -// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index -// CHECK: scf.if %[[VAL_81]] { -// CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref +// CHECK: memref.dealloc %[[VAL_36]] : memref<4xf64> +// CHECK: memref.dealloc %[[VAL_37]] : memref<4xi1> +// CHECK: memref.dealloc %[[VAL_38]] : memref<4xindex> +// CHECK: %[[VAL_82:.*]] = sparse_tensor.storage_specifier.get %[[VAL_83:.*]]#3 ptr_mem_sz at 1 : !sparse_tensor.storage_specifier +// CHECK: %[[VAL_84:.*]] = arith.index_cast %[[VAL_82]] : i64 to index +// CHECK: %[[VAL_85:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_86:.*]] = scf.for %[[VAL_87:.*]] = %[[VAL_12]] to %[[VAL_84]] step %[[VAL_12]] iter_args(%[[VAL_88:.*]] = %[[VAL_85]]) -> (index) { +// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref +// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_89]], %[[VAL_11]] : index +// CHECK: %[[VAL_91:.*]] = arith.select %[[VAL_90]], %[[VAL_88]], %[[VAL_89]] : index +// CHECK: scf.if %[[VAL_90]] { +// CHECK: memref.store %[[VAL_88]], %[[VAL_83]]#0{{\[}}%[[VAL_87]]] : memref // CHECK: } -// CHECK: scf.yield %[[VAL_82]] : index +// CHECK: scf.yield %[[VAL_91]] : index // CHECK: } -// CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } +// CHECK: return %[[VAL_83]]#0, %[[VAL_83]]#1, %[[VAL_83]]#2, %[[VAL_83]]#3 : memref, memref, memref, !sparse_tensor.storage_specifier func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir index 626c6f8..7f385e5 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir @@ -24,16 +24,15 @@ module { %buffer = memref.alloc(%c1) : memref memref.store %c0, %bufferSizes[%c0] : memref - %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 - %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1, %c10 {idx=0 : index} : memref, memref, f32, index + %buffer2, %s0 = sparse_tensor.push_back %c0, %buffer, %d2 : index, memref, f32 + %buffer3, %s1 = sparse_tensor.push_back %s0, %buffer2, %d1, %c10 : index, memref, f32, index // CHECK: 16 %capacity = memref.dim %buffer3, %c0 : memref vector.print %capacity : index - // CHECK: ( 11 ) - %size = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex> - vector.print %size : vector<1xindex> + // CHECK: 11 + vector.print %s1 : index // CHECK ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) %values = vector.transfer_read %buffer3[%c0], %d0: memref, vector<11xf32> -- 2.7.4