def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
[SingleBlockImplicitTerminator<"YieldOp">]>,
- Arguments<(ins AnySparseTensor:$tensor)>{
- let summary = "Iterates over non-zero elements in a sparse tensor";
+ Arguments<(ins AnyTensor:$tensor)>{
+ let summary = "Iterates over elements in a tensor";
let description = [{
- Iterates over every non-zero element in the given sparse tensor and executes
- the block.
+ Iterates over stored elements in a tensor (which are typically, but not always,
+ non-zero for sparse tensors) and executes the block.
- For a input sparse tensor with rank n, the block must take n + 1 arguments. The
+ For an input tensor with rank n, the block must take n + 1 arguments. The
first n arguments must be Index type, together indicating the current coordinates
of the element being visited. The last argument must have the same type as the
- sparse tensor's element type, representing the actual value loaded from the input
+ tensor's element type, representing the actual value loaded from the input
tensor at the given coordinates.
Example:
for (int64_t i = 0; i < rank; i++)
loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
- Value vals = loopEmitter.getTensorValueBuffer(0);
- Value idx = loopEmitter.getLastLevelTensorPointerIndex(0);
- Value val = rewriter.create<memref::LoadOp>(op.getLoc(), vals, idx);
-
SmallVector<Value, 4> coords;
coords.reserve(rank);
loopEmitter.getCoordinateArray(coords);
+ Value vals = loopEmitter.getTensorValueBuffer(0);
+ Value pidx = loopEmitter.getLastLevelTensorPointerIndex(0);
+ // Loads the value from sparse tensor using pointer index;
+ // loads the value from dense tensor using coordinate array.
+ Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pidx)
+ : rewriter.create<memref::LoadOp>(loc, vals, coords);
+
for (int64_t i = 0; i < rank; i++)
loopEmitter.exitCurrentLoop();
return
}
+ func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) {
+ sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do {
+ ^bb0(%1: index, %2: index, %v: f64) :
+ vector.print %1: index
+ vector.print %2: index
+ vector.print %v: f64
+ }
+ return
+ }
+
//
// Main driver.
//
// CHECK-NEXT: 5
// CHECK-NEXT: 1
// CHECK-NEXT: 1
+ // CHECK-NEXT: 6
+ call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> ()
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 2
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 5
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 1
// CHECK-NEXT: 6
call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> ()
// CHECK-NEXT: 0