From d4db52893857a836940e0951daa205de1bb1d201 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 21 Apr 2023 20:06:36 +0000 Subject: [PATCH] [mlir][sparse] extend unpack operation to support unpacking a batched COO type Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D149103 --- .../Dialect/SparseTensor/IR/SparseTensorOps.td | 44 +++- .../SparseTensor/IR/SparseTensorDialect.cpp | 6 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 9 +- .../SparseTensor/Transforms/CodegenUtils.cpp | 12 ++ .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 5 + .../SparseTensor/Transforms/LoopEmitter.cpp | 22 +- .../Transforms/SparseTensorCodegen.cpp | 236 +++++++++++++++++---- mlir/test/Dialect/SparseTensor/invalid.mlir | 12 ++ mlir/test/Dialect/SparseTensor/roundtrip.mlir | 15 ++ mlir/test/Dialect/SparseTensor/sparse_2d.mlir | 74 +++---- mlir/test/Dialect/SparseTensor/sparse_foreach.mlir | 22 +- mlir/test/Dialect/SparseTensor/sparse_pack.mlir | 2 +- .../Dialect/SparseTensor/CPU/sparse_pack.mlir | 59 +++++- 13 files changed, 393 insertions(+), 125 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index eea58f9..f29ea60 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -124,9 +124,10 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>, } def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">, - Arguments<(ins AnySparseTensor:$tensor)>, - Results<(outs 1DTensorOf<[AnyType]>:$values, - 2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates, + Arguments<(ins AnySparseTensor:$tensor, + OptionalAttr:$batched_lvls)>, + Results<(outs TensorOf<[AnyType]>:$values, + TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates, AnySignlessIntegerOrIndex:$nse)> { let summary = "Returns the (values, coordinates) pair unpacked from the input tensor"; @@ -159,11 +160,44 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">, // %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex> // %nse = 3 ``` + + If `batched_lvls` is provided, the operation unpacks each batch of the tensors + separately. The returned `nse` is the maximum nse of all batches. For a batch with + a smaller nse, trailing zeros are appended in the result. + Example: + + ```mlir + // input BCOO format |1.1, 2.2, 3.3, 0.0| + // of 2x4 matrix |0.0, 1.2, 2.3, 0.0| + %values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1 + : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO> + // %values = arith.constant dense<[[ 1.1, 2.2, 3.3 ], + // [ 1.2, 2.3, 0.0 ]]> : tensor<2x3xf64> + // %coordinates = arith.constant dense<[[ [0], [1], [2] ], + // [ [1], [2], [0] ]> : tensor<2x3x1xindex> + ``` + }]; + + let extraClassDeclaration = [{ + /// Returns the number of leading levels that are batched. + unsigned getNumBatchedLvls(); }]; + let builders = [ + OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor), + [{ + build($_builder, $_state, values, coordinates, nse, tensor, nullptr); + }]>, + OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor), + [{ + build($_builder, $_state, resultTypes, tensor, nullptr); + }]> + ]; + + let assemblyFormat = - "$tensor attr-dict `:` type($tensor)" - "`to` type($values) `,` type($coordinates) `,` type($nse)"; + "$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`" + "type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index b235301..42776c7 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -719,7 +719,11 @@ LogicalResult UnpackOp::verify() { const auto coordinatesTp = getRankedTensorType(getCoordinates()); const auto srcTp = getSparseTensorType(getTensor()); return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp, - nullptr); + getBatchedLvlsAttr()); +} + +unsigned UnpackOp::getNumBatchedLvls() { + return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0; } LogicalResult ConvertOp::verify() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp index 8a8b2ed..f17c001 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -153,9 +153,12 @@ struct UnpackOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToAllocation(Operation *op, OpResult opResult) const { - // Similar to InsertOp, reallocation is not considered to allocate a new - // piece of memory. - return false; + // We allocate and return unpacked memory if this is a batched unpack. + // When the number of batched levels equals to zero, we reuse the + // coordinates/values memref (and reallocation if the requested output size + // is larger than the actual size). Similar to InsertOp, reallocation is + // not considered to allocate a new piece of memory. + return llvm::cast(op).getNumBatchedLvls() != 0; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index 3a488b3..9aae52d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -213,6 +213,18 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } +Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem, + Value s) { + Value load = builder.create(loc, mem, s); + if (!load.getType().isa()) { + if (load.getType().getIntOrFloatBitWidth() < 64) + load = builder.create(loc, builder.getI64Type(), load); + load = + builder.create(loc, builder.getIndexType(), load); + } + return load; +} + mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { if (tp.isa()) return builder.getFloatAttr(tp, 1.0); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index b6e6def..3e1d0b0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -75,6 +75,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp); /// Add type casting between arith and index types when needed. Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy); +/// Generates a pointer/index load from the sparse storage scheme. Narrower +/// data types need to be zero extended before casting the value into the +/// index type used for looping and indexing. +Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s); + /// Generates a 1-valued attribute of the given type. This supports /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`, /// for unsupported types we raise `llvm_unreachable` rather than diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index afa4828..ba6b464 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -41,25 +41,6 @@ using namespace mlir::sparse_tensor; // File local helper functions. //===----------------------------------------------------------------------===// -/// Generates a pointer/index load from the sparse storage scheme. Narrower -/// data types need to be zero extended before casting the value into the -/// index type used for looping and indexing. -static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, - Value s) { - // For the scalar case, we simply zero extend narrower indices into 64-bit - // values before casting to index without a performance penalty. Here too, - // however, indices that already are 64-bit, in theory, cannot express the - // full range as explained above. - Value load = builder.create(loc, mem, s); - if (!load.getType().isa()) { - if (load.getType().getIntOrFloatBitWidth() < 64) - load = builder.create(loc, builder.getI64Type(), load); - load = - builder.create(loc, builder.getIndexType(), load); - } - return load; -} - static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, Level lvl) { auto enc = getSparseTensorEncoding(tensor.getType()); @@ -707,7 +688,8 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl( continue; } - bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType); + bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) || + isCompressedWithHiDLT(lvlType); // We can at most have one sparse input, otherwise, a while loop is // required to co-iterate multiple sparse tensors. assert(!isSparseCond || !isSparse); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index c1cb092..4b94392 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -602,6 +602,25 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len, return ifOp.getResult(0); } +static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange bounds) { + assert(ivs.size() == bounds.size()); + Value crd = constantIndex(builder, loc, 0); + for (unsigned i = 0, e = ivs.size(); i < e; i++) { + crd = builder.create(loc, crd, ivs[i]); + if (i != ivs.size() - 1) + crd = builder.create(loc, crd, bounds[i + 1]); + } + return crd; +} + +ReassociationIndices getReassociationForFlattening(ShapedType srcTp) { + ReassociationIndices reassociation; + for (int i = 0, e = srcTp.getRank(); i < e; i++) + reassociation.push_back(i); + return reassociation; +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -1252,12 +1271,7 @@ static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc, [&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc, ValueRange ivs) { // Linearize index variables - Value crd = constantIndex(builder, loc, 0); - for (unsigned i = 0, e = ivs.size(); i < e; i++) { - crd = builder.create(loc, crd, ivs[i]); - if (i != ivs.size() - 1) - crd = builder.create(loc, crd, ubs[i + 1]); - } + Value crd = linearize(builder, loc, ivs, ubs); Value len = constantIndex(builder, loc, nse); Value pLo = builder.create(loc, crd, len); SmallVector indices(ivs.begin(), ivs.end()); @@ -1420,6 +1434,166 @@ struct SparsePackOpConverter : public OpConversionPattern { } }; +static LogicalResult genUnBatchedUnpackOp(UnpackOp op, + SparseTensorDescriptor desc, + ConversionPatternRewriter &rewriter) { + Location loc = op.getLoc(); + const auto srcTp = getSparseTensorType(op.getTensor()); + const Level lvlRank = srcTp.getLvlRank(); + Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0) + : desc.getAOSMemRef(); + Value valuesBuf = desc.getValMemRef(); + + // If frontend requests a static buffer, we reallocate the + // values/coordinates to ensure that we meet their need. + const auto valuesTp = getRankedTensorType(op.getValues()); + if (valuesTp.hasStaticShape()) { + // FIXME: Reallocation is not always safe! E.g., if we are unpacking a + // tensor that is packed from constants. + valuesBuf = + reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf); + } + + const auto coordinatesTp = getRankedTensorType(op.getCoordinates()); + if (coordinatesTp.hasStaticShape()) { + // FIXME: Reallocation is not always safe! E.g., if we are unpacking a + // tensor that is packed from constants. + auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1]; + flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf); + } + + Value coordinatesBuf = rewriter.create( + loc, + MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()), + flatBuf, ArrayRef{ReassociationIndices{0, 1}}); + + // Converts MemRefs back to Tensors. + Value values = rewriter.create(loc, valuesBuf); + Value coordinates = + rewriter.create(loc, coordinatesBuf); + Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc), + op.getNse().getType()); + + rewriter.replaceOp(op, {values, coordinates, nse}); + return success(); +} + +static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched, + SparseTensorDescriptor desc, + ConversionPatternRewriter &rewriter) { + assert(nBatched != 0); + Location loc = op.getLoc(); + Value c0 = constantIndex(rewriter, loc, 0); + Value c1 = constantIndex(rewriter, loc, 1); + Value c2 = constantIndex(rewriter, loc, 2); + + auto genZeroedAlloc = [loc, + &rewriter](TensorType tt) -> TypedValue { + auto mem = rewriter + .create( + loc, MemRefType::get(tt.getShape(), tt.getElementType())) + .getMemref(); + // TODO: Instead of filling the entire buffer, we can only fill the + // trailing zeros. + rewriter.create( + loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem); + return mem; + }; + SparseTensorType stt = getSparseTensorType(op.getTensor()); + TensorType valTensorTp = op.getValues().getType(); + TensorType crdTensorTp = op.getCoordinates().getType(); + TypedValue valMemref = genZeroedAlloc(valTensorTp); + TypedValue crdMemref = genZeroedAlloc(crdTensorTp); + assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape()); + + SmallVector lbs(nBatched, c0), steps(nBatched, c1); + SmallVector ubs; + for (unsigned i = 0; i < nBatched; i++) { + assert(!ShapedType::isDynamic(stt.getDimShape()[i])); + ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i])); + } + + DimLevelType dlt = stt.getLvlType(nBatched); + assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)); + Value posStep = isCompressedDLT(dlt) ? c1 // forward position index by 1 + : c2; // forward position index by 2 + auto loopNest = scf::buildLoopNest( + rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/}, + [&ubs, c0, c1, posStep, desc, nBatched, &valMemref, + &crdMemref](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange args) -> scf::ValueVector { + // crdMemref has shape: <... x nse x rank> + unsigned unBatchedRank = crdMemref.getType().getShape().back(); + Value values = desc.getValMemRef(); + Value flatCrds = unBatchedRank == 1 + ? desc.getCrdMemRefOrView(builder, loc, 0) + : desc.getAOSMemRef(); + + Value positions = desc.getPosMemRef(nBatched); + Value positLo = builder.create( + loc, linearize(builder, loc, ivs, ubs), posStep); + Value positHi = builder.create(loc, positLo, c1); + + Value pLo = genIndexLoad(builder, loc, positions, positLo); + Value pHi = genIndexLoad(builder, loc, positions, positHi); + Value nse = builder.create(loc, pHi, pLo); + + Value crdLo = builder.create( + loc, pLo, constantIndex(builder, loc, unBatchedRank)); + Value nCrd = builder.create( + loc, nse, constantIndex(builder, loc, unBatchedRank)); + + SmallVector offsets, sizes, strides; + for (unsigned i = 0; i < nBatched; i++) { + offsets.push_back(ivs[i]); + sizes.push_back(c1); + strides.push_back(c1); + } + // [0, nse, 1]. + offsets.push_back(c0); + sizes.push_back(nse); + strides.push_back(c1); + + auto valView = builder.create( + loc, valMemref, offsets, sizes, strides); + auto valReass = getReassociationForFlattening(valView.getType()); + Value valDst = + builder.create(loc, valView, valReass); + Value valSrc = + builder.create(loc, values, pLo, nse, c1); + builder.create(loc, valSrc, valDst); + + // [0, rank, 1]. + offsets.push_back(c0); + sizes.push_back(constantIndex(builder, loc, unBatchedRank)); + strides.push_back(c1); + + auto crdView = builder.create( + loc, crdMemref, offsets, sizes, strides); + auto crdReass = getReassociationForFlattening(crdView.getType()); + Value crdDst = + builder.create(loc, crdView, crdReass); + Value crdSrc = + builder.create(loc, flatCrds, crdLo, nCrd, c1); + builder.create(loc, crdSrc, crdDst); + + Value pred = builder.create( + loc, arith::CmpIPredicate::ugt, nse, args[0]); + // Choose the larger NSE + return {builder.create(loc, pred, nse, args[0])}; + }); + + // Converts MemRefs back to Tensors. + Value values = rewriter.create(loc, valMemref); + Value coordinates = + rewriter.create(loc, crdMemref); + Value nse = + genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType()); + + rewriter.replaceOp(op, {values, coordinates, nse}); + return success(); +} + struct SparseUnpackOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context, @@ -1431,52 +1605,26 @@ struct SparseUnpackOpConverter : public OpConversionPattern { matchAndRewrite(UnpackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - Location loc = op.getLoc(); const auto srcTp = getSparseTensorType(op.getTensor()); - const Level lvlRank = srcTp.getLvlRank(); + const unsigned nBatched = op.getNumBatchedLvls(); + assert(isCOOType(srcTp.getEncoding(), nBatched, true) && + desc.getFields().size() == 4); // specifier + pos + crds + values + auto logicRes = nBatched == 0 + ? genUnBatchedUnpackOp(op, desc, rewriter) + : genBatchedUnpackOp(op, nBatched, desc, rewriter); + Value posBuf = desc.getPosMemRef(nBatched); - assert(isUniqueCOOType(srcTp) && desc.getFields().size() == 4); - - Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0) - : desc.getAOSMemRef(); - Value valuesBuf = desc.getValMemRef(); - Value posBuf = desc.getPosMemRef(0); if (createDeallocs) { // Unpack ends the lifetime of the sparse tensor. While the value array // and coordinate array are unpacked and returned, the position array // becomes useless and need to be freed (if user requests). - rewriter.create(loc, posBuf); - } - - // If frontend requests a static buffer, we reallocate the - // values/coordinates to ensure that we meet their need. - const auto valuesTp = getRankedTensorType(op.getValues()); - if (valuesTp.hasStaticShape()) { - valuesBuf = - reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf); - } - - const auto coordinatesTp = getRankedTensorType(op.getCoordinates()); - if (coordinatesTp.hasStaticShape()) { - auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1]; - flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf); + // FIXME: Depending on whether the tensor being unpacked is created by + // PackOp or not, we may or may not need to free other memref fields of + // the sparse tensor too (PackOp borrows value/coordinate buffer). + rewriter.create(op.getLoc(), posBuf); } - Value coordinatesBuf = rewriter.create( - loc, - MemRefType::get(coordinatesTp.getShape(), - coordinatesTp.getElementType()), - flatBuf, ArrayRef{ReassociationIndices{0, 1}}); - - // Converts MemRefs back to Tensors. - Value values = rewriter.create(loc, valuesBuf); - Value coordinates = - rewriter.create(loc, coordinatesBuf); - Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc), - op.getNse().getType()); - - rewriter.replaceOp(op, {values, coordinates, nse}); - return success(); + return logicRes; } private: diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index b6f43ad..0766e90 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -128,6 +128,18 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>) // ----- +#BCOO = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}> + +func.func @invalid_unpack_type(%sp: tensor<2x100xf32, #BCOO>) + -> (tensor<2x6xf32>, tensor<3x6x2xi32>, i32) { + // expected-error@+1 {{values/coordinates batched level sizes don't match statically}} + %values, %coordinates, %nse = sparse_tensor.unpack %sp batched_lvls=1 + : tensor<2x100xf32, #BCOO> to tensor<2x6xf32>, tensor<3x6x2xi32>, i32 + return %values, %coordinates, %nse : tensor<2x6xf32>, tensor<3x6x2xi32>, i32 +} + +// ----- + func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref { // expected-error@+1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<128xf64> to memref diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index e3e548c..3bfa7c2 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -59,6 +59,21 @@ func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>) // ----- +#BatchedSparseVector = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}> + +// CHECK-LABEL: func @sparse_unpack( +// CHECK-SAME: %[[T:.*]]: tensor<2x100xf64, # +// CHECK: %[[D:.*]], %[[I:.*]], %[[N:.*]] = sparse_tensor.unpack %[[T]] batched_lvls = 1 +// CHECK: return %[[D]], %[[I]], %[[N]] +func.func @sparse_unpack(%sp : tensor<2x100xf64, #BatchedSparseVector>) + -> (tensor<2x6xf64>, tensor<2x6x1xi32>, i32) { + %data, %indices, %nnz = sparse_tensor.unpack %sp batched_lvls=1 + : tensor<2x100xf64, #BatchedSparseVector> to tensor<2x6xf64>, tensor<2x6x1xi32>, i32 + return %data, %indices, %nnz : tensor<2x6xf64>, tensor<2x6x1xi32>, i32 +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_dealloc( diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir index 42f2f1c..58dc1e4 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -603,19 +603,19 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T dimLevelType = [ "dense", "compressed-hi" ], }> // CHECK-LABEL: func.func @sub_ss_batched( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf64, #{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3xf64, #{{.*}}>>) -> tensor<2x3xf64, #{{.*}}>> { // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #{{.*}}>> +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #{{.*}}>> to memref +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<2x3xf64, #{{.*}}>> to memref +// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_5]]) -> (tensor<2x3xf64, #{{.*}}>>) { // CHECK: %[[VAL_15:.*]] = arith.muli %[[VAL_13]], %[[VAL_2]] : index // CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref // CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_4]] : index @@ -628,9 +628,9 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T // CHECK: %[[VAL_27:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_18]] : index // CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_22]] : index // CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1 -// CHECK: scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.condition(%[[VAL_29]]) %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, tensor<2x3xf64, #{{.*}}>> // CHECK: } do { -// CHECK: ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>): +// CHECK: ^bb0(%[[VAL_30:.*]]: index, %[[VAL_31:.*]]: index, %[[VAL_32:.*]]: tensor<2x3xf64, #{{.*}}>>): // CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_30]]] : memref // CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_33]] : index @@ -638,31 +638,31 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T // CHECK: %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index // CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index // CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1 -// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (tensor<2x3xf64, #{{.*}}>>) { // CHECK: %[[VAL_41:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref // CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_43:.*]] = arith.subf %[[VAL_41]], %[[VAL_42]] : f64 -// CHECK: %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: scf.yield %[[VAL_44]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_44:.*]] = sparse_tensor.insert %[[VAL_43]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>> +// CHECK: scf.yield %[[VAL_44]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } else { // CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index -// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_45]] -> (tensor<2x3xf64, #{{.*}}>>) { // CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_30]]] : memref -// CHECK: %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: scf.yield %[[VAL_48]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_48:.*]] = sparse_tensor.insert %[[VAL_47]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>> +// CHECK: scf.yield %[[VAL_48]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } else { // CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index -// CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (tensor<2x3xf64, #{{.*}}>>) { // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_52:.*]] = arith.negf %[[VAL_51]] : f64 -// CHECK: %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: scf.yield %[[VAL_53]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_53:.*]] = sparse_tensor.insert %[[VAL_52]] into %[[VAL_32]]{{\[}}%[[VAL_13]], %[[VAL_36]]] : tensor<2x3xf64, #{{.*}}>> +// CHECK: scf.yield %[[VAL_53]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } else { -// CHECK: scf.yield %[[VAL_32]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_32]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } -// CHECK: scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_54:.*]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } -// CHECK: scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_55:.*]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } // CHECK: %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_33]], %[[VAL_36]] : index // CHECK: %[[VAL_57:.*]] = arith.addi %[[VAL_30]], %[[VAL_4]] : index @@ -670,25 +670,25 @@ func.func @add_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T // CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_36]] : index // CHECK: %[[VAL_60:.*]] = arith.addi %[[VAL_31]], %[[VAL_4]] : index // CHECK: %[[VAL_61:.*]] = arith.select %[[VAL_59]], %[[VAL_60]], %[[VAL_31]] : index -// CHECK: scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_58]], %[[VAL_61]], %[[VAL_62:.*]] : index, index, tensor<2x3xf64, #{{.*}}>> // CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_3]] to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_65:.*]] = %[[VAL_66:.*]]#2) +// CHECK: %[[VAL_63:.*]] = scf.for %[[VAL_64:.*]] = %[[VAL_65:.*]]#0 to %[[VAL_18]] step %[[VAL_4]] iter_args(%[[VAL_66:.*]] = %[[VAL_65]]#2) // CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_64]]] : memref // CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_64]]] : memref -// CHECK: %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_65]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: scf.yield %[[VAL_69]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_69:.*]] = sparse_tensor.insert %[[VAL_68]] into %[[VAL_66]]{{\[}}%[[VAL_13]], %[[VAL_67]]] : tensor<2x3xf64, #{{.*}}>> +// CHECK: scf.yield %[[VAL_69]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_3]] to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_72:.*]] = %[[VAL_73:.*]]) -// CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref -// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref -// CHECK: %[[VAL_76:.*]] = arith.negf %[[VAL_75]] : f64 -// CHECK: %[[VAL_77:.*]] = sparse_tensor.insert %[[VAL_76]] into %[[VAL_72]]{{\[}}%[[VAL_13]], %[[VAL_74]]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: scf.yield %[[VAL_77]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_70:.*]] = scf.for %[[VAL_71:.*]] = %[[VAL_72:.*]]#1 to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_73:.*]] = %[[VAL_74:.*]]) +// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_71]]] : memref +// CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_71]]] : memref +// CHECK: %[[VAL_77:.*]] = arith.negf %[[VAL_76]] : f64 +// CHECK: %[[VAL_78:.*]] = sparse_tensor.insert %[[VAL_77]] into %[[VAL_73]]{{\[}}%[[VAL_13]], %[[VAL_75]]] : tensor<2x3xf64, #{{.*}}>> +// CHECK: scf.yield %[[VAL_78]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: scf.yield %[[VAL_78:.*]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_79:.*]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] hasInserts : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> -// CHECK: return %[[VAL_79]] : tensor<2x3xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_80:.*]] = sparse_tensor.load %[[VAL_81:.*]] hasInserts : tensor<2x3xf64, #{{.*}}>> +// CHECK: return %[[VAL_80]] : tensor<2x3xf64, #{{.*}}>> // CHECK: } func.func @sub_ss_batched(%0: tensor<2x3xf64, #BatchedVector>, %1: tensor<2x3xf64, #BatchedVector>) -> tensor<2x3xf64, #BatchedVector> { diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir index 57013e7..3d95c86 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir @@ -145,23 +145,25 @@ func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) { }> // CHECK-LABEL: func.func @foreach_bcoo( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #{{.*}}>>) { // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #{{.*}}>> to memref +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #{{.*}}>> to memref // CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { // CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index -// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index -// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref -// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_10]] step %[[VAL_3]] { -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref -// CHECK: "test.use"(%[[VAL_12]]) : (f64) -> () -// CHECK: } -// CHECK: } +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_9]] to %[[VAL_11]] step %[[VAL_3]] { +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK: "test.use"(%[[VAL_13]]) : (f64) -> () +// CHECK: } {"Emitted from" = "sparse_tensor.foreach"} +// CHECK: } {"Emitted from" = "sparse_tensor.foreach"} // CHECK: return +// CHECK: } func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) { sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do { ^bb0(%1: index, %2: index, %3: index, %v: f64) : diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir index 4648cb3..fb0d4a7 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -45,7 +45,6 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>) // CHECK-SAME: %[[VAL_3:.*]] // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 6 : index // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index -// CHECK-DAG: memref.dealloc %[[VAL_0]] : memref // CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref // CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index // CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) { @@ -69,6 +68,7 @@ func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>) // CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64> // CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32> // CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier +// CHECK: memref.dealloc %[[VAL_0]] : memref // CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_22]] : tensor<6xf64>, tensor<6x2xi32>, index // CHECK: } func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir index 2b86d56..34f0188 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -31,6 +31,10 @@ crdWidth = 32 }> +#BCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ] +}> + module { // // Main driver. @@ -60,6 +64,25 @@ module { %s4 = sparse_tensor.pack %data, %index : tensor<3xf64>, tensor<3x2xindex> to tensor<10x10xf64, #SortedCOO> + %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32> + to tensor<10x10xf64, #SortedCOOI32> + + %bdata = arith.constant dense< + [[ 1.0, 2.0, 3.0], + [ 4.0, 5.0, 0.0]] + > : tensor<2x3xf64> + + %bindex = arith.constant dense< + [[[ 1, 2], + [ 5, 6], + [ 7, 8]], + [[ 2, 3], + [ 4, 2], + [ 10, 10]]] + > : tensor<2x3x2xindex> + %bs = sparse_tensor.pack %bdata, %bindex batched_lvls = 1 : + tensor<2x3xf64>, tensor<2x3x2xindex> to tensor<2x10x10xf64, #BCOO> + // CHECK:1 // CHECK-NEXT:2 // CHECK-NEXT:1 @@ -78,8 +101,6 @@ module { vector.print %v: f64 } - %s5= sparse_tensor.pack %data, %index32 : tensor<3xf64>, tensor<3x2xi32> - to tensor<10x10xf64, #SortedCOOI32> // CHECK-NEXT:1 // CHECK-NEXT:2 // CHECK-NEXT:1 @@ -98,11 +119,23 @@ module { vector.print %v: f64 } + // CHECK-NEXT:1 + // CHECK-NEXT:2 + // CHECK-NEXT:3 + // + // CHECK-NEXT:4 + // CHECK-NEXT:5 + // + // Make sure the trailing zeros are not traversed. + // CHECK-NOT: 0 + sparse_tensor.foreach in %bs : tensor<2x10x10xf64, #BCOO> do { + ^bb0(%0: index, %1: index, %2: index, %v: f64) : + vector.print %v: f64 + } + %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32> to tensor<3xf64>, tensor<3x2xi32>, i32 - - // CHECK-NEXT: ( 1, 2, 3 ) %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64> vector.print %vd : vector<3xf64> @@ -114,8 +147,26 @@ module { // CHECK-NEXT: 3 vector.print %n : i32 + + %bd, %bi, %bn = sparse_tensor.unpack %bs batched_lvls=1 : + tensor<2x10x10xf64, #BCOO> to tensor<2x3xf64>, tensor<2x3x2xindex>, i32 + + // CHECK-NEXT: ( ( 1, 2, 3 ), ( 4, 5, 0 ) ) + %vbd = vector.transfer_read %bd[%c0, %c0], %f0 : tensor<2x3xf64>, vector<2x3xf64> + vector.print %vbd : vector<2x3xf64> + + // CHECK-NEXT: ( ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) ), ( ( 2, 3 ), ( 4, 2 ), ( 0, 0 ) ) ) + %vbi = vector.transfer_read %bi[%c0, %c0, %c0], %c0 : tensor<2x3x2xindex>, vector<2x3x2xindex> + vector.print %vbi : vector<2x3x2xindex> + + // CHECK-NEXT: 3 + vector.print %bn : i32 + %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO> to tensor<3xf64>, tensor<3x2xindex>, index + // FIXME: This should be freed by one-shot-bufferization. + bufferization.dealloc_tensor %bd : tensor<2x3xf64> + bufferization.dealloc_tensor %bi : tensor<2x3x2xindex> return } } -- 2.7.4