SmallVector<Value, 4> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
+ bool fromSparseConst = false;
+ if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
+ if (constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ fromSparseConst = true;
+ }
+ }
+
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
loc, src, cooBuffer,
[&](OpBuilder &builder, Location loc, ValueRange indices, Value v,
ValueRange reduc) {
- builder.create<sparse_tensor::YieldOp>(
- loc, builder.create<InsertOp>(loc, v, reduc.front(), indices));
+ Value input = reduc.front();
+ if (fromSparseConst) {
+ input = builder.create<InsertOp>(loc, v, input, indices);
+ } else {
+ Value cond = genIsNonzero(builder, loc, v);
+ auto ifOp = builder.create<scf::IfOp>(
+ loc, TypeRange(input.getType()), cond, /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value insert = builder.create<InsertOp>(loc, v, input, indices);
+ builder.create<scf::YieldOp>(loc, insert);
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, input);
+ builder.setInsertionPointAfter(ifOp);
+ input = ifOp.getResult(0);
+ }
+ builder.create<sparse_tensor::YieldOp>(loc, input);
});
rewriter.setInsertionPointAfter(op);
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
// CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]])
// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor
-// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
-// CHECK-RWT: sparse_tensor.yield %[[VAL_8]]
+// CHECK-RWT: %[[CMP:.*]] = arith.cmpf une, %[[VAL_6]]
+// CHECK-RWT: %[[IFR:.*]] = scf.if %[[CMP]]
+// CHECK-RWT: %[[Y1:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]
+// CHECK-RWT: scf.yield %[[Y1]]
+// CHECK-RWT: } else {
+// CHECK-RWT: scf.yield %[[VAL_7]]
+// CHECK-RWT: }
+// CHECK-RWT: sparse_tensor.yield %[[IFR]]
// CHECK-RWT: }
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
// CHECK-RWT: %[[VAL_0:.*]] = arith.constant 1 : index
// CHECK-RWT: %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor()
-// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
+// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor
// CHECK-RWT: %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
// CHECK-RWT: sparse_tensor.yield %[[T2]]
%b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
%b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
+ //
+ // Sanity check on stored entries before going into the computations.
+ //
+ // CHECK: 32
+ // CHECK-NEXT: 32
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 32
+ // CHECK-NEXT: 32
+ // CHECK-NEXT: 8
+ // CHECK-NEXT: 8
+ //
+ %noea1 = sparse_tensor.number_of_entries %a1 : tensor<4x8xf64, #CSR>
+ %noea2 = sparse_tensor.number_of_entries %a2 : tensor<4x8xf64, #DCSR>
+ %noea3 = sparse_tensor.number_of_entries %a3 : tensor<4x8xf64, #CSR>
+ %noea4 = sparse_tensor.number_of_entries %a4 : tensor<4x8xf64, #DCSR>
+ %noeb1 = sparse_tensor.number_of_entries %b1 : tensor<8x4xf64, #CSR>
+ %noeb2 = sparse_tensor.number_of_entries %b2 : tensor<8x4xf64, #DCSR>
+ %noeb3 = sparse_tensor.number_of_entries %b3 : tensor<8x4xf64, #CSR>
+ %noeb4 = sparse_tensor.number_of_entries %b4 : tensor<8x4xf64, #DCSR>
+ vector.print %noea1 : index
+ vector.print %noea2 : index
+ vector.print %noea3 : index
+ vector.print %noea4 : index
+ vector.print %noeb1 : index
+ vector.print %noeb2 : index
+ vector.print %noeb3 : index
+ vector.print %noeb4 : index
+
// Call kernels with dense.
%0 = call @matmul1(%da, %db, %zero)
: (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64>
vector.print %v5 : vector<4x4xf64>
//
- // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+ // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
//
%v6 = vector.transfer_read %6[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
vector.print %v6 : vector<4x4xf64>
//
- // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+ // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
//
%c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
%v7 = vector.transfer_read %c7[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
vector.print %v7 : vector<4x4xf64>
//
- // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+ // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
//
%c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
%v8 = vector.transfer_read %c8[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
//
// Sanity check on nonzeros.
//
- // FIXME: bring this back once dense2sparse skips zeros
- //
- // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 )
- // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 )
+ // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 )
+ // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 )
//
%val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref<?xf64>
%val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref<?xf64>
- %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<8xf64>
- %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<8xf64>
- vector.print %nz7 : vector<8xf64>
- vector.print %nz8 : vector<8xf64>
+ %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<5xf64>
+ %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<5xf64>
+ vector.print %nz7 : vector<5xf64>
+ vector.print %nz8 : vector<5xf64>
+
+ //
+ // Sanity check on stored entries after the computations.
+ //
+ // CHECK-NEXT: 5
+ // CHECK-NEXT: 5
+ //
+ %noe7 = sparse_tensor.number_of_entries %7 : tensor<4x4xf64, #CSR>
+ %noe8 = sparse_tensor.number_of_entries %8 : tensor<4x4xf64, #DCSR>
+ vector.print %noe7 : index
+ vector.print %noe8 : index
// Release the resources.
bufferization.dealloc_tensor %a1 : tensor<4x8xf64, #CSR>