From 231146433a45ca8135e132ee0b48469798ca0b1f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 10 Apr 2018 22:44:36 -0700 Subject: [PATCH] [XLA] Fix the size of data buffer for sparse literals. PiperOrigin-RevId: 192404543 --- tensorflow/compiler/xla/literal_util.cc | 13 ++++++++++--- tensorflow/compiler/xla/literal_util.h | 5 +++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index c2950c1..c315b4f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -97,11 +97,18 @@ Literal::Literal(const Shape& shape, bool allocate_arrays) const Shape& subshape = piece.subshape(); if (ShapeUtil::IsArray(subshape)) { if (allocate_arrays) { - piece.set_buffer(new char[piece.size_bytes()]); if (LayoutUtil::IsSparseArray(subshape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(subshape.layout()); + piece.set_buffer( + new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( + subshape.element_type())]); piece.set_sparse_indices(new SparseIndexArray( - LayoutUtil::MaxSparseElements(subshape.layout()), - ShapeUtil::Rank(subshape))); + max_sparse_elements, ShapeUtil::Rank(subshape))); + } else { + piece.set_buffer(new char[piece.size_bytes()]); } } else { piece.set_buffer(nullptr); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a6a3dff..8aa1922 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -1287,12 +1287,13 @@ void Literal::PopulateSparse(SparseIndexArray indices, CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - root_data.remove_suffix(max_elements - values.size()); + // Piece::data() returns an ArraySlice of size equal to the number of indices + // in the SparseIndexArray. So there is no need to adjust the size of the data + // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); *this->root_piece().sparse_indices() = std::move(indices); if (sort) { auto root_data = this->root_piece().data(); - root_data.remove_suffix(root_data.size() - num_elements); this->root_piece().sparse_indices()->SortWithValues(root_data); } DCHECK(this->root_piece().sparse_indices()->Validate(shape())); -- 2.7.4