//===----------------------------------------------------------------------===//
// Sparse Tensor Management Operations. These operations are "impure" in the
-// sense that they do not properly operate on SSA values. Instead, the behavior
-// is solely defined by side-effects. These operations provide a bridge between
-// "sparsification" on one hand and a support library or actual code generation
-// on the other hand. The semantics of these operations may be refined over time
-// as our sparse abstractions evolve.
+// sense that some behavior is defined by side-effects. These operations provide
+// a bridge between "sparsification" on one hand and a support library or actual
+// code generation on the other hand. The semantics of these operations may be
+// refined over time as our sparse abstractions evolve.
//===----------------------------------------------------------------------===//
def SparseTensor_InsertOp : SparseTensor_Op<"insert",
[TypesMatchWith<"value type matches element type of tensor",
"tensor", "value",
- "$_self.cast<ShapedType>().getElementType()">]>,
+ "$_self.cast<ShapedType>().getElementType()">,
+ AllTypesMatch<["tensor", "result"]>]>,
Arguments<(ins AnyType:$value,
AnySparseTensor:$tensor,
- Variadic<Index>:$indices)> {
+ Variadic<Index>:$indices)>,
+ Results<(outs AnySparseTensor:$result)> {
string summary = "Inserts a value into given sparse tensor";
string description = [{
Inserts the given value at given indices into the underlying
different insertion regimens. Inserting in a way contrary to
these properties results in undefined behavior.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve. In
- particular, this operation is scheduled to be unified with the
- dense counterpart `tensor.insert` that has proper SSA semantics.
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, the insertion is eventually
+ done "in place", and referencing the old SSA value is undefined behavior.
+ This operation is scheduled to be unified with the dense counterpart
+ `tensor.insert` that has pure SSA semantics.
Example:
```mlir
- sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
+ %result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
```
}];
- let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict`:` type($tensor)";
+ let assemblyFormat = "$value `into` $tensor `[` $indices `]` attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
the code for capacity check and reallocation. The typical usage will be for
"dynamic" sparse tensors for which a capacity can be set beforehand.
- The operation returns an SSA value for the memref. Referencing the memref
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, referencing the memref
through the old SSA value after this operation is undefined behavior.
Example:
through an indirection using the added array, so that the operations are
kept proportional to the number of nonzeros.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ Note that this operation is "impure" in the sense that even though the
+ results are modeled through SSA values, the operation relies on a proper
+ side-effecting context that sets and resets the expanded arrays.
Example:
" `,` type($filled) `,` type($added)";
}
-def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
+def SparseTensor_CompressOp : SparseTensor_Op<"compress",
+ [AllTypesMatch<["tensor", "result"]>]>,
Arguments<(ins AnyStridedMemRefOfRank<1>:$values,
StridedMemRefRankOf<[I1],[1]>:$filled,
StridedMemRefRankOf<[Index],[1]>:$added,
Index:$count,
AnySparseTensor:$tensor,
- Variadic<Index>:$indices)> {
+ Variadic<Index>:$indices)>,
+ Results<(outs AnySparseTensor:$result)> {
string summary = "Compressed an access pattern for insertion";
string description = [{
Finishes a single access pattern expansion by moving inserted elements
array, so that the operations are kept proportional to the number of
nonzeros. See the `sparse_tensor.expand` operation for more details.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, the insertion is eventually
+ done "in place", and referencing the old SSA value is undefined behavior.
Example:
```mlir
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %result = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #CSR>
```
}];
sparse storage format needs to be finalized. Otherwise, the operation
simply folds away.
- Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ Note that this operation is "impure" in the sense that even though
+ the result is modeled through an SSA value, the operation relies on
+ a proper context of materializing and inserting the tensor value.
- Example:
+ Examples:
```mlir
- %1 = sparse_tensor.load %0 : tensor<8xf64, #SV>
+ %result = sparse_tensor.load %tensor : tensor<8xf64, #SV>
+
+ %1 = sparse_tensor.load %0 hasInserts : tensor<16x32xf32, #CSR>
```
}];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
a buffer defined by a pointer.
Note that this operation is "impure" in the sense that its behavior
- is solely defined by side-effects and not SSA values. The semantics
- may be refined over time as our sparse abstractions evolve.
+ is solely defined by side-effects and not SSA values.
Example:
be used to implement the operator.
Note that this operation is "impure" in the sense that its behavior is
- solely defined by side-effects and not SSA values. The semantics may be
- refined over time as our sparse abstractions evolve.
+ solely defined by side-effects and not SSA values.
Example:
// Helper methods.
//===----------------------------------------------------------------------===//
+/// Returns the "tuple" value of the adapted tensor.
+static UnrealizedConversionCastOp getTuple(Value tensor) {
+ return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
+}
+
+/// Packs the given values as a "tuple" value.
+static Value genTuple(OpBuilder &rewriter, Location loc, Type tp,
+ ValueRange values) {
+ return rewriter.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
+ .getResult(0);
+}
+
/// Flatten a list of operands that may contain sparse tensors.
static void flattenOperands(ValueRange operands,
SmallVectorImpl<Value> &flattened) {
// ==>
// memref ..., c, memref ...
for (auto operand : operands) {
- if (auto cast =
- dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
- cast && getSparseTensorEncoding(cast->getResultTypes()[0]))
+ if (auto tuple = getTuple(operand);
+ tuple && getSparseTensorEncoding(tuple->getResultTypes()[0]))
// An unrealized_conversion_cast will be inserted by type converter to
// inter-mix the gap between 1:N conversion between sparse tensors and
// fields. In this case, take the operands in the cast and replace the
// sparse tensor output with the flattened type array.
- flattened.append(cast.getOperands().begin(), cast.getOperands().end());
+ flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
else
flattened.push_back(operand);
}
// Any other query can consult the dimSizes array at field 0 using,
// accounting for the reordering applied to the sparse storage.
- auto tuple =
- llvm::cast<UnrealizedConversionCastOp>(adaptedValue.getDefiningOp());
+ auto tuple = getTuple(adaptedValue);
Value idx = constantIndex(rewriter, loc, toStoredDim(tensorTp, dim));
return rewriter.create<memref::LoadOp>(loc, tuple.getInputs().front(), idx)
.getResult();
return forOp;
}
+/// Creates a pushback op for given field and updates the fields array
+/// accordingly.
+static void createPushback(OpBuilder &builder, Location loc,
+ SmallVectorImpl<Value> &fields, unsigned field,
+ Value value) {
+ assert(field < fields.size());
+ fields[field] =
+ builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
+ fields[field], value, APInt(64, field));
+}
+
+/// Generates insertion code.
+//
+// TODO: generalize this for any rank and format currently it is just sparse
+// vectors as a proof of concept that we have everything in place!
+//
+static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields,
+ SmallVectorImpl<Value> &indices, Value value) {
+ unsigned rank = indices.size();
+ assert(rtp.getShape().size() == rank);
+ if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
+ !isOrderedDim(rtp, 0))
+ return; // TODO: add codegen
+ // push_back memSizes pointers-0 0
+ // push_back memSizes indices-0 index
+ // push_back memSizes values value
+ Value zero = constantIndex(builder, loc, 0);
+ createPushback(builder, loc, fields, 2, zero);
+ createPushback(builder, loc, fields, 3, indices[0]);
+ createPushback(builder, loc, fields, 4, value);
+}
+
+/// Generations insertion finalization code.
+//
+// TODO: this too only works for the very simple case
+//
+static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+ SmallVectorImpl<Value> &fields) {
+ if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
+ !isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
+ return; // TODO: add codegen
+ // push_back memSizes pointers-0 memSizes[2]
+ Value two = constantIndex(builder, loc, 2);
+ Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
+ createPushback(builder, loc, fields, 2, size);
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- ValueRange sparseElem(iterator_range<ResultRange::iterator>(
+ ValueRange fields(iterator_range<ResultRange::iterator>(
newCall.result_begin() + retOffset,
newCall.result_begin() + retOffset + flatSize));
- auto castOp = rewriter.create<UnrealizedConversionCastOp>(
- loc, TypeRange({retType}), sparseElem);
- castedRet.push_back(castOp.getResult(0));
+ castedRet.push_back(genTuple(rewriter, loc, retType, fields));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
Location loc = op.getLoc();
SmallVector<Value, 8> fields;
createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
- rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
- op, TypeRange{resType}, fields);
+ rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
}
};
// Replace the sparse tensor deallocation with field deallocations.
Location loc = op.getLoc();
- auto tuple = llvm::cast<UnrealizedConversionCastOp>(
- adaptor.getTensor().getDefiningOp());
+ auto tuple = getTuple(adaptor.getTensor());
for (auto input : tuple.getInputs())
// Deallocate every buffer used to store the sparse tensor handler.
rewriter.create<memref::DeallocOp>(loc, input);
LogicalResult
matchAndRewrite(LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.getHasInserts()) {
- // Finalize any pending insertions.
- // TODO: implement
- }
- rewriter.replaceOp(op, adaptor.getOperands());
+ RankedTensorType srcType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ auto tuple = getTuple(adaptor.getTensor());
+ // Prepare fields.
+ SmallVector<Value, 8> fields(tuple.getInputs());
+ // Generate optional insertion finalization code.
+ if (op.getHasInserts())
+ genEndInsert(rewriter, op.getLoc(), srcType, fields);
+ rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
return success();
}
};
RankedTensorType dstType =
op.getTensor().getType().cast<RankedTensorType>();
Type eltType = dstType.getElementType();
+ auto tuple = getTuple(adaptor.getTensor());
Value values = adaptor.getValues();
Value filled = adaptor.getFilled();
Value added = adaptor.getAdded();
Value count = adaptor.getCount();
+ // Prepare fields and indices.
+ SmallVector<Value, 8> fields(tuple.getInputs());
+ SmallVector<Value, 8> indices(adaptor.getIndices());
// If the innermost dimension is ordered, we need to sort the indices
// in the "added" array prior to applying the compression.
unsigned rank = dstType.getShape().size();
// for (i = 0; i < count; i++) {
// index = added[i];
// value = values[index];
- //
- // TODO: insert prev_indices, index, value
- //
+ // insert({prev_indices, index}, value);
// values[index] = 0;
// filled[index] = false;
// }
Value i = createFor(rewriter, loc, count).getInductionVar();
Value index = rewriter.create<memref::LoadOp>(loc, added, i);
- rewriter.create<memref::LoadOp>(loc, values, index);
- // TODO: insert
+ Value value = rewriter.create<memref::LoadOp>(loc, values, index);
+ indices.push_back(index);
+ // TODO: generate yield cycle
+ genInsert(rewriter, loc, dstType, fields, indices, value);
rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
values, index);
rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
filled, index);
-
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = op;
for (; isa<scf::ForOp>(parent->getParentOp()) ||
rewriter.create<memref::DeallocOp>(loc, values);
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
- rewriter.eraseOp(op);
+ rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
+ return success();
+ }
+};
+
+/// Sparse codegen rule for the insert operator.
+class SparseInsertConverter : public OpConversionPattern<InsertOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ RankedTensorType dstType =
+ op.getTensor().getType().cast<RankedTensorType>();
+ auto tuple = getTuple(adaptor.getTensor());
+ // Prepare fields and indices.
+ SmallVector<Value, 8> fields(tuple.getInputs());
+ SmallVector<Value, 8> indices(adaptor.getIndices());
+ // Generate insertion.
+ Value value = adaptor.getValue();
+ genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
+ rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
return success();
}
};
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- auto tuple = llvm::cast<UnrealizedConversionCastOp>(
- adaptor.getTensor().getDefiningOp());
+ auto tuple = getTuple(adaptor.getTensor());
unsigned idx = Base::getIndexForOp(tuple, op);
auto fields = tuple.getInputs();
assert(idx < fields.size());
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToValuesConverter>(typeConverter, patterns.getContext());
+ SparseInsertConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToValuesConverter>(
+ typeConverter, patterns.getContext());
}
constantIndex(rewriter, loc, i));
rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
- replaceOpWithFuncCall(rewriter, op, name, {},
- {adaptor.getTensor(), mref, vref},
- EmitCInterface::On);
+ createFuncCall(rewriter, loc, name, {}, {adaptor.getTensor(), mref, vref},
+ EmitCInterface::On);
+ rewriter.replaceOp(op, adaptor.getTensor());
return success();
}
};
rewriter.create<memref::StoreOp>(loc, adaptor.getIndices()[i], mref,
constantIndex(rewriter, loc, i));
SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)};
- replaceOpWithFuncCall(rewriter, op, name, {},
- {tensor, mref, values, filled, added, count},
- EmitCInterface::On);
+ createFuncCall(rewriter, loc, name, {},
+ {tensor, mref, values, filled, added, count},
+ EmitCInterface::On);
+ rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
Operation *parent = op;
for (; isa<scf::ForOp>(parent->getParentOp()) ||
// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s
+#SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>
+
#SparseVector = #sparse_tensor.encoding<{
dimLevelType = [ "compressed" ],
indexBitWidth = 64,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
- %i: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %i: index) -> tensor<8x8xf64, #CSR> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
- return
+ return %0 : tensor<8x8xf64, #CSR>
}
// CHECK-LABEL: func @sparse_compression_unordered(
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
- %i: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %i: index) -> tensor<8x8xf64, #UCSR> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #UCSR>
- return
+ return %0 : tensor<8x8xf64, #UCSR>
+}
+
+// CHECK-LABEL: func @sparse_insert(
+// CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+// CHECK-SAME: %[[A2:.*2]]: memref<?xindex>,
+// CHECK-SAME: %[[A3:.*3]]: memref<?xindex>,
+// CHECK-SAME: %[[A4:.*4]]: memref<?xf64>,
+// CHECK-SAME: %[[A5:.*5]]: index,
+// CHECK-SAME: %[[A6:.*6]]: f64)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[T0:.*]] = sparse_tensor.push_back %[[A1]], %[[A2]], %[[C0]]
+// CHECK: %[[T1:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A5]]
+// CHECK: %[[T2:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A6]]
+// CHECK: %[[T3:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+// CHECK: %[[T4:.*]] = sparse_tensor.push_back %[[A1]], %[[T0]], %[[T3]]
+// CHECK: return %[[A0]], %[[A1]], %[[T4]], %[[T1]], %[[T2]] : memref<1xindex>, memref<3xindex>, memref<?xindex>, memref<?xindex>, memref<?xf64>
+func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
+ %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
+ return %1 : tensor<128xf64, #SV>
}
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>,
// CHECK-SAME: %[[B:.*]]: index,
-// CHECK-SAME: %[[C:.*]]: f32) {
+// CHECK-SAME: %[[C:.*]]: f32) -> !llvm.ptr<i8> {
// CHECK-DAG: %[[M:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[V:.*]] = memref.alloca() : memref<f32>
// CHECK-DAG: %[[MC:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[B]], %[[M]][%[[C0]]] : memref<1xindex>
// CHECK-DAG: memref.store %[[C]], %[[V]][] : memref<f32>
// CHECK: call @lexInsertF32(%[[A]], %[[MC]], %[[V]]) : (!llvm.ptr<i8>, memref<?xindex>, memref<f32>) -> ()
-// CHECK: return
+// CHECK: return %[[A]] : !llvm.ptr<i8>
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
%arg1: index,
- %arg2: f32) {
- sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
- return
+ %arg2: f32) -> tensor<128xf32, #SparseVector> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
+ return %0 : tensor<128xf32, #SparseVector>
}
// CHECK-LABEL: func @sparse_expansion1()
// CHECK-SAME: %[[C:.*2]]: memref<?xi1>,
// CHECK-SAME: %[[D:.*3]]: memref<?xindex>,
// CHECK-SAME: %[[E:.*4]]: index,
-// CHECK-SAME: %[[F:.*5]]: index)
+// CHECK-SAME: %[[F:.*5]]: index) -> !llvm.ptr<i8> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[X:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Y:.*]] = memref.cast %[[X]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.dealloc %[[B]] : memref<?xf64>
// CHECK-DAG: memref.dealloc %[[C]] : memref<?xi1>
// CHECK-DAG: memref.dealloc %[[D]] : memref<?xindex>
-// CHECK: return
+// CHECK: return %[[A]] : !llvm.ptr<i8>
func.func @sparse_compression(%tensor: tensor<8x8xf64, #CSR>,
%values: memref<?xf64>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
- %i: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
+ %i: index) -> tensor<8x8xf64, #CSR> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%i]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #CSR>
- return
+ return %0 : tensor<8x8xf64, #CSR>
}
// CHECK-LABEL: func @sparse_out1(
// CHECK-LABEL: func @sparse_insert(
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[B:.*]]: index,
-// CHECK-SAME: %[[C:.*]]: f64) {
-// CHECK: sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
-// CHECK: return
-func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) {
- sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
- return
+// CHECK-SAME: %[[C:.*]]: f64)
+// CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
+// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
+func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
+ %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+ return %0 : tensor<128xf64, #SparseVector>
}
// -----
// CHECK-SAME: %[[A3:.*3]]: index
// CHECK-SAME: %[[A4:.*4]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[A5:.*5]]: index)
-// CHECK: sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
-// CHECK: return
+// CHECK: %[[T:.*]] = sparse_tensor.compress %[[A0]], %[[A1]], %[[A2]], %[[A3]] into %[[A4]][%[[A5]]
+// CHECK: return %[[T]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
func.func @sparse_compression(%values: memref<?xf64>,
%filled: memref<?xi1>,
%added: memref<?xindex>,
%count: index,
%tensor: tensor<8x8xf64, #SparseMatrix>,
- %index: index) {
- sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
+ %index: index) -> tensor<8x8xf64, #SparseMatrix> {
+ %0 = sparse_tensor.compress %values, %filled, %added, %count into %tensor[%index]
: memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #SparseMatrix>
- return
+ return %0 : tensor<8x8xf64, #SparseMatrix>
}
// -----