From 80b08b68f219949a6479aeff6a54e3e5129ce7dc Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 26 Oct 2022 15:07:18 -0700 Subject: [PATCH] [mlir][sparse] add a cursor to sparse storage scheme This prepare a subsequent revision that will generalize the insertion code generation. Similar to the support lib, insertions become much easier to perform with some "cursor" bookkeeping. Note that we, in the long run, could perhaps avoid storing the "cursor" permanently and use some retricted-scope solution (alloca?) instead. However, that puts harder restrictions on insertion-chain operations, so for now we follow the more straightforward approach. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D136800 --- .../Transforms/SparseTensorCodegen.cpp | 58 +-- mlir/test/Dialect/SparseTensor/codegen.mlir | 402 +++++++++++---------- mlir/test/Dialect/SparseTensor/invalid.mlir | 10 + .../Dialect/SparseTensor/scf_1_N_conversion.mlir | 35 +- 4 files changed, 283 insertions(+), 222 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index bf2f77d..6baf360 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor; namespace { +static constexpr uint64_t DimSizesIdx = 0; +static constexpr uint64_t DimCursorIdx = 1; +static constexpr uint64_t MemSizesIdx = 2; +static constexpr uint64_t FieldsIdx = 3; + //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -90,11 +95,17 @@ static Optional sizeFromTensorAtDim(OpBuilder &rewriter, Location loc, .getResult(); } +/// Translates field index to memSizes index. +static unsigned getMemSizesIndex(unsigned field) { + assert(FieldsIdx <= field); + return field - FieldsIdx; +} + /// Returns field index of sparse tensor type for pointers/indices, when set. static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) { assert(getSparseTensorEncoding(type)); RankedTensorType rType = type.cast(); - unsigned field = 2; // start past sizes + unsigned field = FieldsIdx; // start past header unsigned ptr = 0; unsigned idx = 0; for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) { @@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl &fields) { // // struct { // memref dimSizes ; size in each dimension + // memref dimCursor ; cursor in each dimension // memref memSizes ; sizes of ptrs/inds/values // ; per-dimension d: // ; if dense: @@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl &fields) { // }; // unsigned rank = rType.getShape().size(); - // The dimSizes array. - fields.push_back(MemRefType::get({rank}, indexType)); - // The memSizes array. unsigned lastField = getFieldIndex(type, -1u, -1u); - fields.push_back(MemRefType::get({lastField - 2}, indexType)); + // The dimSizes array, dimCursor array, and memSizes array. + fields.push_back(MemRefType::get({rank}, indexType)); + fields.push_back(MemRefType::get({rank}, indexType)); + fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType)); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { // Dimension level types apply in order to the reordered dimension. @@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl &fields) { return success(); } -/// Create allocation operation. +/// Creates allocation operation. static Value createAllocation(OpBuilder &builder, Location loc, Type type, Value sz) { auto memType = MemRefType::get({ShapedType::kDynamicSize}, type); @@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, else sizes.push_back(constantIndex(builder, loc, shape[r])); } - // The dimSizes array. + // The dimSizes array, dimCursor array, and memSizes array. + unsigned lastField = getFieldIndex(type, -1u, -1u); Value dimSizes = builder.create(loc, MemRefType::get({rank}, indexType)); - fields.push_back(dimSizes); - // The sizes array. - unsigned lastField = getFieldIndex(type, -1u, -1u); + Value dimCursor = + builder.create(loc, MemRefType::get({rank}, indexType)); Value memSizes = builder.create( - loc, MemRefType::get({lastField - 2}, indexType)); + loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType)); + fields.push_back(dimSizes); + fields.push_back(dimCursor); fields.push_back(memSizes); // Per-dimension storage. for (unsigned r = 0; r < rank; r++) { @@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count, return forOp; } -/// Translates field index to memSizes index. -static unsigned getMemSizesIndex(unsigned field) { - assert(2 <= field); - return field - 2; -} - /// Creates a pushback op for given field and updates the fields array /// accordingly. static void createPushback(OpBuilder &builder, Location loc, SmallVectorImpl &fields, unsigned field, Value value) { - assert(2 <= field && field < fields.size()); + assert(FieldsIdx <= field && field < fields.size()); Type etp = fields[field].getType().cast().getElementType(); if (value.getType() != etp) value = builder.create(loc, etp, value); fields[field] = builder.create( - loc, fields[field].getType(), fields[1], fields[field], value, + loc, fields[field].getType(), fields[MemSizesIdx], fields[field], value, APInt(64, getMemSizesIndex(field))); } @@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, return; // TODO: add codegen // push_back memSizes indices-0 index // push_back memSizes values value - createPushback(builder, loc, fields, 3, indices[0]); - createPushback(builder, loc, fields, 4, value); + createPushback(builder, loc, fields, FieldsIdx + 1, indices[0]); + createPushback(builder, loc, fields, FieldsIdx + 2, value); } /// Generations insertion finalization code. @@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp, // push_back memSizes pointers-0 memSizes[2] Value zero = constantIndex(builder, loc, 0); Value two = constantIndex(builder, loc, 2); - Value size = builder.create(loc, fields[1], two); - createPushback(builder, loc, fields, 2, zero); - createPushback(builder, loc, fields, 2, size); + Value size = builder.create(loc, fields[MemSizesIdx], two); + createPushback(builder, loc, fields, FieldsIdx, zero); + createPushback(builder, loc, fields, FieldsIdx, size); } //===----------------------------------------------------------------------===// @@ -759,7 +767,7 @@ public: unsigned lastField = fields.size() - 1; Value field = constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField)); - rewriter.replaceOpWithNewOp(op, fields[1], field); + rewriter.replaceOpWithNewOp(op, fields[MemSizesIdx], field); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index 71f736d..98009ff 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -48,27 +48,33 @@ // CHECK-LABEL: func @sparse_nop( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] : +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @sparse_nop_multi_ret( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) -> -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]] +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref<1xindex>, +// CHECK-SAME: %[[A7:.*7]]: memref<1xindex>, +// CHECK-SAME: %[[A8:.*8]]: memref<3xindex>, +// CHECK-SAME: %[[A9:.*9]]: memref, +// CHECK-SAME: %[[A10:.*10]]: memref, +// CHECK-SAME: %[[A11:.*11]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]], %[[A10]], %[[A11]] : +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref, +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop_multi_ret(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -77,17 +83,21 @@ func.func @sparse_nop_multi_ret(%arg0: tensor, // CHECK-LABEL: func @sparse_nop_call( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref<1xindex>, -// CHECK-SAME: %[[A6:.*6]]: memref<3xindex>, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: memref, -// CHECK-SAME: %[[A9:.*9]]: memref) -// CHECK: %[[T0:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]) -// CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7, %[[T0]]#8, %[[T0]]#9 +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref<1xindex>, +// CHECK-SAME: %[[A7:.*7]]: memref<1xindex>, +// CHECK-SAME: %[[A8:.*8]]: memref<3xindex>, +// CHECK-SAME: %[[A9:.*9]]: memref, +// CHECK-SAME: %[[A10:.*10]]: memref, +// CHECK-SAME: %[[A11:.*11]]: memref) +// CHECK: %[[T:.*]]:12 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]], %[[A10]], %[[A11]]) +// CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9, %[[T]]#10, %[[T]]#11 +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref, +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop_call(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { @@ -99,11 +109,13 @@ func.func @sparse_nop_call(%arg0: tensor, // CHECK-LABEL: func @sparse_nop_cast( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] : +// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor return %0 : tensor @@ -111,9 +123,11 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]] : memref<3xindex>, memref<1xindex>, memref +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] : +// CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor { %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor return %0 : tensor @@ -121,8 +135,9 @@ func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref) // CHECK: return func.func @sparse_dense_2d(%arg0: tensor) { return @@ -130,10 +145,11 @@ func.func @sparse_dense_2d(%arg0: tensor) { // CHECK-LABEL: func @sparse_row( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) // CHECK: return func.func @sparse_row(%arg0: tensor) { return @@ -141,10 +157,11 @@ func.func @sparse_row(%arg0: tensor) { // CHECK-LABEL: func @sparse_csr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) // CHECK: return func.func @sparse_csr(%arg0: tensor) { return @@ -152,12 +169,13 @@ func.func @sparse_csr(%arg0: tensor) { // CHECK-LABEL: func @sparse_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref) // CHECK: return func.func @sparse_dcsr(%arg0: tensor) { return @@ -169,8 +187,9 @@ func.func @sparse_dcsr(%arg0: tensor) { // // CHECK-LABEL: func @sparse_dense_3d( // CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -186,8 +205,9 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { // // CHECK-LABEL: func @sparse_dense_3d_dyn( // CHECK-SAME: %[[A0:.*0]]: memref<3xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<1xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref) // CHECK: %[[C:.*]] = arith.constant 2 : index // CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex> // CHECK: return %[[L]] : index @@ -199,13 +219,14 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor) -> index { // CHECK-LABEL: func @sparse_pointers_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A4]] : memref +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref) +// CHECK: return %[[A5]] : memref func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref @@ -213,13 +234,14 @@ func.func @sparse_pointers_dcsr(%arg0: tensor) -> memref // CHECK-LABEL: func @sparse_indices_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A5]] : memref +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref) +// CHECK: return %[[A6]] : memref func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor to memref return %0 : memref @@ -227,13 +249,14 @@ func.func @sparse_indices_dcsr(%arg0: tensor) -> memref { // CHECK-LABEL: func @sparse_values_dcsr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<5xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref) -// CHECK: return %[[A6]] : memref +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<5xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref) +// CHECK: return %[[A7]] : memref func.func @sparse_values_dcsr(%arg0: tensor) -> memref { %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref @@ -241,12 +264,13 @@ func.func @sparse_values_dcsr(%arg0: tensor) -> memref { // CHECK-LABEL: func @sparse_noe( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) // CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> +// CHECK: %[[NOE:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex> // CHECK: return %[[NOE]] : index func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector> @@ -255,15 +279,17 @@ func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index { // CHECK-LABEL: func @sparse_dealloc_csr( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref) +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) // CHECK: memref.dealloc %[[A0]] : memref<2xindex> -// CHECK: memref.dealloc %[[A1]] : memref<3xindex> -// CHECK: memref.dealloc %[[A2]] : memref -// CHECK: memref.dealloc %[[A3]] : memref -// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A1]] : memref<2xindex> +// CHECK: memref.dealloc %[[A2]] : memref<3xindex> +// CHECK: memref.dealloc %[[A3]] : memref +// CHECK: memref.dealloc %[[A4]] : memref +// CHECK: memref.dealloc %[[A5]] : memref // CHECK: return func.func @sparse_dealloc_csr(%arg0: tensor) { bufferization.dealloc_tensor %arg0 : tensor @@ -272,11 +298,12 @@ func.func @sparse_dealloc_csr(%arg0: tensor) { // CHECK-LABEL: func @sparse_alloc_csc( // CHECK-SAME: %[[A:.*]]: index) -> -// CHECK-SAME: memref<2xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: memref<2xindex>, memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> +// CHECK: %[[CC:.*]] = memref.alloc() : memref<2xindex> // CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex> // CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> // CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> @@ -287,7 +314,8 @@ func.func @sparse_dealloc_csr(%arg0: tensor) { // CHECK: %[[T6:.*]] = memref.alloc() : memref<1xf64> // CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<1xf64> to memref // CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) -// CHECK: return %[[T0]], %[[T1]], %[[T3]], %[[T5]], %[[T7]] +// CHECK: return %[[T0]], %[[CC]], %[[T1]], %[[T3]], %[[T5]], %[[T7]] : +// CHECK-SAME: memref<2xindex>, memref<2xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> @@ -295,7 +323,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { } // CHECK-LABEL: func @sparse_alloc_3d() -> -// CHECK-SAME: memref<3xindex>, memref<1xindex>, memref +// CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -304,6 +332,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { // CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index // CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index // CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> +// CHECK: %[[CC:.*]] = memref.alloc() : memref<3xindex> // CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex> // CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> // CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> @@ -311,7 +340,8 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { // CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64> // CHECK: %[[A2:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref // CHECK: memref.store %[[C6000]], %[[A1]][%[[C0]]] : memref<1xindex> -// CHECK: return %[[A0]], %[[A1]], %[[A2]] : memref<3xindex>, memref<1xindex>, memref +// CHECK: return %[[A0]], %[[CC]], %[[A1]], %[[A2]] : +// CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> @@ -370,36 +400,38 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref { // CHECK-LABEL: func @sparse_compression_1d( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, // CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index) -// CHECK-DAG: %[[B0:.*]] = arith.constant false -// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref, memref) { -// CHECK: %[[T1:.*]] = memref.load %[[A7]][%[[I]]] : memref -// CHECK: %[[T2:.*]] = memref.load %[[A5]][%[[T1]]] : memref -// CHECK: %[[T3:.*]] = sparse_tensor.push_back %[[A1]], %[[P0]], %[[T1]] {idx = 1 : index} : memref<3xindex>, memref, index -// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[T2]] {idx = 2 : index} : memref<3xindex>, memref, f64 -// CHECK: memref.store %[[F0]], %arg5[%[[T1]]] : memref -// CHECK: memref.store %[[B0]], %arg6[%[[T1]]] : memref -// CHECK: scf.yield %[[T3]], %[[T4]] : memref, memref -// CHECK: } -// CHECK: memref.dealloc %[[A5]] : memref -// CHECK: memref.dealloc %[[A6]] : memref -// CHECK: memref.dealloc %[[A7]] : memref -// CHECK: %[[LL:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> -// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[LL]] {idx = 0 : index} : memref<3xindex>, memref, index -// CHECK: return %[[A0]], %[[A1]], %[[P2]], %[[R]]#0, %[[R]]#1 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: index) +// CHECK-DAG: %[[B0:.*]] = arith.constant false +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: sparse_tensor.sort %[[A9]], %[[A8]] : memref +// CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A4]], %[[P1:.*]] = %[[A5]]) -> (memref, memref) { +// CHECK: %[[T1:.*]] = memref.load %[[A8]][%[[I]]] : memref +// CHECK: %[[T2:.*]] = memref.load %[[A6]][%[[T1]]] : memref +// CHECK: %[[T3:.*]] = sparse_tensor.push_back %[[A2]], %[[P0]], %[[T1]] {idx = 1 : index} : memref<3xindex>, memref, index +// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[P1]], %[[T2]] {idx = 2 : index} : memref<3xindex>, memref, f64 +// CHECK: memref.store %[[F0]], %[[A6]][%[[T1]]] : memref +// CHECK: memref.store %[[B0]], %[[A7]][%[[T1]]] : memref +// CHECK: scf.yield %[[T3]], %[[T4]] : memref, memref +// CHECK: } +// CHECK: memref.dealloc %[[A6]] : memref +// CHECK: memref.dealloc %[[A7]] : memref +// CHECK: memref.dealloc %[[A8]] : memref +// CHECK: %[[LL:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex> +// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]] {idx = 0 : index} : memref<3xindex>, memref, index +// CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A2]], %[[P1]], %[[LL]] {idx = 0 : index} : memref<3xindex>, memref, index +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[P2]], %[[R]]#0, %[[R]]#1 +// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, %values: memref, %filled: memref, @@ -413,29 +445,30 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>, // CHECK-LABEL: func @sparse_compression( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, // CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: index, +// CHECK-SAME: %[[A10:.*10]]: index) // CHECK-DAG: %[[B0:.*]] = arith.constant false // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref -// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] { -// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref +// CHECK: sparse_tensor.sort %[[A9]], %[[A8]] : memref +// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] { +// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref // TODO: 2D-insert -// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref +// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref +// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref // CHECK-NEXT: } -// CHECK-DAG: memref.dealloc %[[A5]] : memref -// CHECK-DAG: memref.dealloc %[[A6]] : memref -// CHECK-DAG: memref.dealloc %[[A7]] : memref +// CHECK-DAG: memref.dealloc %[[A6]] : memref +// CHECK-DAG: memref.dealloc %[[A7]] : memref +// CHECK-DAG: memref.dealloc %[[A8]] : memref // CHECK: return func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, %values: memref, @@ -451,29 +484,30 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>, // CHECK-LABEL: func @sparse_compression_unordered( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref<2xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, // CHECK-SAME: %[[A5:.*5]]: memref, -// CHECK-SAME: %[[A6:.*6]]: memref, -// CHECK-SAME: %[[A7:.*7]]: memref, -// CHECK-SAME: %[[A8:.*8]]: index, -// CHECK-SAME: %[[A9:.*9]]: index) +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: index, +// CHECK-SAME: %[[A10:.*10]]: index) // CHECK-DAG: %[[B0:.*]] = arith.constant false // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-NOT: sparse_tensor.sort -// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] { -// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref +// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] { +// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref // TODO: 2D-insert -// CHECK-DAG: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref -// CHECK-DAG: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref +// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref +// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref // CHECK-NEXT: } -// CHECK-DAG: memref.dealloc %[[A5]] : memref -// CHECK-DAG: memref.dealloc %[[A6]] : memref -// CHECK-DAG: memref.dealloc %[[A7]] : memref +// CHECK-DAG: memref.dealloc %[[A6]] : memref +// CHECK-DAG: memref.dealloc %[[A7]] : memref +// CHECK-DAG: memref.dealloc %[[A8]] : memref // CHECK: return func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, %values: memref, @@ -489,20 +523,22 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>, // CHECK-LABEL: func @sparse_insert( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, // CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: index, +// CHECK-SAME: %[[A7:.*7]]: f64) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]] -// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] -// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> -// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] -// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]] -// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A2]], %[[A4]], %[[A6]] +// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A2]], %[[A5]], %[[A7]] +// CHECK: %[[T3:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex> +// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]] +// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[T0]], %[[T3]] +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[T4]], %[[T1]], %[[T2]] : +// CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV> @@ -511,35 +547,39 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) // CHECK-LABEL: func @sparse_insert_typed( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*2]]: memref, -// CHECK-SAME: %[[A3:.*3]]: memref, -// CHECK-SAME: %[[A4:.*4]]: memref, -// CHECK-SAME: %[[A5:.*5]]: index, -// CHECK-SAME: %[[A6:.*6]]: f64) +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: index, +// CHECK-SAME: %[[A7:.*7]]: f64) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[S1:.*]] = arith.index_cast %[[A5]] : index to i64 -// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[S1]] -// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]] -// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex> -// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] +// CHECK: %[[S1:.*]] = arith.index_cast %[[A6]] : index to i64 +// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A2]], %[[A4]], %[[S1]] +// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A2]], %[[A5]], %[[A7]] +// CHECK: %[[T3:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex> +// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]] // CHECK: %[[S2:.*]] = arith.index_cast %[[T3]] : index to i32 -// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[S2]] -// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[T0]], %[[S2]] +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[T4]], %[[T1]], %[[T2]] : +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector> %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector> return %1 : tensor<128xf64, #SparseVector> } -// CHECK-LABEL: func.func @sparse_nop_convert( -// CHECK-SAME: %[[A0:.*]]: memref<1xindex>, -// CHECK-SAME: %[[A1:.*]]: memref<3xindex>, -// CHECK-SAME: %[[A2:.*]]: memref, -// CHECK-SAME: %[[A3:.*]]: memref, -// CHECK-SAME: %[[A4:.*]]: memref) -// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-LABEL: func.func @sparse_nop_convert( +// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref) +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] : +// CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 1591cb4..1ab4a66 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -574,3 +574,13 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8> return } + +// ----- + +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> + +func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> { + // expected-error@+1 {{sparse tensor allocation should not escape function}} + %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSR> + return %0: tensor<10x?xf64, #CSR> +} diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index 13765eb..334d58c 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -3,22 +3,24 @@ #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> // CHECK-LABEL: func @for( // CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, -// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, -// CHECK-SAME: %[[POINTER:.*2]]: memref, -// CHECK-SAME: %[[INDICES:.*3]]: memref, -// CHECK-SAME: %[[VALUE:.*4]]: memref, -// CHECK-SAME: %[[TMP_arg5:.*5]]: index, -// CHECK-SAME: %[[TMP_arg6:.*6]]: index, -// CHECK-SAME: %[[TMP_arg7:.*7]]: index -// CHECK: %[[TMP_0:.*]]:5 = scf.for %[[TMP_arg8:.*]] = %[[TMP_arg5]] to %[[TMP_arg6]] step %[[TMP_arg7]] iter_args( -// CHECK-SAME: %[[TMP_arg9:.*]] = %[[DIM_SIZE]], -// CHECK-SAME: %[[TMP_arg10:.*]] = %[[MEM_SIZE]], -// CHECK-SAME: %[[TMP_arg11:.*]] = %[[POINTER]], -// CHECK-SAME: %[[TMP_arg12:.*]] = %[[INDICES]], -// CHECK-SAME: %[[TMP_arg13:.*]] = %[[VALUE]]) -// CHECK: scf.yield %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]], %[[TMP_arg13]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -// CHECK: } -// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2, %[[TMP_0]]#3, %[[TMP_0]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK-SAME: %[[DIM_CURSOR:.*1]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE:.*2]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER:.*3]]: memref, +// CHECK-SAME: %[[INDICES:.*4]]: memref, +// CHECK-SAME: %[[VALUE:.*5]]: memref, +// CHECK-SAME: %[[LB:.*6]]: index, +// CHECK-SAME: %[[UB:.*7]]: index, +// CHECK-SAME: %[[STEP:.*8]]: index) +// CHECK: %[[OUT:.*]]:6 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args( +// CHECK-SAME: %[[SIZE:.*]] = %[[DIM_SIZE]], +// CHECK-SAME: %[[CUR:.*]] = %[[DIM_CURSOR]], +// CHECK-SAME: %[[MEM:.*]] = %[[MEM_SIZE]], +// CHECK-SAME: %[[PTR:.*]] = %[[POINTER]], +// CHECK-SAME: %[[IDX:.*]] = %[[INDICES]], +// CHECK-SAME: %[[VAL:.*]] = %[[VALUE]]) +// CHECK: scf.yield %[[SIZE]], %[[CUR]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } +// CHECK: return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4, %[[OUT]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @for(%in: tensor<1024xf32, #SparseVector>, %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> { %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in) @@ -27,3 +29,4 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>, } return %1 : tensor<1024xf32, #SparseVector> } + -- 2.7.4