[mlir][sparse] implement simple codegen for insertion (and related ops)
authorAart Bik <ajcbik@google.com>
Tue, 18 Oct 2022 00:02:25 +0000 (17:02 -0700)
committerAart Bik <ajcbik@google.com>
Tue, 18 Oct 2022 01:02:08 +0000 (18:02 -0700)
This is a proof of concept insertion implementation that sets up
the basic framework and implements it with push backs for just
sparse vectors. It adds insertion/compression through SSA values,
so that we properly update the memref after after pushback operation.

Note that properly using SSA values in sparsification is still TBD
but I will wait until Peiming's loop emitter is in to avoid conflicts.

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D136008

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir

index fef73ad..14e5af0 100644 (file)
@@ -190,20 +190,21 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
 
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Management Operations. These operations are "impure" in the
-// sense that they do not properly operate on SSA values. Instead, the behavior
-// is solely defined by side-effects. These operations provide a bridge between
-// "sparsification" on one hand and a support library or actual code generation
-// on the other hand. The semantics of these operations may be refined over time
-// as our sparse abstractions evolve.
+// sense that some behavior is defined by side-effects. These operations provide
+// a bridge between "sparsification" on one hand and a support library or actual
+// code generation on the other hand. The semantics of these operations may be
+// refined over time as our sparse abstractions evolve.
 //===----------------------------------------------------------------------===//
 
 def SparseTensor_InsertOp : SparseTensor_Op<"insert",
     [TypesMatchWith<"value type matches element type of tensor",
                     "tensor", "value",
-                    "$_self.cast<ShapedType>().getElementType()">]>,
+                    "$_self.cast<ShapedType>().getElementType()">,
+     AllTypesMatch<["tensor", "result"]>]>,
     Arguments<(ins AnyType:$value,
                AnySparseTensor:$tensor,
-               Variadic<Index>:$indices)> {
+               Variadic<Index>:$indices)>,
+    Results<(outs AnySparseTensor:$result)> {
   string summary = "Inserts a value into given sparse tensor";
   string description = [{
     Inserts the given value at given indices into the underlying
@@ -221,19 +222,19 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
     different insertion regimens. Inserting in a way contrary to
     these properties results in undefined behavior.
 
-    Note that this operation is "impure" in the sense that its behavior
-    is solely defined by side-effects and not SSA values. The semantics
-    may be refined over time as our sparse abstractions evolve. In
-    particular, this operation is scheduled to be unified with the
-    dense counterpart `tensor.insert` that has proper SSA semantics.
+    Note that this operation is "impure" in the sense that even though
+    the result is modeled through an SSA value, the insertion is eventually
+    done "in place", and referencing the old SSA value is undefined behavior.
+    This operation is scheduled to be unified with the dense counterpart
+    `tensor.insert` that has pure SSA semantics.
 
     Example:
 
     ```mlir
-    sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
+    %result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
     ```
   }];
-  let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict`:` type($tensor)";
+  let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict `:` type($tensor)";
   let hasVerifier = 1;
 }
 
@@ -255,7 +256,8 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>,
     the code for capacity check and reallocation. The typical usage will be for
     "dynamic" sparse tensors for which a capacity can be set beforehand.
 
-    The operation returns an SSA value for the memref. Referencing the memref
+    Note that this operation is "impure" in the sense that even though
+    the result is modeled through an SSA value, referencing the memref
     through the old SSA value after this operation is undefined behavior.
 
     Example:
@@ -302,9 +304,9 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
     through an indirection using the added array, so that the operations are
     kept proportional to the number of nonzeros.
 
-    Note that this operation is "impure" in the sense that its behavior
-    is solely defined by side-effects and not SSA values. The semantics
-    may be refined over time as our sparse abstractions evolve.
+    Note that this operation is "impure" in the sense that even though the
+    results are modeled through SSA values, the operation relies on a proper
+    side-effecting context that sets and resets the expanded arrays.
 
     Example:
 
@@ -317,13 +319,15 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
                        " `,` type($filled) `,` type($added)";
 }
 
-def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
+def SparseTensor_CompressOp : SparseTensor_Op<"compress",
+    [AllTypesMatch<["tensor", "result"]>]>,
     Arguments<(ins AnyStridedMemRefOfRank<1>:$values,
                    StridedMemRefRankOf<[I1],[1]>:$filled,
                    StridedMemRefRankOf<[Index],[1]>:$added,
                    Index:$count,
                    AnySparseTensor:$tensor,
-                  Variadic<Index>:$indices)> {
+                  Variadic<Index>:$indices)>,
+    Results<(outs AnySparseTensor:$result)> {
   string summary = "Compressed an access pattern for insertion";
   string description = [{
     Finishes a single access pattern expansion by moving inserted elements
@@ -335,14 +339,14 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
     array, so that the operations are kept proportional to the number of
     nonzeros. See the `sparse_tensor.expand` operation for more details.
 
-    Note that this operation is "impure" in the sense that its behavior
-    is solely defined by side-effects and not SSA values. The semantics
-    may be refined over time as our sparse abstractions evolve.
+    Note that this operation is "impure" in the sense that even though
+    the result is modeled through an SSA value, the insertion is eventually
+    done "in place", and referencing the old SSA value is undefined behavior.
 
     Example:
 
     ```mlir
-    sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+    %result = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
       : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #CSR>
     ```
   }];
@@ -372,14 +376,16 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
     sparse storage format needs to be finalized. Otherwise, the operation
     simply folds away.
 
-    Note that this operation is "impure" in the sense that its behavior
-    is solely defined by side-effects and not SSA values. The semantics
-    may be refined over time as our sparse abstractions evolve.
+    Note that this operation is "impure" in the sense that even though
+    the result is modeled through an SSA value, the operation relies on
+    a proper context of materializing and inserting the tensor value.
 
-    Example:
+    Examples:
 
     ```mlir
-    %1 = sparse_tensor.load %0 : tensor<8xf64, #SV>
+    %result = sparse_tensor.load %tensor : tensor<8xf64, #SV>
+
+    %1 = sparse_tensor.load %0 hasInserts : tensor<16x32xf32, #CSR>
     ```
   }];
   let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
@@ -397,8 +403,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
     a buffer defined by a pointer.
 
     Note that this operation is "impure" in the sense that its behavior
-    is solely defined by side-effects and not SSA values. The semantics
-    may be refined over time as our sparse abstractions evolve.
+    is solely defined by side-effects and not SSA values.
 
     Example:
 
@@ -442,8 +447,7 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     be used to implement the operator.
 
     Note that this operation is "impure" in the sense that its behavior is
-    solely defined by side-effects and not SSA values. The semantics may be
-    refined over time as our sparse abstractions evolve.
+    solely defined by side-effects and not SSA values.
 
     Example:
 
index e18e1ce..5e5815b 100644 (file)
@@ -35,6 +35,18 @@ namespace {
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
+/// Returns the "tuple" value of the adapted tensor.
+static UnrealizedConversionCastOp getTuple(Value tensor) {
+  return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
+}
+
+/// Packs the given values as a "tuple" value.
+static Value genTuple(OpBuilder &rewriter, Location loc, Type tp,
+                      ValueRange values) {
+  return rewriter.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
+      .getResult(0);
+}
+
 /// Flatten a list of operands that may contain sparse tensors.
 static void flattenOperands(ValueRange operands,
                             SmallVectorImpl<Value> &flattened) {
@@ -43,14 +55,13 @@ static void flattenOperands(ValueRange operands,
   // ==>
   // memref ..., c, memref ...
   for (auto operand : operands) {
-    if (auto cast =
-            dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
-        cast && getSparseTensorEncoding(cast->getResultTypes()[0]))
+    if (auto tuple = getTuple(operand);
+        tuple && getSparseTensorEncoding(tuple->getResultTypes()[0]))
       // An unrealized_conversion_cast will be inserted by type converter to
       // inter-mix the gap between 1:N conversion between sparse tensors and
       // fields. In this case, take the operands in the cast and replace the
       // sparse tensor output with the flattened type array.
-      flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+      flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
     else
       flattened.push_back(operand);
   }
@@ -73,8 +84,7 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
 
   // Any other query can consult the dimSizes array at field 0 using,
   // accounting for the reordering applied to the sparse storage.
-  auto tuple =
-      llvm::cast<UnrealizedConversionCastOp>(adaptedValue.getDefiningOp());
+  auto tuple = getTuple(adaptedValue);
   Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim));
   return rewriter.create<memref::LoadOp>(loc, tuple.getInputs().front(), idx)
       .getResult();
@@ -264,6 +274,54 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
   return forOp;
 }
 
+/// Creates a pushback op for given field and updates the fields array
+/// accordingly.
+static void createPushback(OpBuilder &builder, Location loc,
+                           SmallVectorImpl<Value> &fields, unsigned field,
+                           Value value) {
+  assert(field < fields.size());
+  fields[field] =
+      builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
+                                 fields[field], value, APInt(64, field));
+}
+
+/// Generates insertion code.
+//
+// TODO: generalize this for any rank and format currently it is just sparse
+//       vectors as a proof of concept that we have everything in place!
+//
+static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+                      SmallVectorImpl<Value> &fields,
+                      SmallVectorImpl<Value> &indices, Value value) {
+  unsigned rank = indices.size();
+  assert(rtp.getShape().size() == rank);
+  if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
+      !isOrderedDim(rtp, 0))
+    return; // TODO: add codegen
+  // push_back memSizes pointers-0 0
+  // push_back memSizes indices-0 index
+  // push_back memSizes values    value
+  Value zero = constantIndex(builder, loc, 0);
+  createPushback(builder, loc, fields, 2, zero);
+  createPushback(builder, loc, fields, 3, indices[0]);
+  createPushback(builder, loc, fields, 4, value);
+}
+
+/// Generations insertion finalization code.
+//
+// TODO: this too only works for the very simple case
+//
+static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+                         SmallVectorImpl<Value> &fields) {
+  if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
+      !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
+    return; // TODO: add codegen
+  // push_back memSizes pointers-0 memSizes[2]
+  Value two = constantIndex(builder, loc, 2);
+  Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
+  createPushback(builder, loc, fields, 2, size);
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -325,12 +383,10 @@ public:
       assert(!sparseFlat.empty());
       if (sparseFlat.size() > 1) {
         auto flatSize = sparseFlat.size();
-        ValueRange sparseElem(iterator_range<ResultRange::iterator>(
+        ValueRange fields(iterator_range<ResultRange::iterator>(
             newCall.result_begin() + retOffset,
             newCall.result_begin() + retOffset + flatSize));
-        auto castOp = rewriter.create<UnrealizedConversionCastOp>(
-            loc, TypeRange({retType}), sparseElem);
-        castedRet.push_back(castOp.getResult(0));
+        castedRet.push_back(genTuple(rewriter, loc, retType, fields));
         retOffset += flatSize;
       } else {
         // If this is an 1:1 conversion, no need for casting.
@@ -404,8 +460,7 @@ public:
     Location loc = op.getLoc();
     SmallVector<Value, 8> fields;
     createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
-    rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
-        op, TypeRange{resType}, fields);
+    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
     return success();
   }
 };
@@ -424,8 +479,7 @@ public:
 
     // Replace the sparse tensor deallocation with field deallocations.
     Location loc = op.getLoc();
-    auto tuple = llvm::cast<UnrealizedConversionCastOp>(
-        adaptor.getTensor().getDefiningOp());
+    auto tuple = getTuple(adaptor.getTensor());
     for (auto input : tuple.getInputs())
       // Deallocate every buffer used to store the sparse tensor handler.
       rewriter.create<memref::DeallocOp>(loc, input);
@@ -442,11 +496,15 @@ public:
   LogicalResult
   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (op.getHasInserts()) {
-      // Finalize any pending insertions.
-      // TODO: implement
-    }
-    rewriter.replaceOp(op, adaptor.getOperands());
+    RankedTensorType srcType =
+        op.getTensor().getType().cast<RankedTensorType>();
+    auto tuple = getTuple(adaptor.getTensor());
+    // Prepare fields.
+    SmallVector<Value, 8> fields(tuple.getInputs());
+    // Generate optional insertion finalization code.
+    if (op.getHasInserts())
+      genEndInsert(rewriter, op.getLoc(), srcType, fields);
+    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
     return success();
   }
 };
@@ -514,10 +572,14 @@ public:
     RankedTensorType dstType =
         op.getTensor().getType().cast<RankedTensorType>();
     Type eltType = dstType.getElementType();
+    auto tuple = getTuple(adaptor.getTensor());
     Value values = adaptor.getValues();
     Value filled = adaptor.getFilled();
     Value added = adaptor.getAdded();
     Value count = adaptor.getCount();
+    // Prepare fields and indices.
+    SmallVector<Value, 8> fields(tuple.getInputs());
+    SmallVector<Value, 8> indices(adaptor.getIndices());
     // If the innermost dimension is ordered, we need to sort the indices
     // in the "added" array prior to applying the compression.
     unsigned rank = dstType.getShape().size();
@@ -532,21 +594,20 @@ public:
     //    for (i = 0; i < count; i++) {
     //      index = added[i];
     //      value = values[index];
-    //
-    //      TODO: insert prev_indices, index, value
-    //
+    //      insert({prev_indices, index}, value);
     //      values[index] = 0;
     //      filled[index] = false;
     //    }
     Value i = createFor(rewriter, loc, count).getInductionVar();
     Value index = rewriter.create<memref::LoadOp>(loc, added, i);
-    rewriter.create<memref::LoadOp>(loc, values, index);
-    // TODO: insert
+    Value value = rewriter.create<memref::LoadOp>(loc, values, index);
+    indices.push_back(index);
+    // TODO: generate yield cycle
+    genInsert(rewriter, loc, dstType, fields, indices, value);
     rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
                                      values, index);
     rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
                                      filled, index);
-
     // Deallocate the buffers on exit of the full loop nest.
     Operation *parent = op;
     for (; isa<scf::ForOp>(parent->getParentOp()) ||
@@ -559,7 +620,28 @@ public:
     rewriter.create<memref::DeallocOp>(loc, values);
     rewriter.create<memref::DeallocOp>(loc, filled);
     rewriter.create<memref::DeallocOp>(loc, added);
-    rewriter.eraseOp(op);
+    rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
+    return success();
+  }
+};
+
+/// Sparse codegen rule for the insert operator.
+class SparseInsertConverter : public OpConversionPattern<InsertOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    RankedTensorType dstType =
+        op.getTensor().getType().cast<RankedTensorType>();
+    auto tuple = getTuple(adaptor.getTensor());
+    // Prepare fields and indices.
+    SmallVector<Value, 8> fields(tuple.getInputs());
+    SmallVector<Value, 8> indices(adaptor.getIndices());
+    // Generate insertion.
+    Value value = adaptor.getValue();
+    genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
     return success();
   }
 };
@@ -576,8 +658,7 @@ public:
     // Replace the requested pointer access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    auto tuple = llvm::cast<UnrealizedConversionCastOp>(
-        adaptor.getTensor().getDefiningOp());
+    auto tuple = getTuple(adaptor.getTensor());
     unsigned idx = Base::getIndexForOp(tuple, op);
     auto fields = tuple.getInputs();
     assert(idx < fields.size());
@@ -648,6 +729,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseCastConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorLoadConverter,
                SparseExpandConverter, SparseCompressConverter,
-               SparseToPointersConverter, SparseToIndicesConverter,
-               SparseToValuesConverter>(typeConverter, patterns.getContext());
+               SparseInsertConverter, SparseToPointersConverter,
+               SparseToIndicesConverter, SparseToValuesConverter>(
+      typeConverter, patterns.getContext());
 }
index c7a6048..e54e586 100644 (file)
@@ -1071,9 +1071,9 @@ public:
                                        constantIndex(rewriter, loc, i));
     rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
     SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
-    replaceOpWithFuncCall(rewriter, op, name, {},
-                          {adaptor.getTensor(), mref, vref},
-                          EmitCInterface::On);
+    createFuncCall(rewriter, loc, name, {}, {adaptor.getTensor(), mref, vref},
+                   EmitCInterface::On);
+    rewriter.replaceOp(op, adaptor.getTensor());
     return success();
   }
 };
@@ -1149,9 +1149,10 @@ public:
       rewriter.create<memref::StoreOp>(loc, adaptor.getIndices()[i], mref,
                                        constantIndex(rewriter, loc, i));
     SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
-    replaceOpWithFuncCall(rewriter, op, name, {},
-                          {tensor, mref, values, filled, added, count},
-                          EmitCInterface::On);
+    createFuncCall(rewriter, loc, name, {},
+                   {tensor, mref, values, filled, added, count},
+                   EmitCInterface::On);
+    rewriter.replaceOp(op, adaptor.getTensor());
     // Deallocate the buffers on exit of the loop nest.
     Operation *parent = op;
     for (; isa<scf::ForOp>(parent->getParentOp()) ||
index b227cfc..ecdf0e4 100644 (file)
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s --sparse-tensor-codegen  --canonicalize --cse | FileCheck %s
 
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
 #SparseVector = #sparse_tensor.encoding<{
   dimLevelType = [ "compressed" ],
   indexBitWidth = 64,
@@ -383,10 +385,10 @@ func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
                               %filled: memref<?xi1>,
                               %added: memref<?xindex>,
                               %count: index,
-                              %i: index) {
-  sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+                              %i: index) -> tensor<8x8xf64, #CSR> {
+  %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
-  return
+  return %0 : tensor<8x8xf64, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_compression_unordered(
@@ -420,8 +422,30 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
                                         %filled: memref<?xi1>,
                                         %added: memref<?xindex>,
                                         %count: index,
-                                        %i: index) {
-  sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+                                        %i: index) -> tensor<8x8xf64, #UCSR> {
+  %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
-  return
+  return %0 : tensor<8x8xf64, #UCSR>
+}
+
+// CHECK-LABEL: func @sparse_insert(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+//  CHECK-SAME: %[[A5:.*5]]: index,
+//  CHECK-SAME: %[[A6:.*6]]: f64)
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
+//       CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]]
+//       CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
+//       CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]]
+//       CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
+  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
+  %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
+  return %1 : tensor<128xf64, #SV>
 }
index 2694564..44fcd42 100644 (file)
@@ -288,7 +288,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
 // CHECK-LABEL: func @sparse_insert(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
 //  CHECK-SAME: %[[B:.*]]: index,
-//  CHECK-SAME: %[[C:.*]]: f32) {
+//  CHECK-SAME: %[[C:.*]]: f32) -> !llvm.ptr<i8> {
 //   CHECK-DAG: %[[M:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[V:.*]] = memref.alloca() : memref<f32>
 //   CHECK-DAG: %[[MC:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
@@ -296,12 +296,12 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
 //   CHECK-DAG: memref.store %[[B]], %[[M]][%[[C0]]] : memref<1xindex>
 //   CHECK-DAG: memref.store %[[C]], %[[V]][] : memref<f32>
 //       CHECK: call @lexInsertF32(%[[A]], %[[MC]], %[[V]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f32>) -> ()
-//       CHECK: return
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
 func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
                          %arg1: index,
-                         %arg2: f32) {
-  sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
-  return
+                         %arg2: f32) -> tensor<128xf32, #SparseVector> {
+  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
+  return %0 : tensor<128xf32, #SparseVector>
 }
 
 // CHECK-LABEL: func @sparse_expansion1()
@@ -359,7 +359,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //  CHECK-SAME: %[[C:.*2]]: memref<?xi1>,
 //  CHECK-SAME: %[[D:.*3]]: memref<?xindex>,
 //  CHECK-SAME: %[[E:.*4]]: index,
-//  CHECK-SAME: %[[F:.*5]]: index)
+//  CHECK-SAME: %[[F:.*5]]: index) -> !llvm.ptr<i8> {
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[X:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[X]] : memref<2xindex> to memref<?xindex>
@@ -368,16 +368,16 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: memref.dealloc %[[B]] : memref<?xf64>
 //   CHECK-DAG: memref.dealloc %[[C]] : memref<?xi1>
 //   CHECK-DAG: memref.dealloc %[[D]] : memref<?xindex>
-//       CHECK: return
+//       CHECK: return %[[A]] : !llvm.ptr<i8>
 func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
                               %values: memref<?xf64>,
                               %filled: memref<?xi1>,
                               %added: memref<?xindex>,
                               %count: index,
-                              %i: index) {
-  sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+                              %i: index) -> tensor<8x8xf64, #CSR> {
+  %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
-  return
+  return %0 : tensor<8x8xf64, #CSR>
 }
 
 // CHECK-LABEL: func @sparse_out1(
index fd850aa..ca3f884 100644 (file)
@@ -122,12 +122,12 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
 // CHECK-LABEL: func @sparse_insert(
 //  CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>,
 //  CHECK-SAME: %[[B:.*]]: index,
-//  CHECK-SAME: %[[C:.*]]: f64) {
-//       CHECK: sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
-//       CHECK: return
-func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) {
-  sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
-  return
+//  CHECK-SAME: %[[C:.*]]: f64)
+//       CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
+//       CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
+func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
+  %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+  return %0 : tensor<128xf64, #SparseVector>
 }
 
 // -----
@@ -181,17 +181,17 @@ func.func @sparse_expansion(%tensor: tensor<8x8xf64, #SparseMatrix>) -> index {
 //  CHECK-SAME: %[[A3:.*3]]: index
 //  CHECK-SAME: %[[A4:.*4]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
 //  CHECK-SAME: %[[A5:.*5]]: index)
-//       CHECK: sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
-//       CHECK: return
+//       CHECK: %[[T:.*]] = sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
+//       CHECK: return %[[T]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
 func.func @sparse_compression(%values: memref<?xf64>,
                               %filled: memref<?xi1>,
                               %added: memref<?xindex>,
                               %count: index,
                              %tensor: tensor<8x8xf64, #SparseMatrix>,
-                             %index: index) {
-  sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
+                             %index: index) -> tensor<8x8xf64, #SparseMatrix> {
+  %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
     : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #SparseMatrix>
-  return
+  return %0 : tensor<8x8xf64, #SparseMatrix>
 }
 
 // -----