[mlir][sparse] skip zeros during dense2sparse
authorAart Bik <ajcbik@google.com>
Wed, 9 Nov 2022 22:57:43 +0000 (14:57 -0800)
committerAart Bik <ajcbik@google.com>
Thu, 10 Nov 2022 04:54:27 +0000 (20:54 -0800)
This enables the full matmul integration test with runtime_lib=true/false!

Background:
https://github.com/llvm/llvm-project/issues/51657

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D137750

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir

index 701be4c..57230eb 100644 (file)
@@ -533,6 +533,13 @@ private:
     SmallVector<Value, 4> dynSizes;
     getDynamicSizes(dstTp, sizes, dynSizes);
 
+    bool fromSparseConst = false;
+    if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
+      if (constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+        fromSparseConst = true;
+      }
+    }
+
     RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
     auto cooBuffer =
         rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
@@ -540,8 +547,22 @@ private:
         loc, src, cooBuffer,
         [&](OpBuilder &builder, Location loc, ValueRange indices, Value v,
             ValueRange reduc) {
-          builder.create<sparse_tensor::YieldOp>(
-              loc, builder.create<InsertOp>(loc, v, reduc.front(), indices));
+          Value input = reduc.front();
+          if (fromSparseConst) {
+            input = builder.create<InsertOp>(loc, v, input, indices);
+          } else {
+            Value cond = genIsNonzero(builder, loc, v);
+            auto ifOp = builder.create<scf::IfOp>(
+                loc, TypeRange(input.getType()), cond, /*else*/ true);
+            builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+            Value insert = builder.create<InsertOp>(loc, v, input, indices);
+            builder.create<scf::YieldOp>(loc, insert);
+            builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+            builder.create<scf::YieldOp>(loc, input);
+            builder.setInsertionPointAfter(ifOp);
+            input = ifOp.getResult(0);
+          }
+          builder.create<sparse_tensor::YieldOp>(loc, input);
         });
     rewriter.setInsertionPointAfter(op);
     src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
index 96d7887..7c80314 100644 (file)
@@ -109,8 +109,14 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
 //       CHECK-RWT:     %[[VAL_2:.*]] = bufferization.alloc_tensor()
 //       CHECK-RWT:     %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]])
 //       CHECK-RWT:     ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor
-//       CHECK-RWT:       %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
-//       CHECK-RWT:       sparse_tensor.yield %[[VAL_8]]
+//       CHECK-RWT:        %[[CMP:.*]] = arith.cmpf une, %[[VAL_6]]
+//       CHECK-RWT:        %[[IFR:.*]] = scf.if %[[CMP]]
+//       CHECK-RWT:          %[[Y1:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]
+//       CHECK-RWT:          scf.yield %[[Y1]]
+//       CHECK-RWT:        } else {
+//       CHECK-RWT:          scf.yield %[[VAL_7]]
+//       CHECK-RWT:        }
+//       CHECK-RWT:        sparse_tensor.yield %[[IFR]]
 //       CHECK-RWT:     }
 //       CHECK-RWT:     %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts
 //       CHECK-RWT:     %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
@@ -166,7 +172,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
 //       CHECK-RWT:     %[[VAL_0:.*]] = arith.constant 1 : index
 //       CHECK-RWT:     %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
 //       CHECK-RWT:     %[[T0:.*]] = bufferization.alloc_tensor()
-//       CHECK-RWT:     %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]]) 
+//       CHECK-RWT:     %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
 //       CHECK-RWT:     ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor
 //       CHECK-RWT:       %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
 //       CHECK-RWT:       sparse_tensor.yield %[[T2]]
index 88238e9..ce930eb 100644 (file)
@@ -116,6 +116,35 @@ module {
     %b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
     %b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
 
+    //
+    // Sanity check on stored entries before going into the computations.
+    //
+    // CHECK:      32
+    // CHECK-NEXT: 32
+    // CHECK-NEXT: 4
+    // CHECK-NEXT: 4
+    // CHECK-NEXT: 32
+    // CHECK-NEXT: 32
+    // CHECK-NEXT: 8
+    // CHECK-NEXT: 8
+    //
+    %noea1 = sparse_tensor.number_of_entries %a1 : tensor<4x8xf64, #CSR>
+    %noea2 = sparse_tensor.number_of_entries %a2 : tensor<4x8xf64, #DCSR>
+    %noea3 = sparse_tensor.number_of_entries %a3 : tensor<4x8xf64, #CSR>
+    %noea4 = sparse_tensor.number_of_entries %a4 : tensor<4x8xf64, #DCSR>
+    %noeb1 = sparse_tensor.number_of_entries %b1 : tensor<8x4xf64, #CSR>
+    %noeb2 = sparse_tensor.number_of_entries %b2 : tensor<8x4xf64, #DCSR>
+    %noeb3 = sparse_tensor.number_of_entries %b3 : tensor<8x4xf64, #CSR>
+    %noeb4 = sparse_tensor.number_of_entries %b4 : tensor<8x4xf64, #DCSR>
+    vector.print %noea1 : index
+    vector.print %noea2 : index
+    vector.print %noea3 : index
+    vector.print %noea4 : index
+    vector.print %noeb1 : index
+    vector.print %noeb2 : index
+    vector.print %noeb3 : index
+    vector.print %noeb4 : index
+
     // Call kernels with dense.
     %0 = call @matmul1(%da, %db, %zero)
        : (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64>
@@ -205,20 +234,20 @@ module {
     vector.print %v5 : vector<4x4xf64>
 
     //
-    // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+    // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
     //
     %v6 = vector.transfer_read %6[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v6 : vector<4x4xf64>
 
     //
-    // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+    // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
     //
     %c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
     %v7 = vector.transfer_read %c7[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
     vector.print %v7 : vector<4x4xf64>
 
     //
-    // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+    // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
     //
     %c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
     %v8 = vector.transfer_read %c8[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
@@ -227,17 +256,26 @@ module {
     //
     // Sanity check on nonzeros.
     //
-    // FIXME: bring this back once dense2sparse skips zeros
-    //
-    // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 )
-    // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 )
+    // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 )
+    // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 )
     //
     %val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref<?xf64>
     %val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref<?xf64>
-    %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<8xf64>
-    %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<8xf64>
-    vector.print %nz7 : vector<8xf64>
-    vector.print %nz8 : vector<8xf64>
+    %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<5xf64>
+    %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<5xf64>
+    vector.print %nz7 : vector<5xf64>
+    vector.print %nz8 : vector<5xf64>
+
+    //
+    // Sanity check on stored entries after the computations.
+    //
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 5
+    //
+    %noe7 = sparse_tensor.number_of_entries %7 : tensor<4x4xf64, #CSR>
+    %noe8 = sparse_tensor.number_of_entries %8 : tensor<4x4xf64, #DCSR>
+    vector.print %noe7 : index
+    vector.print %noe8 : index
 
     // Release the resources.
     bufferization.dealloc_tensor %a1 : tensor<4x8xf64, #CSR>