[mlir][sparse] Avoid values buffer reallocation for annotated all dense tensors.
authorbixia1 <bixia@google.com>
Wed, 11 Jan 2023 17:06:42 +0000 (09:06 -0800)
committerbixia1 <bixia@google.com>
Thu, 12 Jan 2023 00:31:07 +0000 (16:31 -0800)
Previously, we rely on the InsertOp to gradually increase the size of the
storage for all sparse tensors. We now allocate the full size values buffer
for annotated all dense tensors when we first allocate the tensor. This avoids
the cost of gradually increasing the buffer and allows accessing the values
buffer as if it were a dense tensor.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D141516

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir

index 38a7e0e0610fb6e9985ab0593d72ef6c95aa46dc..0ce37620061a6a640a00a3a816c07edc1a591bf1 100644 (file)
@@ -205,11 +205,30 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
                               ValueRange dynSizes, bool enableInit,
                               SmallVectorImpl<Value> &fields) {
   RankedTensorType rtp = type.cast<RankedTensorType>();
-  Value heuristic = constantIndex(builder, loc, 16);
+  // Build original sizes.
+  SmallVector<Value> sizes;
+  auto shape = rtp.getShape();
+  unsigned rank = shape.size();
+  for (unsigned r = 0, o = 0; r < rank; r++) {
+    if (ShapedType::isDynamic(shape[r]))
+      sizes.push_back(dynSizes[o++]);
+    else
+      sizes.push_back(constantIndex(builder, loc, shape[r]));
+  }
 
+  Value heuristic = constantIndex(builder, loc, 16);
+  Value valHeuristic = heuristic;
+  SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp);
+  if (enc.isAllDense()) {
+    Value linear = sizes[0];
+    for (unsigned r = 1; r < rank; r++) {
+      linear = builder.create<arith::MulIOp>(loc, linear, sizes[r]);
+    }
+    valHeuristic = linear;
+  }
   foreachFieldAndTypeInSparseTensor(
       rtp,
-      [&builder, &fields, rtp, loc, heuristic,
+      [&builder, &fields, rtp, loc, heuristic, valHeuristic,
        enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
                    unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
         assert(fields.size() == fIdx);
@@ -222,7 +241,10 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
         case SparseTensorFieldKind::IdxMemRef:
         case SparseTensorFieldKind::ValMemRef:
           field = createAllocation(builder, loc, fType.cast<MemRefType>(),
-                                   heuristic, enableInit);
+                                   fKind == SparseTensorFieldKind::ValMemRef
+                                       ? valHeuristic
+                                       : heuristic,
+                                   enableInit);
           break;
         }
         assert(field);
@@ -233,16 +255,6 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
 
   MutSparseTensorDescriptor desc(rtp, fields);
 
-  // Build original sizes.
-  SmallVector<Value> sizes;
-  auto shape = rtp.getShape();
-  unsigned rank = shape.size();
-  for (unsigned r = 0, o = 0; r < rank; r++) {
-    if (ShapedType::isDynamic(shape[r]))
-      sizes.push_back(dynSizes[o++]);
-    else
-      sizes.push_back(constantIndex(builder, loc, shape[r]));
-  }
   // Initialize the storage scheme to an empty tensor. Initialized memSizes
   // to all zeros, sets the dimSizes to known values and gives all pointer
   // fields an initial zero entry, so that it is easier to maintain the
index 652923ea22d07cb97091ddb7a16ff244c71bbc06..61c4324cf1a413fef47a80f99f812e9ef0defed5 100644 (file)
@@ -345,8 +345,8 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
 //       CHECK:     %[[A2:.*]] = arith.constant 10 : i64
 //       CHECK:     %[[A3:.*]] = arith.constant 30 : i64
 //       CHECK:     %[[A4:.*]] = arith.constant 0.000000e+00 : f64
-//       CHECK:     %[[A5:.*]] = memref.alloc() : memref<16xf64>
-//       CHECK:     %[[A6:.*]] = memref.cast %[[A5]] : memref<16xf64> to memref<?xf64>
+//       CHECK:     %[[A5:.*]] = memref.alloc() : memref<6000xf64>
+//       CHECK:     %[[A6:.*]] = memref.cast %[[A5]] : memref<6000xf64> to memref<?xf64>
 //       CHECK:     %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
 //       CHECK:     %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]]  dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier
 //       CHECK:     %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]]  dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier