[mlir][sparse] Fix a bug in rewriting for the convert op.
authorbixia1 <bixia@google.com>
Thu, 10 Nov 2022 21:07:57 +0000 (13:07 -0800)
committerbixia1 <bixia@google.com>
Thu, 10 Nov 2022 22:30:51 +0000 (14:30 -0800)
The code to retrieve the number of entries isn't correct.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D137795

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir

index 57230eb..6d1fa8b 100644 (file)
@@ -666,12 +666,7 @@ private:
     }
 
     // Retrieve NNZ.
-    auto ptrTp =
-        MemRefType::get(dynShape, getPointerOverheadType(rewriter, encSrc));
-    Value p0 =
-        rewriter.create<ToIndicesOp>(loc, ptrTp, src, rewriter.getIndexAttr(0));
-    Value c1 = constantIndex(rewriter, loc, 1);
-    Value nnz = rewriter.create<memref::LoadOp>(loc, p0, c1);
+    Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
     nnz =
         rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nnz);
 
index 7c80314..74f7117 100644 (file)
@@ -105,35 +105,34 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
 
 // CHECK-RWT-LABEL:   func.func @sparse_convert_2d(
 //  CHECK-RWT-SAME:     %[[T0:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
-//       CHECK-RWT:     %[[VAL_1:.*]] = arith.constant 1 : index
-//       CHECK-RWT:     %[[VAL_2:.*]] = bufferization.alloc_tensor()
-//       CHECK-RWT:     %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]])
-//       CHECK-RWT:     ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor
-//       CHECK-RWT:        %[[CMP:.*]] = arith.cmpf une, %[[VAL_6]]
+//       CHECK-RWT:     %[[T1:.*]] = bufferization.alloc_tensor()
+//       CHECK-RWT:     %[[T2:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[T1]])
+//       CHECK-RWT:     ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f64, %[[L0T:.*]]: tensor
+//       CHECK-RWT:        %[[CMP:.*]] = arith.cmpf une, %[[L0V]]
 //       CHECK-RWT:        %[[IFR:.*]] = scf.if %[[CMP]]
-//       CHECK-RWT:          %[[Y1:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]
-//       CHECK-RWT:          scf.yield %[[Y1]]
+//       CHECK-RWT:          %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I0]], %[[L0I1]]]
+//       CHECK-RWT:          scf.yield %[[L0T2]]
 //       CHECK-RWT:        } else {
-//       CHECK-RWT:          scf.yield %[[VAL_7]]
+//       CHECK-RWT:          scf.yield %[[L0T]]
 //       CHECK-RWT:        }
 //       CHECK-RWT:        sparse_tensor.yield %[[IFR]]
 //       CHECK-RWT:     }
-//       CHECK-RWT:     %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts
+//       CHECK-RWT:     %[[COO:.*]] = sparse_tensor.load %[[T2]] hasInserts
 //       CHECK-RWT:     %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
 //       CHECK-RWT:     %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
-//       CHECK-RWT:     %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_1]]] : memref<?xindex>
-//       CHECK-RWT:     %[[VAL_14:.*]] = sparse_tensor.values %[[COO]]
-//       CHECK-RWT:     sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[VAL_14]] : memref<?xindex>, memref<?xindex> jointly memref<?xf64>
-//       CHECK-RWT:     %[[VAL_15:.*]] = bufferization.alloc_tensor()
-//       CHECK-RWT:     %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]])
-//       CHECK-RWT:     ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f64, %[[VAL_20:.*]]: tensor
-//       CHECK-RWT:       %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]]
-//       CHECK-RWT:       sparse_tensor.yield %[[VAL_21]]
+//       CHECK-RWT:     %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
+//       CHECK-RWT:     %[[V:.*]] = sparse_tensor.values %[[COO]]
+//       CHECK-RWT:     sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref<?xindex>, memref<?xindex> jointly memref<?xf64>
+//       CHECK-RWT:     %[[T3:.*]] = bufferization.alloc_tensor()
+//       CHECK-RWT:     %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
+//       CHECK-RWT:     ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor
+//       CHECK-RWT:       %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I0]], %[[L1I1]]]
+//       CHECK-RWT:       sparse_tensor.yield %[[L1T2]]
 //       CHECK-RWT:     }
-//       CHECK-RWT:     %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts
-//       CHECK-RWT:     %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]]
+//       CHECK-RWT:     %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts
+//       CHECK-RWT:     %[[T6:.*]] = sparse_tensor.convert %[[T5]]
 //       CHECK-RWT:     bufferization.dealloc_tensor %[[COO]]
-//       CHECK-RWT:     return %[[VAL_24]]
+//       CHECK-RWT:     return %[[T6]]
 //       CHECK-RWT:   }
 func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
@@ -169,30 +168,29 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 
 // CHECK-RWT-LABEL:   func.func @sparse_constant() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
-//       CHECK-RWT:     %[[VAL_0:.*]] = arith.constant 1 : index
-//       CHECK-RWT:     %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
+//       CHECK-RWT:     %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
 //       CHECK-RWT:     %[[T0:.*]] = bufferization.alloc_tensor()
-//       CHECK-RWT:     %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
-//       CHECK-RWT:     ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor
-//       CHECK-RWT:       %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
-//       CHECK-RWT:       sparse_tensor.yield %[[T2]]
+//       CHECK-RWT:     %[[T1:.*]] = sparse_tensor.foreach in %[[F0]] init(%[[T0]])
+//       CHECK-RWT:     ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32, %[[L0T:.*]]: tensor
+//       CHECK-RWT:       %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I0]], %[[L0I1]]]
+//       CHECK-RWT:       sparse_tensor.yield %[[L0T2]]
 //       CHECK-RWT:     }
 //       CHECK-RWT:     %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts
 //       CHECK-RWT:     %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
 //       CHECK-RWT:     %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
-//       CHECK-RWT:     %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_0]]] : memref<?xindex>
+//       CHECK-RWT:     %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
 //       CHECK-RWT:     %[[V:.*]] = sparse_tensor.values %[[COO]]
-//       CHECK-RWT:     sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[V]] : memref<?xindex>, memref<?xindex> jointly memref<?xf32>
-//       CHECK-RWT:     %[[VAL_15:.*]] = bufferization.alloc_tensor()
-//       CHECK-RWT:     %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]])
-//       CHECK-RWT:     ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: tensor
-//       CHECK-RWT:       %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]]
-//       CHECK-RWT:       sparse_tensor.yield %[[VAL_21]]
+//       CHECK-RWT:     sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref<?xindex>, memref<?xindex> jointly memref<?xf32>
+//       CHECK-RWT:     %[[T3:.*]] = bufferization.alloc_tensor()
+//       CHECK-RWT:     %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
+//       CHECK-RWT:     ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor
+//       CHECK-RWT:       %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I0]], %[[L1I1]]]
+//       CHECK-RWT:       sparse_tensor.yield %[[L1T2]]
 //       CHECK-RWT:     }
-//       CHECK-RWT:     %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts
-//       CHECK-RWT:     %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]]
+//       CHECK-RWT:     %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts
+//       CHECK-RWT:     %[[T6:.*]] = sparse_tensor.convert %[[T5]]
 //       CHECK-RWT:     bufferization.dealloc_tensor %[[COO]]
-//       CHECK-RWT:     return %[[VAL_24]]
+//       CHECK-RWT:     return %[[T6]]
 //       CHECK-RWT:   }
 func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
   // Initialize a tensor.
index 17145f8..f806d89 100644 (file)
@@ -86,11 +86,9 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
 // CHECK-RWT-LABEL: func.func @sparse_convert(
 //  CHECK-RWT-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 64, indexBitWidth = 64 }>>)
 //  CHECK-RWT-DAG:  %[[C0:.*]] = arith.constant 0 : index
-//  CHECK-RWT-DAG:  %[[C1:.*]] = arith.constant 1 : index
 //      CHECK-RWT:  %[[D:.*]] = tensor.dim %[[A]], %[[C0]]
 //      CHECK-RWT:  %[[I0:.*]] = sparse_tensor.indices %[[A]] {dimension = 0 : index}
-//      CHECK-RWT:  %[[NNZr:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref<?xi64>
-//      CHECK-RWT:  %[[NNZ:.*]] = arith.index_cast %[[NNZr]] : i64 to index
+//      CHECK-RWT:  %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]]
 //      CHECK-RWT:  %[[V:.*]] = sparse_tensor.values %[[A]]
 //      CHECK-RWT:  sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]]
 //      CHECK-RWT:  %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])