const Shape& subshape = piece.subshape();
if (ShapeUtil::IsArray(subshape)) {
if (allocate_arrays) {
- piece.set_buffer(new char[piece.size_bytes()]);
if (LayoutUtil::IsSparseArray(subshape)) {
+ // For sparse arrays, the buffer must be of the size of the maximum
+ // number of sparse elements possible.
+ const int64 max_sparse_elements =
+ LayoutUtil::MaxSparseElements(subshape.layout());
+ piece.set_buffer(
+ new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(
+ subshape.element_type())]);
piece.set_sparse_indices(new SparseIndexArray(
- LayoutUtil::MaxSparseElements(subshape.layout()),
- ShapeUtil::Rank(subshape)));
+ max_sparse_elements, ShapeUtil::Rank(subshape)));
+ } else {
+ piece.set_buffer(new char[piece.size_bytes()]);
}
} else {
piece.set_buffer(nullptr);
CHECK_LE(num_elements, max_elements);
CHECK_EQ(num_elements, indices.index_count());
auto root_data = root_piece().data<NativeT>();
- root_data.remove_suffix(max_elements - values.size());
+ // Piece::data() returns an ArraySlice of size equal to the number of indices
+ // in the SparseIndexArray. So there is no need to adjust the size of the data
+ // here. It is enough to just copy the incoming values into the data buffer.
std::copy(values.begin(), values.end(), root_data.begin());
*this->root_piece().sparse_indices() = std::move(indices);
if (sort) {
auto root_data = this->root_piece().data<NativeT>();
- root_data.remove_suffix(root_data.size() - num_elements);
this->root_piece().sparse_indices()->SortWithValues(root_data);
}
DCHECK(this->root_piece().sparse_indices()->Validate(shape()));