[XLA] Fix the size of data buffer for sparse literals.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Apr 2018 05:44:36 +0000 (22:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 11 Apr 2018 05:47:29 +0000 (22:47 -0700)
PiperOrigin-RevId: 192404543

tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h

index c2950c1..c315b4f 100644 (file)
@@ -97,11 +97,18 @@ Literal::Literal(const Shape& shape, bool allocate_arrays)
     const Shape& subshape = piece.subshape();
     if (ShapeUtil::IsArray(subshape)) {
       if (allocate_arrays) {
-        piece.set_buffer(new char[piece.size_bytes()]);
         if (LayoutUtil::IsSparseArray(subshape)) {
+          // For sparse arrays, the buffer must be of the size of the maximum
+          // number of sparse elements possible.
+          const int64 max_sparse_elements =
+              LayoutUtil::MaxSparseElements(subshape.layout());
+          piece.set_buffer(
+              new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(
+                                                 subshape.element_type())]);
           piece.set_sparse_indices(new SparseIndexArray(
-              LayoutUtil::MaxSparseElements(subshape.layout()),
-              ShapeUtil::Rank(subshape)));
+              max_sparse_elements, ShapeUtil::Rank(subshape)));
+        } else {
+          piece.set_buffer(new char[piece.size_bytes()]);
         }
       } else {
         piece.set_buffer(nullptr);
index a6a3dff..8aa1922 100644 (file)
@@ -1287,12 +1287,13 @@ void Literal::PopulateSparse(SparseIndexArray indices,
   CHECK_LE(num_elements, max_elements);
   CHECK_EQ(num_elements, indices.index_count());
   auto root_data = root_piece().data<NativeT>();
-  root_data.remove_suffix(max_elements - values.size());
+  // Piece::data() returns an ArraySlice of size equal to the number of indices
+  // in the SparseIndexArray. So there is no need to adjust the size of the data
+  // here. It is enough to just copy the incoming values into the data buffer.
   std::copy(values.begin(), values.end(), root_data.begin());
   *this->root_piece().sparse_indices() = std::move(indices);
   if (sort) {
     auto root_data = this->root_piece().data<NativeT>();
-    root_data.remove_suffix(root_data.size() - num_elements);
     this->root_piece().sparse_indices()->SortWithValues(root_data);
   }
   DCHECK(this->root_piece().sparse_indices()->Validate(shape()));