ValueRange dynSizes, bool enableInit,
SmallVectorImpl<Value> &fields) {
RankedTensorType rtp = type.cast<RankedTensorType>();
- Value heuristic = constantIndex(builder, loc, 16);
+ // Build original sizes.
+ SmallVector<Value> sizes;
+ auto shape = rtp.getShape();
+ unsigned rank = shape.size();
+ for (unsigned r = 0, o = 0; r < rank; r++) {
+ if (ShapedType::isDynamic(shape[r]))
+ sizes.push_back(dynSizes[o++]);
+ else
+ sizes.push_back(constantIndex(builder, loc, shape[r]));
+ }
+ Value heuristic = constantIndex(builder, loc, 16);
+ Value valHeuristic = heuristic;
+ SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp);
+ if (enc.isAllDense()) {
+ Value linear = sizes[0];
+ for (unsigned r = 1; r < rank; r++) {
+ linear = builder.create<arith::MulIOp>(loc, linear, sizes[r]);
+ }
+ valHeuristic = linear;
+ }
foreachFieldAndTypeInSparseTensor(
rtp,
- [&builder, &fields, rtp, loc, heuristic,
+ [&builder, &fields, rtp, loc, heuristic, valHeuristic,
enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
assert(fields.size() == fIdx);
case SparseTensorFieldKind::IdxMemRef:
case SparseTensorFieldKind::ValMemRef:
field = createAllocation(builder, loc, fType.cast<MemRefType>(),
- heuristic, enableInit);
+ fKind == SparseTensorFieldKind::ValMemRef
+ ? valHeuristic
+ : heuristic,
+ enableInit);
break;
}
assert(field);
MutSparseTensorDescriptor desc(rtp, fields);
- // Build original sizes.
- SmallVector<Value> sizes;
- auto shape = rtp.getShape();
- unsigned rank = shape.size();
- for (unsigned r = 0, o = 0; r < rank; r++) {
- if (ShapedType::isDynamic(shape[r]))
- sizes.push_back(dynSizes[o++]);
- else
- sizes.push_back(constantIndex(builder, loc, shape[r]));
- }
// Initialize the storage scheme to an empty tensor. Initialized memSizes
// to all zeros, sets the dimSizes to known values and gives all pointer
// fields an initial zero entry, so that it is easier to maintain the
// CHECK: %[[A2:.*]] = arith.constant 10 : i64
// CHECK: %[[A3:.*]] = arith.constant 30 : i64
// CHECK: %[[A4:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xf64>
-// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<16xf64> to memref<?xf64>
+// CHECK: %[[A5:.*]] = memref.alloc() : memref<6000xf64>
+// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<6000xf64> to memref<?xf64>
// CHECK: %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier
// CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier
// CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier