[mlir][sparse] add a cursor to sparse storage scheme
authorAart Bik <ajcbik@google.com>
Wed, 26 Oct 2022 22:07:18 +0000 (15:07 -0700)
committerAart Bik <ajcbik@google.com>
Thu, 27 Oct 2022 18:18:50 +0000 (11:18 -0700)
This prepare a subsequent revision that will generalize
the insertion code generation. Similar to the support lib,
insertions become much easier to perform with some "cursor"
bookkeeping. Note that we, in the long run, could perhaps
avoid storing the "cursor" permanently and use some
retricted-scope solution (alloca?) instead. However,
that puts harder restrictions on insertion-chain operations,
so for now we follow the more straightforward approach.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136800

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir

index bf2f77d..6baf360 100644 (file)
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
 
 namespace {
 
+static constexpr uint64_t DimSizesIdx = 0;
+static constexpr uint64_t DimCursorIdx = 1;
+static constexpr uint64_t MemSizesIdx = 2;
+static constexpr uint64_t FieldsIdx = 3;
+
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
@@ -90,11 +95,17 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
       .getResult();
 }
 
+/// Translates field index to memSizes index.
+static unsigned getMemSizesIndex(unsigned field) {
+  assert(FieldsIdx <= field);
+  return field - FieldsIdx;
+}
+
 /// Returns field index of sparse tensor type for pointers/indices, when set.
 static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
   assert(getSparseTensorEncoding(type));
   RankedTensorType rType = type.cast<RankedTensorType>();
-  unsigned field = 2; // start past sizes
+  unsigned field = FieldsIdx; // start past header
   unsigned ptr = 0;
   unsigned idx = 0;
   for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
@@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
   //
   // struct {
   //   memref<rank x index> dimSizes     ; size in each dimension
+  //   memref<rank x index> dimCursor    ; cursor in each dimension
   //   memref<n x index> memSizes        ; sizes of ptrs/inds/values
   //   ; per-dimension d:
   //   ;  if dense:
@@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
   // };
   //
   unsigned rank = rType.getShape().size();
-  // The dimSizes array.
-  fields.push_back(MemRefType::get({rank}, indexType));
-  // The memSizes array.
   unsigned lastField = getFieldIndex(type, -1u, -1u);
-  fields.push_back(MemRefType::get({lastField - 2}, indexType));
+  // The dimSizes array, dimCursor array, and memSizes array.
+  fields.push_back(MemRefType::get({rank}, indexType));
+  fields.push_back(MemRefType::get({rank}, indexType));
+  fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
   // Per-dimension storage.
   for (unsigned r = 0; r < rank; r++) {
     // Dimension level types apply in order to the reordered dimension.
@@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
   return success();
 }
 
-/// Create allocation operation.
+/// Creates allocation operation.
 static Value createAllocation(OpBuilder &builder, Location loc, Type type,
                               Value sz) {
   auto memType = MemRefType::get({ShapedType::kDynamicSize}, type);
@@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
     else
       sizes.push_back(constantIndex(builder, loc, shape[r]));
   }
-  // The dimSizes array.
+  // The dimSizes array, dimCursor array, and memSizes array.
+  unsigned lastField = getFieldIndex(type, -1u, -1u);
   Value dimSizes =
       builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
-  fields.push_back(dimSizes);
-  // The sizes array.
-  unsigned lastField = getFieldIndex(type, -1u, -1u);
+  Value dimCursor =
+      builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
   Value memSizes = builder.create<memref::AllocOp>(
-      loc, MemRefType::get({lastField - 2}, indexType));
+      loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
+  fields.push_back(dimSizes);
+  fields.push_back(dimCursor);
   fields.push_back(memSizes);
   // Per-dimension storage.
   for (unsigned r = 0; r < rank; r++) {
@@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
   return forOp;
 }
 
-/// Translates field index to memSizes index.
-static unsigned getMemSizesIndex(unsigned field) {
-  assert(2 <= field);
-  return field - 2;
-}
-
 /// Creates a pushback op for given field and updates the fields array
 /// accordingly.
 static void createPushback(OpBuilder &builder, Location loc,
                            SmallVectorImpl<Value> &fields, unsigned field,
                            Value value) {
-  assert(2 <= field && field < fields.size());
+  assert(FieldsIdx <= field && field < fields.size());
   Type etp = fields[field].getType().cast<ShapedType>().getElementType();
   if (value.getType() != etp)
     value = builder.create<arith::IndexCastOp>(loc, etp, value);
   fields[field] = builder.create<PushBackOp>(
-      loc, fields[field].getType(), fields[1], fields[field], value,
+      loc, fields[field].getType(), fields[MemSizesIdx], fields[field], value,
       APInt(64, getMemSizesIndex(field)));
 }
 
@@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
     return; // TODO: add codegen
   // push_back memSizes indices-0 index
   // push_back memSizes values    value
-  createPushback(builder, loc, fields, 3, indices[0]);
-  createPushback(builder, loc, fields, 4, value);
+  createPushback(builder, loc, fields, FieldsIdx + 1, indices[0]);
+  createPushback(builder, loc, fields, FieldsIdx + 2, value);
 }
 
 /// Generations insertion finalization code.
@@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
   // push_back memSizes pointers-0 memSizes[2]
   Value zero = constantIndex(builder, loc, 0);
   Value two = constantIndex(builder, loc, 2);
-  Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
-  createPushback(builder, loc, fields, 2, zero);
-  createPushback(builder, loc, fields, 2, size);
+  Value size = builder.create<memref::LoadOp>(loc, fields[MemSizesIdx], two);
+  createPushback(builder, loc, fields, FieldsIdx, zero);
+  createPushback(builder, loc, fields, FieldsIdx, size);
 }
 
 //===----------------------------------------------------------------------===//
@@ -759,7 +767,7 @@ public:
     unsigned lastField = fields.size() - 1;
     Value field =
         constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
-    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[MemSizesIdx], field);
     return success();
   }
 };
index 71f736d..98009ff 100644 (file)
 
 // CHECK-LABEL: func @sparse_nop(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] :
+//  CHECK-SAME:   memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
   return %arg0 : tensor<?xf64, #SparseVector>
 }
 
 // CHECK-LABEL: func @sparse_nop_multi_ret(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-//  CHECK-SAME: %[[A5:.*5]]: memref<1xindex>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<3xindex>,
-//  CHECK-SAME: %[[A7:.*7]]: memref<?xi32>,
-//  CHECK-SAME: %[[A8:.*8]]: memref<?xi64>,
-//  CHECK-SAME: %[[A9:.*9]]: memref<?xf64>) ->
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]]
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<1xindex>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<1xindex>,
+//  CHECK-SAME: %[[A8:.*8]]: memref<3xindex>,
+//  CHECK-SAME: %[[A9:.*9]]: memref<?xi32>,
+//  CHECK-SAME: %[[A10:.*10]]: memref<?xi64>,
+//  CHECK-SAME: %[[A11:.*11]]: memref<?xf64>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]], %[[A10]], %[[A11]] :
+//  CHECK-SAME:   memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
+//  CHECK-SAME:   memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
                                 %arg1: tensor<?xf64, #SparseVector>) ->
                                 (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
@@ -77,17 +83,21 @@ func.func @sparse_nop_multi_ret(%arg0: tensor<?xf64, #SparseVector>,
 
 // CHECK-LABEL: func @sparse_nop_call(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-//  CHECK-SAME: %[[A5:.*5]]: memref<1xindex>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<3xindex>,
-//  CHECK-SAME: %[[A7:.*7]]: memref<?xi32>,
-//  CHECK-SAME: %[[A8:.*8]]: memref<?xi64>,
-//  CHECK-SAME: %[[A9:.*9]]: memref<?xf64>)
-//       CHECK: %[[T0:.*]]:10 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]])
-//       CHECK: return %[[T0]]#0, %[[T0]]#1, %[[T0]]#2, %[[T0]]#3, %[[T0]]#4, %[[T0]]#5, %[[T0]]#6, %[[T0]]#7, %[[T0]]#8, %[[T0]]#9
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<1xindex>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<1xindex>,
+//  CHECK-SAME: %[[A8:.*8]]: memref<3xindex>,
+//  CHECK-SAME: %[[A9:.*9]]: memref<?xi32>,
+//  CHECK-SAME: %[[A10:.*10]]: memref<?xi64>,
+//  CHECK-SAME: %[[A11:.*11]]: memref<?xf64>)
+//       CHECK: %[[T:.*]]:12 = call @sparse_nop_multi_ret(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[A9]], %[[A10]], %[[A11]])
+//       CHECK: return %[[T]]#0, %[[T]]#1, %[[T]]#2, %[[T]]#3, %[[T]]#4, %[[T]]#5, %[[T]]#6, %[[T]]#7, %[[T]]#8, %[[T]]#9, %[[T]]#10, %[[T]]#11
+//  CHECK-SAME:   memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>,
+//  CHECK-SAME:   memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
                            %arg1: tensor<?xf64, #SparseVector>) ->
                            (tensor<?xf64, #SparseVector>, tensor<?xf64, #SparseVector>) {
@@ -99,11 +109,13 @@ func.func @sparse_nop_call(%arg0: tensor<?xf64, #SparseVector>,
 
 // CHECK-LABEL: func @sparse_nop_cast(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf32>)
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf32>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] :
+//  CHECK-SAME:   memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
 func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
   %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
   return %0 : tensor<?xf32, #SparseVector>
@@ -111,9 +123,11 @@ func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32
 
 // CHECK-LABEL: func @sparse_nop_cast_3d(
 //  CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xf32>)
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]] : memref<3xindex>, memref<1xindex>, memref<?xf32>
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf32>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]] :
+//  CHECK-SAME:   memref<3xindex>, memref<3xindex>, memref<1xindex>, memref<?xf32>
 func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
   %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
   return %0 : tensor<?x?x?xf32, #Dense3D>
@@ -121,8 +135,9 @@ func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?
 
 // CHECK-LABEL: func @sparse_dense_2d(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
 //       CHECK: return
 func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
   return
@@ -130,10 +145,11 @@ func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
 
 // CHECK-LABEL: func @sparse_row(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
 //       CHECK: return
 func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
   return
@@ -141,10 +157,11 @@ func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
 
 // CHECK-LABEL: func @sparse_csr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
 //       CHECK: return
 func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
   return
@@ -152,12 +169,13 @@ func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
 
 // CHECK-LABEL: func @sparse_dcsr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-//  CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
 //       CHECK: return
 func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
   return
@@ -169,8 +187,9 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
 //
 // CHECK-LABEL: func @sparse_dense_3d(
 //  CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
 //       CHECK: %[[C:.*]] = arith.constant 20 : index
 //       CHECK: return %[[C]] : index
 func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
@@ -186,8 +205,9 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
 //
 // CHECK-LABEL: func @sparse_dense_3d_dyn(
 //  CHECK-SAME: %[[A0:.*0]]: memref<3xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<1xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xf64>)
 //       CHECK: %[[C:.*]] = arith.constant 2 : index
 //       CHECK: %[[L:.*]] = memref.load %[[A0]][%[[C]]] : memref<3xindex>
 //       CHECK: return %[[L]] : index
@@ -199,13 +219,14 @@ func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
 
 // CHECK-LABEL: func @sparse_pointers_dcsr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-//  CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
-//       CHECK: return %[[A4]] : memref<?xi32>
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
+//       CHECK: return %[[A5]] : memref<?xi32>
 func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
   %0 = sparse_tensor.pointers %arg0 { dimension = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi32>
   return %0 : memref<?xi32>
@@ -213,13 +234,14 @@ func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32>
 
 // CHECK-LABEL: func @sparse_indices_dcsr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-//  CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
-//       CHECK: return %[[A5]] : memref<?xi64>
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
+//       CHECK: return %[[A6]] : memref<?xi64>
 func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
   %0 = sparse_tensor.indices %arg0 { dimension = 1 : index } : tensor<?x?xf64, #DCSR> to memref<?xi64>
   return %0 : memref<?xi64>
@@ -227,13 +249,14 @@ func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
 
 // CHECK-LABEL: func @sparse_values_dcsr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<5xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xi32>,
-//  CHECK-SAME: %[[A5:.*5]]: memref<?xi64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>)
-//       CHECK: return %[[A6]] : memref<?xf64>
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<5xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xi32>,
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xi64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xf64>)
+//       CHECK: return %[[A7]] : memref<?xf64>
 func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
   %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
   return %0 : memref<?xf64>
@@ -241,12 +264,13 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
 
 // CHECK-LABEL: func @sparse_noe(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
 //       CHECK: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[NOE:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
 //       CHECK: return %[[NOE]] : index
 func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
   %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
@@ -255,15 +279,17 @@ func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
 
 // CHECK-LABEL: func @sparse_dealloc_csr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>)
 //       CHECK: memref.dealloc %[[A0]] : memref<2xindex>
-//       CHECK: memref.dealloc %[[A1]] : memref<3xindex>
-//       CHECK: memref.dealloc %[[A2]] : memref<?xi32>
-//       CHECK: memref.dealloc %[[A3]] : memref<?xi64>
-//       CHECK: memref.dealloc %[[A4]] : memref<?xf64>
+//       CHECK: memref.dealloc %[[A1]] : memref<2xindex>
+//       CHECK: memref.dealloc %[[A2]] : memref<3xindex>
+//       CHECK: memref.dealloc %[[A3]] : memref<?xi32>
+//       CHECK: memref.dealloc %[[A4]] : memref<?xi64>
+//       CHECK: memref.dealloc %[[A5]] : memref<?xf64>
 //       CHECK: return
 func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
   bufferization.dealloc_tensor %arg0 : tensor<?x?xf64, #CSR>
@@ -272,11 +298,12 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
 
 // CHECK-LABEL: func @sparse_alloc_csc(
 //  CHECK-SAME: %[[A:.*]]: index) ->
-//  CHECK-SAME: memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+//  CHECK-SAME: memref<2xindex>, memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
 //       CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex>
+//       CHECK: %[[CC:.*]] = memref.alloc() : memref<2xindex>
 //       CHECK: %[[T1:.*]] = memref.alloc() : memref<3xindex>
 //       CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex>
 //       CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex>
@@ -287,7 +314,8 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
 //       CHECK: %[[T6:.*]] = memref.alloc() : memref<1xf64>
 //       CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<1xf64> to memref<?xf64>
 //       CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
-//       CHECK: return %[[T0]], %[[T1]], %[[T3]], %[[T5]], %[[T7]]
+//       CHECK: return %[[T0]], %[[CC]], %[[T1]], %[[T3]], %[[T5]], %[[T7]] :
+//  CHECK-SAME:   memref<2xindex>, memref<2xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
   %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
   %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
@@ -295,7 +323,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 }
 
 // CHECK-LABEL: func @sparse_alloc_3d() ->
-//  CHECK-SAME: memref<3xindex>, memref<1xindex>, memref<?xf64>
+//  CHECK-SAME: memref<3xindex>, memref<3xindex>, memref<1xindex>, memref<?xf64>
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
@@ -304,6 +332,7 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 //   CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index
 //   CHECK-DAG: %[[C6000:.*]] = arith.constant 6000 : index
 //       CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex>
+//       CHECK: %[[CC:.*]] = memref.alloc() : memref<3xindex>
 //       CHECK: %[[A1:.*]] = memref.alloc() : memref<1xindex>
 //       CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex>
 //       CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex>
@@ -311,7 +340,8 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 //       CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64>
 //       CHECK: %[[A2:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref<?xf64>
 //       CHECK: memref.store %[[C6000]], %[[A1]][%[[C0]]] : memref<1xindex>
-//       CHECK: return %[[A0]], %[[A1]], %[[A2]] : memref<3xindex>, memref<1xindex>, memref<?xf64>
+//       CHECK: return %[[A0]], %[[CC]], %[[A1]], %[[A2]] :
+//  CHECK-SAME:   memref<3xindex>, memref<3xindex>, memref<1xindex>, memref<?xf64>
 func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
   %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
   %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
@@ -370,36 +400,38 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 
 // CHECK-LABEL: func @sparse_compression_1d(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xindex>,
 //  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
-//  CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
-//  CHECK-SAME: %[[A8:.*8]]: index)
-//   CHECK-DAG:  %[[B0:.*]] = arith.constant false
-//   CHECK-DAG:  %[[F0:.*]] = arith.constant 0.000000e+00 : f64
-//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG:  %[[C2:.*]] = arith.constant 2 : index
-//       CHECK:  sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-//       CHECK:  %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A3]], %[[P1:.*]] = %[[A4]]) -> (memref<?xindex>, memref<?xf64>) {
-//       CHECK:    %[[T1:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
-//       CHECK:    %[[T2:.*]] = memref.load %[[A5]][%[[T1]]] : memref<?xf64>
-//       CHECK:    %[[T3:.*]] = sparse_tensor.push_back %[[A1]], %[[P0]], %[[T1]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
-//       CHECK:    %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[T2]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
-//       CHECK:    memref.store %[[F0]], %arg5[%[[T1]]] : memref<?xf64>
-//       CHECK:    memref.store %[[B0]], %arg6[%[[T1]]] : memref<?xi1>
-//       CHECK:    scf.yield %[[T3]], %[[T4]] : memref<?xindex>, memref<?xf64>
-//       CHECK:  }
-//       CHECK:  memref.dealloc %[[A5]] : memref<?xf64>
-//       CHECK:  memref.dealloc %[[A6]] : memref<?xi1>
-//       CHECK:  memref.dealloc %[[A7]] : memref<?xindex>
-//       CHECK:  %[[LL:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
-//       CHECK:  %[[P1:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
-//       CHECK:  %[[P2:.*]] = sparse_tensor.push_back %[[A1]], %[[P1]], %[[LL]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
-//       CHECK:  return %[[A0]], %[[A1]], %[[P2]], %[[R]]#0, %[[R]]#1 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
+//  CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
+//  CHECK-SAME: %[[A9:.*9]]: index)
+//   CHECK-DAG: %[[B0:.*]] = arith.constant false
+//   CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: sparse_tensor.sort %[[A9]], %[[A8]] : memref<?xindex>
+//       CHECK: %[[R:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] iter_args(%[[P0:.*]] = %[[A4]], %[[P1:.*]] = %[[A5]]) -> (memref<?xindex>, memref<?xf64>) {
+//       CHECK:   %[[T1:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
+//       CHECK:   %[[T2:.*]] = memref.load %[[A6]][%[[T1]]] : memref<?xf64>
+//       CHECK:   %[[T3:.*]] = sparse_tensor.push_back %[[A2]], %[[P0]], %[[T1]] {idx = 1 : index} : memref<3xindex>, memref<?xindex>, index
+//       CHECK:   %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[P1]], %[[T2]] {idx = 2 : index} : memref<3xindex>, memref<?xf64>, f64
+//       CHECK:   memref.store %[[F0]], %[[A6]][%[[T1]]] : memref<?xf64>
+//       CHECK:   memref.store %[[B0]], %[[A7]][%[[T1]]] : memref<?xi1>
+//       CHECK:   scf.yield %[[T3]], %[[T4]] : memref<?xindex>, memref<?xf64>
+//       CHECK: }
+//       CHECK: memref.dealloc %[[A6]] : memref<?xf64>
+//       CHECK: memref.dealloc %[[A7]] : memref<?xi1>
+//       CHECK: memref.dealloc %[[A8]] : memref<?xindex>
+//       CHECK: %[[LL:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
+//       CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A2]], %[[P1]], %[[LL]] {idx = 0 : index} : memref<3xindex>, memref<?xindex>, index
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[P2]], %[[R]]#0, %[[R]]#1
+//  CHECK-SAME:   memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
                                  %values: memref<?xf64>,
                                  %filled: memref<?xi1>,
@@ -413,29 +445,30 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 
 // CHECK-LABEL: func @sparse_compression(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
 //  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
-//  CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
-//  CHECK-SAME: %[[A8:.*8]]: index,
-//  CHECK-SAME: %[[A9:.*9]]: index)
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
+//  CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
+//  CHECK-SAME: %[[A9:.*9]]: index,
+//  CHECK-SAME: %[[A10:.*10]]: index)
 //   CHECK-DAG: %[[B0:.*]] = arith.constant false
 //   CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//       CHECK: sparse_tensor.sort %[[A8]], %[[A7]] : memref<?xindex>
-//  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
-//  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+//       CHECK: sparse_tensor.sort %[[A9]], %[[A8]] : memref<?xindex>
+//  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
+//  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
 //        TODO:   2D-insert
-//   CHECK-DAG:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
-//   CHECK-DAG:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
+//   CHECK-DAG:   memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
+//   CHECK-DAG:   memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
 //  CHECK-NEXT: }
-//   CHECK-DAG: memref.dealloc %[[A5]] : memref<?xf64>
-//   CHECK-DAG: memref.dealloc %[[A6]] : memref<?xi1>
-//   CHECK-DAG: memref.dealloc %[[A7]] : memref<?xindex>
+//   CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
+//   CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
+//   CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
 //       CHECK: return
 func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
                               %values: memref<?xf64>,
@@ -451,29 +484,30 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
 
 // CHECK-LABEL: func @sparse_compression_unordered(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<2xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xindex>,
 //  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
-//  CHECK-SAME: %[[A6:.*6]]: memref<?xi1>,
-//  CHECK-SAME: %[[A7:.*7]]: memref<?xindex>,
-//  CHECK-SAME: %[[A8:.*8]]: index,
-//  CHECK-SAME: %[[A9:.*9]]: index)
+//  CHECK-SAME: %[[A6:.*6]]: memref<?xf64>,
+//  CHECK-SAME: %[[A7:.*7]]: memref<?xi1>,
+//  CHECK-SAME: %[[A8:.*8]]: memref<?xindex>,
+//  CHECK-SAME: %[[A9:.*9]]: index,
+//  CHECK-SAME: %[[A10:.*10]]: index)
 //   CHECK-DAG: %[[B0:.*]] = arith.constant false
 //   CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-NOT: sparse_tensor.sort
-//  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A8]] step %[[C1]] {
-//  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref<?xindex>
+//  CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] {
+//  CHECK-NEXT:   %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref<?xindex>
 //        TODO:   2D-insert
-//   CHECK-DAG:   memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref<?xf64>
-//   CHECK-DAG:   memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref<?xi1>
+//   CHECK-DAG:   memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref<?xf64>
+//   CHECK-DAG:   memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref<?xi1>
 //  CHECK-NEXT: }
-//   CHECK-DAG: memref.dealloc %[[A5]] : memref<?xf64>
-//   CHECK-DAG: memref.dealloc %[[A6]] : memref<?xi1>
-//   CHECK-DAG: memref.dealloc %[[A7]] : memref<?xindex>
+//   CHECK-DAG: memref.dealloc %[[A6]] : memref<?xf64>
+//   CHECK-DAG: memref.dealloc %[[A7]] : memref<?xi1>
+//   CHECK-DAG: memref.dealloc %[[A8]] : memref<?xindex>
 //       CHECK: return
 func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
                                         %values: memref<?xf64>,
@@ -489,20 +523,22 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
 
 // CHECK-LABEL: func @sparse_insert(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
 //  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-//  CHECK-SAME: %[[A5:.*5]]: index,
-//  CHECK-SAME: %[[A6:.*6]]: f64)
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xindex>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+//  CHECK-SAME: %[[A6:.*6]]: index,
+//  CHECK-SAME: %[[A7:.*7]]: f64)
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]]
-//       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
-//       CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
-//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
-//       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]]
-//       CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+//       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A2]], %[[A4]], %[[A6]]
+//       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A2]], %[[A5]], %[[A7]]
+//       CHECK: %[[T3:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]]
+//       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[T0]], %[[T3]]
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[T4]], %[[T1]], %[[T2]] :
+//  CHECK-SAME:   memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
 func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
   %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
@@ -511,35 +547,39 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
 
 // CHECK-LABEL: func @sparse_insert_typed(
 //  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
-//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
-//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
-//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
-//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
-//  CHECK-SAME: %[[A5:.*5]]: index,
-//  CHECK-SAME: %[[A6:.*6]]: f64)
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf64>,
+//  CHECK-SAME: %[[A6:.*6]]: index,
+//  CHECK-SAME: %[[A7:.*7]]: f64)
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//       CHECK: %[[S1:.*]] = arith.index_cast %[[A5]] : index to i64
-//       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[S1]]
-//       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
-//       CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
-//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
+//       CHECK: %[[S1:.*]] = arith.index_cast %[[A6]] : index to i64
+//       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A2]], %[[A4]], %[[S1]]
+//       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A2]], %[[A5]], %[[A7]]
+//       CHECK: %[[T3:.*]] = memref.load %[[A2]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A2]], %[[A3]], %[[C0]]
 //       CHECK: %[[S2:.*]] = arith.index_cast %[[T3]] : index to i32
-//       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[S2]]
-//       CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
+//       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A2]], %[[T0]], %[[S2]]
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[T4]], %[[T1]], %[[T2]] :
+//  CHECK-SAME:   memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>
 func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
   %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
   return %1 : tensor<128xf64, #SparseVector>
 }
 
-// CHECK-LABEL:   func.func @sparse_nop_convert(
-//  CHECK-SAME:   %[[A0:.*]]: memref<1xindex>,
-//  CHECK-SAME:   %[[A1:.*]]: memref<3xindex>,
-//  CHECK-SAME:   %[[A2:.*]]: memref<?xi32>,
-//  CHECK-SAME:   %[[A3:.*]]: memref<?xi64>,
-//  CHECK-SAME:   %[[A4:.*]]: memref<?xf32>)
-//       CHECK:   return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+// CHECK-LABEL: func.func @sparse_nop_convert(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<1xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<3xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi32>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xi64>,
+//  CHECK-SAME: %[[A5:.*5]]: memref<?xf32>)
+//       CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]] :
+//  CHECK-SAME: memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
 func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
   %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor<?xf32, #SparseVector>
   return %0 : tensor<?xf32, #SparseVector>
index 1591cb4..1ab4a66 100644 (file)
@@ -574,3 +574,13 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a
   sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
   return
 }
+
+// -----
+
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+
+func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
+  // expected-error@+1 {{sparse tensor allocation should not escape function}}
+  %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSR>
+  return %0: tensor<10x?xf64, #CSR>
+}
index 13765eb..334d58c 100644 (file)
@@ -3,22 +3,24 @@
 #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
 // CHECK-LABEL:  func @for(
 // CHECK-SAME:             %[[DIM_SIZE:.*0]]: memref<1xindex>,
-// CHECK-SAME:             %[[MEM_SIZE:.*1]]: memref<3xindex>,
-// CHECK-SAME:             %[[POINTER:.*2]]: memref<?xindex>,
-// CHECK-SAME:             %[[INDICES:.*3]]: memref<?xindex>,
-// CHECK-SAME:             %[[VALUE:.*4]]: memref<?xf32>,
-// CHECK-SAME:             %[[TMP_arg5:.*5]]: index,
-// CHECK-SAME:             %[[TMP_arg6:.*6]]: index,
-// CHECK-SAME:             %[[TMP_arg7:.*7]]: index
-// CHECK:          %[[TMP_0:.*]]:5 = scf.for %[[TMP_arg8:.*]] = %[[TMP_arg5]] to %[[TMP_arg6]] step %[[TMP_arg7]] iter_args(
-// CHECK-SAME:       %[[TMP_arg9:.*]] = %[[DIM_SIZE]],
-// CHECK-SAME:       %[[TMP_arg10:.*]] = %[[MEM_SIZE]],
-// CHECK-SAME:       %[[TMP_arg11:.*]] = %[[POINTER]],
-// CHECK-SAME:       %[[TMP_arg12:.*]] = %[[INDICES]],
-// CHECK-SAME:       %[[TMP_arg13:.*]] = %[[VALUE]]) 
-// CHECK:              scf.yield %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]], %[[TMP_arg13]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
-// CHECK:            }
-// CHECK:          return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2, %[[TMP_0]]#3, %[[TMP_0]]#4 : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK-SAME:             %[[DIM_CURSOR:.*1]]: memref<1xindex>,
+// CHECK-SAME:             %[[MEM_SIZE:.*2]]: memref<3xindex>,
+// CHECK-SAME:             %[[POINTER:.*3]]: memref<?xindex>,
+// CHECK-SAME:             %[[INDICES:.*4]]: memref<?xindex>,
+// CHECK-SAME:             %[[VALUE:.*5]]: memref<?xf32>,
+// CHECK-SAME:             %[[LB:.*6]]: index,
+// CHECK-SAME:             %[[UB:.*7]]: index,
+// CHECK-SAME:             %[[STEP:.*8]]: index)
+// CHECK:          %[[OUT:.*]]:6 = scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(
+// CHECK-SAME:       %[[SIZE:.*]] = %[[DIM_SIZE]],
+// CHECK-SAME:       %[[CUR:.*]] = %[[DIM_CURSOR]],
+// CHECK-SAME:       %[[MEM:.*]] = %[[MEM_SIZE]],
+// CHECK-SAME:       %[[PTR:.*]] = %[[POINTER]],
+// CHECK-SAME:       %[[IDX:.*]] = %[[INDICES]],
+// CHECK-SAME:       %[[VAL:.*]] = %[[VALUE]])
+// CHECK:            scf.yield %[[SIZE]], %[[CUR]], %[[MEM]], %[[PTR]], %[[IDX]], %[[VAL]] : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
+// CHECK:          }
+// CHECK:          return %[[OUT]]#0, %[[OUT]]#1, %[[OUT]]#2, %[[OUT]]#3, %[[OUT]]#4, %[[OUT]]#5 : memref<1xindex>, memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
 func.func @for(%in: tensor<1024xf32, #SparseVector>,
                %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
   %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
@@ -27,3 +29,4 @@ func.func @for(%in: tensor<1024xf32, #SparseVector>,
   }
   return %1 : tensor<1024xf32, #SparseVector>
 }
+