Remove protobuf-compatibility methods from the Literal class.
authorMark Heffernan <meheff@google.com>
Sat, 6 Jan 2018 05:46:44 +0000 (21:46 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 6 Jan 2018 05:50:47 +0000 (21:50 -0800)
This CL primarily does two things:

 (1) Remove the protobuf-compatibility methods (eg, mutable_f32s()) from Literal. These were added to Literal as part of the migration of Literal from a proto to a c++ class. Now that Literal is a proper class, these protobuf methods make it difficult to enforce invariants and expose too much of the class' implementation details.

 (2) Make shape an immutable property of Literals, and make shape and the data members holding the Literal data coherent by construction. Previously, the shape could be set arbitrarily, and the data members such as f32_ could be arbitrarily sized irrespective of the shape of the literal.

The remainder of the CL mostly deals with the fallout. Notable other changes:

- Literal is no longer a recursive data structure. To avoid copies when passing a subliteral of a tuple-shaped Literal, a LiteralView class is added which provides a read-only view of an arbitrary subliteral.

- Tuple-shaped Literals can no longer be built up incrementally so to avoid copying Literal values during construction, the following methods with move semantics are added: Literal::MoveFrom and Literal::MoveIntoTuple. These methods transfer ownership the underlying buffers enabling, for example, a literal to be moved into an element of a tuple-shaped literal with no data copying.

- Replace the internal data structure holding the actual data from a bunch of std::vectors (eg, s32s_, f32s, etc) to a single ShapeTree<char*>. This significantly simplifies accessors and makes improved support of tuple-shaped literals much easier (eg, Literal::Get<>() can now access elements in arbitrary subliterals).

Also, Literal is made movable, but not copyable. Otherwise, it is all too easy to accidentally introduce expensive copies of Literals. Literal::Clone is added to handle the case where a copy is needed (Literal::CloneToUnique already exists).

PiperOrigin-RevId: 181014890

54 files changed:
tensorflow/compiler/tf2xla/kernels/reverse_op.cc
tensorflow/compiler/tf2xla/kernels/shape_op.cc
tensorflow/compiler/tf2xla/kernels/transpose_op.cc
tensorflow/compiler/tf2xla/literal_util.cc
tensorflow/compiler/tf2xla/xla_helpers.cc
tensorflow/compiler/tf2xla/xla_op_kernel.cc
tensorflow/compiler/xla/BUILD
tensorflow/compiler/xla/client/client.cc
tensorflow/compiler/xla/client/computation_builder.cc
tensorflow/compiler/xla/client/computation_builder.h
tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/literal_util_test.cc
tensorflow/compiler/xla/packed_literal_reader.cc
tensorflow/compiler/xla/python/local_computation_builder.i
tensorflow/compiler/xla/python/numpy_bridge.cc
tensorflow/compiler/xla/python/numpy_bridge.h
tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
tensorflow/compiler/xla/service/generic_transfer_manager.cc
tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/service.cc
tensorflow/compiler/xla/service/user_computation.cc
tensorflow/compiler/xla/service/while_loop_simplifier.cc
tensorflow/compiler/xla/shape_util.cc
tensorflow/compiler/xla/shape_util.h
tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
tensorflow/compiler/xla/tests/batch_normalization_test.cc
tensorflow/compiler/xla/tests/bfloat16_test.cc
tensorflow/compiler/xla/tests/broadcast_test.cc
tensorflow/compiler/xla/tests/client_library_test_base.cc
tensorflow/compiler/xla/tests/client_test.cc
tensorflow/compiler/xla/tests/compute_constant_test.cc
tensorflow/compiler/xla/tests/constants_test.cc
tensorflow/compiler/xla/tests/convolution_test.cc
tensorflow/compiler/xla/tests/copy_test.cc
tensorflow/compiler/xla/tests/literal_test_util.cc
tensorflow/compiler/xla/tests/literal_test_util_test.cc
tensorflow/compiler/xla/tests/local_client_execute_test.cc
tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
tensorflow/compiler/xla/tests/params_test.cc
tensorflow/compiler/xla/tests/prng_test.cc
tensorflow/compiler/xla/tests/reshape_test.cc
tensorflow/compiler/xla/tests/transfer_manager_test.cc
tensorflow/compiler/xla/text_literal_reader.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/replay_computation.cc
tensorflow/compiler/xla/tools/show_literal.cc
tensorflow/compiler/xla/tools/show_text_literal.cc

index 17a345fc942f78d161218d9a66e9583f3465a569..e51d386926763ecbb5a943dfb6f872e78901dc69 100644 (file)
@@ -52,7 +52,8 @@ class ReverseOp : public XlaOpKernel {
     xla::Literal lax;
     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
     std::vector<bool> revdims(x_shape.dims());
-    std::copy(lax.preds().begin(), lax.preds().end(), revdims.begin());
+    std::copy(lax.data<bool>().begin(), lax.data<bool>().end(),
+              revdims.begin());
     std::vector<int64> dimensions;
 
     for (int d = 0; d < x_shape.dims(); ++d) {
index 8fb7a74310b6760af8e6893ea2e3e3868c85f536..05354bca5bb089703fdcceb6f44648bbb98d004b 100644 (file)
@@ -121,7 +121,7 @@ class ExpandDimsOp : public XlaOpKernel {
     xla::Literal literal;
     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
 
-    int dim = literal.s32s(0);
+    int dim = literal.data<int32>()[0];
 
     OP_REQUIRES(ctx,
                 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
index 5c17b7fbf01990704c47ac564e69ba6d479a902e..c167642174b328a968d7f7ce1f0ad6e0ab8a7a68 100644 (file)
@@ -54,7 +54,8 @@ class TransposeOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
 
     std::vector<int32> perm(dims);
-    std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin());
+    std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
+              perm.begin());
 
     std::vector<int64> transposed_order;
     // Check whether permutation is a permutation of integers of [0 .. dims).
index 576cd9bf9abb43e29d9eb8f706e0f42ac2d038e9..fcbd157c6191655865d5e250fdf71338780bc2a6 100644 (file)
@@ -23,17 +23,17 @@ limitations under the License.
 namespace tensorflow {
 
 Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
-  literal->Clear();
+  xla::Shape literal_shape;
   TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
-      host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape()));
+      host_tensor.dtype(), host_tensor.shape(), &literal_shape));
 
-  literal->Reserve(host_tensor.NumElements());
+  *literal = xla::Literal(literal_shape);
 
   // memcpy over the payload ...
   // TODO(phawkins): handle string types.
   size_t total_bytes = host_tensor.TotalBytes();
   if (total_bytes > 0) {
-    void* dst_ptr = literal->MutableInternalData();
+    void* dst_ptr = literal->untyped_data();
     const void* src_ptr = DMAHelper::base(&host_tensor);
     memcpy(dst_ptr, src_ptr, total_bytes);
   }
@@ -56,7 +56,7 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
   *host_tensor = Tensor(target_type, shape);
   size_t total_bytes = host_tensor->TotalBytes();
   if (total_bytes > 0) {
-    const void* src_ptr = literal.InternalData();
+    const void* src_ptr = literal.untyped_data();
     void* dst_ptr = DMAHelper::base(host_tensor);
     memcpy(dst_ptr, src_ptr, total_bytes);
   }
index ec9e535b707beec6ea26dc81c7ee76b1d4da9225..77e24162676045b88dc8b62d2c6a4ecc1e738e96 100644 (file)
@@ -140,31 +140,31 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
   switch (type) {
     case xla::U8:
-      literal = *xla::Literal::CreateR0<uint8>(value);
+      literal = std::move(*xla::Literal::CreateR0<uint8>(value));
       break;
     case xla::U32:
-      literal = *xla::Literal::CreateR0<uint32>(value);
+      literal = std::move(*xla::Literal::CreateR0<uint32>(value));
       break;
     case xla::U64:
-      literal = *xla::Literal::CreateR0<uint64>(value);
+      literal = std::move(*xla::Literal::CreateR0<uint64>(value));
       break;
     case xla::S8:
-      literal = *xla::Literal::CreateR0<int8>(value);
+      literal = std::move(*xla::Literal::CreateR0<int8>(value));
       break;
     case xla::S32:
-      literal = *xla::Literal::CreateR0<int32>(value);
+      literal = std::move(*xla::Literal::CreateR0<int32>(value));
       break;
     case xla::S64:
-      literal = *xla::Literal::CreateR0<int64>(value);
+      literal = std::move(*xla::Literal::CreateR0<int64>(value));
       break;
     case xla::F32:
-      literal = *xla::Literal::CreateR0<float>(value);
+      literal = std::move(*xla::Literal::CreateR0<float>(value));
       break;
     case xla::F64:
-      literal = *xla::Literal::CreateR0<double>(value);
+      literal = std::move(*xla::Literal::CreateR0<double>(value));
       break;
     case xla::C64:
-      literal = *xla::Literal::CreateR0<complex64>(value);
+      literal = std::move(*xla::Literal::CreateR0<complex64>(value));
       break;
     case xla::PRED:
       LOG(FATAL) << "pred element type is not integral";
@@ -172,11 +172,12 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
     case xla::U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
     case xla::BF16:
-      literal = *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value));
+      literal = std::move(
+          *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
       break;
     case xla::F16:
-      literal =
-          *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value));
+      literal = std::move(
+          *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
       break;
     case xla::TUPLE:
       LOG(FATAL) << "tuple element type is not integral";
@@ -212,8 +213,8 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
         "elements.");
   }
 
-  *output = input;
-  output->mutable_shape()->Swap(&shape);
+  *output = input.Clone();
+  output->mutable_shape_do_not_use()->Swap(&shape);
   return Status::OK();
 }
 
index 73a91bfade74ed5f792b58575a73df95ea352658..a9dcb662b3178ad3d3703f6c27a1c838236cc82e 100644 (file)
@@ -206,15 +206,15 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index,
   xla::Literal literal;
   TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
   switch (literal.shape().element_type()) {
-    case xla::S32:
-      out->Clear();
-      *out->mutable_shape() = literal.shape();
-      out->mutable_shape()->set_element_type(xla::S64);
-      for (int32 x : literal.s32s()) {
-        out->add_s64s(x);
+    case xla::S32: {
+      *out = xla::Literal(
+          xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
+      auto src_data = literal.data<int32>();
+      for (int64 i = 0; i < src_data.size(); ++i) {
+        out->data<int64>()[i] = src_data[i];
       }
       return Status::OK();
-
+    }
     case xla::S64:
       *out = std::move(literal);
       return Status::OK();
index cd69c69889b2487ad12abea275e79fee4f5c51e6..88de17a5ff0d300cf05be0734ac972042d77c3d0 100644 (file)
@@ -300,6 +300,7 @@ cc_library(
         ":array2d",
         ":array3d",
         ":array4d",
+        ":shape_tree",
         ":shape_util",
         ":status_macros",
         ":types",
index 66937d64aff18817bbd5310e0c24e19556e9d727..d15ccb0c28522c647617153aaa8e738d029dfaba 100644 (file)
@@ -60,7 +60,7 @@ StatusOr<std::unique_ptr<Literal>> Client::Transfer(
         "server provided response without a literal in "
         "TransferToClient request");
   }
-  return MakeUnique<Literal>(response.literal());
+  return Literal::CreateFromProto(*response.mutable_literal());
 }
 
 StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
@@ -142,7 +142,7 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
         "TransferToClient request");
   }
 
-  return MakeUnique<Literal>(response.literal());
+  return Literal::CreateFromProto(response.literal());
 }
 
 Status Client::ResetDevice() {
index bb45a9a0827c53920d45c5dacff947b5b83cbd01..72d92ef0f7d4cbe2fc5d46d577e4d953e168b02e 100644 (file)
@@ -158,15 +158,13 @@ bool ComputationBuilder::MakeWindow(
   return true;
 }
 
-ComputationDataHandle ComputationBuilder::ConstantOp(
-    const PopulateLiteral& populate) {
+ComputationDataHandle ComputationBuilder::ConstantLiteral(
+    const Literal& literal) {
   if (!first_error_.ok() || !PrepareComputation().ok()) {
     return ComputationDataHandle();
   }
 
   ConstantRequest request;
-  Literal literal;
-  populate(&literal);
   *request.mutable_literal() = literal.ToProto();
   VLOG(3) << "created constant: " << request.literal().ShortDebugString();
   OpRequest op_request;
@@ -180,12 +178,6 @@ ComputationDataHandle ComputationBuilder::ConstantOp(
   return ParseOpResponse(s, &response);
 }
 
-ComputationDataHandle ComputationBuilder::ConstantLiteral(
-    const Literal& literal) {
-  return ConstantOp(
-      [literal](Literal* mutable_literal) { *mutable_literal = literal; });
-}
-
 ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
                                                     const Shape& shape,
                                                     const string& name) {
@@ -1456,7 +1448,7 @@ StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
         "no computed literal in the provided response in ComputeConstant "
         "request");
   }
-  return MakeUnique<Literal>(response.literal());
+  return Literal::CreateFromProto(response.literal());
 }
 
 ComputationDataHandle ComputationBuilder::Map(
index 7a8d810191f089f033755472313e12327142bc83..afe0a722d8273f0038a2fe03febf47c76774425b 100644 (file)
@@ -830,8 +830,6 @@ class ComputationBuilder {
   Status first_error() const { return first_error_; }
 
  private:
-  using PopulateLiteral = std::function<void(Literal*)>;
-
   // Limited checking of convolution parameters. Returns false on
   // error.
   bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
@@ -850,11 +848,6 @@ class ComputationBuilder {
                   tensorflow::gtl::ArraySlice<int64> rhs_dilation,
                   Window* window);
 
-  // Internal helper method that makes a request for a constant operation -- the
-  // provided function is used to populate the literal before sending the
-  // request.
-  ComputationDataHandle ConstantOp(const PopulateLiteral& populate);
-
   // Internal helper method that does the building for an arbitrary unary op.
   ComputationDataHandle UnaryOp(UnaryOperation binop,
                                 const ComputationDataHandle& operand);
@@ -930,68 +923,66 @@ class ComputationBuilder {
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
-  return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); });
+  return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR1(
     tensorflow::gtl::ArraySlice<NativeT> values) {
-  return ConstantOp(
-      [&values](Literal* literal) { literal->PopulateR1(values); });
+  return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
                                                      NativeT value) {
-  return ConstantOp([length, value](Literal* literal) {
-    literal->PopulateWithValue(value, {length});
-  });
+  Literal literal(ShapeUtil::MakeShape(
+      primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
+  literal.PopulateWithValue(value);
+  return ConstantLiteral(literal);
 }
 
 inline ComputationDataHandle ComputationBuilder::ConstantR1(
     const tensorflow::core::Bitmap& values) {
-  return ConstantOp(
-      [&values](Literal* literal) { literal->PopulateR1(values); });
+  return ConstantLiteral(*Literal::CreateR1(values));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR2(
     std::initializer_list<std::initializer_list<NativeT>> values) {
-  return ConstantOp(
-      [&values](Literal* literal) { literal->PopulateR2(values); });
+  return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
     const Array<NativeT>& values, const Layout& layout) {
-  return ConstantOp([&values, &layout](Literal* literal) {
-    literal->PopulateFromArrayWithLayout(values, layout);
-  });
+  return ConstantLiteral(
+      *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantFromArray(
     const Array<NativeT>& values) {
-  return ConstantOp(
-      [&values](Literal* literal) { literal->PopulateFromArray(values); });
+  return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
     const Array2D<NativeT>& values, const Layout& layout) {
-  return ConstantFromArrayWithLayout(values, layout);
+  return ConstantLiteral(
+      *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
     const Array2D<NativeT>& values) {
-  return ConstantFromArray(values);
+  return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
 }
 
 template <typename NativeT>
 ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
     const Array3D<NativeT>& values, const Layout& layout) {
-  return ConstantFromArrayWithLayout(values, layout);
+  return ConstantLiteral(
+      *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
index 3e909f76f9b94e0d612254a30ce1f462019a1821..cc1735e6f2c1bd130d847dbf4de58896402de2f1 100644 (file)
@@ -27,14 +27,20 @@ limitations under the License.
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+
+using tensorflow::strings::Printf;
+using tensorflow::strings::StrCat;
+
+namespace xla {
+
 namespace {
-using tensorflow::int64;
 
 constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
 
@@ -46,9 +52,8 @@ void ConvertEndianShort(char* bytes, int64 size) {
     std::swap(bytes[i], bytes[i + 1]);
   }
 }
-}  // namespace
 
-namespace xla {
+}  // namespace
 
 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
   out << literal.ToString();
@@ -78,19 +83,83 @@ Literal::StrideConfig::StrideConfig(
   }
 }
 
+Literal::Literal(const Shape& shape)
+    : Literal(shape, /*allocate_arrays=*/true) {}
+
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+    : shape_(shape), pieces_(shape), owns_buffers_(true) {
+  CHECK(LayoutUtil::HasLayout(shape));
+  for (auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+
+    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+    if (ShapeUtil::IsArray(piece.subshape())) {
+      if (allocate_arrays) {
+        piece.set_buffer(new char[piece.size_bytes()]);
+      } else {
+        piece.set_buffer(nullptr);
+      }
+    }
+  }
+}
+
+Literal::~Literal() { DeallocateBuffers(); }
+
+void Literal::DeallocateBuffers() {
+  if (owns_buffers_) {
+    for (auto& pair : pieces_) {
+      Piece& piece = pair.second;
+      if (piece.buffer() != nullptr) {
+        delete[] piece.buffer();
+      }
+    }
+  }
+}
+
+Literal::Literal(Literal&& other) {
+  shape_ = std::move(other.shape_);
+  pieces_ = std::move(other.pieces_);
+  // We need to iterate through the pieces to set the subshape pointer
+  // properly. It must refer to subshapes within shape_.
+  for (auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+  }
+  owns_buffers_ = other.owns_buffers_;
+
+  other.shape_ = ShapeUtil::MakeNil();
+  other.pieces_ = ShapeTree<Piece>(other.shape_);
+  other.piece({}).set_subshape(&other.shape_);
+}
+
+Literal& Literal::operator=(Literal&& other) {
+  DeallocateBuffers();
+  shape_ = std::move(other.shape_);
+  pieces_ = std::move(other.pieces_);
+  // We need to iterate through the pieces to set the subshape pointer
+  // properly. It must refer to subshapes within shape_.
+  for (auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+  }
+  owns_buffers_ = other.owns_buffers_;
+
+  other.shape_ = ShapeUtil::MakeNil();
+  other.pieces_ = ShapeTree<Piece>(other.shape_);
+  other.piece({}).set_subshape(&other.shape_);
+  return *this;
+}
+
 std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
-  auto literal = MakeUnique<Literal>();
-  *literal->mutable_shape() = shape;
-  if (ShapeUtil::IsTuple(shape)) {
-    int64 num_elements = ShapeUtil::TupleElementCount(shape);
-    literal->tuple_literals_.resize(num_elements);
-    for (int i = 0; i < num_elements; ++i) {
-      std::unique_ptr<Literal> elem =
-          CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i));
-      literal->tuple_literals_[i] = std::move(*elem);
+  auto literal = MakeUnique<Literal>(shape);
+  for (auto& pair : literal->pieces_) {
+    Piece& piece = pair.second;
+    if (ShapeUtil::IsArray(piece.subshape())) {
+      memset(piece.untyped_data(), 0, piece.size_bytes());
     }
-  } else {
-    literal->Reserve(ShapeUtil::ElementsIn(literal->shape()));
   }
   return literal;
 }
@@ -101,29 +170,31 @@ std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
   return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
 }
 
-template <typename T>
-Status Literal::CopyRange(const Literal& src_literal,
-                          tensorflow::gtl::ArraySlice<int64> src_base,
-                          tensorflow::gtl::ArraySlice<int64> dest_base,
-                          tensorflow::gtl::ArraySlice<int64> copy_size) {
-  const Shape& src_shape = src_literal.shape();
-  const Shape& dest_shape = shape();
-  tensorflow::gtl::ArraySlice<T> src_data = src_literal.GetArraySlice<T>();
-  tensorflow::gtl::MutableArraySlice<T> dest_data = GetMutableArraySlice<T>();
-
-  TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size());
-  TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size());
+template <typename NativeT>
+Status Literal::CopySliceFromInternal(
+    const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+    tensorflow::gtl::ArraySlice<int64> dest_base,
+    tensorflow::gtl::ArraySlice<int64> copy_size) {
+  TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
+  TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
+
+  auto linear_index = [](const Shape& shape,
+                         tensorflow::gtl::ArraySlice<int64> multi_index) {
+    return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
+  };
 
-  if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) {
+  if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
+      ShapeUtil::Rank(shape()) == 0) {
     // If any of the two shapes are scalars, we can just call the StridedCopy()
     // directly, and we know we will be copying only one value.
     TF_RET_CHECK(copy_size.empty());
-    StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data,
-                src_literal.LinearIndex(src_base), 0, 1);
-  } else if (!ShapeUtil::HasZeroElements(dest_shape) &&
-             !ShapeUtil::HasZeroElements(src_shape)) {
-    // Perform copy if neither src literal nor dest literal has dimensions with
-    // zero element, otherwise it's a no-op.
+    StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
+                src_literal.data<NativeT>(),
+                linear_index(src_literal.shape(), src_base), 0, 1);
+  } else if (!ShapeUtil::HasZeroElements(shape()) &&
+             !ShapeUtil::HasZeroElements(src_literal.shape())) {
+    // Perform copy if neither src nor dest has dimensions with zero element,
+    // otherwise it's a no-op.
     TF_RET_CHECK(src_base.size() == dest_base.size());
     TF_RET_CHECK(src_base.size() == copy_size.size());
 
@@ -133,7 +204,8 @@ Status Literal::CopyRange(const Literal& src_literal,
     // proper stride size at the matching dimension.
     DimensionVector src_indexes(src_base.size(), 0);
     DimensionVector dest_indexes(dest_base.size(), 0);
-    StrideConfig stride_config(src_shape, dest_shape, copy_size);
+    Literal::StrideConfig stride_config(src_literal.shape(), shape(),
+                                        copy_size);
 
     auto copy_proc = [&](const std::vector<int64>& indexes) {
       // Map from multi-dimensional index, to source index.
@@ -143,89 +215,290 @@ Status Literal::CopyRange(const Literal& src_literal,
       std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
                      dest_indexes.begin(), std::plus<int64>());
 
-      int64 src_index = src_literal.LinearIndex(src_indexes);
-      int64 dest_index = LinearIndex(dest_indexes);
+      int64 src_index = linear_index(src_literal.shape(), src_indexes);
+      int64 dest_index = linear_index(shape(), dest_indexes);
 
-      StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data,
-                  src_index, stride_config.source_stride,
-                  stride_config.minor_loop_size);
+      StridedCopy(data<NativeT>(), dest_index, stride_config.dest_stride,
+                  src_literal.data<NativeT>(), src_index,
+                  stride_config.source_stride, stride_config.minor_loop_size);
       return true;
     };
 
-    ShapeUtil::ForEachIndex(src_shape, stride_config.base,
+    ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
                             stride_config.dimensions, stride_config.step,
                             copy_proc);
   }
   return Status::OK();
 }
 
-Status Literal::Copy(const Literal& src_literal,
-                     tensorflow::gtl::ArraySlice<int64> src_base,
-                     tensorflow::gtl::ArraySlice<int64> dest_base,
-                     tensorflow::gtl::ArraySlice<int64> copy_size) {
+std::vector<Literal> Literal::DecomposeTuple() {
+  CHECK(ShapeUtil::IsTuple(shape()));
+  std::vector<Literal> elements;
+  for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+    elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
+                               /*allocate_arrays=*/false));
+    Literal& element = elements.back();
+    for (auto& pair : element.pieces_) {
+      const ShapeIndex& index = pair.first;
+      Piece& dest_piece = pair.second;
+      ShapeIndex src_index = {i};
+      for (int64 j : index) {
+        src_index.push_back(j);
+      }
+      Piece& src_piece = piece(src_index);
+
+      // Move the respective buffer over to the element Literal.
+      dest_piece.set_buffer(src_piece.buffer());
+      src_piece.set_buffer(nullptr);
+    }
+  }
+  // Set this literal to be nil-shaped.
+  *this = Literal();
+  return elements;
+}
+
+/* static */ Literal Literal::MoveIntoTuple(
+    tensorflow::gtl::MutableArraySlice<Literal> elements) {
+  std::vector<Shape> element_shapes;
+  for (const Literal& element : elements) {
+    element_shapes.push_back(element.shape());
+  }
+  Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
+                  /*allocate_arrays=*/false);
+  for (int i = 0; i < elements.size(); ++i) {
+    TF_CHECK_OK(
+        literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
+  }
+  return literal;
+}
+
+namespace {
+
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
+// the array slices are indicated by dest_shape and src_shape respectively.
+template <typename NativeT>
+void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+                         tensorflow::gtl::ArraySlice<NativeT> src,
+                         const Shape& dest_shape, const Shape& src_shape) {
+  CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
+  if (ShapeUtil::HasZeroElements(dest_shape)) {
+    return;
+  }
+  std::vector<int64> index(ShapeUtil::Rank(dest_shape));
+  do {
+    dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
+        src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
+  } while (IndexUtil::BumpIndices(dest_shape, &index));
+}
+
+}  // namespace
+
+Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
+  if (ShapeUtil::Equal(subshape(), src.subshape())) {
+    // If the layouts are equal it's faster just to memcpy.
+    memcpy(buffer(), src.buffer(), src.size_bytes());
+  } else {
+    TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
+    std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
+    switch (subshape().element_type()) {
+#define COPY_ELEMENTS(XLA_T, NATIVE_T)                                    \
+  case (XLA_T):                                                           \
+    CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
+                                  subshape(), src.subshape());            \
+    break;
+      COPY_ELEMENTS(U8, uint8);
+      COPY_ELEMENTS(U16, uint16);
+      COPY_ELEMENTS(U32, uint32);
+      COPY_ELEMENTS(U64, uint64);
+      COPY_ELEMENTS(S8, int8);
+      COPY_ELEMENTS(S16, int16);
+      COPY_ELEMENTS(S32, int32);
+      COPY_ELEMENTS(S64, int64);
+      COPY_ELEMENTS(F16, half);
+      COPY_ELEMENTS(BF16, bfloat16);
+      COPY_ELEMENTS(F32, float);
+      COPY_ELEMENTS(F64, double);
+      COPY_ELEMENTS(C64, complex64);
+      COPY_ELEMENTS(PRED, bool);
+#undef COPY_ELEMENTS
+      default:
+        return Unimplemented(
+            "Unhandled primitive type %s",
+            PrimitiveType_Name(subshape().element_type()).c_str());
+    }
+  }
+  return Status::OK();
+}
+
+Status Literal::CopyFrom(const Literal& src_literal,
+                         const ShapeIndex& dest_shape_index,
+                         const ShapeIndex& src_shape_index) {
+  const Shape& dest_subshape =
+      ShapeUtil::GetSubshape(shape(), dest_shape_index);
+  const Shape& src_subshape =
+      ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
+  if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
+    return InvalidArgument(
+        "Destination subshape incompatible with source subshape: %s vs %s",
+        ShapeUtil::HumanString(dest_subshape).c_str(),
+        ShapeUtil::HumanString(src_subshape).c_str());
+  }
+
+  for (auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+    if (!ShapeUtil::IsArray(piece.subshape())) {
+      continue;
+    }
+
+    // Determine if this index is in the part of this literal that we want to
+    // copy over from src_literal.
+    bool in_subtree_to_copy = true;
+    for (int i = 0; i < dest_shape_index.size(); ++i) {
+      if (index[i] != dest_shape_index[i]) {
+        in_subtree_to_copy = false;
+        break;
+      }
+    }
+    if (!in_subtree_to_copy) {
+      continue;
+    }
+
+    // Construct the index of the corresponding piece in the source literal.
+    ShapeIndex src_piece_index = src_shape_index;
+    for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+      src_piece_index.push_back(index[i]);
+    }
+
+    TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
+  }
+  return Status::OK();
+}
+
+Status Literal::MoveFrom(Literal&& src_literal,
+                         const ShapeIndex& dest_shape_index) {
+  const Shape& dest_subshape =
+      ShapeUtil::GetSubshape(shape(), dest_shape_index);
+  if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
+    return InvalidArgument(
+        "Destination subshape not equal to source shape: %s vs %s",
+        ShapeUtil::HumanString(dest_subshape).c_str(),
+        ShapeUtil::HumanString(src_literal.shape()).c_str());
+  }
+
+  if (!(owns_buffers_ && src_literal.owns_buffers_)) {
+    return InvalidArgument(
+        "Source and destination literals must both own their buffers (ie, not "
+        "be views)");
+  }
+
+  for (auto& pair : src_literal.pieces_) {
+    const ShapeIndex& src_index = pair.first;
+    Piece& src_piece = pair.second;
+    if (!ShapeUtil::IsArray(src_piece.subshape())) {
+      continue;
+    }
+
+    ShapeIndex dest_index = dest_shape_index;
+    for (int64 i : src_index) {
+      dest_index.push_back(i);
+    }
+    Piece& dest_piece = piece(dest_index);
+    delete[] dest_piece.buffer();
+    dest_piece.set_buffer(src_piece.buffer());
+  }
+
+  src_literal.shape_ = ShapeUtil::MakeNil();
+  src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
+  src_literal.piece({}).set_subshape(&src_literal.shape_);
+  return Status::OK();
+}
+
+Status Literal::CopySliceFrom(const Literal& src_literal,
+                              tensorflow::gtl::ArraySlice<int64> src_base,
+                              tensorflow::gtl::ArraySlice<int64> dest_base,
+                              tensorflow::gtl::ArraySlice<int64> copy_size) {
+  TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
+  TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
+      << ShapeUtil::HumanString(src_literal.shape());
   TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
-  switch (src_literal.shape().element_type()) {
+
+  switch (shape().element_type()) {
     case U8:
-      return CopyRange<uint8>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
+                                          copy_size);
     case U16:
-      return CopyRange<uint16>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
+                                           copy_size);
     case U32:
-      return CopyRange<uint32>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
+                                           copy_size);
     case U64:
-      return CopyRange<uint64>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
+                                           copy_size);
     case S8:
-      return CopyRange<int8>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
+                                         copy_size);
     case S16:
-      return CopyRange<int16>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
+                                          copy_size);
     case S32:
-      return CopyRange<int32>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
+                                          copy_size);
     case S64:
-      return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
+                                          copy_size);
     case F16:
-      return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
+                                         copy_size);
     case BF16:
-      return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
+                                             copy_size);
     case F32:
-      return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
+                                          copy_size);
     case F64:
-      return CopyRange<double>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
+                                           copy_size);
     case C64:
-      return CopyRange<complex64>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
+                                              copy_size);
     case PRED:
-      return CopyRange<bool>(src_literal, src_base, dest_base, copy_size);
+      return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
+                                         copy_size);
     default:
       break;
   }
-  return Unimplemented("Unhandled primitive type %d",
-                       src_literal.shape().element_type());
+  return Unimplemented("Unhandled primitive type %d", shape().element_type());
 }
 
 /* static */ Literal Literal::Zero(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return *Literal::CreateR0<uint8>(0);
+      return std::move(*Literal::CreateR0<uint8>(0));
     case U32:
-      return *Literal::CreateR0<uint32>(0);
+      return std::move(*Literal::CreateR0<uint32>(0));
     case U64:
-      return *Literal::CreateR0<uint64>(0);
+      return std::move(*Literal::CreateR0<uint64>(0));
     case S8:
-      return *Literal::CreateR0<int8>(0);
+      return std::move(*Literal::CreateR0<int8>(0));
     case S32:
-      return *Literal::CreateR0<int32>(0);
+      return std::move(*Literal::CreateR0<int32>(0));
     case S64:
-      return *Literal::CreateR0<int64>(0);
+      return std::move(*Literal::CreateR0<int64>(0));
     case F16:
-      return *Literal::CreateR0<half>(static_cast<half>(0.0f));
+      return std::move(*Literal::CreateR0<half>(static_cast<half>(0.0f)));
     case BF16:
-      return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
+      return std::move(
+          *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
     case F32:
-      return *Literal::CreateR0<float>(0);
+      return std::move(*Literal::CreateR0<float>(0));
     case F64:
-      return *Literal::CreateR0<double>(0);
+      return std::move(*Literal::CreateR0<double>(0));
     case C64:
-      return *Literal::CreateR0<complex64>(0);
+      return std::move(*Literal::CreateR0<complex64>(0));
     case PRED:
-      return *Literal::CreateR0<bool>(false);
+      return std::move(*Literal::CreateR0<bool>(false));
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -241,29 +514,30 @@ Status Literal::Copy(const Literal& src_literal,
 /* static */ Literal Literal::One(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return *Literal::CreateR0<uint8>(1);
+      return std::move(*Literal::CreateR0<uint8>(1));
     case U32:
-      return *Literal::CreateR0<uint32>(1);
+      return std::move(*Literal::CreateR0<uint32>(1));
     case U64:
-      return *Literal::CreateR0<uint64>(1);
+      return std::move(*Literal::CreateR0<uint64>(1));
     case S8:
-      return *Literal::CreateR0<int8>(1);
+      return std::move(*Literal::CreateR0<int8>(1));
     case S32:
-      return *Literal::CreateR0<int32>(1);
+      return std::move(*Literal::CreateR0<int32>(1));
     case S64:
-      return *Literal::CreateR0<int64>(1);
+      return std::move(*Literal::CreateR0<int64>(1));
     case F16:
-      return *Literal::CreateR0<half>(static_cast<half>(1.0f));
+      return std::move(*Literal::CreateR0<half>(static_cast<half>(1.0f)));
     case BF16:
-      return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
+      return std::move(
+          *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
     case F32:
-      return *Literal::CreateR0<float>(1);
+      return std::move(*Literal::CreateR0<float>(1));
     case F64:
-      return *Literal::CreateR0<double>(1);
+      return std::move(*Literal::CreateR0<double>(1));
     case C64:
-      return *Literal::CreateR0<complex64>(1);
+      return std::move(*Literal::CreateR0<complex64>(1));
     case PRED:
-      return *Literal::CreateR0<bool>(true);
+      return std::move(*Literal::CreateR0<bool>(true));
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -279,35 +553,42 @@ Status Literal::Copy(const Literal& src_literal,
 /* static */ Literal Literal::MinValue(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min());
+      return std::move(
+          *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
     case U32:
-      return *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min());
+      return std::move(
+          *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
     case U64:
-      return *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min());
+      return std::move(
+          *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
     case S8:
-      return *Literal::CreateR0<int8>(std::numeric_limits<int8>::min());
+      return std::move(
+          *Literal::CreateR0<int8>(std::numeric_limits<int8>::min()));
     case S32:
-      return *Literal::CreateR0<int32>(std::numeric_limits<int32>::min());
+      return std::move(
+          *Literal::CreateR0<int32>(std::numeric_limits<int32>::min()));
     case S64:
-      return *Literal::CreateR0<int64>(std::numeric_limits<int64>::min());
+      return std::move(
+          *Literal::CreateR0<int64>(std::numeric_limits<int64>::min()));
     case F32:
-      return *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity());
+      return std::move(
+          *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity()));
     case F64:
-      return *Literal::CreateR0<double>(
-          -std::numeric_limits<double>::infinity());
+      return std::move(
+          *Literal::CreateR0<double>(-std::numeric_limits<double>::infinity()));
     case C64:
       LOG(FATAL) << "C64 element type has no minimum value";
     case PRED:
-      return *Literal::CreateR0<bool>(false);
+      return std::move(*Literal::CreateR0<bool>(false));
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
     case F16:
-      return *Literal::CreateR0<half>(
-          static_cast<half>(-std::numeric_limits<float>::infinity()));
+      return std::move(*Literal::CreateR0<half>(
+          static_cast<half>(-std::numeric_limits<float>::infinity())));
     case BF16:
-      return *Literal::CreateR0<bfloat16>(
-          static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
+      return std::move(*Literal::CreateR0<bfloat16>(
+          static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
     case TUPLE:
       LOG(FATAL) << "tuple element type has no minimum value";
     case OPAQUE:
@@ -320,33 +601,40 @@ Status Literal::Copy(const Literal& src_literal,
 /* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max());
+      return std::move(
+          *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
     case U32:
-      return *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max());
+      return std::move(
+          *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
     case U64:
-      return *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max());
+      return std::move(
+          *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
     case S8:
-      return *Literal::CreateR0<int8>(std::numeric_limits<int8>::max());
+      return std::move(
+          *Literal::CreateR0<int8>(std::numeric_limits<int8>::max()));
     case S32:
-      return *Literal::CreateR0<int32>(std::numeric_limits<int32>::max());
+      return std::move(
+          *Literal::CreateR0<int32>(std::numeric_limits<int32>::max()));
     case S64:
-      return *Literal::CreateR0<int64>(std::numeric_limits<int64>::max());
+      return std::move(
+          *Literal::CreateR0<int64>(std::numeric_limits<int64>::max()));
     case F32:
-      return *Literal::CreateR0<float>(std::numeric_limits<float>::infinity());
+      return std::move(
+          *Literal::CreateR0<float>(std::numeric_limits<float>::infinity()));
     case F64:
-      return *Literal::CreateR0<double>(
-          std::numeric_limits<double>::infinity());
+      return std::move(
+          *Literal::CreateR0<double>(std::numeric_limits<double>::infinity()));
     case PRED:
-      return *Literal::CreateR0<bool>(true);
+      return std::move(*Literal::CreateR0<bool>(true));
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
     case F16:
-      return *Literal::CreateR0<half>(
-          static_cast<half>(std::numeric_limits<float>::infinity()));
+      return std::move(*Literal::CreateR0<half>(
+          static_cast<half>(std::numeric_limits<float>::infinity())));
     case BF16:
-      return *Literal::CreateR0<bfloat16>(
-          static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
+      return std::move(*Literal::CreateR0<bfloat16>(
+          static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
     case TUPLE:
       LOG(FATAL) << "tuple element type has no maximum value";
     case OPAQUE:
@@ -358,17 +646,29 @@ Status Literal::Copy(const Literal& src_literal,
 
 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
     const tensorflow::core::Bitmap& values) {
-  auto literal = MakeUnique<Literal>();
+  auto literal = MakeUnique<Literal>(
+      ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
   literal->PopulateR1(values);
   return literal;
 }
 
+void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+  CHECK_EQ(element_count(), values.bits());
+  CHECK_EQ(shape().element_type(), PRED);
+  for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
+    Set({i}, values.get(i));
+  }
+}
+
 /* static */ std::unique_ptr<Literal> Literal::CreateR1U8(
     tensorflow::StringPiece value) {
-  auto literal = MakeUnique<Literal>();
-  *literal->mutable_shape() =
-      ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
-  literal->set_u8s(tensorflow::StringPiece(value.ToString()));
+  auto literal = MakeUnique<Literal>(
+      ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+  for (int i = 0; i < value.size(); ++i) {
+    literal->Set<uint8>({i}, value[i]);
+  }
   return literal;
 }
 
@@ -382,26 +682,14 @@ Status Literal::Copy(const Literal& src_literal,
 
 std::unique_ptr<Literal> Literal::Relayout(
     const Layout& new_layout, const ShapeIndex& shape_index) const {
-  std::unique_ptr<Literal> outer_result = CloneToUnique();
-
-  const Literal* copy_from = this;
-  Literal* copy_to = outer_result.get();
-  for (int64 i = 0; i < shape_index.size(); i++) {
-    *ShapeUtil::GetMutableSubshape(copy_to->mutable_shape(), {shape_index, i})
-         ->mutable_layout() = new_layout;
-    copy_from = &copy_from->tuple_literals_[shape_index[i]];
-    copy_to = &copy_to->tuple_literals_[shape_index[i]];
-  }
-
-  DimensionVector base(ShapeUtil::Rank(copy_from->shape()), 0);
-  DimensionVector copy_size(copy_from->shape().dimensions().begin(),
-                            copy_from->shape().dimensions().end());
-
-  CHECK(ShapeUtil::IsArray(copy_from->shape()));
-  CHECK(ShapeUtil::IsArray(copy_to->shape()));
-  *copy_to->mutable_shape()->mutable_layout() = new_layout;
-  TF_CHECK_OK(copy_to->Copy(*copy_from, base, base, copy_size));
-  return outer_result;
+  // Create new shape with 'new_layout' set at the given shape index.
+  Shape new_shape = shape();
+  Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
+  TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
+  *subshape->mutable_layout() = new_layout;
+  auto result = MakeUnique<Literal>(new_shape);
+  TF_CHECK_OK(result->CopyFrom(*this));
+  return result;
 }
 
 std::unique_ptr<Literal> Literal::Relayout(
@@ -415,11 +703,9 @@ std::unique_ptr<Literal> Literal::Relayout(
       result->shape(),
       [this, &result](const Shape& subshape, const ShapeIndex& index) {
         if (ShapeUtil::IsArray(subshape)) {
-          DimensionVector base(ShapeUtil::Rank(subshape), 0);
-          DimensionVector copy_size(subshape.dimensions().begin(),
-                                    subshape.dimensions().end());
-          TF_CHECK_OK(result->GetSubliteral(index).Copy(GetSubliteral(index),
-                                                        base, base, copy_size));
+          TF_CHECK_OK(result->CopyFrom(*this,
+                                       /*dest_shape_index=*/index,
+                                       /*src_shape_index=*/index));
         }
       });
   return result;
@@ -427,7 +713,7 @@ std::unique_ptr<Literal> Literal::Relayout(
 
 StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
     tensorflow::gtl::ArraySlice<int64> dimensions) const {
-  if (ShapeUtil::IsTuple(shape())) {
+  if (!ShapeUtil::IsArray(shape())) {
     return InvalidArgument("Reshape does not support tuples.");
   }
   std::unique_ptr<Literal> output;
@@ -439,8 +725,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
   }
   // Because the layout is monotonic, we can simply reuse the same sequence of
   // values without changing their order.
-  *output->mutable_shape() =
-      ShapeUtil::MakeShape(shape().element_type(), dimensions);
+  output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
 
   int64 elements_before = ShapeUtil::ElementsIn(shape());
   int64 elements_after = ShapeUtil::ElementsIn(output->shape());
@@ -456,7 +741,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
 
 std::unique_ptr<Literal> Literal::Transpose(
     tensorflow::gtl::ArraySlice<int64> permutation) const {
-  CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose";
+  CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
   CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
       << "Given permutation is not a permutation of dimension numbers";
   // To transpose the array, we just permute the dimensions and layout, and
@@ -488,15 +773,15 @@ std::unique_ptr<Literal> Literal::Transpose(
   std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
   DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
             ShapeUtil::ByteSizeOf(shape()));
-  std::memcpy(new_literal->MutableInternalData(), InternalData(),
-              ShapeUtil::ByteSizeOf(shape()));
+  std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
+              root_piece().size_bytes());
   return new_literal;
 }
 
 std::unique_ptr<Literal> Literal::Slice(
     tensorflow::gtl::ArraySlice<int64> start_indices,
     tensorflow::gtl::ArraySlice<int64> limit_indices) const {
-  CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape";
+  CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
 
   DimensionVector result_dimensions;
   for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
@@ -510,9 +795,7 @@ std::unique_ptr<Literal> Literal::Slice(
       ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
                                      LayoutUtil::MinorToMajor(shape()));
 
-  auto result_literal = MakeUnique<Literal>();
-  *result_literal->mutable_shape() = result_shape;
-  result_literal->Reserve(ShapeUtil::ElementsIn(result_shape));
+  auto result_literal = MakeUnique<Literal>(result_shape);
 
   DimensionVector new_indices(ShapeUtil::Rank(result_shape));
   switch (result_shape.element_type()) {
@@ -552,43 +835,49 @@ std::unique_ptr<Literal> Literal::Slice(
   }
 }
 
+Literal Literal::Clone() const {
+  Literal result(shape());
+  TF_CHECK_OK(result.CopyFrom(*this));
+  return result;
+}
+
 std::unique_ptr<Literal> Literal::CloneToUnique() const {
-  auto unique = MakeUnique<Literal>();
-  *unique = *this;
-  return unique;
+  auto result = MakeUnique<Literal>(shape());
+  TF_CHECK_OK(result->CopyFrom(*this));
+  return result;
 }
 
-string Literal::GetAsString(
-    tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  switch (shape().element_type()) {
+string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+                            const ShapeIndex& shape_index) const {
+  const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+  switch (subshape.element_type()) {
     case PRED:
-      return Get<bool>(multi_index) ? "true" : "false";
+      return Get<bool>(multi_index, shape_index) ? "true" : "false";
     case U8:
-      return tensorflow::strings::StrCat(Get<uint8>(multi_index));
+      return StrCat(Get<uint8>(multi_index, shape_index));
     case S32:
-      return tensorflow::strings::StrCat(Get<int32>(multi_index));
+      return StrCat(Get<int32>(multi_index, shape_index));
     case S64:
-      return tensorflow::strings::StrCat(Get<int64>(multi_index));
+      return StrCat(Get<int64>(multi_index, shape_index));
     case U32:
-      return tensorflow::strings::StrCat(Get<uint32>(multi_index));
+      return StrCat(Get<uint32>(multi_index, shape_index));
     case U64:
-      return tensorflow::strings::StrCat(Get<uint64>(multi_index));
+      return StrCat(Get<uint64>(multi_index, shape_index));
     case F32:
-      return tensorflow::strings::StrCat(Get<float>(multi_index));
+      return StrCat(Get<float>(multi_index, shape_index));
     case F64:
-      return tensorflow::strings::StrCat(Get<double>(multi_index));
+      return StrCat(Get<double>(multi_index, shape_index));
     case C64: {
-      complex64 c = Get<complex64>(multi_index);
-      return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")");
+      complex64 c = Get<complex64>(multi_index, shape_index);
+      return StrCat("(", c.real(), ", ", c.imag(), ")");
     }
     case F16:
-      return tensorflow::strings::StrCat(Get<half>(multi_index));
+      return StrCat(Get<half>(multi_index, shape_index));
     case BF16:
-      return tensorflow::strings::StrCat(
-          static_cast<float>(Get<bfloat16>(multi_index)));
+      return StrCat(
+          static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
     default:
-      return tensorflow::strings::StrCat(
-          "[", PrimitiveType_Name(shape().element_type()), "]");
+      return StrCat("[", PrimitiveType_Name(subshape.element_type()), "]");
   }
 }
 
@@ -614,13 +903,11 @@ StatusOr<int64> Literal::GetIntegralAsS64(
   }
 }
 
-int64 Literal::LinearIndex(
-    tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
-}
+namespace {
 
-string Literal::ToString(bool print_layout) const {
-  std::vector<string> pieces;
+void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
+                    bool print_layout, std::vector<string>* pieces) {
+  const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
 
   auto shape_to_string = [print_layout](const Shape& shape) {
     if (print_layout) {
@@ -631,277 +918,151 @@ string Literal::ToString(bool print_layout) const {
   };
 
   auto element_to_string =
-      [this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
-    PrimitiveType element_type = shape().element_type();
+      [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+    PrimitiveType element_type = subshape.element_type();
     if (element_type == PRED) {
       // We display predicates in a densely packed form.
-      return Get<bool>(indices) ? "1" : "0";
+      return literal.Get<bool>(indices, shape_index) ? "1" : "0";
     }
     return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
-           GetAsString(indices);
+           literal.GetAsString(indices, shape_index);
   };
 
   // TODO(b/32894291): refactor this code to reduce code duplication.
-  if (ShapeUtil::IsTuple(shape())) {
-    pieces.push_back(shape_to_string(shape()));
-    pieces.push_back(" (\n");
-    pieces.push_back(tensorflow::str_util::Join(
-        tuple_literals(), ",\n", [](string* out, const Literal& element) {
-          tensorflow::strings::StrAppend(out, element.ToString());
-        }));
-    pieces.push_back("\n)");
-  } else if (ShapeUtil::Rank(shape()) == 0) {
-    pieces.push_back(GetAsString({}));
-  } else if (ShapeUtil::Rank(shape()) == 1) {
-    pieces.push_back("{");
-    for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
-      pieces.push_back(element_to_string({i0}));
+  if (ShapeUtil::IsTuple(subshape)) {
+    pieces->push_back(shape_to_string(subshape));
+    pieces->push_back(" (\n");
+    std::vector<string> tuple_pieces;
+    for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
+      ShapeIndex element_index = shape_index;
+      element_index.push_back(i);
+      std::vector<string> element_pieces;
+      ToStringHelper(literal, element_index, print_layout, &element_pieces);
+      tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
+    }
+    pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
+    pieces->push_back("\n)");
+  } else if (ShapeUtil::Rank(subshape) == 0) {
+    pieces->push_back(literal.GetAsString({}, shape_index));
+  } else if (ShapeUtil::Rank(subshape) == 1) {
+    pieces->push_back("{");
+    for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+      pieces->push_back(element_to_string({i0}));
     }
-    pieces.push_back("}");
-  } else if (ShapeUtil::Rank(shape()) == 2) {
-    pieces.push_back(shape_to_string(shape()));
-    pieces.push_back(" {\n");
-    for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
-      pieces.push_back("  { ");
-      for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
-        pieces.push_back(element_to_string({i0, i1}));
+    pieces->push_back("}");
+  } else if (ShapeUtil::Rank(subshape) == 2) {
+    pieces->push_back(shape_to_string(subshape));
+    pieces->push_back(" {\n");
+    for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+      pieces->push_back("  { ");
+      for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+        pieces->push_back(element_to_string({i0, i1}));
       }
-      pieces.push_back(" ");
-      pieces.push_back(i0 == shape().dimensions(0) - 1 ? "}\n" : "},\n");
+      pieces->push_back(" ");
+      pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
     }
-    pieces.push_back("}");
-  } else if (ShapeUtil::Rank(shape()) == 3) {
-    pieces.push_back(shape_to_string(shape()));
-    pieces.push_back(" {\n");
-    for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
-      pieces.push_back(i0 > 0 ? ",\n{" : "{");
-      for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
-        pieces.push_back(i1 > 0 ? ",\n  { " : " { ");
-        for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) {
-          pieces.push_back(element_to_string({i0, i1, i2}));
+    pieces->push_back("}");
+  } else if (ShapeUtil::Rank(subshape) == 3) {
+    pieces->push_back(shape_to_string(subshape));
+    pieces->push_back(" {\n");
+    for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+      pieces->push_back(i0 > 0 ? ",\n{" : "{");
+      for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+        pieces->push_back(i1 > 0 ? ",\n  { " : " { ");
+        for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+          pieces->push_back(element_to_string({i0, i1, i2}));
         }
-        pieces.push_back(" }");
+        pieces->push_back(" }");
       }
-      pieces.push_back(" }");
+      pieces->push_back(" }");
     }
-    pieces.push_back("\n}");
-  } else if (ShapeUtil::Rank(shape()) == 4) {
-    pieces.push_back(shape_to_string(shape()));
-    pieces.push_back(" {\n");
-    for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
-      pieces.push_back(tensorflow::strings::Printf("  {  /*i0=%lld*/\n", i0));
-      for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
-        pieces.push_back(
-            tensorflow::strings::Printf("    {  /*i1=%lld*/\n", i1));
-        for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) {
-          pieces.push_back("      {");
-          for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) {
-            pieces.push_back(element_to_string({i0, i1, i2, i3}));
+    pieces->push_back("\n}");
+  } else if (ShapeUtil::Rank(subshape) == 4) {
+    pieces->push_back(shape_to_string(subshape));
+    pieces->push_back(" {\n");
+    for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+      pieces->push_back(Printf("  {  /*i0=%lld*/\n", i0));
+      for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+        pieces->push_back(Printf("    {  /*i1=%lld*/\n", i1));
+        for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+          pieces->push_back("      {");
+          for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+            pieces->push_back(element_to_string({i0, i1, i2, i3}));
           }
-          pieces.push_back(i2 == shape().dimensions(2) - 1 ? "}\n" : "},\n");
+          pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
         }
-        pieces.push_back(i1 == shape().dimensions(1) - 1 ? "    }\n"
-                                                         : "    },\n");
+        pieces->push_back(i1 == subshape.dimensions(1) - 1 ? "    }\n"
+                                                           : "    },\n");
       }
-      pieces.push_back(i0 == shape().dimensions(0) - 1 ? "  }\n" : "  },\n");
+      pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "  }\n" : "  },\n");
     }
-    pieces.push_back("}");
-  } else if (ShapeUtil::Rank(shape()) == 5) {
-    pieces.push_back(shape_to_string(shape()));
-    pieces.push_back(" {\n");
-    for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
-      pieces.push_back(tensorflow::strings::Printf("  {  /*i0=%lld*/\n", i0));
-      for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
-        pieces.push_back(
-            tensorflow::strings::Printf("    {  /*i1=%lld*/\n", i1));
-        for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) {
-          pieces.push_back(
-              tensorflow::strings::Printf("      {  /*i2=%lld*/\n", i2));
-          for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) {
-            pieces.push_back("        {");
-            for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) {
-              pieces.push_back(element_to_string({i0, i1, i2, i3, i4}));
+    pieces->push_back("}");
+  } else if (ShapeUtil::Rank(subshape) == 5) {
+    pieces->push_back(shape_to_string(subshape));
+    pieces->push_back(" {\n");
+    for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+      pieces->push_back(Printf("  {  /*i0=%lld*/\n", i0));
+      for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+        pieces->push_back(Printf("    {  /*i1=%lld*/\n", i1));
+        for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+          pieces->push_back(Printf("      {  /*i2=%lld*/\n", i2));
+          for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+            pieces->push_back("        {");
+            for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
+              pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
             }
-            pieces.push_back(i3 == shape().dimensions(3) - 1 ? "}\n" : "},\n");
+            pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
+                                                               : "},\n");
           }
-          pieces.push_back(i2 == shape().dimensions(2) - 1 ? "      }\n"
-                                                           : "      },\n");
+          pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "      }\n"
+                                                             : "      },\n");
         }
-        pieces.push_back(i1 == shape().dimensions(1) - 1 ? "    }\n"
-                                                         : "    },\n");
+        pieces->push_back(i1 == subshape.dimensions(1) - 1 ? "    }\n"
+                                                           : "    },\n");
       }
-      pieces.push_back(i0 == shape().dimensions(0) - 1 ? "  }\n" : "  },\n");
+      pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "  }\n" : "  },\n");
     }
-    pieces.push_back("}");
+    pieces->push_back("}");
   } else {
-    pieces.push_back(shape_to_string(shape()));
-    pieces.push_back(" {");
-    EachCellAsString(
+    pieces->push_back(shape_to_string(subshape));
+    pieces->push_back(" {");
+    literal.EachCellAsString(
         [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
-          pieces.push_back(" ");
-          pieces.push_back(value);
+          pieces->push_back(" ");
+          pieces->push_back(value);
         });
-    pieces.push_back("}");
+    pieces->push_back("}");
   }
+}
 
+}  // namespace
+
+string Literal::ToString(bool print_layout) const {
+  std::vector<string> pieces;
+  ToStringHelper(*this, {}, print_layout, &pieces);
   return tensorflow::str_util::Join(pieces, "");
 }
 
 /* static */ std::unique_ptr<Literal> Literal::MakeTuple(
     tensorflow::gtl::ArraySlice<const Literal*> elements) {
-  auto literal = MakeUnique<Literal>();
-  std::vector<Shape> shape;
-  for (const Literal* tuple_element : elements) {
-    *literal->add_tuple_literals() = *tuple_element;
-    shape.push_back(tuple_element->shape());
+  std::vector<Shape> element_shapes;
+  for (const Literal* element : elements) {
+    element_shapes.push_back(element->shape());
+  }
+  auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+  for (int i = 0; i < elements.size(); ++i) {
+    TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
   }
-  *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape);
   return literal;
 }
 
 /* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
     std::vector<std::unique_ptr<Literal>> elements) {
-  auto literal = MakeUnique<Literal>();
-  std::vector<Shape> shape;
-  for (auto& tuple_element : elements) {
-    shape.push_back(tuple_element->shape());
-    *literal->add_tuple_literals() = std::move(*tuple_element);
-  }
-  *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape);
-  return literal;
-}
-
-const void* Literal::InternalData() const {
-  return const_cast<const void*>(
-      const_cast<Literal*>(this)->MutableInternalData());
-}
-
-void* Literal::MutableInternalData() {
-  // NOTE: We access the vectors directly to avoid the const reference
-  // created by the accessor functions.
-  switch (shape().element_type()) {
-    case PRED:
-    case U8:
-      return reinterpret_cast<void*>(u8s_.data());
-    case S32:
-      return reinterpret_cast<void*>(s32s_.data());
-    case S64:
-      return reinterpret_cast<void*>(s64s_.data());
-    case U32:
-      return reinterpret_cast<void*>(u32s_.data());
-    case U64:
-      return reinterpret_cast<void*>(u64s_.data());
-    case F32:
-      return reinterpret_cast<void*>(f32s_.data());
-    case F64:
-      return reinterpret_cast<void*>(f64s_.data());
-    case C64:
-      return reinterpret_cast<void*>(c64s_.data());
-    case F16:
-      return reinterpret_cast<void*>(f16s_.data());
-    case BF16:
-      return reinterpret_cast<void*>(bf16s_.data());
-    default:
-      LOG(FATAL) << "primitive type not supported in literals: "
-                 << PrimitiveType_Name(shape().element_type());
+  std::vector<const Literal*> element_ptrs;
+  for (const auto& element : elements) {
+    element_ptrs.push_back(element.get());
   }
-}
-
-void Literal::Reserve(int64 num_elements) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  switch (shape().element_type()) {
-    case PRED:
-      Resize<bool>(num_elements, false);
-      break;
-    case S8:
-      Resize<int8>(num_elements, 0);
-      break;
-    case U8:
-      Resize<uint8>(num_elements, 0);
-      break;
-    case S32:
-      Resize<int32>(num_elements, 0);
-      break;
-    case S64:
-      Resize<int64>(num_elements, 0);
-      break;
-    case U32:
-      Resize<uint32>(num_elements, 0);
-      break;
-    case U64:
-      Resize<uint64>(num_elements, 0);
-      break;
-    case F32:
-      Resize<float>(num_elements, 0);
-      break;
-    case F64:
-      Resize<double>(num_elements, 0);
-      break;
-    case C64:
-      Resize<complex64>(num_elements, 0);
-      break;
-    case F16:
-      Resize<half>(num_elements, static_cast<half>(0.0f));
-      break;
-    case BF16:
-      Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
-      break;
-    default:
-      LOG(FATAL) << "primitive type not supported in literals: "
-                 << PrimitiveType_Name(shape().element_type());
-  }
-}
-
-tensorflow::Status Literal::ValidateLiteral() const {
-  TF_CHECK_OK(ShapeUtil::ValidateShape(shape()));
-  int64 expected = ShapeUtil::ElementsIn(shape());
-  int64 actual = -1;
-  switch (shape().element_type()) {
-    case PRED:
-    case U8:
-      actual = u8s_size();
-      break;
-    case S32:
-      actual = s32s_size();
-      break;
-    case U32:
-      actual = u32s_size();
-      break;
-    case S64:
-      actual = s64s_size();
-      break;
-    case U64:
-      actual = u64s_size();
-      break;
-    case F32:
-      actual = f32s_size();
-      break;
-    case F64:
-      actual = f64s_size();
-      break;
-    case C64:
-      actual = c64s_size();
-      break;
-    case F16:
-      actual = f16s().size() / sizeof(half);
-      break;
-    case BF16:
-      actual = bf16s().size();
-      break;
-    default:
-      return tensorflow::errors::Unimplemented(
-          "unhandled element type for literal validation: " +
-          PrimitiveType_Name(shape().element_type()));
-  }
-
-  if (expected != actual) {
-    return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf(
-        "literal has bad number of elements for its shape %s: want %lld "
-        "got %lld",
-        ShapeUtil::HumanString(shape()).c_str(), expected, actual));
-  }
-
-  return tensorflow::Status::OK();
+  return MakeTuple(element_ptrs);
 }
 
 void Literal::EachCellAsString(
@@ -920,17 +1081,13 @@ void Literal::EachCellAsString(
 namespace {
 template <typename NativeSrcT, typename NativeDestT>
 std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
-  auto result_literal = MakeUnique<Literal>();
-  Shape* result_shape = result_literal->mutable_shape();
-  *result_shape = src_literal.shape();
-  result_shape->set_element_type(
-      primitive_util::NativeToPrimitiveType<NativeDestT>());
-  result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
-  tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
-      src_literal.GetArraySlice<NativeSrcT>();
-  tensorflow::gtl::MutableArraySlice<NativeDestT> dest_data =
-      result_literal->GetMutableArraySlice<NativeDestT>();
-  int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
+  CHECK(ShapeUtil::IsArray(src_literal.shape()));
+  auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
+      src_literal.shape(),
+      primitive_util::NativeToPrimitiveType<NativeDestT>()));
+  auto src_data = src_literal.data<NativeSrcT>();
+  auto dest_data = result_literal->template data<NativeDestT>();
+  int64 num_elements = src_literal.element_count();
 
   for (int64 i = 0; i < num_elements; ++i) {
     dest_data[i] = static_cast<NativeDestT>(src_data[i]);
@@ -940,18 +1097,16 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
 
 template <PrimitiveType primitive_src_type>
 std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
-  auto result_literal = MakeUnique<Literal>();
-  Shape* result_shape = result_literal->mutable_shape();
-  *result_shape = src_literal.shape();
-  result_shape->set_element_type(C64);
-  result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
+  CHECK(ShapeUtil::IsArray(src_literal.shape()));
+  auto result_literal = MakeUnique<Literal>(
+      ShapeUtil::ChangeElementType(src_literal.shape(), C64));
   using NativeSrcT =
       typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
   tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
-      src_literal.GetArraySlice<NativeSrcT>();
+      src_literal.data<NativeSrcT>();
   tensorflow::gtl::MutableArraySlice<complex64> dest_data =
-      result_literal->GetMutableArraySlice<complex64>();
-  int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
+      result_literal->data<complex64>();
+  int64 num_elements = src_literal.element_count();
   for (int64 i = 0; i < num_elements; ++i) {
     dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
   }
@@ -996,10 +1151,12 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
           PrimitiveType_Name(primitive_dest_type).c_str());
   }
 }
+
 }  // namespace
 
 StatusOr<std::unique_ptr<Literal>> Literal::Convert(
     PrimitiveType primitive_dest_type) const {
+  TF_RET_CHECK(ShapeUtil::IsArray(shape()));
   switch (shape().element_type()) {
 #define CONVERT_IF_DEST_TYPE_MATCHES(type) \
   case (type):                             \
@@ -1024,356 +1181,192 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
   }
 }
 
-namespace {
-
-// Helper function which compares whether the elements of literal1 are equal to
-// the elements of literal2. Recursively iterates through the entire
-// multidimensional index space and compares the literal elements
-// one-by-one. literal1 and literal2 must be compatible (same dimensions and
-// type).
 template <typename NativeT>
-bool EqualElements(const Literal& literal1, const Literal& literal2,
-                   int dimension, std::vector<int64>* multi_index) {
-  if (dimension == ShapeUtil::Rank(literal1.shape())) {
-    return (literal1.Get<NativeT>(*multi_index) ==
-            literal2.Get<NativeT>(*multi_index));
+bool Literal::Piece::EqualElementsInternal(
+    const Literal::Piece& other, std::vector<int64>* multi_index) const {
+  if (multi_index->size() == ShapeUtil::Rank(subshape())) {
+    return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
   }
-  for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) {
-    (*multi_index)[dimension] = i;
-    if (!EqualElements<NativeT>(literal1, literal2, dimension + 1,
-                                multi_index)) {
+  for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
+    multi_index->push_back(i);
+    if (!EqualElementsInternal<NativeT>(other, multi_index)) {
       return false;
     }
+    multi_index->pop_back();
   }
   return true;
 }
 
-}  // namespace
+bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
+  DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
+
+  std::vector<int64> multi_index;
+  switch (subshape().element_type()) {
+    case PRED:
+      return EqualElementsInternal<bool>(other, &multi_index);
+    case U8:
+      return EqualElementsInternal<uint8>(other, &multi_index);
+    case S32:
+      return EqualElementsInternal<int32>(other, &multi_index);
+    case S64:
+      return EqualElementsInternal<int64>(other, &multi_index);
+    case U32:
+      return EqualElementsInternal<uint32>(other, &multi_index);
+    case U64:
+      return EqualElementsInternal<uint64>(other, &multi_index);
+    case F32:
+      return EqualElementsInternal<float>(other, &multi_index);
+    case F64:
+      return EqualElementsInternal<double>(other, &multi_index);
+    case F16:
+      return EqualElementsInternal<half>(other, &multi_index);
+    case BF16:
+      return EqualElementsInternal<bfloat16>(other, &multi_index);
+    case C64:
+      return EqualElementsInternal<complex64>(other, &multi_index);
+    default:
+      LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
+                 << PrimitiveType_Name(subshape().element_type());
+  }
+}
 
 bool Literal::operator==(const Literal& other) const {
   if (!ShapeUtil::Compatible(shape(), other.shape())) {
     return false;
   }
-  if (ShapeUtil::IsTuple(shape())) {
-    // Because the shapes are compatible, they must have the same number of
-    // tuple elements.
-    CHECK_EQ(tuple_literals_size(), other.tuple_literals_size());
-    for (int i = 0; i < tuple_literals_size(); ++i) {
-      if (tuple_literals(i) != other.tuple_literals(i)) {
-        return false;
-      }
+  for (const auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    const Piece& piece = pair.second;
+    if (!ShapeUtil::IsArray(piece.subshape())) {
+      continue;
     }
-    return true;
-  } else {
-    std::vector<int64> multi_index(ShapeUtil::Rank(shape()), 0);
-    switch (shape().element_type()) {
-      case PRED:
-        return EqualElements<bool>(*this, other, 0, &multi_index);
-      case U8:
-        return EqualElements<uint8>(*this, other, 0, &multi_index);
-      case S32:
-        return EqualElements<int32>(*this, other, 0, &multi_index);
-      case S64:
-        return EqualElements<int64>(*this, other, 0, &multi_index);
-      case U32:
-        return EqualElements<uint32>(*this, other, 0, &multi_index);
-      case U64:
-        return EqualElements<uint64>(*this, other, 0, &multi_index);
-      case F32:
-        return EqualElements<float>(*this, other, 0, &multi_index);
-      case F64:
-        return EqualElements<double>(*this, other, 0, &multi_index);
-      case F16:
-        return EqualElements<half>(*this, other, 0, &multi_index);
-      case BF16:
-        return EqualElements<bfloat16>(*this, other, 0, &multi_index);
-      case C64:
-        return EqualElements<complex64>(*this, other, 0, &multi_index);
-      default:
-        LOG(FATAL) << "Unimplemented: Literal::Equal for type "
-                   << PrimitiveType_Name(shape().element_type());
+
+    const Piece& other_piece = other.piece(index);
+    if (!piece.EqualElements(other_piece)) {
+      return false;
     }
   }
+  return true;
 }
 
-template <>
-tensorflow::gtl::MutableArraySlice<bool> Literal::GetMutableArraySlice() {
-  auto values = mutable_preds();
-  return tensorflow::gtl::MutableArraySlice<bool>(
-      reinterpret_cast<bool*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int8> Literal::GetMutableArraySlice() {
-  auto values = mutable_u8s();
-  return tensorflow::gtl::MutableArraySlice<int8>(
-      reinterpret_cast<int8*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint8> Literal::GetMutableArraySlice() {
-  auto values = mutable_u8s();
-  return tensorflow::gtl::MutableArraySlice<uint8>(values->data(),
-                                                   values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int16> Literal::GetMutableArraySlice() {
-  auto values = mutable_s16s();
-  return tensorflow::gtl::MutableArraySlice<int16>(values->data(),
-                                                   values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint16> Literal::GetMutableArraySlice() {
-  auto values = mutable_u16s();
-  return tensorflow::gtl::MutableArraySlice<uint16>(values->data(),
-                                                    values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int32> Literal::GetMutableArraySlice() {
-  auto values = mutable_s32s();
-  return tensorflow::gtl::MutableArraySlice<int32>(values->data(),
-                                                   values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint32> Literal::GetMutableArraySlice() {
-  auto values = mutable_u32s();
-  return tensorflow::gtl::MutableArraySlice<uint32>(values->data(),
-                                                    values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int64> Literal::GetMutableArraySlice() {
-  static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) &&
-                    alignof(int64) == alignof(tensorflow::protobuf_int64),
-                "The int64 and tensorflow::protobuf_int64 types are not "
-                "compatible");
-  auto values = mutable_s64s();
-  // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t
-  // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is
-  // necessary from the raw data pointer returned by the mutable_data() API.
-  return tensorflow::gtl::MutableArraySlice<int64>(
-      reinterpret_cast<int64*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint64> Literal::GetMutableArraySlice() {
-  static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) &&
-                    alignof(uint64) == alignof(tensorflow::protobuf_uint64),
-                "The uint64 and tensorflow::protobuf_uint64 types are not "
-                "compatible");
-  auto values = mutable_u64s();
-  // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t
-  // while tensorflow::uint64 is defined as unsigned long long, a
-  // reinterpret_cast<> is necessary from the raw data pointer returned by the
-  // mutable_data() API.
-  return tensorflow::gtl::MutableArraySlice<uint64>(
-      reinterpret_cast<uint64*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<float> Literal::GetMutableArraySlice() {
-  auto values = mutable_f32s();
-  return tensorflow::gtl::MutableArraySlice<float>(values->data(),
-                                                   values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice() {
-  auto values = mutable_f64s();
-  return tensorflow::gtl::MutableArraySlice<double>(values->data(),
-                                                    values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
-  auto values = mutable_c64s();
-  return {values->data(), values->size()};
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
-  auto values = mutable_f16s();
-  return tensorflow::gtl::MutableArraySlice<half>(values->data(),
-                                                  values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<bfloat16>
-Literal::GetMutableArraySlice<bfloat16>() {
-  auto values = mutable_bf16s();
-  return {values->data(), values->size()};
-}
-
-template <>
-tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
-  CHECK_EQ(shape().element_type(), PRED) << ShapeUtil::HumanString(shape());
-  return tensorflow::gtl::ArraySlice<bool>(
-      reinterpret_cast<const bool*>(preds().data()), preds().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint8> Literal::GetArraySlice<uint8>() const {
-  CHECK_EQ(shape().element_type(), U8) << ShapeUtil::HumanString(shape());
-  return tensorflow::gtl::ArraySlice<uint8>(
-      reinterpret_cast<const uint8*>(u8s().data()), u8s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int8> Literal::GetArraySlice<int8>() const {
-  CHECK_EQ(shape().element_type(), S8) << ShapeUtil::HumanString(shape());
-  return tensorflow::gtl::ArraySlice<int8>(
-      reinterpret_cast<const int8*>(u8s().data()), u8s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint16> Literal::GetArraySlice<uint16>() const {
-  CHECK_EQ(shape().element_type(), U16) << ShapeUtil::HumanString(shape());
-  return tensorflow::gtl::ArraySlice<uint16>(u16s().data(), u16s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int16> Literal::GetArraySlice<int16>() const {
-  CHECK_EQ(shape().element_type(), S16) << ShapeUtil::HumanString(shape());
-  return tensorflow::gtl::ArraySlice<int16>(s16s().data(), s16s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint32> Literal::GetArraySlice<uint32>() const {
-  CHECK_EQ(shape().element_type(), U32) << ShapeUtil::HumanString(shape());
-  return u32s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint64> Literal::GetArraySlice<uint64>() const {
-  CHECK_EQ(shape().element_type(), U64) << ShapeUtil::HumanString(shape());
-  return u64s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int32> Literal::GetArraySlice<int32>() const {
-  CHECK_EQ(shape().element_type(), S32) << ShapeUtil::HumanString(shape());
-  return s32s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int64> Literal::GetArraySlice<int64>() const {
-  CHECK_EQ(shape().element_type(), S64) << ShapeUtil::HumanString(shape());
-  return s64s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const {
-  CHECK_EQ(shape().element_type(), F64) << ShapeUtil::HumanString(shape());
-  return f64s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
-  CHECK_EQ(shape().element_type(), F16) << ShapeUtil::HumanString(shape());
-  return tensorflow::gtl::ArraySlice<half>(f16s().data(),
-                                           f16s().size() / sizeof(half));
-}
-
-template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
-  CHECK_EQ(shape().element_type(), BF16) << ShapeUtil::HumanString(shape());
-  return {bf16s().data(), bf16s().size()};
-}
-
-template <>
-tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
-    const {
-  CHECK_EQ(shape().element_type(), C64) << ShapeUtil::HumanString(shape());
-  return c64s();
-}
+namespace {
 
 template <typename NativeT>
-static bool AllElementsEqualValue(const Literal& literal, NativeT value) {
-  for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
-    auto multi_index =
-        IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
-    if (literal.Get<NativeT>(multi_index) != value) {
+static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
+                                  NativeT value) {
+  for (int64 i = 0; i < data.size(); ++i) {
+    if (data[i] != value) {
       return false;
     }
   }
   return true;
 }
 
+}  // namespace
+
 bool Literal::IsAll(int8 value) const {
-  switch (shape().element_type()) {
-    case U8:
-      if (value >= 0) {
-        return AllElementsEqualValue<uint8>(*this, value);
-      }
-      return false;
-    case U32:
-      if (value >= 0) {
-        return AllElementsEqualValue<uint32>(*this, value);
-      }
-      return false;
-    case U64:
-      if (value >= 0) {
-        return AllElementsEqualValue<uint64>(*this, value);
-      }
-      return false;
-    case S8:
-      return AllElementsEqualValue<int8>(*this, value);
-    case S32:
-      return AllElementsEqualValue<int32>(*this, value);
-    case S64:
-      return AllElementsEqualValue<int64>(*this, value);
-    case F32:
-      return AllElementsEqualValue<float>(*this, value);
-    case F64:
-      return AllElementsEqualValue<double>(*this, value);
-    case F16:
-      return AllElementsEqualValue<half>(*this, static_cast<half>(value));
-    case BF16:
-      return AllElementsEqualValue<bfloat16>(*this,
-                                             static_cast<bfloat16>(value));
-    case PRED:
-      if (value == 0) {
-        return AllElementsEqualValue<bool>(*this, false);
-      }
-      if (value == 1) {
-        return AllElementsEqualValue<bool>(*this, true);
+  for (const auto& pair : pieces_) {
+    const Piece& piece = pair.second;
+    if (!ShapeUtil::IsArray(piece.subshape())) {
+      continue;
+    }
+
+    auto piece_is_all = [&]() {
+      switch (shape().element_type()) {
+        case U8:
+          if (value >= 0) {
+            return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
+          }
+          return false;
+        case U32:
+          if (value >= 0) {
+            return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
+          }
+          return false;
+        case U64:
+          if (value >= 0) {
+            return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
+          }
+          return false;
+        case S8:
+          return AllElementsEqualValue<int8>(piece.data<int8>(), value);
+        case S32:
+          return AllElementsEqualValue<int32>(piece.data<int32>(), value);
+        case S64:
+          return AllElementsEqualValue<int64>(piece.data<int64>(), value);
+        case F32:
+          return AllElementsEqualValue<float>(piece.data<float>(), value);
+        case F64:
+          return AllElementsEqualValue<double>(piece.data<double>(), value);
+        case F16:
+          return AllElementsEqualValue<half>(piece.data<half>(),
+                                             static_cast<half>(value));
+        case BF16:
+          return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
+                                                 static_cast<bfloat16>(value));
+        case PRED:
+          if (value == 0) {
+            return AllElementsEqualValue<bool>(piece.data<bool>(), false);
+          }
+          if (value == 1) {
+            return AllElementsEqualValue<bool>(piece.data<bool>(), true);
+          }
+          return false;
+        default:
+          return false;
       }
       return false;
-    default:
+    };
+
+    if (!piece_is_all()) {
       return false;
+    }
   }
+  return true;
 }
 
 bool Literal::IsAllFloat(float value) const {
-  switch (shape().element_type()) {
-    case F32:
-      return AllElementsEqualValue<float>(*this, value);
-    case F64:
-      return AllElementsEqualValue<double>(*this, value);
-    case F16:
-      return AllElementsEqualValue<half>(*this, static_cast<half>(value));
-    case BF16:
-      return AllElementsEqualValue<bfloat16>(*this,
-                                             static_cast<bfloat16>(value));
-    default:
+  for (const auto& pair : pieces_) {
+    const Piece& piece = pair.second;
+    if (!ShapeUtil::IsArray(piece.subshape())) {
+      continue;
+    }
+
+    auto piece_is_all = [&]() {
+      switch (shape().element_type()) {
+        case F32:
+          return AllElementsEqualValue<float>(piece.data<float>(), value);
+        case F64:
+          return AllElementsEqualValue<double>(piece.data<double>(), value);
+        case F16:
+          return AllElementsEqualValue<half>(piece.data<half>(),
+                                             static_cast<half>(value));
+        case BF16:
+          return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
+                                                 static_cast<bfloat16>(value));
+        default:
+          return false;
+      }
+    };
+    if (!piece_is_all()) {
       return false;
+    }
   }
+  return true;
 }
 
 bool Literal::IsAllComplex(complex64 value) const {
   switch (shape().element_type()) {
     case C64:
-      return AllElementsEqualValue<complex64>(*this, value);
+      return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
+                                              value);
     default:
       return false;
   }
 }
 
 bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+  CHECK(ShapeUtil::IsArray(shape()));
   switch (shape().element_type()) {
     case U8:
       return Get<uint8>(indices) == 0;
@@ -1404,247 +1397,287 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
   }
 }
 
-template <>
-/* static */ void Literal::Resize<bool>(int64 num_elements, bool value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_preds()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<int8>(int64 num_elements, int8 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_u8s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<uint8>(int64 num_elements, uint8 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_u8s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<int32>(int64 num_elements, int32 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_s32s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<uint32>(int64 num_elements, uint32 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_u32s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<int64>(int64 num_elements, int64 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_s64s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<uint64>(int64 num_elements, uint64 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_u64s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<float>(int64 num_elements, float value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_f32s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<double>(int64 num_elements, double value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_f64s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<half>(int64 num_elements, half value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_f16s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_bf16s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
-  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
-  mutable_c64s()->resize(num_elements, value);
-}
+namespace {
 
 template <typename RepeatedFieldT, typename NativeT>
 void CopyToRepeatedField(RepeatedFieldT* dest,
-                         const std::vector<NativeT>& src) {
+                         const tensorflow::gtl::ArraySlice<NativeT> src) {
   *dest = RepeatedFieldT(src.begin(), src.end());
 }
 
-template <>
-void CopyToRepeatedField<tensorflow::protobuf::RepeatedField<float>, complex64>(
-    tensorflow::protobuf::RepeatedField<float>* dest,
-    const std::vector<complex64>& src) {
-  *dest = tensorflow::protobuf::RepeatedField<float>(
-      reinterpret_cast<const float*>(src.data()),
-      reinterpret_cast<const float*>(src.data()) + src.size() * 2);
-}
+}  // namespace
 
-LiteralProto Literal::ToProto() const {
-  LiteralProto proto;
-  proto.Clear();
-  *proto.mutable_shape() = shape();
-  switch (shape().element_type()) {
+void Literal::Piece::WriteToProto(LiteralProto* proto) const {
+  *proto->mutable_shape() = subshape();
+  switch (subshape().element_type()) {
     case PRED:
-      CopyToRepeatedField(proto.mutable_preds(), preds());
+      CopyToRepeatedField(proto->mutable_preds(), data<bool>());
       break;
     case U8:
-      *proto.mutable_u8s() = u8s_string();
-      break;
-    case S32:
-      CopyToRepeatedField(proto.mutable_s32s(), s32s());
-      break;
-    case S64:
-      CopyToRepeatedField(proto.mutable_s64s(), s64s());
+      proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
+                     element_count());
       break;
     case U32:
-      CopyToRepeatedField(proto.mutable_u32s(), u32s());
+      CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
       break;
     case U64:
-      CopyToRepeatedField(proto.mutable_u64s(), u64s());
+      CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
+      break;
+    case S32:
+      CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
+      break;
+    case S64:
+      CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
       break;
     case F16:
-      *proto.mutable_f16s() =
-          string(reinterpret_cast<const char*>(f16s_.data()),
-                 f16s_.size() * sizeof(half));
+      *proto->mutable_f16s() = string(
+          reinterpret_cast<const char*>(data<half>().data()), size_bytes());
       if (!kLittleEndian) {
-        ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
-                           proto.f16s().size());
+        ConvertEndianShort(const_cast<char*>(proto->mutable_f16s()->data()),
+                           proto->f16s().size());
       }
       break;
     case BF16:
-      *proto.mutable_bf16s() =
-          string(reinterpret_cast<const char*>(bf16s_.data()),
-                 bf16s_.size() * sizeof(bfloat16));
+      *proto->mutable_bf16s() = string(
+          reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
       if (!kLittleEndian) {
-        ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
-                           proto.bf16s().size());
+        ConvertEndianShort(const_cast<char*>(proto->mutable_bf16s()->data()),
+                           proto->bf16s().size());
       }
       break;
     case F32:
-      CopyToRepeatedField(proto.mutable_f32s(), f32s());
+      CopyToRepeatedField(proto->mutable_f32s(), data<float>());
       break;
     case F64:
-      CopyToRepeatedField(proto.mutable_f64s(), f64s());
+      CopyToRepeatedField(proto->mutable_f64s(), data<double>());
       break;
     case C64:
-      CopyToRepeatedField(proto.mutable_c64s(), c64s());
-      break;
-    case TUPLE:
-      for (const auto& tuple : tuple_literals()) {
-        *proto.add_tuple_literals() = tuple.ToProto();
+      for (complex64 value : data<complex64>()) {
+        proto->add_c64s(value.real());
+        proto->add_c64s(value.imag());
       }
       break;
+    case TUPLE:
+      // Nothing to do but assign the shape which is done above.
+      return;
     default:
-      LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
+      LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
   }
-
-  return proto;
 }
 
-template <typename RepeatedFieldT, typename NativeT>
-void CopyFromRepeatedField(std::vector<NativeT>* dest,
-                           const RepeatedFieldT& src) {
-  *dest = std::vector<NativeT>(src.begin(), src.end());
+const void* Literal::Piece::untyped_data() const {
+  CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+  return buffer();
 }
 
-template <>
-void CopyFromRepeatedField<tensorflow::protobuf::RepeatedField<float>,
-                           complex64>(
-    std::vector<complex64>* dest,
-    const tensorflow::protobuf::RepeatedField<float>& src) {
-  *dest = std::vector<complex64>(
-      reinterpret_cast<const complex64*>(src.data()),
-      reinterpret_cast<const complex64*>(src.data()) + src.size() / 2);
+void* Literal::Piece::untyped_data() {
+  CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+  return buffer();
 }
 
-void Literal::CopyFromProto(const LiteralProto& literal_proto) {
-  if (!literal_proto.has_shape()) {
-    return;
+namespace {
+
+template <typename RepeatedFieldT, typename NativeT>
+Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+                             const RepeatedFieldT& src) {
+  if (dest.size() != src.size()) {
+    return InvalidArgument(
+        "Expected %lu elements in LiteralProto repeated field, has %d",
+        dest.size(), src.size());
   }
+  std::copy(src.begin(), src.end(), dest.begin());
+  return Status::OK();
+}
 
-  *mutable_shape() = literal_proto.shape();
-  switch (shape().element_type()) {
+}  // namespace
+
+Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
+  // These conditions should have been checked in Literal::CreateFromProto.
+  TF_RET_CHECK(proto.has_shape());
+  TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
+  TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+
+  switch (subshape().element_type()) {
     case PRED:
-      CopyFromRepeatedField(mutable_preds(), literal_proto.preds());
-      break;
-    case U8:
-      set_u8s(literal_proto.u8s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
       break;
+    case U8: {
+      auto u8_data = data<uint8>();
+      TF_RET_CHECK(proto.u8s().size() == u8_data.size());
+      std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
+    } break;
     case S32:
-      CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
       break;
     case S64:
-      CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
       break;
     case U32:
-      CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
       break;
     case U64:
-      CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
       break;
     case F16: {
-      const string& s(literal_proto.f16s());
-      CHECK_EQ(0, s.size() % sizeof(half));
-      f16s_ = std::vector<half>(s.size() / sizeof(half));
-      memcpy(f16s_.data(), s.data(), s.size());
-
+      const string& s(proto.f16s());
+      TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
+      memcpy(untyped_data(), s.data(), s.size());
       if (!kLittleEndian) {
-        ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
+        ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
       }
-      break;
-    }
-    case BF16: {
-      const string& s(literal_proto.bf16s());
-      CHECK_EQ(0, s.size() % sizeof(bfloat16));
-      bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
-      memcpy(bf16s_.data(), s.data(), s.size());
+    } break;
 
+    case BF16: {
+      const string& s(proto.bf16s());
+      TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
+      memcpy(untyped_data(), s.data(), s.size());
       if (!kLittleEndian) {
-        ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
+        ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
       }
-      break;
-    }
+    } break;
     case F32:
-      CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
       break;
     case F64:
-      CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
-      break;
-    case C64:
-      CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s());
+      TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
       break;
-    case TUPLE:
-      for (const auto& proto : literal_proto.tuple_literals()) {
-        mutable_tuple_literals()->push_back(Literal(proto));
+    case C64: {
+      auto complex_data = data<complex64>();
+      TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
+      for (int64 i = 0; i < complex_data.size(); ++i) {
+        complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
       }
+    } break;
+    case TUPLE:
+      LOG(FATAL) << "Should not be called on tuple shapes: "
+                 << ShapeUtil::HumanString(subshape());
       break;
     default:
-      LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
+      LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+  }
+  return Status::OK();
+}
+
+LiteralProto Literal::ToProto() const {
+  LiteralProto proto;
+  for (const auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    const Piece& piece = pair.second;
+
+    LiteralProto* proto_piece = &proto;
+    for (int64 i : index) {
+      while (proto_piece->tuple_literals_size() <= i) {
+        proto_piece->add_tuple_literals();
+      }
+      proto_piece = proto_piece->mutable_tuple_literals(i);
+    }
+    piece.WriteToProto(proto_piece);
+  }
+  return proto;
+}
+
+/* static */
+StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
+    const LiteralProto& proto) {
+  if (!proto.has_shape()) {
+    return InvalidArgument("LiteralProto has no shape");
   }
+  if (!LayoutUtil::HasLayout(proto.shape())) {
+    return InvalidArgument("LiteralProto has no layout");
+  }
+
+  auto literal = MakeUnique<Literal>(proto.shape());
+
+  for (auto& pair : literal->pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+    const LiteralProto* proto_element = &proto;
+    for (int64 i : index) {
+      TF_RET_CHECK(i < proto_element->tuple_literals_size());
+      proto_element = &proto_element->tuple_literals(i);
+    }
+
+    if (ShapeUtil::IsTuple(piece.subshape())) {
+      if (proto_element->tuple_literals_size() !=
+          ShapeUtil::TupleElementCount(piece.subshape())) {
+        return InvalidArgument(
+            "Expected %lld tuple elements in LiteralProto, has %d",
+            ShapeUtil::TupleElementCount(piece.subshape()),
+            proto_element->tuple_literals_size());
+      }
+      continue;
+    }
+
+    TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
+    TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
+  }
+  return std::move(literal);
+}
+
+const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
+  return piece(shape_index).untyped_data();
+}
+
+void* Literal::untyped_data(const ShapeIndex& shape_index) {
+  return piece(shape_index).untyped_data();
 }
 
-const Literal& Literal::GetSubliteral(const ShapeIndex& index) const {
-  return const_cast<Literal*>(this)->GetSubliteral(index);
+int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
+  return piece(shape_index).size_bytes();
+}
+
+string Literal::GetR1U8AsString() const {
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+  CHECK_EQ(shape().element_type(), U8);
+  return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
+                ShapeUtil::ElementsIn(shape()));
+}
+
+/* static */ const LiteralView LiteralView::Create(
+    const Literal& literal, const ShapeIndex& view_root) {
+  return LiteralView(literal, view_root);
+}
+
+LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
+  shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
+  pieces_ = ShapeTree<Piece>(shape_);
+  owns_buffers_ = false;
+  for (auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+
+    ShapeIndex src_index = view_root;
+    for (int64 i : index) {
+      src_index.push_back(i);
+    }
+    const Piece& src_piece = literal.piece(src_index);
+    piece.set_buffer(src_piece.buffer());
+    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+  }
+}
+
+LiteralView::~LiteralView() {}
+
+LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
+
+LiteralView& LiteralView::operator=(const LiteralView& other) {
+  CopyFrom(other);
+  return *this;
 }
 
-Literal& Literal::GetSubliteral(const ShapeIndex& index) {
-  Literal* subliteral = this;
-  for (int64 i : index) {
-    subliteral = &subliteral->tuple_literals_.at(i);
+void LiteralView::CopyFrom(const LiteralView& other) {
+  // We can't use the default copy-constructor/copy-assignment because
+  // Piece::subshape_ points to subshapes within the Shape of the owning
+  // Literal/LiteralView.
+  shape_ = other.shape();
+  pieces_ = other.pieces_;
+  for (auto& pair : pieces_) {
+    const ShapeIndex& index = pair.first;
+    Piece& piece = pair.second;
+    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
   }
-  return *subliteral;
+  owns_buffers_ = false;
 }
 
 }  // namespace xla
index 6254fafaf3ee2867b50ff85c64f456e9f0ed7654..dc29c6359c6c691e25bb73d5814f787d0941b698 100644 (file)
@@ -34,6 +34,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -50,153 +51,64 @@ limitations under the License.
 
 namespace xla {
 
-// Utility class for dealing with XLA literal values.  Most methods are
-// templated by native (host) type which corresponds to a unique XLA
-// PrimitiveType. See ComputationBuilder for details.  Not all primitive types
-// defined in xla_data.proto have a corresponding native type or even have a
-// storage location in the Literal proto yet (for example, primitive type F16).
+// Class representing literal values in XLA.
+//
+// TODO(b/67651157): The methods in this class should be reduced to a minimal
+// set of methods which construct Literals and accessors methods. Other methods
+// which perform computation on Literals (Reshape, Slice, etc) should be moved
+// elsewhere, and perhaps combined with evaluator code which operates on
+// Literals.
 class Literal {
  public:
-  Literal() {}
+  Literal() : Literal(ShapeUtil::MakeNil()) {}
 
-  Literal(const Literal& other) = default;
-  Literal(Literal&&) = default;
+  // Create a literal of the given shape. The literal is allocated sufficient
+  // memory to hold the shape. Memory is uninitialized.
+  explicit Literal(const Shape& shape);
+  virtual ~Literal();
 
-  explicit Literal(const LiteralProto& other) { CopyFromProto(other); }
-
-  Literal& operator=(const Literal& other) = default;
-  Literal& operator=(Literal&&) = default;
+  // Literals are moveable, but not copyable. To copy a literal use
+  // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
+  // of literals which can be expensive.
+  Literal(const Literal& other) = delete;
+  Literal& operator=(const Literal& other) = delete;
+  Literal(Literal&& other);
+  Literal& operator=(Literal&& other);
 
   // Literals are equal if they have compatible shapes and the same data
-  // values. Layout is not checked.
+  // values. Layout is not compared.
   bool operator==(const Literal& other) const;
   bool operator!=(const Literal& other) const { return !(*this == other); }
 
+  // Serialize to and from a proto.
+  static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+      const LiteralProto& proto);
   LiteralProto ToProto() const;
 
-  bool has_shape() const {
-    return shape_.element_type() != PRIMITIVE_TYPE_INVALID;
-  }
-
-  // Basic accessor functions.  Names mirror the original protobuf
-  // functions for convenience.
-  string DebugString() const { return ToProto().DebugString(); }
-  string ShortDebugString() const { return ToProto().ShortDebugString(); }
-
-  // Return the nested literal at the given shape index.
-  const Literal& GetSubliteral(const ShapeIndex& index) const;
-  Literal& GetSubliteral(const ShapeIndex& index);
-
-  void Clear() {
-    shape_.Clear();
-    u8s_.clear();
-    s16s_.clear();
-    s32s_.clear();
-    s64s_.clear();
-    u16s_.clear();
-    u32s_.clear();
-    u64s_.clear();
-    f16s_.clear();
-    f32s_.clear();
-    f64s_.clear();
-    c64s_.clear();
-    tuple_literals_.clear();
-  }
-
-  int preds_size() const { return u8s().size(); }
-  const std::vector<uint8>& preds() const {
-    static_assert(sizeof(uint8) == sizeof(bool),
-                  "The uint8 and bool types should be the same size");
-    return u8s_;
-  }
-  std::vector<uint8>* mutable_preds() {
-    static_assert(sizeof(uint8) == sizeof(bool),
-                  "The uint8 and bool types should be the same size");
-    return &u8s_;
-  }
-
-  int s16s_size() const { return s16s().size(); }
-  int32 s16s(int i) const { return s16s_[i]; }
-  const std::vector<int16>& s16s() const { return s16s_; }
-  std::vector<int16>* mutable_s16s() { return &s16s_; }
-
-  int s32s_size() const { return s32s().size(); }
-  int32 s32s(int i) const { return s32s_[i]; }
-  const std::vector<int32>& s32s() const { return s32s_; }
-  std::vector<int32>* mutable_s32s() { return &s32s_; }
-
-  int s64s_size() const { return s64s().size(); }
-  void add_s64s(int64 value) { s64s_.push_back(value); }
-  const std::vector<int64>& s64s() const { return s64s_; }
-  std::vector<int64>* mutable_s64s() { return &s64s_; }
-
-  int u16s_size() const { return u16s().size(); }
-  uint32 u16s(int i) const { return u16s_[i]; }
-  const std::vector<uint16>& u16s() const { return u16s_; }
-  std::vector<uint16>* mutable_u16s() { return &u16s_; }
-
-  int u32s_size() const { return u32s().size(); }
-  uint32 u32s(int i) const { return u32s_[i]; }
-  const std::vector<uint32>& u32s() const { return u32s_; }
-  std::vector<uint32>* mutable_u32s() { return &u32s_; }
-
-  int u64s_size() const { return u64s().size(); }
-  const std::vector<uint64>& u64s() const { return u64s_; }
-  std::vector<uint64>* mutable_u64s() { return &u64s_; }
-
-  int f16s_size() const { return f16s().size(); }
-  half f16s(int i) const { return f16s_[i]; }
-  const std::vector<half>& f16s() const { return f16s_; }
-  std::vector<half>* mutable_f16s() { return &f16s_; }
-
-  int f32s_size() const { return f32s().size(); }
-  float f32s(int i) const { return f32s_[i]; }
-  void add_f32s(float value) { f32s_.push_back(value); }
-  const std::vector<float>& f32s() const { return f32s_; }
-  std::vector<float>& f32s() { return f32s_; }
-  std::vector<float>* mutable_f32s() { return &f32s_; }
-
-  int f64s_size() const { return f64s().size(); }
-  const std::vector<double>& f64s() const { return f64s_; }
-  std::vector<double>* mutable_f64s() { return &f64s_; }
-
-  int c64s_size() const { return c64s().size(); }
-  const std::vector<complex64>& c64s() const { return c64s_; }
-  std::vector<complex64>* mutable_c64s() { return &c64s_; }
-
-  int bf16s_size() const { return bf16s().size(); }
-  bfloat16 bf16s(int i) const { return bf16s_[i]; }
-  const std::vector<bfloat16>& bf16s() const { return bf16s_; }
-  std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
-
-  int tuple_literals_size() const { return tuple_literals().size(); }
-  const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
-  Literal* add_tuple_literals() {
-    tuple_literals_.push_back(Literal());
-    return &tuple_literals_.back();
-  }
-  std::vector<Literal>* mutable_tuple_literals() { return &tuple_literals_; }
-  const std::vector<Literal>& tuple_literals() const { return tuple_literals_; }
-
-  int u8s_size() const { return u8s().size(); }
-  const std::vector<uint8>& u8s() const { return u8s_; }
-  void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
-  void set_u8s(tensorflow::StringPiece value) {
-    u8s_ = std::vector<uint8>(value.size());
-    u8s_.clear();
-    append_u8s(value);
-  }
-
-  void append_u8s(tensorflow::StringPiece value) {
-    u8s_.insert(u8s_.end(), value.begin(), value.end());
-  }
+  // Return the shape of the literal.
+  const Shape& shape() const { return shape_; }
 
-  string u8s_string() const { return string(u8s().begin(), u8s().end()); }
+  // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+  // mutate the shape as this can produce malformed Literals.
+  Shape* mutable_shape_do_not_use() { return &shape_; }
 
-  std::vector<uint8>* mutable_u8s() { return &u8s_; }
+  // Returns a (Mutable)ArraySlice view of the array for this literal for the
+  // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
+  // given ShapeIndex is not array. See primitive_util.h for the mapping from
+  // XLA type to native type.
+  template <typename NativeT>
+  tensorflow::gtl::ArraySlice<NativeT> data(
+      const ShapeIndex& shape_index = {}) const;
+  template <typename NativeT>
+  tensorflow::gtl::MutableArraySlice<NativeT> data(
+      const ShapeIndex& shape_index = {});
 
-  const Shape& shape() const { return shape_; }
-  Shape* mutable_shape() { return &shape_; }
+  // Returns a pointer to (or size of) the underlying buffer holding the array
+  // at the given shape index. CHECKs if the subshape of the literal at the
+  // given ShapeIndex is not array.
+  const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+  void* untyped_data(const ShapeIndex& shape_index = {});
+  int64 size_bytes(const ShapeIndex& shape_index = {}) const;
 
   // Creates a new literal of a given rank. To minimize ambiguity (for users
   // and the compiler) these CreateR[0-2] methods should explicitly specify the
@@ -244,6 +156,10 @@ class Literal {
           values,
       const Layout& layout);
 
+  // Returns this literal's data as a string. This literal must be a rank-1 U8
+  // array.
+  string GetR1U8AsString() const;
+
   // Creates a new Literal object with the shape specified as parameter.
   // The content of the literal values is the default value of the primitive
   // type of literal itself (0 for numeric types, and false for predicates).
@@ -257,6 +173,23 @@ class Literal {
       PrimitiveType primitive_type,
       tensorflow::gtl::ArraySlice<int64> dimensions);
 
+  // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+  // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+  // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+  // rooted at 'src_shape_index', but need not be arrays.
+  Status CopyFrom(const Literal& src_literal,
+                  const ShapeIndex& dest_shape_index = {},
+                  const ShapeIndex& src_shape_index = {});
+
+  // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+  // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+  // (layouts and shapes must match), but need not be arrays. The memory
+  // allocated in this literal for the subshape at dest_shape_index is
+  // deallocated, and the respective buffers are replaced with those in
+  // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+  Status MoveFrom(Literal&& src_literal,
+                  const ShapeIndex& dest_shape_index = {});
+
   // Copies the values from src_literal, starting at src_base shape indexes,
   // to this literal, starting at dest_base, where the copy size in each
   // dimension is specified by copy_size.
@@ -266,10 +199,24 @@ class Literal {
   // Note: if either src_literal or this literal contains dimensions with zero
   // element, then copy_size must be 0 in these dimensions while the
   // corresponding base indices being 0.
-  Status Copy(const Literal& src_literal,
-              tensorflow::gtl::ArraySlice<int64> src_base,
-              tensorflow::gtl::ArraySlice<int64> dest_base,
-              tensorflow::gtl::ArraySlice<int64> copy_size);
+  // This literal and 'src_literal' must be arrays.
+  Status CopySliceFrom(const Literal& src_literal,
+                       tensorflow::gtl::ArraySlice<int64> src_base,
+                       tensorflow::gtl::ArraySlice<int64> dest_base,
+                       tensorflow::gtl::ArraySlice<int64> copy_size);
+
+  // Returns a vector containing the tuple elements of this Literal as separate
+  // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+  // elements are moved into the new Literals; no data is copied. Upon return
+  // this Literal is set to a nil shape (empty tuple)
+  std::vector<Literal> DecomposeTuple();
+
+  // This operation is the inverse of DecomposeTuple. The given elements are
+  // moved into the tuple elements of a new tuple-shaped Literal which is
+  // returned. Upon return, each of the Literals in 'elements' is set to a nil
+  // shape (empty tuple).
+  static Literal MoveIntoTuple(
+      tensorflow::gtl::MutableArraySlice<Literal> elements);
 
   // Creates a new value that has the equivalent value as this literal, but
   // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
@@ -293,6 +240,7 @@ class Literal {
   // Creates a new literal by reshaping this literal to have the given
   // dimensions. The total number of elements must not change; The
   // implementation currently only supports monotonic dim0-major layouts.
+  // This literal must be an array.
   StatusOr<std::unique_ptr<Literal>> Reshape(
       tensorflow::gtl::ArraySlice<int64> dimensions) const;
 
@@ -302,6 +250,7 @@ class Literal {
   // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+  // This literal must be an array.
   std::unique_ptr<Literal> Transpose(
       tensorflow::gtl::ArraySlice<int64> permutation) const;
 
@@ -310,6 +259,7 @@ class Literal {
   // same rank and layout as for the given literal. The number of indices in
   // start_indices and limit_indices must be the rank of the literal, and the
   // indices follow the order of the dimensions.
+  // This literal must be an array.
   std::unique_ptr<Literal> Slice(
       tensorflow::gtl::ArraySlice<int64> start_indices,
       tensorflow::gtl::ArraySlice<int64> limit_indices) const;
@@ -317,25 +267,26 @@ class Literal {
   // Creates a literal with a prepended dimension with bound "times"; e.g. a
   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
   // literal replicated four times.
+  // This literal must be an array.
   template <typename NativeT>
   std::unique_ptr<Literal> Replicate(int64 times) const;
 
   // Converts this literal to another primitive type. Returns an error if the
-  // conversion is not possible.
+  // conversion is not possible. This literal must be array-shaped.
   StatusOr<std::unique_ptr<Literal>> Convert(
       PrimitiveType primitive_dest_type) const;
 
-  // Creates a literal value zero of the given primitive type.
+  // Creates a scalar literal value zero of the given primitive type.
   static Literal Zero(PrimitiveType primitive_type);
 
-  // Creates a literal value one of the given primitive type.
+  // Creates a scalar literal value one of the given primitive type.
   static Literal One(PrimitiveType primitive_type);
 
-  // Creates a literal value containing the minimum value of the given
+  // Creates a scalar literal value containing the minimum value of the given
   // primitive type. For floating-point types, returns -inf.
   static Literal MinValue(PrimitiveType primitive_type);
 
-  // Creates a literal value containing the maximum value of the given
+  // Creates a scalar literal value containing the maximum value of the given
   // primitive type. For floating-point types, returns inf.
   static Literal MaxValue(PrimitiveType primitive_type);
 
@@ -344,7 +295,7 @@ class Literal {
   static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
       tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
 
-  // Creates a new literal from an array. The variants not ending with
+  // Creates a new literal from an Array type. The variants not ending with
   // WithLayout use the default XLA layout for the literal's linear
   // representation in memory.
   template <typename NativeT>
@@ -393,35 +344,25 @@ class Literal {
       std::initializer_list<std::initializer_list<NativeT>> values,
       int64 projection_p, int64 projection_z);
 
-  // Clones this literal into an owned unique_ptr version.
+  // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
+  Literal Clone() const;
   std::unique_ptr<Literal> CloneToUnique() const;
 
-  // Returns the linear index of the given index within this literal's
-  // element_type repeated field.
-  int64 LinearIndex(tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
-  // Gets or sets an element in the literal at the given index. The index is
-  // CHECKed against the dimension sizes.
+  // Gets or sets an element in the literal at the given index. The multi_index
+  // is CHECKed against the dimension sizes.
   template <typename NativeT>
-  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+              const ShapeIndex& shape_index) const;
   template <typename NativeT>
-  void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+  void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+           const ShapeIndex& shape_index, NativeT value);
 
-  // Returns a (Mutable)ArraySlice view of the array for this literal for the
-  // given NativeT (e.g., float). These functions map native type to XLA
-  // PrimitiveType via template specialization. The unspecialized forms below
-  // aborts to handle the error case where the given native type does not map to
-  // an XLA primitive type.
+  // Overloads of Get and Set for array literals. CHECKs if the literal is not
+  // array-shaped.
   template <typename NativeT>
-  tensorflow::gtl::ArraySlice<NativeT> GetArraySlice() const {
-    static_assert(!std::is_same<NativeT, NativeT>::value,
-                  "Cannot map native type to primitive type.");
-  }
+  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
   template <typename NativeT>
-  tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice() {
-    static_assert(!std::is_same<NativeT, NativeT>::value,
-                  "Cannot map native type to primitive type.");
-  }
+  void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
 
   // Returns the element value at index (0, ..., 0), however many zeroes are
   // required for that index.
@@ -430,10 +371,11 @@ class Literal {
 
   // As Get(), but determines the correct type and converts the value
   // into text.
-  string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+  string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+                     const ShapeIndex& shape_index = {}) const;
 
   // As Get(), but determines the correct type and converts the value into
-  // int64.
+  // int64.  This literal must be an array.
   StatusOr<int64> GetIntegralAsS64(
       tensorflow::gtl::ArraySlice<int64> multi_index) const;
 
@@ -441,7 +383,8 @@ class Literal {
   template <typename NativeT>
   static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
 
-  // Returns a tuple literal composed of given literals.
+  // Returns a tuple literal composed of given literals. Data is copied from the
+  // given elements into the returned literal.
   static std::unique_ptr<Literal> MakeTuple(
       tensorflow::gtl::ArraySlice<const Literal*> elements);
 
@@ -455,10 +398,6 @@ class Literal {
   static std::unique_ptr<Literal> MakeTupleOwned(
       std::vector<std::unique_ptr<Literal>> elements);
 
-  // Validates that the data payload of the literal matches the literal shape;
-  // if it does not, an appropriate status is returned.
-  tensorflow::Status ValidateLiteral() const;
-
   // Returns a string representation of the literal value.
   string ToString(bool print_layout = false) const;
 
@@ -477,48 +416,31 @@ class Literal {
                                    NativeT value)>
                     per_cell) const;
 
-  // Templated methods which populate the given repeated field in this literal
-  // with the given value(s). The Shape field of this literal is set
-  // to match the array dimensions and type. Examples:
+  // Populate this literal with the given values. Examples:
   //
   //   // Populate with floats.
   //   Array2D<float> float_values = ...
   //   literal.PopulateR2FromArray2D(values);
   //
   //   // Populate with int32s.
-  //   literal.PopulateR2({{1, 2}, {3, 4}});
+  //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
   //
-  template <typename NativeT>
-  void PopulateR0(NativeT values);
+  // The shape and element type of this literal must match given values. For
+  // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+  // array of S32.
   template <typename NativeT>
   void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
   void PopulateR1(const tensorflow::core::Bitmap& values);
   template <typename NativeT>
   void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
   template <typename NativeT>
-  void PopulateR2WithLayout(
-      std::initializer_list<std::initializer_list<NativeT>> values,
-      const Layout& layout);
-  template <typename NativeT>
   void PopulateFromArray(const Array<NativeT>& values);
   template <typename NativeT>
-  void PopulateFromArrayWithLayout(const Array<NativeT>& values,
-                                   const Layout& layout);
-  template <typename NativeT>
   void PopulateR2FromArray2D(const Array2D<NativeT>& values);
   template <typename NativeT>
-  void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
-                                       const Layout& layout);
-  template <typename NativeT>
   void PopulateR3FromArray3D(const Array3D<NativeT>& values);
   template <typename NativeT>
-  void PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
-                                       const Layout& layout);
-  template <typename NativeT>
   void PopulateR4FromArray4D(const Array4D<NativeT>& values);
-  template <typename NativeT>
-  void PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
-                                       const Layout& layout);
 
   // Populates literal values by calling the generator function for every cell
   // in this literal object.
@@ -528,29 +450,9 @@ class Literal {
   template <typename NativeT, typename FnType>
   Status Populate(const FnType& generator);
 
-  // Creates a Literal of the given dimensions with all elements set to the
-  // given value.
+  // Fills this literal with the given value.
   template <typename NativeT>
-  void PopulateWithValue(NativeT value,
-                         tensorflow::gtl::ArraySlice<int64> dimensions);
-
-  // Returns a pointer to the underlying vector corresponding to the Literal's
-  // shape.
-  const void* InternalData() const;
-  void* MutableInternalData();
-
-  // Allocates space in the underlying vector of this literal sufficient to hold
-  // num_elements of this literal's primitive type. Values in the vector are set
-  // to zero. num_elements must equal the number of elements in the literal's
-  // shape.
-  void Reserve(int64 num_elements);
-
-  // Allocates space in the underlying vector of this literal sufficient to hold
-  // num_elements of this literal's primitive type and sets each element in this
-  // literal to the given value. num_elements must equal the number of elements
-  // in this literal's shape.
-  template <typename NativeT>
-  void Resize(int64 num_elements, NativeT value);
+  void PopulateWithValue(NativeT value);
 
   // Returns whether every element in this literal is equal to value.
   //
@@ -560,7 +462,7 @@ class Literal {
   //
   // If value doesn't fit in this literal's type, returns false.  Values of 1/0
   // are considered equal to true/false; other values are not considered equal
-  // to true.
+  // to true. Also if this literal is not array-shaped false is returned.
   bool IsAll(int8 value) const;
 
   // Like IsAll(const Literal&, int8), except we check whether the literal is
@@ -571,7 +473,7 @@ class Literal {
   // This casts value to the type of literal, then compares using ==.  The usual
   // admonishments about floating-point equality checks apply.  We expect you to
   // use this to check for values that can be expressed precisely as a float,
-  // e.g. -0.5.
+  // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
   bool IsAllFloat(float value) const;
 
   // Like IsAll(const Literal&, int8), except we check whether the literal is
@@ -589,17 +491,25 @@ class Literal {
   // must be an array.
   bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
 
- private:
-  // Copy from a LiteralProto instance.
-  void CopyFromProto(const LiteralProto& literal_proto);
+  // Return the count of the elements in the array at the given shape index in
+  // this literal.
+  int64 element_count(const ShapeIndex& index = {}) const {
+    return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+  }
+
+ protected:
+  // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+  // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+  // to nullptr.
+  Literal(const Shape& shape, bool allocate_arrays);
 
-  // Internal template helper for the Copy() API, matching its arguments one by
-  // one.
-  template <typename T>
-  Status CopyRange(const Literal& src_literal,
-                   tensorflow::gtl::ArraySlice<int64> src_base,
-                   tensorflow::gtl::ArraySlice<int64> dest_base,
-                   tensorflow::gtl::ArraySlice<int64> copy_size);
+  // Internal template helper for the Literal::CopySliceFrom(), matching its
+  // arguments one by one.
+  template <typename NativeT>
+  Status CopySliceFromInternal(const Literal& src_literal,
+                               tensorflow::gtl::ArraySlice<int64> src_base,
+                               tensorflow::gtl::ArraySlice<int64> dest_base,
+                               tensorflow::gtl::ArraySlice<int64> copy_size);
 
   // Utility structure which is used to create the optimal configuration for
   // a ShapeUtil::ForEachIndex() scan across two literals.
@@ -624,163 +534,222 @@ class Literal {
     int64 minor_loop_size = 1;
   };
 
-  Shape shape_;
-  std::vector<uint8> u8s_;
-  std::vector<int16> s16s_;
-  std::vector<int32> s32s_;
-  std::vector<int64> s64s_;
-  std::vector<uint16> u16s_;
-  std::vector<uint32> u32s_;
-  std::vector<uint64> u64s_;
-  std::vector<bfloat16> bf16s_;
-  std::vector<half> f16s_;
-  std::vector<float> f32s_;
-  std::vector<double> f64s_;
-  std::vector<complex64> c64s_;
-  std::vector<Literal> tuple_literals_;
-};
-
-std::ostream& operator<<(std::ostream& out, const Literal& literal);
-
-// Declarations of template specializations for GetArraySlice and
-// GetMutableArraySlice. The specializations map native type to XLA primitive
-// type.
-template <>
-tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint8> Literal::GetArraySlice<uint8>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int8> Literal::GetArraySlice<int8>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint16> Literal::GetArraySlice<uint16>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int16> Literal::GetArraySlice<int16>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint32> Literal::GetArraySlice<uint32>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint64> Literal::GetArraySlice<uint64>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int32> Literal::GetArraySlice<int32>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int64> Literal::GetArraySlice<int64>() const;
-
-template <>
-inline tensorflow::gtl::ArraySlice<float> Literal::GetArraySlice<float>()
-    const {
-  DCHECK(shape().element_type() == F32);
-  return f32s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
-    const;
-
-template <>
-tensorflow::gtl::MutableArraySlice<bool> Literal::GetMutableArraySlice();
-
-template <>
-tensorflow::gtl::MutableArraySlice<int8> Literal::GetMutableArraySlice();
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint8> Literal::GetMutableArraySlice();
-
-template <>
-tensorflow::gtl::MutableArraySlice<int16> Literal::GetMutableArraySlice();
+  // A data structure representing a subshape at a particular ShapeIndex within
+  // the literal. For array-shaped ShapeIndexes, this data structure holds the
+  // pointer to the memory allocated for the array data.
+  class Piece {
+   public:
+    // Return the buffer holding the array data for this piece as an array
+    // slice. This piece must be array-shaped.
+    template <typename NativeT>
+    tensorflow::gtl::ArraySlice<NativeT> data() const;
+    template <typename NativeT>
+    tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+    // Return the buffer holding the array data for this piece as a void*. This
+    // piece must be array-shaped.
+    void* untyped_data();
+    const void* untyped_data() const;
+
+    // Gets or sets an element in the array at the given index. The multi_index
+    // is CHECKed against the dimension sizes of the array.  This piece must be
+    // array-shaped.
+    template <typename NativeT>
+    NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+    template <typename NativeT>
+    void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+    // Gets/sets the buffer holding the array data.
+    char* buffer() const { return buffer_; }
+    void set_buffer(char* buffer) { buffer_ = buffer; }
+
+    // Gets or sets the subshape of this piece. This reference points to a
+    // subshape within the shape in the containing Literal (Literal::shape_).
+    const Shape& subshape() const { return *subshape_; }
+    void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+    // Returns the size in bytes of the buffer holding the array data.
+    int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+    // Returns the number of elements in this piece's array.
+    int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
+
+    // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+    // and src must be compatible.
+    Status CopyFrom(const Piece& src);
+
+    // Returns true if this piece and 'other' contain the same data. This piece
+    // and 'other' must be array-shaped and compatible.
+    bool EqualElements(const Piece& other) const;
+
+    // Writes the shape and data (if array-shaped) into the given proto.
+    void WriteToProto(LiteralProto* proto) const;
+
+    // Copies the data from the given proto into this piece. The shape of this
+    // piece must be equal (not just compatible) to the shape of the proto.
+    Status CopyFromProto(const LiteralProto& proto);
+
+   private:
+    // Recursive helper for EqualElements.
+    template <typename NativeT>
+    bool EqualElementsInternal(const Piece& other,
+                               std::vector<int64>* multi_index) const;
+
+    // For array-shaped pieces, this is the buffer holding the literal data.
+    char* buffer_ = nullptr;
+
+    // The shape of piece. This points into the shape of the containing Literal
+    // (Literal::shape_).
+    const Shape* subshape_ = nullptr;
+  };
 
-template <>
-tensorflow::gtl::MutableArraySlice<uint16> Literal::GetMutableArraySlice();
+  // Returns the piece at the given ShapeIndex.
+  Piece& piece(const ShapeIndex& shape_index) {
+    return *pieces_.mutable_element(shape_index);
+  }
+  const Piece& piece(const ShapeIndex& shape_index) const {
+    return pieces_.element(shape_index);
+  }
 
-template <>
-tensorflow::gtl::MutableArraySlice<int32> Literal::GetMutableArraySlice();
+  // Returns the piece at the root of the shape (empty ShapeIndex).
+  Piece& root_piece() { return piece({}); }
+  const Piece& root_piece() const { return piece({}); }
 
-template <>
-tensorflow::gtl::MutableArraySlice<uint32> Literal::GetMutableArraySlice();
+  // Deallocate the buffers held by this literal (if the literal owns the
+  // buffer).
+  void DeallocateBuffers();
 
-template <>
-tensorflow::gtl::MutableArraySlice<int64> Literal::GetMutableArraySlice();
+  Shape shape_;
+  ShapeTree<Piece> pieces_;
 
-template <>
-tensorflow::gtl::MutableArraySlice<uint64> Literal::GetMutableArraySlice();
+  // Whether the buffers held in pieces_ are owned by this Literal.
+  bool owns_buffers_;
 
-template <>
-tensorflow::gtl::MutableArraySlice<float> Literal::GetMutableArraySlice();
+  // LiteralView must access and manipulate Pieces of other Literals.
+  friend class LiteralView;
+};  // namespace xla
 
-template <>
-tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
 
-template <>
-tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
+// A read-only view of a Literal. A LiteralView contains pointers to buffers
+// owned by the viewed Literal.
+//
+// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
+// and mutable) similar to (Mutable)ArraySlice.
+class LiteralView : public Literal {
+ public:
+  // Create and return a view of the given literal rooted at the given shape
+  // index within the given literal. A factory is used rather than a public
+  // constructor because only const LiteralViews are supported. It's still
+  // possible to create non-const LiteralViews via the copy constructors, but
+  // the factory method makes it a bit less likely. Implementing literal slices
+  // will fix this undesirable situation (b/71550060).
+  static const LiteralView Create(const Literal& literal,
+                                  const ShapeIndex& view_root = {});
 
-template <>
-tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
+  LiteralView(const LiteralView& other);
+  LiteralView& operator=(const LiteralView& other);
 
-template <>
-tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
+  virtual ~LiteralView();
 
-template <>
-void Literal::Resize<bool>(int64 num_elements, bool value);
+ private:
+  LiteralView(const Literal& literal, const ShapeIndex& view_root);
 
-template <>
-void Literal::Resize<int8>(int64 num_elements, int8 value);
+  // Helper for the copy constructor and copy assignment operator.
+  void CopyFrom(const LiteralView& other);
+};
 
-template <>
-void Literal::Resize<uint8>(int64 num_elements, uint8 value);
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
+  CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+  CHECK_EQ(subshape().element_type(),
+           primitive_util::NativeToPrimitiveType<NativeT>())
+      << "Attempting to access "
+      << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+      << " type, but literal element type is "
+      << PrimitiveType_Name(subshape().element_type());
+  return tensorflow::gtl::ArraySlice<NativeT>(
+      reinterpret_cast<const NativeT*>(buffer()),
+      ShapeUtil::ElementsIn(subshape()));
+}
 
-template <>
-void Literal::Resize<int32>(int64 num_elements, int32 value);
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
+  CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+  CHECK_EQ(subshape().element_type(),
+           primitive_util::NativeToPrimitiveType<NativeT>())
+      << "Attempting to access "
+      << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+      << " type, but literal element type is "
+      << PrimitiveType_Name(subshape().element_type());
+  return tensorflow::gtl::MutableArraySlice<NativeT>(
+      reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape()));
+}
 
-template <>
-void Literal::Resize<uint32>(int64 num_elements, uint32 value);
+template <typename NativeT>
+NativeT Literal::Piece::Get(
+    tensorflow::gtl::ArraySlice<int64> multi_index) const {
+  return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+      subshape(), multi_index)];
+}
 
-template <>
-void Literal::Resize<int64>(int64 num_elements, int64 value);
+template <typename NativeT>
+void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+                         NativeT value) {
+  data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+      subshape(), multi_index)] = value;
+}
 
-template <>
-void Literal::Resize<uint64>(int64 num_elements, uint64 value);
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> Literal::data(
+    const ShapeIndex& shape_index) const {
+  return piece(shape_index).data<NativeT>();
+}
 
-template <>
-void Literal::Resize<float>(int64 num_elements, float value);
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
+    const ShapeIndex& shape_index) {
+  return piece(shape_index).data<NativeT>();
+}
 
-template <>
-void Literal::Resize<double>(int64 num_elements, double value);
+template <typename NativeT>
+inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+                            const ShapeIndex& shape_index) const {
+  return piece(shape_index).Get<NativeT>(multi_index);
+}
 
-template <>
-void Literal::Resize<half>(int64 num_elements, half value);
+template <typename NativeT>
+inline NativeT Literal::Get(
+    tensorflow::gtl::ArraySlice<int64> multi_index) const {
+  return root_piece().Get<NativeT>(multi_index);
+}
 
-template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+                         const ShapeIndex& shape_index, NativeT value) {
+  return piece(shape_index).Set<NativeT>(multi_index, value);
+}
 
-template <>
-void Literal::Resize<complex64>(int64 num_elements, complex64 value);
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+                         NativeT value) {
+  return root_piece().Set<NativeT>(multi_index, value);
+}
 
 template <typename NativeT>
 /* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
-  auto literal = MakeUnique<Literal>();
-  literal->PopulateR0<NativeT>(value);
+  auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
+      primitive_util::NativeToPrimitiveType<NativeT>(), {}));
+  literal->Set({}, value);
   return literal;
 }
 
 template <typename NativeT>
 /* static */ std::unique_ptr<Literal> Literal::CreateR1(
     tensorflow::gtl::ArraySlice<NativeT> values) {
-  auto literal = MakeUnique<Literal>();
+  auto literal = MakeUnique<Literal>(
+      ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
+                           {static_cast<int64>(values.size())}));
   literal->PopulateR1(values);
   return literal;
 }
@@ -789,8 +758,12 @@ template <typename NativeT>
 /* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
     std::initializer_list<std::initializer_list<NativeT>> values,
     const Layout& layout) {
-  auto literal = MakeUnique<Literal>();
-  literal->PopulateR2WithLayout(values, layout);
+  auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+      primitive_util::NativeToPrimitiveType<NativeT>(),
+      {static_cast<int64>(values.size()),
+       static_cast<int64>(values.begin()->size())},
+      AsInt64Slice(layout.minor_to_major())));
+  literal->PopulateR2(values);
   return literal;
 }
 
@@ -874,8 +847,10 @@ template <typename NativeT>
 template <typename NativeT>
 /* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
     const Array<NativeT>& values, const Layout& layout) {
-  auto literal = MakeUnique<Literal>();
-  literal->PopulateFromArrayWithLayout(values, layout);
+  auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+      primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
+      AsInt64Slice(layout.minor_to_major())));
+  literal->PopulateFromArray(values);
   return literal;
 }
 
@@ -975,81 +950,9 @@ template <typename NativeT>
   return CreateFromArrayWithLayout(values, layout);
 }
 
-template <typename NativeT>
-NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  int64 linear_index = LinearIndex(multi_index);
-  return GetArraySlice<NativeT>().at(linear_index);
-}
-
 template <typename NativeT>
 NativeT Literal::GetFirstElement() const {
-  return GetArraySlice<NativeT>().at(0);
-}
-
-template <>
-inline uint8 Literal::Get<uint8>(
-    tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  CHECK(shape().element_type() == U8);
-  int64 linear_index = LinearIndex(multi_index);
-  return u8s()[linear_index];
-}
-
-template <>
-inline int8 Literal::Get<int8>(
-    tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  CHECK(shape().element_type() == S8);
-  int64 linear_index = LinearIndex(multi_index);
-  return u8s()[linear_index];
-}
-
-template <>
-inline half Literal::Get<half>(
-    tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  CHECK(shape().element_type() == F16);
-  int64 linear_index = LinearIndex(multi_index);
-  return GetArraySlice<half>()[linear_index];
-}
-
-template <>
-inline bfloat16 Literal::Get<bfloat16>(
-    tensorflow::gtl::ArraySlice<int64> multi_index) const {
-  CHECK(shape().element_type() == BF16);
-  int64 linear_index = LinearIndex(multi_index);
-  return GetArraySlice<bfloat16>()[linear_index];
-}
-
-template <typename NativeT>
-void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
-                  NativeT value) {
-  int64 linear_index = LinearIndex(multi_index);
-  GetMutableArraySlice<NativeT>().at(linear_index) = value;
-}
-
-template <>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
-                         uint8 value) {
-  int64 linear_index = LinearIndex(multi_index);
-  (*mutable_u8s())[linear_index] = value;
-}
-
-template <>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
-                         int8 value) {
-  return Set<uint8>(multi_index, value);
-}
-
-template <>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
-                         int64 value) {
-  int64 linear_index = LinearIndex(multi_index);
-  (*mutable_s64s())[linear_index] = value;
-}
-
-template <>
-/* static */ inline void Literal::Set(
-    tensorflow::gtl::ArraySlice<int64> multi_index, uint64 value) {
-  int64 linear_index = LinearIndex(multi_index);
-  (*mutable_u64s())[linear_index] = value;
+  return data<NativeT>().at(0);
 }
 
 // Returns an identity matrix (rank 2) with the given row and column count.
@@ -1076,51 +979,31 @@ void Literal::EachCell(
   } while (IndexUtil::BumpIndices(shape(), &indices));
 }
 
-template <typename NativeT>
-inline void Literal::PopulateR0(NativeT value) {
-  *mutable_shape() = ShapeUtil::MakeShape(
-      primitive_util::NativeToPrimitiveType<NativeT>(), {});
-  Resize<NativeT>(1, value);
-}
-
 template <typename NativeT>
 inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
-  *mutable_shape() =
-      ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
-                           {static_cast<int64>(values.size())});
-  Reserve(values.size());
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+  CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
+  CHECK_EQ(shape().element_type(),
+           primitive_util::NativeToPrimitiveType<NativeT>());
   for (int64 i = 0; i < values.size(); ++i) {
     Set({i}, values[i]);
   }
 }
 
-inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
-  *mutable_shape() =
-      ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())});
-  Reserve(values.bits());
-  for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
-    Set({i}, values.get(i));
-  }
-}
-
 template <typename NativeT>
-void Literal::PopulateR2WithLayout(
-    std::initializer_list<std::initializer_list<NativeT>> values,
-    const Layout& layout) {
-  *mutable_shape() = ShapeUtil::MakeShapeWithLayout(
-      primitive_util::NativeToPrimitiveType<NativeT>(),
-      {static_cast<int64>(values.size()),
-       static_cast<int64>(values.begin()->size())},
-      LayoutUtil::MinorToMajor(layout));
+void Literal::PopulateR2(
+    std::initializer_list<std::initializer_list<NativeT>> values) {
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_EQ(ShapeUtil::Rank(shape()), 2);
+  CHECK_EQ(shape().element_type(),
+           primitive_util::NativeToPrimitiveType<NativeT>());
 
   const int64 dim0_size = values.size();
   const int64 dim1_size = values.begin()->size();
   CHECK_EQ(dim0_size, shape().dimensions(0));
   CHECK_EQ(dim1_size, shape().dimensions(1));
 
-  const int64 num_elements = dim1_size * dim0_size;
-  Reserve(num_elements);
-
   int64 dim0 = 0;
   for (auto inner_list : values) {
     int64 dim1 = 0;
@@ -1134,57 +1017,28 @@ void Literal::PopulateR2WithLayout(
 }
 
 template <typename NativeT>
-void Literal::PopulateR2(
-    std::initializer_list<std::initializer_list<NativeT>> values) {
-  PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
-}
-
-template <typename NativeT>
-void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
-                                          const Layout& layout) {
-  CHECK_EQ(layout.format(), DENSE);
-  *mutable_shape() = ShapeUtil::MakeShapeWithLayout(
-      primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
-      LayoutUtil::MinorToMajor(layout));
-  Reserve(values.num_elements());
+void Literal::PopulateFromArray(const Array<NativeT>& values) {
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_EQ(shape().element_type(),
+           primitive_util::NativeToPrimitiveType<NativeT>());
+  CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
+  for (int dim = 0; dim < values.num_dimensions(); ++dim) {
+    CHECK_EQ(values.dim(dim), shape().dimensions(dim));
+  }
   values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
                      NativeT value) { this->Set(indices, value); });
 }
 
-template <typename NativeT>
-void Literal::PopulateFromArray(const Array<NativeT>& values) {
-  PopulateFromArrayWithLayout(
-      values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
-}
-
-template <typename NativeT>
-void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
-                                              const Layout& layout) {
-  PopulateFromArrayWithLayout(values, layout);
-}
-
 template <typename NativeT>
 void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
   PopulateFromArray(values);
 }
 
-template <typename NativeT>
-void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
-                                              const Layout& layout) {
-  PopulateFromArrayWithLayout(values, layout);
-}
-
 template <typename NativeT>
 void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
   PopulateFromArray(values);
 }
 
-template <typename NativeT>
-void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
-                                              const Layout& layout) {
-  PopulateFromArrayWithLayout(values, layout);
-}
-
 template <typename NativeT>
 void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
   PopulateFromArray(values);
@@ -1194,10 +1048,10 @@ template <typename NativeT, typename FnType>
 Status Literal::Populate(const FnType& generator) {
   const Shape& this_shape = shape();
   const int64 rank = ShapeUtil::Rank(this_shape);
+  TF_RET_CHECK(ShapeUtil::IsArray(this_shape));
   TF_RET_CHECK(this_shape.element_type() ==
                primitive_util::NativeToPrimitiveType<NativeT>());
-  tensorflow::gtl::MutableArraySlice<NativeT> data =
-      GetMutableArraySlice<NativeT>();
+  tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
   if (rank > 0) {
     StrideConfig stride_config(this_shape, this_shape,
                                AsInt64Slice(this_shape.dimensions()));
@@ -1206,11 +1060,12 @@ Status Literal::Populate(const FnType& generator) {
         ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
 
     auto init_function = [&](const std::vector<int64>& indexes) {
-      const int64 index = LinearIndex(indexes);
+      const int64 index =
+          IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
       std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
       for (int64 i = 0; i < minor_dimension_size; ++i) {
         minor_scan_indexes[stride_config.minor_dimension] = i;
-        data.at(index + i) = generator(minor_scan_indexes);
+        literal_data.at(index + i) = generator(minor_scan_indexes);
       }
       return true;
     };
@@ -1219,31 +1074,27 @@ Status Literal::Populate(const FnType& generator) {
                             init_function);
   } else {
     // For scalars.
-    data.at(0) = generator({});
+    literal_data.at(0) = generator({});
   }
   return Status::OK();
 }
 
 template <typename NativeT>
-void Literal::PopulateWithValue(NativeT value,
-                                tensorflow::gtl::ArraySlice<int64> dimensions) {
-  *mutable_shape() = ShapeUtil::MakeShape(
-      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
-  Resize<NativeT>(ShapeUtil::ElementsIn(shape()), value);
+void Literal::PopulateWithValue(NativeT value) {
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_EQ(shape().element_type(),
+           primitive_util::NativeToPrimitiveType<NativeT>());
+  for (NativeT& element : data<NativeT>()) {
+    element = value;
+  }
 }
 
 template <typename NativeT>
 /* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
     tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
-  Shape this_shape = ShapeUtil::MakeShapeWithDescendingLayout(
-      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
-  auto literal = MakeUnique<Literal>();
-  *literal->mutable_shape() = this_shape;
-  literal->Reserve(ShapeUtil::ElementsIn(this_shape));
-  std::vector<int64> index(dimensions.size(), 0);
-  do {
-    literal->Set(index, value);
-  } while (IndexUtil::BumpIndices(this_shape, &index));
+  auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
+      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+  literal->PopulateWithValue(value);
   return literal;
 }
 
@@ -1254,14 +1105,12 @@ std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
   for (int64 bound : shape().dimensions()) {
     bounds.push_back(bound);
   }
-  auto literal = MakeUnique<Literal>();
-  *literal->mutable_shape() =
-      ShapeUtil::MakeShape(shape().element_type(), bounds);
+  auto literal =
+      MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
   int64 elements = ShapeUtil::ElementsIn(literal->shape());
   if (elements == 0) {
     return literal;
   }
-  literal->Reserve(elements);
 
   DimensionVector output_indices(bounds.size(), 0);
   tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
index 7ff64c4134155e7fe22ab99584970a7d6d6e8803..4974ead048d7939036cbb4e607c5ec30a817cac8 100644 (file)
@@ -31,6 +31,7 @@ namespace xla {
 namespace {
 
 using ::testing::ElementsAre;
+using ::testing::HasSubstr;
 
 class LiteralUtilTest : public ::testing::Test {
  protected:
@@ -293,29 +294,28 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
   auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
   auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
   auto scalar = Literal::CreateR0<float>(1.0);
+  Literal nil(ShapeUtil::MakeNil());
 
   EXPECT_EQ(*matrix, *matrix);
   EXPECT_EQ(*matrix, *matrix_clone);
   EXPECT_NE(*matrix, *matrix_different);
   EXPECT_NE(*matrix, *vector_literal);
   EXPECT_NE(*matrix, *scalar);
+  EXPECT_NE(*matrix, nil);
+  EXPECT_EQ(nil, nil);
 }
 
 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
   // Test equality with literals which have different layouts.
-  auto colmajor = MakeUnique<Literal>();
-  *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
-  *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
-  colmajor->Reserve(4);
+  auto colmajor =
+      MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
   colmajor->Set<float>({0, 0}, 1.0);
   colmajor->Set<float>({0, 1}, 2.0);
   colmajor->Set<float>({1, 0}, 3.0);
   colmajor->Set<float>({1, 1}, 4.0);
 
-  auto rowmajor = MakeUnique<Literal>();
-  *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
-  *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
-  rowmajor->Reserve(4);
+  auto rowmajor =
+      MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
   rowmajor->Set<float>({0, 0}, 1.0);
   rowmajor->Set<float>({0, 1}, 2.0);
   rowmajor->Set<float>({1, 0}, 3.0);
@@ -597,24 +597,26 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
 
 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
   // Test expected memory layout of R2 dim0-minor (column-major) literal.
-  auto mat_dim0minor = Literal::CreateR2WithLayout<int>({{1, 2, 3}, {4, 5, 6}},
-                                                        layout_r2_dim0minor_);
-  EXPECT_EQ(mat_dim0minor->s32s_size(), 6);
-  EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6));
+  auto mat_dim0minor = Literal::CreateR2WithLayout<int32>(
+      {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
+  EXPECT_EQ(mat_dim0minor->element_count(), 6);
+  EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
 
   // Test expected memory layout when using Relayout to row major.
   auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
-  EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6));
+  EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+              ElementsAre(1, 2, 3, 4, 5, 6));
 
   // Test expected memory layout of R2 created with dim0-major (row-major).
-  auto mat_dim0major = Literal::CreateR2WithLayout<int>({{1, 2, 3}, {4, 5, 6}},
-                                                        layout_r2_dim0major_);
-  EXPECT_EQ(mat_dim0major->s32s_size(), 6);
-  EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6));
+  auto mat_dim0major = Literal::CreateR2WithLayout<int32>(
+      {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
+  EXPECT_EQ(mat_dim0major->element_count(), 6);
+  EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
 
   // Test expected memory layout when using Relayout to column major.
   auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
-  EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6));
+  EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+              ElementsAre(1, 4, 2, 5, 3, 6));
 }
 
 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
@@ -634,27 +636,27 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
   auto lit_dim0minor =
       Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0minor_);
 
-  EXPECT_EQ(lit_dim0minor->s32s_size(), 12);
+  EXPECT_EQ(lit_dim0minor->element_count(), 12);
   std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
-  EXPECT_THAT(lit_dim0minor->s32s(),
+  EXPECT_THAT(lit_dim0minor->data<int32>(),
               testing::ElementsAreArray(expected_dim0minor));
 
   // Test expected memory layout when using Relayout to row major.
   auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
   std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
-  EXPECT_THAT(relaid_lit_to_dim0major->s32s(),
+  EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
               testing::ElementsAreArray(expected_dim0major));
 
   // Test expected memory layout of R3 created with dim0-major (row-major).
   auto lit_dim0major =
       Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0major_);
-  EXPECT_EQ(lit_dim0major->s32s_size(), 12);
-  EXPECT_THAT(lit_dim0major->s32s(),
+  EXPECT_EQ(lit_dim0major->element_count(), 12);
+  EXPECT_THAT(lit_dim0major->data<int32>(),
               testing::ElementsAreArray(expected_dim0major));
 
   // Test expected memory layout when using Relayout to column major.
   auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
-  EXPECT_THAT(relaid_lit_to_dim0minor->s32s(),
+  EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
               testing::ElementsAreArray(expected_dim0minor));
 }
 
@@ -687,28 +689,28 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) {
 }
 
 TEST_F(LiteralUtilTest, PopulateR1S64) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(S64, {1}));
   output.PopulateR1<int64>({77});
   auto expected = Literal::CreateR1<int64>({77});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateR1U64) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(U64, {2}));
   output.PopulateR1<uint64>({{77, 88}});
   auto expected = Literal::CreateR1<uint64>({{77, 88}});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateR1C64) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(C64, {1}));
   output.PopulateR1<complex64>({{77, 88}});
   auto expected = Literal::CreateR1<complex64>({{77, 88}});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateR2C64) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
   output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
   auto expected =
       Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
@@ -716,78 +718,78 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(BF16, {}));
   bfloat16 h(0.25f);
-  output.PopulateWithValue<bfloat16>(h, {});
+  output.PopulateWithValue<bfloat16>(h);
   auto expected = Literal::CreateR0<bfloat16>(h);
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(BF16, {3}));
   bfloat16 h(0.5f);
-  output.PopulateWithValue<bfloat16>(h, {3});
+  output.PopulateWithValue<bfloat16>(h);
   auto expected = Literal::CreateR1<bfloat16>({h, h, h});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
   bfloat16 h(2.0f);
-  output.PopulateWithValue<bfloat16>(h, {2, 2});
+  output.PopulateWithValue<bfloat16>(h);
   auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
-  Literal output;
-  output.PopulateWithValue<float>(2.5f, {});
+  Literal output(ShapeUtil::MakeShape(F32, {}));
+  output.PopulateWithValue<float>(2.5f);
   auto expected = Literal::CreateR0<float>(2.5f);
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
-  Literal output;
-  output.PopulateWithValue<int64>(-7, {3});
+  Literal output(ShapeUtil::MakeShape(S64, {3}));
+  output.PopulateWithValue<int64>(-7);
   auto expected = Literal::CreateR1<int64>({-7, -7, -7});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
-  Literal output;
-  output.PopulateWithValue<uint64>(42, {2, 2});
+  Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
+  output.PopulateWithValue<uint64>(42);
   auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
-  Literal output;
-  output.PopulateWithValue<complex64>({4, 2}, {2, 2});
+  Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
+  output.PopulateWithValue<complex64>({4, 2});
   auto expected =
       Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(F16, {}));
   half h(0.25f);
-  output.PopulateWithValue<half>(h, {});
+  output.PopulateWithValue<half>(h);
   auto expected = Literal::CreateR0<half>(h);
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(F16, {3}));
   half h(0.5f);
-  output.PopulateWithValue<half>(h, {3});
+  output.PopulateWithValue<half>(h);
   auto expected = Literal::CreateR1<half>({h, h, h});
   EXPECT_EQ(output, *expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
-  Literal output;
+  Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
   half h(2.0f);
-  output.PopulateWithValue<half>(h, {2, 2});
+  output.PopulateWithValue<half>(h);
   auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
   EXPECT_EQ(output, *expected);
 }
@@ -803,7 +805,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) {
   EXPECT_EQ(*output, *expected);
 }
 
-TEST_F(LiteralUtilTest, Copy) {
+TEST_F(LiteralUtilTest, CopySliceFrom) {
   const int64 dimensions[] = {17, 15, 34, 21};
   const int64 layouts[][4] = {
       {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
@@ -826,7 +828,7 @@ TEST_F(LiteralUtilTest, Copy) {
     const int64 src_base[] = {3, 1, 5, 7};
     const int64 dest_base[] = {6, 4, 12, 2};
     const int64 copy_size[] = {7, 8, 11, 9};
-    TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size));
+    TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
 
     std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
     std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
@@ -849,16 +851,16 @@ TEST_F(LiteralUtilTest, Copy) {
   }
 }
 
-TEST_F(LiteralUtilTest, CopyScalars) {
+TEST_F(LiteralUtilTest, CopyFromScalars) {
   auto zero = Literal::CreateR0<uint32>(0);
   auto nine = Literal::CreateR0<uint32>(9);
-  TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {}));
+  TF_EXPECT_OK(zero->CopyFrom(*nine));
   EXPECT_EQ(*zero, *nine);
 
   auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
-  TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {}));
+  TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
   EXPECT_EQ(zero->Get<uint32>({}), 17);
-  TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {}));
+  TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
   EXPECT_EQ(vect->Get<uint32>({4}), 17);
 }
 
@@ -872,7 +874,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
     const auto empty = Literal::CreateFromShape(empty_r1_shape);
     auto nine = Literal::CreateR1<float>({9});
 
-    TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0}));
+    TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
     EXPECT_EQ(*nine, *const_nine);
   }
 
@@ -881,18 +883,101 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
     const auto empty = Literal::CreateFromShape(empty_r1_shape);
     auto nine = Literal::CreateR1<float>({9});
 
-    TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0}));
+    TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
     EXPECT_EQ(*empty, *const_empty);
   }
 }
 
+TEST_F(LiteralUtilTest, CopyFromNilShape) {
+  Literal nil_literal0(ShapeUtil::MakeNil());
+  Literal nil_literal1(ShapeUtil::MakeNil());
+  // This doesn't actually do any copying, but it should succeed.
+  TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
+}
+
+TEST_F(LiteralUtilTest, CopyFromArrays) {
+  auto scalar_42 = Literal::CreateR0<float>(42.0);
+  auto scalar_123 = Literal::CreateR0<float>(123.0);
+  EXPECT_NE(*scalar_42, *scalar_123);
+  TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
+                                   /*src_shape_index=*/{}));
+  EXPECT_EQ(*scalar_42, *scalar_123);
+  EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+
+  auto matrix_1234 = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  auto matrix_5678 = Literal::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
+  EXPECT_NE(*matrix_1234, *matrix_5678);
+  EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
+  TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
+                                     /*src_shape_index=*/{}));
+  EXPECT_EQ(*matrix_1234, *matrix_5678);
+  EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+}
+
+TEST_F(LiteralUtilTest, CopyFromTuples) {
+  auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  Literal nil_literal(ShapeUtil::MakeNil());
+  auto nested_tuple = Literal::MakeTuple(
+      {matrix.get(),
+       Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
+                           Literal::CreateR1<double>({23.0, 44.0}).get(),
+                           &nil_literal})
+           .get()});
+  // Create a tuple the same shape as the inner tuple of nested_tuple but with
+  // different values..
+  auto tuple = Literal::MakeTuple({Literal::CreateR0<int32>(-5).get(),
+                                   Literal::CreateR1<double>({2.0, 4.0}).get(),
+                                   &nil_literal});
+
+  EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+  EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
+  EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
+  EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+
+  // Overwrite the inner tuple element of nested_tuple with the contents of
+  // 'tuple'.
+  TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
+                                      /*src_shape_index=*/{}));
+
+  // The matrix element should be unchanged.
+  EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+
+  // The tuple element should have been copied from 'tuple'.
+  EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
+  EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
+  EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+}
+TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
+  auto tuple = Literal::MakeTuple(
+      {Literal::CreateR0<int32>(-2).get(), Literal::CreateR0<int32>(4).get()});
+
+  EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
+  EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+
+  // Copy from one element to the other.
+  TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
+                               /*src_shape_index=*/{0}));
+
+  EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
+  EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+}
+
+TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
+  auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  auto vector = Literal::CreateR1<float>({5.0, 7.0});
+  Status status = matrix->CopyFrom(*vector);
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              HasSubstr("Destination subshape incompatible"));
+}
+
 TEST_F(LiteralUtilTest, F16) {
   // Verify that the internal data views are consistent and that they
   // are in little endian format
   // TODO - modify if we make the data format machine endianess dependent
   auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
   Literal* l1 = m1.get();
-  const char* d1 = static_cast<const char*>(l1->InternalData());
+  const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
   EXPECT_EQ(d1[0], 0);
   EXPECT_EQ(d1[1], 0);
   EXPECT_EQ(d1[2], 0);
@@ -901,13 +986,12 @@ TEST_F(LiteralUtilTest, F16) {
   EXPECT_EQ(d1[5], 0);
   EXPECT_EQ(d1[6], 0);
   EXPECT_EQ(d1[7], 0);
-  EXPECT_EQ(l1->InternalData(), l1->MutableInternalData());
 
   half h1(1.0f);
   half h2(2.0f);
   auto m2 = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
   Literal* l2 = m2.get();
-  const char* d2 = static_cast<const char*>(l2->InternalData());
+  const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
   EXPECT_EQ(d2[0], 0);
   EXPECT_EQ(d2[1], 0x3C);
   EXPECT_EQ(d2[2], 0);
@@ -916,7 +1000,6 @@ TEST_F(LiteralUtilTest, F16) {
   EXPECT_EQ(d2[5], 0x40);
   EXPECT_EQ(d2[6], 0);
   EXPECT_EQ(d2[7], 0x3C);
-  EXPECT_EQ(l2->InternalData(), l2->MutableInternalData());
 }
 
 TEST_F(LiteralUtilTest, Populate) {
@@ -941,7 +1024,9 @@ TEST_F(LiteralUtilTest, Populate) {
     auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
       // Offsets from linear index just to avoid R0 literals to be initialized
       // with zero.
-      return literal->LinearIndex(indexes) + 17;
+      return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+                                                           indexes) +
+             17;
     };
     TF_EXPECT_OK(literal->Populate<uint32>(generator));
 
@@ -1118,16 +1203,18 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
   for (int len = 0; len < 25; ++len) {
     p.mutable_shape()->clear_dimensions();
     p.mutable_shape()->add_dimensions(len);
+    LayoutUtil::SetToDefaultLayout(p.mutable_shape());
     p.clear_preds();
     for (int i = 0; i < len; ++i) {
       p.add_preds((i % 2) == (len % 2));
     }
 
-    Literal literal(p);
-    ASSERT_EQ(len, literal.preds_size());
+    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
+                            Literal::CreateFromProto(p));
+    ASSERT_EQ(len, literal->data<bool>().size());
     int i = 0;
-    for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) {
-      EXPECT_EQ((i % 2) == (len % 2), *it);
+    for (bool value : literal->data<bool>()) {
+      EXPECT_EQ((i % 2) == (len % 2), value);
       ++i;
     }
   }
@@ -1141,8 +1228,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
   auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
   Literal* l = m.get();
   EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
-  EXPECT_EQ(4, l->f16s().size());
-  EXPECT_EQ(4, l->f16s_size());
+  EXPECT_EQ(4, l->data<half>().size());
 
   LiteralProto p = l->ToProto();
   EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
@@ -1168,17 +1254,12 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
   p.mutable_shape()->set_element_type(F16);
   p.mutable_shape()->clear_dimensions();
   p.mutable_shape()->add_dimensions(4);
+  LayoutUtil::SetToDefaultLayout(p.mutable_shape());
   p.clear_f16s();
   p.set_f16s(half_vals, 8);
-
-  Literal literal(p);
-  ASSERT_EQ(4, literal.f16s_size());
-  ASSERT_EQ(h1, literal.f16s(0));
-  ASSERT_EQ(h2, literal.f16s(1));
-  ASSERT_EQ(h2, literal.f16s(2));
-  ASSERT_EQ(h1, literal.f16s(3));
-
-  const std::vector<half>& r = literal.f16s();
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
+                          Literal::CreateFromProto(p));
+  auto r = literal->data<half>();
   ASSERT_EQ(4, r.size());
   ASSERT_EQ(h1, r[0]);
   ASSERT_EQ(h2, r[1]);
@@ -1186,24 +1267,365 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
   ASSERT_EQ(h1, r[3]);
 }
 
-TEST_F(LiteralUtilTest, Subliterals) {
+TEST_F(LiteralUtilTest, LiteralViewTest) {
+  auto scalar = Literal::CreateR0<float>(1.0);
+  auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+  auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+  Literal nil(ShapeUtil::MakeNil());
+
+  EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar);
+  EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix);
+  EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple);
+  EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple);
+  EXPECT_EQ(LiteralView::Create(nil, {}), nil);
+
+  EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar);
+  EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix);
+
+  EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple);
+  EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar);
+  EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix);
+  EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar);
+}
+
+TEST_F(LiteralUtilTest, MutatingLiteralView) {
   auto scalar = Literal::CreateR0<float>(1.0);
   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+  // Verify that changing the underlying data beneath the view changes the
+  // data of the view itself.
+  const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+  EXPECT_EQ(
+      nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+      1.0f);
+  EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
+                                         /*shape_index=*/{0, 0}),
+            1.0f);
+  nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+  EXPECT_EQ(
+      nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+      555.0f);
+  EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
+                                         /*shape_index=*/{0, 0}),
+            555.0f);
+}
+
+TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) {
+  auto scalar = Literal::CreateR0<float>(1.0);
+  auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+  auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+
+  const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+  const auto tuple_view =
+      LiteralView::Create(nested_tuple_view, /*view_root=*/{0});
+  const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1});
+  EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+}
+
+TEST_F(LiteralUtilTest, LiteralMove) {
+  std::unique_ptr<Literal> matrix =
+      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  Literal literal(std::move(*matrix));
+
+  EXPECT_TRUE(
+      ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
+  EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
+  EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
+  EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
+  EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
+}
 
-  EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get());
-  EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get());
-  EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get());
-  EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get());
+TEST_F(LiteralUtilTest, DecomposeTuple) {
+  Literal nil_literal(ShapeUtil::MakeNil());
+  auto nested_tuple = Literal::MakeTuple(
+      {Literal::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
+       Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
+                           Literal::CreateR1<double>({23.0, 44.0}).get(),
+                           &nil_literal})
+           .get(),
+       &nil_literal});
+
+  EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
+  std::vector<Literal> elements = nested_tuple->DecomposeTuple();
+  EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+
+  ASSERT_EQ(elements.size(), 3);
+
+  EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
+                                    ShapeUtil::MakeShape(S32, {2, 2})));
+  EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
+  EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
+  EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
+  EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
+
+  EXPECT_TRUE(ShapeUtil::Compatible(
+      elements[1].shape(),
+      ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
+                                 ShapeUtil::MakeShape(F64, {2}),
+                                 ShapeUtil::MakeNil()})));
+  EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
+  EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
+  EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
+
+  EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
+}
+
+TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
+  Literal nil_literal(ShapeUtil::MakeNil());
+  std::vector<Literal> elements = nil_literal.DecomposeTuple();
+  EXPECT_EQ(elements.size(), 0);
+}
+
+TEST_F(LiteralUtilTest, MoveIntoTuple) {
+  std::vector<Literal> elements;
+  elements.push_back(std::move(*Literal::CreateR0<float>(1.0)));
+  elements.push_back(std::move(*Literal::CreateR1<int32>({4, 8})));
+  elements.push_back(std::move(
+      *Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
+                           Literal::CreateR1<double>({23.0, 44.0}).get()})
+
+          ));
+
+  Literal literal = Literal::MoveIntoTuple(&elements);
+  ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
+  ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
+
+  EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
+  EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
+  EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
+  EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
+  EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
+  EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
+
+  for (const Literal& element : elements) {
+    EXPECT_TRUE(ShapeUtil::IsNil(element.shape()));
+  }
+}
+
+TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
+  Literal literal = Literal::MoveIntoTuple({});
+  ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
+  ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
+}
+
+TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
+  Literal literal;
+  EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
+
+  std::unique_ptr<Literal> matrix =
+      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  literal = std::move(*matrix);
+
+  EXPECT_TRUE(
+      ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
+  EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
+  EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
+  EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
+  EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
+}
+
+TEST_F(LiteralUtilTest, LiteralViewCopy) {
+  std::unique_ptr<Literal> matrix =
+      Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  const auto matrix_view = LiteralView::Create(*matrix);
+  LiteralView matrix_view_copy(matrix_view);
+
+  EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
+  EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
+  EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
+  EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
+}
+
+TEST_F(LiteralUtilTest, GetSetTuple) {
+  auto tuple = Literal::MakeTuple(
+      {Literal::CreateR0<float>(42.0).get(),
+       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
+  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+  tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+            3.0);
+  tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+            -4.0);
+}
+
+TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
+  // Literals constructed using CreateFromShape should be zero initialized.
+  std::unique_ptr<Literal> scalar_f32 =
+      Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+  EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
+  EXPECT_TRUE(scalar_f32->IsAll(0));
+
+  std::unique_ptr<Literal> vector_s32 =
+      Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+  EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
+  EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
+  EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
+  EXPECT_TRUE(vector_s32->IsAll(0));
+
+  std::unique_ptr<Literal> tuple =
+      Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+          {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+           ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+  EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
+  EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
+  EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
+  EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
+  EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
+  EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+}
+
+TEST_F(LiteralUtilTest, ProtoRoundTrip) {
+  // Test serializing then deserializing a Literal through a proto.
+  auto one_f32 = Literal::CreateR0<float>(1.0);
+  auto two_f32 = Literal::CreateR0<float>(2.0);
+  auto vector_int8 = Literal::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+  auto vector_c64 = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+  auto vector_bfloat16 = Literal::CreateR1<bfloat16>(
+      {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
+  auto vector_half =
+      Literal::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
+  auto matrix_pred =
+      Literal::CreateR2<bool>({{true, false, true}, {false, false, true}});
+  auto tuple = Literal::MakeTuple(
+      {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+  Literal nil_literal(ShapeUtil::MakeNil());
+  auto nested_tuple = Literal::MakeTuple(
+      {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+
+  auto to_from_proto = [](const Literal& literal) -> Literal {
+    return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+  };
+
+  EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
+  EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
+  EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
+  EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
+  EXPECT_EQ(*tuple, to_from_proto(*tuple));
+  EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+  EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
+
+  EXPECT_NE(*one_f32, *two_f32);
+  EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
+  // Proto contains a shape, but no values.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              HasSubstr("Expected 3 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
+  // Proto contains values, but no shape.
+  LiteralProto proto;
+  proto.add_preds(false);
+  proto.add_preds(true);
+  proto.add_preds(false);
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
+}
 
-  EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar);
-  EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix);
+TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
+  // Proto contains values in wrong container.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
+  proto.add_preds(false);
+  proto.add_preds(true);
+  proto.add_preds(false);
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              HasSubstr("Expected 3 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
+  // Proto contains too few values.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2});
+  proto.add_f32s(1.0);
+  proto.add_f32s(2.0);
+  proto.add_f32s(3.0);
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              HasSubstr("Expected 84 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
+  // Proto contains too many values.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2});
+  proto.add_s32s(42);
+  proto.add_s32s(-10);
+  proto.add_s32s(100);
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(),
+              HasSubstr("Expected 2 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
+  // Proto shape missing layout.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2});
+  LayoutUtil::ClearLayout(proto.mutable_shape());
+  proto.add_preds(true);
+  proto.add_preds(false);
+  proto.add_preds(true);
+  proto.add_preds(false);
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
+  // Proto has the too few tuple elements.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
+  LiteralProto* element0 = proto.add_tuple_literals();
+  *element0->mutable_shape() =
+      ShapeUtil::GetTupleElementShape(proto.shape(), 0);
+  element0->add_preds(false);
+  element0->add_preds(true);
+
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
+}
 
-  EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple);
-  EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar);
-  EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix);
-  EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar);
+TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
+  // Proto has the too many tuple elements.
+  LiteralProto proto;
+  *proto.mutable_shape() = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
+  LiteralProto* element0 = proto.add_tuple_literals();
+  *element0->mutable_shape() =
+      ShapeUtil::GetTupleElementShape(proto.shape(), 0);
+  element0->add_preds(false);
+  element0->add_preds(true);
+  LiteralProto* element1 = proto.add_tuple_literals();
+  *element1->mutable_shape() =
+      ShapeUtil::GetTupleElementShape(proto.shape(), 1);
+  element1->add_f32s(42.0);
+  LiteralProto* element2 = proto.add_tuple_literals();
+  *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {});
+  element2->add_f32s(123.0);
+
+  Status status = Literal::CreateFromProto(proto).status();
+  ASSERT_FALSE(status.ok());
+  ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
 }
 
 }  // namespace
index 70e0f5a74711c8ceef1b6d4225141aa1cc9c6219..857aae0a7982a57bb3057a6f267f5f033a0fdde4 100644 (file)
@@ -44,11 +44,11 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
   VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
           << " layout: "
           << (layout == nullptr ? "<none>" : layout->ShortDebugString());
-  auto result = MakeUnique<Literal>();
-  *result->mutable_shape() = shape;
+  Shape literal_shape = shape;
   if (layout != nullptr) {
-    TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*layout, shape));
-    *result->mutable_shape()->mutable_layout() = *layout;
+    TF_RETURN_IF_ERROR(
+        LayoutUtil::ValidateLayoutForShape(*layout, literal_shape));
+    *literal_shape.mutable_layout() = *layout;
   }
 
   if (shape.element_type() != F32) {
@@ -57,10 +57,12 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
         PrimitiveType_Name(shape.element_type()).c_str());
   }
 
+  auto result = MakeUnique<Literal>(literal_shape);
+  result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+
   int64 elements = ShapeUtil::ElementsIn(shape);
-  result->Resize(elements, std::numeric_limits<float>::quiet_NaN());
-  std::vector<float>* field = result->mutable_f32s();
-  char* data = tensorflow::bit_cast<char*>(field->data());
+  tensorflow::gtl::ArraySlice<float> field = result->data<float>();
+  char* data = tensorflow::bit_cast<char*>(field.data());
   uint64 bytes = elements * sizeof(float);
   tensorflow::StringPiece sp;
   auto s = file_->Read(offset_, bytes, &sp, data);
index 8b4779a0cdf6b056ea0703d0dd03f7dd377c1a57..7d1f057101a9240ed03545499dd76c6ca3d74a02 100644 (file)
@@ -311,7 +311,7 @@ tensorflow::ImportNumpy();
   const int size = PySequence_Size($input);
   for (int i = 0; i < size; ++i) {
     PyObject* o = PySequence_GetItem($input, i);
-    temps.push_back(*numpy::XlaLiteralFromPyObject(o));
+    temps.push_back(std::move(*numpy::XlaLiteralFromPyObject(o)));
     Py_DECREF(o);
   }
   $1 = &temps;
index b30bdc3669de3992a08ab70ef49b0aa17cc855f3..d88d78e474c95656edd581065b8642fafae15737 100644 (file)
@@ -225,11 +225,11 @@ Shape XlaShapeFromPyShapeInfo(PyObject* o) {
 
 PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
   if (ShapeUtil::IsTuple(literal.shape())) {
-    const std::vector<Literal>& tuple_literals = literal.tuple_literals();
     int num_elements = ShapeUtil::TupleElementCount(literal.shape());
     PyObject* tuple = PyTuple_New(num_elements);
     for (int i = 0; i < num_elements; i++) {
-      PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i]));
+      PyTuple_SET_ITEM(
+          tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i})));
     }
     return tuple;
   } else {
index 4e6ecbb0e8b58979ec1f1484e722725c391106fb..3f39869765873f31b705535c873087c649ad5876 100644 (file)
@@ -96,14 +96,14 @@ void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
 template <typename NativeT>
 void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
   NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
-  auto dest = literal->GetMutableArraySlice<NativeT>();
+  auto dest = literal->data<NativeT>();
   std::copy(source, source + PyArray_SIZE(py_array), dest.data());
 }
 
 template <typename NativeT>
 void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
   NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
-  auto source = literal.GetArraySlice<NativeT>();
+  auto source = literal.data<NativeT>();
   std::copy(source.begin(), source.end(), dest);
 }
 
index 8d476d771499dd21ded24d2f073b228b193b7382..90a3f0b6748fc00c9cd9226700805bf243a1acdd 100644 (file)
@@ -498,13 +498,14 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation,
   if (ShapeUtil::IsTuple(literal.shape())) {
     std::vector<HloInstruction*> elems;
     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
-    for (const Literal& child : literal.tuple_literals()) {
-      elems.push_back(BuildTupleConstant(computation, child));
+    for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
+      elems.push_back(
+          BuildTupleConstant(computation, LiteralView::Create(literal, {i})));
     }
     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
   } else {
     return computation->AddInstruction(
-        HloInstruction::CreateConstant(MakeUnique<Literal>(literal)));
+        HloInstruction::CreateConstant(literal.CloneToUnique()));
   }
 }
 
index a5c1f29832c9eb54adbff46b04d7da8e1ff72d44..8e6562c237e310c19249e64009f3ffdd1b7a86e8 100644 (file)
@@ -572,7 +572,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
       if (instruction->opcode() == HloOpcode::kConstant) {
         // Copy the constant out of the ProtocolBuffer so that we can give it a
         // higher alignment.
-        const void* data = instruction->literal().InternalData();
+        const void* data = instruction->literal().untyped_data();
         int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape());
         auto iter = aligned_constants.emplace(
             instruction, xla::MakeUnique<unsigned char[]>(size));
index b53719fcc260d706eab3d7460c42af4a1b5e775f..f5e61aef534da57ce13d3ee9bbeaeaec31f53d2e 100644 (file)
@@ -98,7 +98,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
 
   if (!ShapeUtil::IsTuple(shape)) {
     int64 size = GetByteSizeRequirement(shape);
-    return TransferBufferToInfeed(executor, size, literal.InternalData());
+    return TransferBufferToInfeed(executor, size, literal.untyped_data());
   }
 
   if (ShapeUtil::IsNestedTuple(shape)) {
@@ -111,20 +111,20 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
   // enqueue the resulting destination device addresses with the
   // infeed manager.
   std::vector<cpu::runtime::XfeedBuffer*> buffers;
-  buffers.reserve(literal.tuple_literals_size());
+  buffers.reserve(ShapeUtil::TupleElementCount(shape));
   auto cleanup = tensorflow::gtl::MakeCleanup([&buffers]() {
     for (cpu::runtime::XfeedBuffer* b : buffers) {
       b->Done(Cancelled("Failed to infeed buffer to device."));
     }
   });
 
-  for (const auto& tuple_element : literal.tuple_literals()) {
-    const Shape& tuple_element_shape = tuple_element.shape();
+  for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+    const Shape& tuple_element_shape = ShapeUtil::GetSubshape(shape, {i});
     int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape);
     TF_ASSIGN_OR_RETURN(
         cpu::runtime::XfeedBuffer * buffer,
         TransferBufferToInfeedInternal(executor, tuple_element_size,
-                                       tuple_element.InternalData()));
+                                       literal.untyped_data({i})));
     buffers.push_back(buffer);
   }
 
@@ -187,14 +187,14 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
         literal_shape.element_type(), dimensions));
     TF_ASSIGN_OR_RETURN(Shape received_shape,
                         TransferArrayBufferFromOutfeed(
-                            executor, literal->MutableInternalData(), size));
+                            executor, literal->untyped_data(), size));
     TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
         << "Shape received from outfeed "
         << ShapeUtil::HumanString(received_shape)
         << " did not match the shape that was requested for outfeed: "
         << ShapeUtil::HumanString(literal_shape);
     TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
-    *literal->mutable_shape() = received_shape;
+    *literal->mutable_shape_do_not_use() = received_shape;
     return Status::OK();
   }
 
@@ -217,7 +217,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
     auto empty = Literal::CreateFromDimensions(
         tuple_element_shape.element_type(), dimensions);
     int64 size = GetByteSizeRequirement(tuple_element_shape);
-    buffer_data.push_back({empty->MutableInternalData(), size});
+    buffer_data.push_back({empty->untyped_data(), size});
     elements.push_back(std::move(empty));
   }
 
@@ -233,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
                GetByteSizeRequirement(received_shape));
 
   for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
-    *elements[i]->mutable_shape() = received_shape.tuple_shapes(i);
+    *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
   }
   *literal = std::move(*Literal::MakeTupleOwned(std::move(elements)));
   TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
index 7a97021dda82942e2ddc35d25e620b73b1e378bb..7dcc4ca7fa08b478f24065275ffa69725dc51682 100644 (file)
@@ -38,7 +38,7 @@ void ExternalConstantPool::Insert(string name, const Literal& literal,
   CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
                                 << " bytes with alignment of " << alignment;
 
-  std::memcpy(raw_pointer, literal.InternalData(), literal_size);
+  std::memcpy(raw_pointer, literal.untyped_data(), literal_size);
   entries_.emplace(std::move(name), static_cast<uint8*>(raw_pointer));
 }
 
index 271a856efd66f9f977ac4e201161ba4b505f31e1..78dc0ad4fcd167c93f19d0c2b18ea72d666897ef 100644 (file)
@@ -89,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice(
               /*source=*/device_buffer.buffer(index),
               /*size=*/GetByteSizeRequirement(subshape),
               /*destination=*/
-              literal->GetSubliteral(index).MutableInternalData()));
+              literal->untyped_data(index)));
         }
 
         return Status::OK();
@@ -124,17 +124,17 @@ Status GenericTransferManager::TransferLiteralToDevice(
           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
                        device_memory.size());
           // Element is array-shaped: transfer array data to device buffer.
-          const Literal& subliteral = literal.GetSubliteral(index);
+          const auto subliteral = LiteralView::Create(literal, index);
           std::unique_ptr<Literal> relayed_out_literal;
           const void* source;
           if (LayoutUtil::Equal(device_subshape.layout(),
                                 subliteral.shape().layout())) {
-            source = subliteral.InternalData();
+            source = subliteral.untyped_data();
           } else {
             // Relayout data before transferring.
             relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
                                                       /*shape_index=*/{});
-            source = relayed_out_literal->InternalData();
+            source = relayed_out_literal->untyped_data();
           }
           return TransferBufferToDevice(
               executor,
index ae92daef8882de2e7d64b69f68452061cb5507f2..af9897769fda371e47af06c19abce9a06015e094 100644 (file)
@@ -54,7 +54,7 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
 
   if (!ShapeUtil::IsTuple(shape)) {
     int64 size = GetByteSizeRequirement(shape);
-    return TransferBufferToInfeed(executor, size, literal.InternalData());
+    return TransferBufferToInfeed(executor, size, literal.untyped_data());
   }
 
   if (ShapeUtil::IsNestedTuple(shape)) {
@@ -67,20 +67,21 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
   // enqueue the resulting destination device addresses with the
   // infeed manager.
   std::vector<gpu::InfeedBuffer*> buffers;
-  buffers.reserve(literal.tuple_literals_size());
+  buffers.reserve(ShapeUtil::TupleElementCount(shape));
   auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
     for (gpu::InfeedBuffer* b : buffers) {
       b->Done();
     }
   });
 
-  for (const auto& tuple_element : literal.tuple_literals()) {
-    const Shape& tuple_element_shape = tuple_element.shape();
+  for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+    const Shape& tuple_element_shape =
+        ShapeUtil::GetTupleElementShape(shape, i);
     int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape);
     TF_ASSIGN_OR_RETURN(
         gpu::InfeedBuffer * buffer,
         TransferBufferToInfeedInternal(executor, tuple_element_size,
-                                       tuple_element.InternalData()));
+                                       literal.untyped_data({i})));
     buffers.push_back(buffer);
   }
 
index 1aa506a3a9479dbc88ffeff01e6944993692accc..6761fe08b5f3dae4f481945badb60f08c0ec047f 100644 (file)
@@ -1737,7 +1737,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
   const HloInstruction* operand = inst->operand(0);
   CHECK_EQ(HloOpcode::kConstant, operand->opcode());
   return MakeUnique<HostToDeviceCopyThunk>(
-      /*source_address=*/operand->literal().InternalData(),
+      /*source_address=*/operand->literal().untyped_data(),
       /*destination_buffer=*/GetAllocationSlice(*inst),
       /*mem_size=*/
       llvm_ir::ByteSizeOf(operand->shape(),
index ddd75bbfd12d4d15b67a6e37b8fd0d04375d34f2..021e06f32c94a58e55ae1d2baea82886115cc71c 100644 (file)
@@ -261,7 +261,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     parent_->evaluated_[broadcast] =
         Literal::CreateFromShape(broadcast->shape());
     auto output = parent_->evaluated_[broadcast].get();
-    auto operand_to_broadcast =
+    const Literal& operand_to_broadcast =
         parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
     std::vector<int64> broadcast_indices(
         ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
@@ -836,7 +836,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
         << " but is inferred to be: "
         << ShapeUtil::HumanString(inferred_return_shape);
 
-    auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+    const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
     auto result = Literal::CreateFromShape(result_shape);
 
     TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
@@ -1079,7 +1079,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
           return scalar;
         }));
 
-    auto evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0));
+    const Literal& evaluated_operand =
+        parent_->GetEvaluatedLiteralFor(pad->operand(0));
 
     std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
                                    0);
@@ -1514,7 +1515,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
         << ShapeUtil::HumanString(inferred_return_shape);
 
     const int64 rank = ShapeUtil::Rank(operand->shape());
-    auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+    const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
     auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
       DimensionVector operand_index(rank);
       for (int64 i = 0; i < rank; ++i) {
@@ -1580,8 +1581,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
   StatusOr<std::unique_ptr<Literal>> DynamicSlice(
       const Literal& operand_literal, const Literal& start_indices_literal,
       const Shape& result_shape) {
-    const auto& start_indices_typed =
-        start_indices_literal.GetArraySlice<IndexT>();
+    auto start_indices_typed = start_indices_literal.data<IndexT>();
     std::vector<int64> start(start_indices_typed.begin(),
                              start_indices_typed.end());
 
@@ -1609,12 +1609,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
   StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
       const Literal& operand_literal, const Literal& update_literal,
       const Literal& start_indices_literal) {
-    const auto& start_indices_typed =
-        start_indices_literal.GetArraySlice<IndexT>();
+    auto start_indices_typed = start_indices_literal.data<IndexT>();
     const std::vector<int64> start(start_indices_typed.begin(),
                                    start_indices_typed.end());
 
-    auto result = MakeUnique<Literal>(operand_literal);
+    auto result = operand_literal.CloneToUnique();
     std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
 
     auto func = [&](const std::vector<int64>& update_index) {
@@ -1772,8 +1771,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
 
   TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
 
-  return MakeUnique<Literal>(
-      GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()));
+  return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
+      .CloneToUnique();
 }
 
 template <typename LiteralPtr>
@@ -1790,8 +1789,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
   }
 
   TF_RETURN_IF_ERROR(computation.Accept(this));
-  return MakeUnique<Literal>(
-      GetEvaluatedLiteralFor(computation.root_instruction()));
+  return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
 }
 
 template <typename LiteralPtr>
@@ -1816,14 +1814,14 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
               << input_literal->ToString();
       TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
 
-      evaluated_[operand] = MakeUnique<Literal>(*input_literal);
+      evaluated_[operand] = input_literal->CloneToUnique();
     }
   }
 
   TF_RETURN_IF_ERROR(Preprocess(instruction));
   TF_RETURN_IF_ERROR(instruction->Visit(this));
   TF_RETURN_IF_ERROR(Postprocess(instruction));
-  return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
+  return GetEvaluatedLiteralFor(instruction).CloneToUnique();
 }
 
 StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
@@ -1844,7 +1842,7 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
   TF_RETURN_IF_ERROR(Preprocess(instruction));
   TF_RETURN_IF_ERROR(instruction->Visit(this));
   TF_RETURN_IF_ERROR(Postprocess(instruction));
-  return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
+  return GetEvaluatedLiteralFor(instruction).CloneToUnique();
 }
 
 std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
@@ -1901,7 +1899,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
       << ", but input literal shape is: "
       << ShapeUtil::HumanString(input_literal->shape());
 
-  evaluated_[parameter] = MakeUnique<Literal>(*input_literal);
+  evaluated_[parameter] = input_literal->CloneToUnique();
   return Status::OK();
 }
 
@@ -1952,7 +1950,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
 
   for (auto operand : operands) {
     const Shape& operand_shape = operand->shape();
-    TF_RETURN_IF_ERROR(result_literal->Copy(
+    TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
         AsInt64Slice(operand_shape.dimensions())));
     dest_indices[concat_dim] +=
@@ -2110,16 +2108,17 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
 
   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
 
-  evaluated_[get_tuple_element] =
-      MakeUnique<Literal>(operand_tuple_literal.tuple_literals(index));
-
-  return Status::OK();
+  evaluated_[get_tuple_element] = MakeUnique<Literal>(
+      ShapeUtil::GetTupleElementShape(operand->shape(), index));
+  return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
+                                                 /*dest_shape_index=*/{},
+                                                 /*src_shape_index=*/{index});
 }
 
 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
 
-  auto result = MakeUnique<Literal>(GetEvaluatedLiteralFor(copy->operand(0)));
+  auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
   evaluated_[copy] = std::move(result);
   return Status::OK();
 }
index 138fb9a19049b8119d30c2750a6a0727d33e1010..fffaa1708f049a9f7ebbe20803527b08f0f9ebc4 100644 (file)
@@ -101,7 +101,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
 
   instruction->metadata_ = proto.metadata();
   if (proto.has_literal()) {
-    instruction->literal_ = MakeUnique<Literal>(proto.literal());
+    TF_ASSIGN_OR_RETURN(instruction->literal_,
+                        Literal::CreateFromProto(proto.literal()));
   }
   instruction->parameter_number_ = proto.parameter_number();
 
@@ -166,8 +167,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
   auto instruction =
       WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
   instruction->operands_.push_back(operand);
-  instruction->literal_.reset(new Literal);
-  instruction->literal_->append_u8s(tag);
+  instruction->literal_ = Literal::CreateR1U8(tag);
   return instruction;
 }
 
@@ -2308,7 +2308,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
 string HloInstruction::TracingTag() const {
   CHECK_EQ(HloOpcode::kTrace, opcode());
   CHECK(literal_ != nullptr);
-  return literal_->u8s_string();
+  return literal_->GetR1U8AsString();
 }
 
 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
index 65c0e3c2d988e971d8123db24d4208ddbbef215a..fc848bdb036125e5dadb471be431d3d2523c6770 100644 (file)
@@ -1040,8 +1040,9 @@ std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice(
 
 tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
                                              TransferToServerResponse* result) {
-  Literal literal = Literal(arg->literal());
-  const Shape& shape = literal.shape();
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+                      Literal::CreateFromProto(arg->literal()));
+  const Shape& shape = literal->shape();
 
   std::vector<se::StreamExecutor*> replicas;
   if (arg->has_device_handle()) {
@@ -1065,7 +1066,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
     if (executor->device_ordinal() == master_device_ordinal) {
       TF_RETURN_IF_ERROR(
           execute_backend_->transfer_manager()->TransferLiteralToDevice(
-              executor, literal, *shaped_buffer));
+              executor, *literal, *shaped_buffer));
     } else {
       // The replica is not the master. Create an cloned shaped buffer with
       // the replica's device ordinal. This is required because
@@ -1075,7 +1076,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
           CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal());
       TF_RETURN_IF_ERROR(
           execute_backend_->transfer_manager()->TransferLiteralToDevice(
-              executor, literal, *clone));
+              executor, *literal, *clone));
     }
   }
   TF_ASSIGN_OR_RETURN(
@@ -1111,8 +1112,10 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
     executor = replicas[arg->replica_id()];
   }
 
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+                      Literal::CreateFromProto(arg->literal()));
   return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
-      executor, Literal(arg->literal()));
+      executor, *literal);
 }
 
 tensorflow::Status Service::TransferFromOutfeed(
@@ -1239,18 +1242,15 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
                                           /*include_unreachable_instructions=*/
                                           false));
 
-  std::vector<Literal> parameters(arg->parameters_size());
+  std::vector<std::unique_ptr<Literal>> parameters(arg->parameters_size());
   for (int64 i = 0; i < arg->parameters_size(); ++i) {
-    parameters[i] = Literal(arg->parameters(i));
+    TF_ASSIGN_OR_RETURN(parameters[i],
+                        Literal::CreateFromProto(arg->parameters(i)));
   }
-  std::vector<const Literal*> parameter_ptrs;
-  std::transform(parameters.begin(), parameters.end(),
-                 std::back_inserter(parameter_ptrs),
-                 [](const Literal& literal) { return &literal; });
-
   HloEvaluator evaluator;
-  TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<const Literal*>(
-                                               *module, parameter_ptrs));
+  TF_ASSIGN_OR_RETURN(
+      auto result_literal,
+      evaluator.Evaluate<std::unique_ptr<Literal>>(*module, parameters));
 
   // Since the shape_with_output_layout option in ExecutionOption is
   // non-effective to the Evaluator results, explicit relayout here.
index a2411dc93cb3d7a312b4fdf90d43855c313f7a27..4987e03a70b9901ea9ae18623d6add215f46fce3 100644 (file)
@@ -2838,7 +2838,8 @@ void ComputationLowerer::Visit(
       const ConstantRequest& constant_request =
           request.request().constant_request();
       hlo_instruction = add_instruction(HloInstruction::CreateConstant(
-          Literal(constant_request.literal()).CloneToUnique()));
+          Literal::CreateFromProto(constant_request.literal())
+              .ConsumeValueOrDie()));
       break;
     }
 
index 1073cc7700ee4b43c6af6005ed33d4c96845b10a..615a089d125d3ddc7f7a007d10ada3ea373cb4b2 100644 (file)
@@ -236,7 +236,7 @@ static optional<int64> GetLoopTripCount(HloInstruction* while_op) {
       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
       return nullopt;
     }
-    return result.ValueOrDie()->GetArraySlice<bool>() ==
+    return result.ValueOrDie()->data<bool>() ==
            tensorflow::gtl::ArraySlice<bool>{true};
   };
 
index 2c1b1d22ad4b5c9206c79143b3e9897105246de4..3d4080e353e8e618b3b45b2cb1bf8f2e8ea6114b 100644 (file)
@@ -59,6 +59,11 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) {
   return out;
 }
 
+std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
+  out << shape_index.ToString();
+  return out;
+}
+
 namespace {
 
 // Recursive helper for comparing the equality of two shapes. Returns true if
@@ -148,7 +153,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
 }
 
 /* static */ int64 ShapeUtil::Rank(const Shape& shape) {
-  CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank";
+  CHECK(!ShapeUtil::IsTuple(shape))
+      << "Tuples do not have a rank, shape: " << shape;
   return shape.dimensions_size();
 }
 
@@ -735,7 +741,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
                                                  ShapeIndexView index) {
   const Shape* return_shape = &shape;
   for (auto i : index) {
-    CHECK(IsTuple(*return_shape));
+    CHECK(IsTuple(*return_shape))
+        << "Invalid index " << index << " for shape " << shape;
     return_shape = &return_shape->tuple_shapes(i);
   }
   return *return_shape;
@@ -1352,4 +1359,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
   return shape;
 }
 
+std::ostream& operator<<(std::ostream& out, const Shape& shape) {
+  out << ShapeUtil::HumanString(shape);
+  return out;
+}
+
 }  // namespace xla
index a2043eff1ed64104cfad69b271edbdb8cebdb02b..59bdffee5a8c19f11b4edcc478f62fb9ae75beee 100644 (file)
@@ -134,6 +134,7 @@ class ShapeIndexView {
 };
 
 std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
+std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
 
 // Namespaced collection of (static) shape utilities.
 //
@@ -538,6 +539,8 @@ class ShapeUtil {
   TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
 };
 
+std::ostream& operator<<(std::ostream& out, const Shape& shape);
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
index 935b94c718c4c032ac13d6792ff48e31977151a0..56fc21d019bb823f8f4631420a15fd607ef46a9a 100644 (file)
@@ -2532,9 +2532,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
   std::iota(r1.begin(), r1.end(), 1.0);
 
   ComputationBuilder builder(client_, TestName());
-  std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4D(r4);
-  *a_literal->mutable_shape()->mutable_layout() =
-      LayoutUtil::MakeLayout({0, 1, 2, 3});
+  std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout(
+      r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
   auto a = builder.ConstantLiteral(*a_literal);
   auto b = builder.ConstantR1<float>(r1);
   builder.Add(a, b, {1});
index dbbecc58fb65fbf9fcb9229a00fe78229d8fee54..28ab9654997728fbafd6610af840e721e72cce5a 100644 (file)
@@ -62,7 +62,7 @@ class BatchNormalizationTest
         {5.0f, 4.4f},   // p2
     });
     input_array_.FillWithPZ(pz);
-    input_literal_ = *Literal::CreateR4FromArray4D(input_array_);
+    input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
     CHECK_EQ(kSamples, input_array_.planes());
     CHECK_EQ(kZ, input_array_.depth());
     CHECK_EQ(kY, input_array_.height());
@@ -231,14 +231,14 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
   auto tuple = builder.BatchNormTraining(operand, scale, offset,
                                          /*epsilon=*/0.001, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
                                  {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
            .get(),
        Literal::CreateR1<float>({4, 5}).get(),
        Literal::CreateR1<float>({5, 5}).get()});
 
-  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
 }
 
 XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
@@ -255,14 +255,14 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
   auto tuple = builder.BatchNormTraining(operand, scale, offset,
                                          /*epsilon=*/0.001, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
                                  {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
            .get(),
        Literal::CreateR1<float>({4, 5}).get(),
        Literal::CreateR1<float>({5, 5}).get()});
 
-  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
 }
 
 XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -286,13 +286,13 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
   auto tuple = builder.BatchNormTraining(h0, h1, h2,
                                          /*epsilon=*/1, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
            .get(),
        Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
        Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
 
-  ComputeAndCompareTuple(&builder, expected,
+  ComputeAndCompareTuple(&builder, *expected,
                          {operand.get(), scale.get(), offset.get()},
                          ErrorSpec(0.1));
 }
@@ -319,13 +319,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
   auto tuple = builder.BatchNormTraining(h0, h1, h2,
                                          /*epsilon=*/-100, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
            .get(),
        Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
        Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
 
-  ComputeAndCompareTuple(&builder, expected,
+  ComputeAndCompareTuple(&builder, *expected,
                          {operand.get(), scale.get(), offset.get()},
                          ErrorSpec(0.1));
 }
@@ -349,14 +349,14 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
   builder.BatchNormGrad(operand, scale, mean, var, grad_output,
                         /*epsilon=*/0.0, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
                                  {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
            .get(),
        Literal::CreateR1<float>({0, 0}).get(),
        Literal::CreateR1<float>({16, 20}).get()});
 
-  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
 }
 
 struct BatchNormTestParam {
@@ -513,9 +513,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
   auto offset_activations =
       builder.Parameter(2, offset_literal->shape(), "scale");
 
-  auto expected = *Literal::MakeTuple({expected_normalized.get(),
-                                       Literal::CreateR1<float>(mean).get(),
-                                       Literal::CreateR1<float>(var).get()});
+  auto expected = Literal::MakeTuple({expected_normalized.get(),
+                                      Literal::CreateR1<float>(mean).get(),
+                                      Literal::CreateR1<float>(var).get()});
 
   std::unique_ptr<GlobalData> input_data =
       client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -532,7 +532,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
   // testcase.
   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
   ComputeAndCompareTuple(
-      &builder, expected,
+      &builder, *expected,
       {input_data.get(), scale_data.get(), offset_data.get()},
       ErrorSpec(0.01, 1));
 }
@@ -819,16 +819,16 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
                                  grad_output_parameter, epsilon, feature_index);
 
   auto expected =
-      *Literal::MakeTuple({expected_grad_activation.get(),
-                           Literal::CreateR1<float>(grad_scale).get(),
-                           Literal::CreateR1<float>(grad_offset).get()});
+      Literal::MakeTuple({expected_grad_activation.get(),
+                          Literal::CreateR1<float>(grad_scale).get(),
+                          Literal::CreateR1<float>(grad_offset).get()});
 
   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
   // disables constant folding, but we want it enabled for our zero-sized tensor
   // testcase.
   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
 
-  ComputeAndCompareTuple(&builder, expected,
+  ComputeAndCompareTuple(&builder, *expected,
                          {input_data.get(), scale_data.get(), mean_data.get(),
                           var_data.get(), grad_output_data.get()},
                          ErrorSpec(0.01, 1));
index ac3f3f4c9ddb03d003a44f5abd7a2e26c42f490d..e47fcad475bb176a7b4598daf2c98897eb34182b 100644 (file)
@@ -97,7 +97,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
   auto tuple = builder.BatchNormTraining(operand, scale, offset,
                                          /*epsilon=*/0.001, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR4<bfloat16>(
            {{{{static_cast<bfloat16>(-1.7f)}, {static_cast<bfloat16>(-2.04f)}},
              {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.65f)}}},
@@ -111,7 +111,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
            {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
            .get()});
 
-  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
+  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
 }
 
 XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,7 +139,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
   builder.BatchNormGrad(operand, scale, mean, var, grad_output,
                         /*epsilon=*/0.0, kFeatureIndex);
 
-  auto expected = *Literal::MakeTuple(
+  auto expected = Literal::MakeTuple(
       {Literal::CreateR4<bfloat16>(
            {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
              {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
@@ -153,7 +153,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
            {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
            .get()});
 
-  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
+  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
 }
 
 }  // namespace
index 0294628a127c9d506e6387d0b80f3da583c5a174..6ebbf7191833ef85ee4a48cc96c0a3be38c71228 100644 (file)
@@ -87,11 +87,11 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
 
   LiteralTestUtil::ExpectNear(
       *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
-      result->tuple_literals(0), error_spec_);
+      LiteralView::Create(*result, {0}), error_spec_);
 
   LiteralTestUtil::ExpectNear(
       *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
-      result->tuple_literals(1), error_spec_);
+      LiteralView::Create(*result, {1}), error_spec_);
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
index ee8b6dec43b9433f283cbb73345a3cf3a70e64b5..d445ced7b08cf930dea6ebc7a8395eccdfe92d81 100644 (file)
@@ -375,7 +375,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
   VLOG(1) << "expected: " << expected_literal->ToString();
   VLOG(1) << "actual:   " << actual->ToString();
 
-  EXPECT_EQ(expected, actual->u8s_string());
+  EXPECT_EQ(expected, actual->GetR1U8AsString());
 }
 
 void ClientLibraryTestBase::ComputeAndCompareTuple(
index 92c2956f87e4d1f62b4cf6c50ee41db3fcde7e35..045148cdd11da94ae4789a753efca95c6aaa1f27 100644 (file)
@@ -90,9 +90,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
       auto result,
       client_->ExecuteAndTransfer(computation, {}, &execution_options));
   LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
-                                        result->tuple_literals(0));
+                                        LiteralView::Create(*result, {0}));
   LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
-                                        result->tuple_literals(1));
+                                        LiteralView::Create(*result, {1}));
 
   EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
index 7a849f2b8dbb2a4ae747c5ada4485ffd8ebeeace..ec2c580670cfac14ba42e8c9a836c86551af4b89 100644 (file)
@@ -149,7 +149,7 @@ TEST_F(ComputeConstantTest, Param) {
     auto computation = b.Add(param, b.ConstantR0<float>(1.5f));
 
     std::vector<Literal> arguments;
-    arguments.emplace_back(*Literal::CreateR0(42.5f));
+    arguments.push_back(std::move(*Literal::CreateR0(42.5f)));
     EXPECT_TRUE(IsConstant(computation, &b, arguments.size()));
 
     auto value =
index 97bd1553664a6c0fcb097b441ec42efb4eaa9cc2..35aa3f6d696297efb7d95d826ed75a504a24529d 100644 (file)
@@ -141,11 +141,12 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
       {5.0f, 4.4f},   // p2
   });
   input_array.FillWithPZ(pz);
-  Literal input_literal = *Literal::CreateR4FromArray4D(input_array);
+  std::unique_ptr<Literal> input_literal =
+      Literal::CreateR4FromArray4D(input_array);
 
   {
     ComputationBuilder builder(client_, TestName());
-    builder.ConstantLiteral(input_literal);
+    builder.ConstantLiteral(*input_literal);
     ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
   }
 
@@ -165,10 +166,10 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) {
 
   std::unique_ptr<Literal> result = ExecuteAndTransferOrDie(&builder, {});
 
-  LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
-                                       result->tuple_literals(0), error_spec_);
-  LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, result->tuple_literals(1),
-                                       error_spec_);
+  LiteralTestUtil::ExpectR2Near<float>(
+      {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_);
+  LiteralTestUtil::ExpectR1Near<float>(
+      {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_);
 }
 
 }  // namespace
index 2924c08615fa706bb19addf04bf58e1d5dd5a659..a10e17dbf34b3a6fe503f156fab496708b833c07 100644 (file)
@@ -105,8 +105,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
   }));
 
   ComputeAndCompare(&builder, conv,
-                    {*Literal::CreateFromArray(input_data),
-                     *Literal::CreateFromArray(filter_data)},
+                    {std::move(*Literal::CreateFromArray(input_data)),
+                     std::move(*Literal::CreateFromArray(filter_data))},
                     error_spec_);
 }
 
@@ -136,8 +136,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
   }));
   // clang-format on
   ComputeAndCompare(&builder, conv,
-                    {*Literal::CreateFromArray(input_data),
-                     *Literal::CreateFromArray(filter_data)},
+                    {std::move(*Literal::CreateFromArray(input_data)),
+                     std::move(*Literal::CreateFromArray(filter_data))},
                     error_spec_);
 }
 
@@ -167,8 +167,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
   }));
   // clang-format on
   ComputeAndCompare(&builder, conv,
-                    {*Literal::CreateFromArray(input_data),
-                     *Literal::CreateFromArray(filter_data)},
+                    {std::move(*Literal::CreateFromArray(input_data)),
+                     std::move(*Literal::CreateFromArray(filter_data))},
                     error_spec_);
 }
 
@@ -200,8 +200,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
   }));
   // clang-format on
   ComputeAndCompare(&builder, conv,
-                    {*Literal::CreateFromArray(input_data),
-                     *Literal::CreateFromArray(filter_data)},
+                    {std::move(*Literal::CreateFromArray(input_data)),
+                     std::move(*Literal::CreateFromArray(filter_data))},
                     error_spec_);
 }
 
@@ -501,10 +501,10 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
   Array2D<float> expected_result(29, 10);
   expected_result.Fill(0);
 
-  ComputeAndCompare(
-      &builder, conv,
-      {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)},
-      error_spec_);
+  ComputeAndCompare(&builder, conv,
+                    {std::move(*Literal::CreateFromArray(param0)),
+                     std::move(*Literal::CreateFromArray(param1))},
+                    error_spec_);
 }
 
 INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
index d64bf0aa5bd5e9d6213ea07b3da3305a9c621c65..ece7c3b05e7fafa299db7f9cbf50610c8204f95e 100644 (file)
@@ -40,7 +40,7 @@ class CopyOpTest : public HloTestBase {
   void TestCopyOp(const Literal& literal) {
     auto builder = HloComputation::Builder(TestName());
     auto constant = builder.AddInstruction(
-        HloInstruction::CreateConstant(MakeUnique<Literal>(literal)));
+        HloInstruction::CreateConstant(literal.CloneToUnique()));
     builder.AddInstruction(HloInstruction::CreateUnary(
         constant->shape(), HloOpcode::kCopy, constant));
     auto computation = builder.Build();
@@ -132,7 +132,8 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
   std::unique_ptr<Literal> literal =
       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   // Reverse the minor-to-major order of the literal.
-  Layout* literal_layout = literal->mutable_shape()->mutable_layout();
+  Layout* literal_layout =
+      literal->mutable_shape_do_not_use()->mutable_layout();
   ASSERT_EQ(2, literal_layout->minor_to_major_size());
   literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
 
index fb425fe6f3cfbb35d7824f3dd1b7d3a2f869313f..e5b96c51ce303819e33d67f5f383c119d313bae1 100644 (file)
@@ -101,56 +101,57 @@ namespace xla {
   ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
 }
 
+namespace {
+
+// Return a literal with all arrays of type FromNativeT converted to type
+// ToNativeT in the given literal.
+template <typename FromNativeT, typename ToNativeT>
+std::unique_ptr<Literal> ConvertType(const Literal& literal) {
+  // First construct shape of the result.
+  Shape result_shape(literal.shape());
+  ShapeUtil::ForEachMutableSubshape(
+      &result_shape, [](Shape* subshape, const ShapeIndex&) {
+        if (subshape->element_type() ==
+            primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+          subshape->set_element_type(
+              primitive_util::NativeToPrimitiveType<ToNativeT>());
+        }
+      });
+  auto result = MakeUnique<Literal>(result_shape);
+
+  // Then copy over the data from 'literal' converting FromNativeT values to
+  // ToNativeT values as necessary.
+  ShapeUtil::ForEachSubshape(
+      literal.shape(),
+      [&](const Shape& subshape, const ShapeIndex& shape_index) {
+        if (ShapeUtil::IsArray(subshape)) {
+          if (subshape.element_type() ==
+              primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+            auto src = literal.data<FromNativeT>(shape_index);
+            auto dest = result->data<ToNativeT>(shape_index);
+            for (int64 i = 0; i < src.size(); ++i) {
+              dest[i] = static_cast<ToNativeT>(src[i]);
+            }
+          } else {
+            TF_CHECK_OK(result->CopyFrom(literal,
+                                         /*dest_shape_index=*/shape_index,
+                                         /*src_shape_index=*/shape_index));
+          }
+        }
+      });
+  return result;
+}
+
+}  // namespace
+
 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
     const Literal& literal) {
-  if (ShapeUtil::IsTuple(literal.shape())) {
-    std::vector<std::unique_ptr<Literal>> converted_elements;
-    for (const auto& element : literal.tuple_literals()) {
-      converted_elements.push_back(ConvertBF16ToF32(element));
-    }
-    return Literal::MakeTupleOwned(std::move(converted_elements));
-  }
-
-  if (literal.shape().element_type() != BF16) {
-    return MakeUnique<Literal>(literal);
-  }
-  Shape converted_shape = literal.shape();
-  converted_shape.set_element_type(F32);
-  auto converted = Literal::CreateFromShape(converted_shape);
-  if (!ShapeUtil::HasZeroElements(converted_shape)) {
-    std::vector<int64> index(converted_shape.dimensions_size(), 0);
-    do {
-      converted->Set<float>(index,
-                            static_cast<float>(literal.Get<bfloat16>(index)));
-    } while (IndexUtil::BumpIndices(converted_shape, &index));
-  }
-  return converted;
+  return ConvertType<bfloat16, float>(literal);
 }
 
 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
     const Literal& literal) {
-  if (ShapeUtil::IsTuple(literal.shape())) {
-    std::vector<std::unique_ptr<Literal>> converted_elements;
-    for (const auto& element : literal.tuple_literals()) {
-      converted_elements.push_back(ConvertF32ToBF16(element));
-    }
-    return Literal::MakeTupleOwned(std::move(converted_elements));
-  }
-
-  if (literal.shape().element_type() != F32) {
-    return MakeUnique<Literal>(literal);
-  }
-  Shape converted_shape = literal.shape();
-  converted_shape.set_element_type(BF16);
-  auto converted = Literal::CreateFromShape(converted_shape);
-  if (!ShapeUtil::HasZeroElements(converted_shape)) {
-    std::vector<int64> index(converted_shape.dimensions_size(), 0);
-    do {
-      converted->Set<bfloat16>(
-          index, static_cast<bfloat16>(literal.Get<float>(index)));
-    } while (IndexUtil::BumpIndices(converted_shape, &index));
-  }
-  return converted;
+  return ConvertType<float, bfloat16>(literal);
 }
 
 namespace {
@@ -311,9 +312,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
       break;
     case TUPLE: {
       bool tuple_match = true;
-      for (int i = 0; i < actual.tuple_literals_size(); ++i) {
-        auto result =
-            Equal(expected.tuple_literals(i), actual.tuple_literals(i));
+      for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+        // Create LiteralViews of the expected and actual elements.
+        auto result = Equal(LiteralView::Create(expected, {i}),
+                            LiteralView::Create(actual, {i}));
         tuple_match = tuple_match ? !!result : false;
       }
       match = tuple_match;
@@ -348,11 +350,11 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
   AssertEqualShapes(expected.shape(), actual.shape());
 
   ::testing::AssertionResult err = ::testing::AssertionSuccess();
-  for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+  for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
     SCOPED_TRACE(tensorflow::strings::StrCat(
         "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
-    const auto& expected_element = expected.tuple_literals(i);
-    const auto& actual_element = actual.tuple_literals(i);
+    const auto expected_element = LiteralView::Create(expected, {i});
+    const auto actual_element = LiteralView::Create(actual, {i});
 
     ::testing::AssertionResult res = [&] {
       if (ShapeUtil::IsTuple(expected_element.shape())) {
@@ -408,10 +410,7 @@ class NearComparator {
     abs_expected_miscompare_sum_ = 0.0;
     max_rel_err_ = 0.0;
     max_abs_err_ = 0.0;
-    *miscompares_.mutable_shape() =
-        ShapeUtil::ChangeElementType(actual.shape(), PRED);
-    miscompares_.mutable_preds()->resize(
-        ShapeUtil::ElementsIn(miscompares_.shape()), false);
+    miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
     multi_index_.resize(expected.shape().dimensions_size(), 0);
 
     switch (expected.shape().element_type()) {
@@ -644,11 +643,11 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
   AssertEqualShapes(expected.shape(), actual.shape());
 
   ::testing::AssertionResult err = ::testing::AssertionSuccess();
-  for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+  for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
     SCOPED_TRACE(tensorflow::strings::StrCat(
         "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
-    const auto& expected_element = expected.tuple_literals(i);
-    const auto& actual_element = actual.tuple_literals(i);
+    const auto expected_element = LiteralView::Create(expected, {i});
+    const auto actual_element = LiteralView::Create(actual, {i});
 
     ::testing::AssertionResult res = [&] {
       if (ShapeUtil::IsTuple(expected_element.shape())) {
@@ -714,9 +713,8 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
   }
   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
 
-  auto new_literal = MakeUnique<Literal>();
-  *new_literal->mutable_shape() =
-      ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions);
+  auto new_literal = MakeUnique<Literal>(
+      ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
 
   // Create a new shape with the given minor-to-major layout. This shape is used
   // solely for converting linear address to multi-dimensional addresses when
@@ -724,9 +722,6 @@ bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
   Shape shape_with_layout = new_literal->shape();
   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
 
-  // Allocate space in the new literal.
-  new_literal->Reserve(ShapeUtil::ElementsIn(literal.shape()));
-
   // Copy data into new literal, element-by-element.
   for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
     std::vector<int64> from_multi_index =
index 2acf27ed390b0732ba40fcf505c746bd7d8b651e..e477784557a3b9340cff644a3695485389d8cc22 100644 (file)
@@ -83,13 +83,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
     LiteralProto literal_proto;
     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
                                             &literal_proto));
-    Literal literal(literal_proto);
+    std::unique_ptr<Literal> literal =
+        Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
     if (result.find("expected") != string::npos) {
-      EXPECT_EQ("2", literal.ToString());
+      EXPECT_EQ("2", literal->ToString());
     } else if (result.find("actual") != string::npos) {
-      EXPECT_EQ("4", literal.ToString());
+      EXPECT_EQ("4", literal->ToString());
     } else if (result.find("miscompares") != string::npos) {
-      EXPECT_EQ("true", literal.ToString());
+      EXPECT_EQ("true", literal->ToString());
     } else {
       FAIL() << "unknown file in temporary directory: " << result;
     }
index e3298e98c67969f97adfdf15d22dfe72592b56aa..b60266426b17cff723b39701cfe532a1364bf1e8 100644 (file)
@@ -217,12 +217,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
   EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape()));
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        result_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
-                                        result_literal->tuple_literals(1));
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        result_literal->tuple_literals(2));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{10.0f, 20.0f}, {30.0f, 40.0f}},
+      LiteralView::Create(*result_literal, {1}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -245,15 +246,17 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape()));
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        result_literal->tuple_literals(1));
-  const Literal& inner_tuple_literal = result_literal->tuple_literals(0);
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        inner_tuple_literal.tuple_literals(0));
-  LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
-                                        inner_tuple_literal.tuple_literals(1));
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        inner_tuple_literal.tuple_literals(2));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}},
+      LiteralView::Create(*result_literal, {0, 0}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{10.0f, 20.0f}, {30.0f, 40.0f}},
+      LiteralView::Create(*result_literal, {0, 1}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}},
+      LiteralView::Create(*result_literal, {0, 2}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -278,10 +281,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
       DefaultExecutableRunOptions());
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        result_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        result_literal->tuple_literals(1));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -324,10 +327,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape()));
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
-  LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
-                                        result_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
-                                        result_literal->tuple_literals(1));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{56.0f, 46.0f}, {36.0f, 26.0f}},
+      LiteralView::Create(*result_literal, {0}));
+  LiteralTestUtil::ExpectR1Equal<float>(
+      {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -365,10 +369,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
       ExecuteLocallyOrDie(computation, {arg_buffer.get()});
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
-  LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
-                                        result_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
-                                        result_literal->tuple_literals(1));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0}));
+  LiteralTestUtil::ExpectR1Equal<float>(
+      {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -395,18 +399,19 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
   std::unique_ptr<ScopedShapedBuffer> result_0 =
       ExecuteLocallyOrDie(computation, {arg_buffer.get()});
   std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(*result_0);
-  LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
-                                        result_0_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
-                                        result_0_literal->tuple_literals(1));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{-1.0, -2.0}, {-3.0, -4.0}},
+      LiteralView::Create(*result_0_literal, {0}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1}));
 
   std::unique_ptr<ScopedShapedBuffer> result_1 =
       ExecuteLocallyOrDie(computation, {result_0.get()});
   std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(*result_1);
-  LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                        result_1_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
-                                        result_1_literal->tuple_literals(1));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0}));
+  LiteralTestUtil::ExpectR2Equal<float>(
+      {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -455,7 +460,8 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
 
   for (int i = 0; i < kElementCount; ++i) {
     LiteralTestUtil::ExpectR1Near<float>(
-        {2.0f * i, 0.0f}, result_literal->tuple_literals(i), error_spec_);
+        {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}),
+        error_spec_);
   }
 }
 
@@ -512,8 +518,8 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) {
   for (int i = 0; i < kFanout; ++i) {
     for (int j = 0; j < kFanout; ++j) {
       LiteralTestUtil::ExpectR0Near<float>(
-          i + j + i * kFanout + j,
-          result_literal->tuple_literals(i).tuple_literals(j), error_spec_);
+          i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}),
+          error_spec_);
     }
   }
 }
@@ -554,11 +560,12 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
       ExecuteLocallyOrDie(computation, {arg_buffer.get()});
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
 
-  const Literal* result_element = result_literal.get();
+  ShapeIndex index;
   for (int i = 0; i < kTupleDepth; ++i) {
-    result_element = &result_element->tuple_literals(0);
+    index.push_back(0);
   }
-  LiteralTestUtil::ExpectR0Equal<float>(165.0, *result_element);
+  LiteralTestUtil::ExpectR0Equal<float>(
+      165.0, LiteralView::Create(*result_literal, index));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -763,10 +770,10 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
   std::unique_ptr<ScopedShapedBuffer> result =
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
   std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(*result);
-  LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
-                                        tuple_literal->tuple_literals(0));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
-                                        tuple_literal->tuple_literals(1));
+  LiteralTestUtil::ExpectR1Equal<float>(
+      {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0}));
+  LiteralTestUtil::ExpectR1Equal<float>(
+      {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
index 62d24a11fdb164ed6776d1e83877cf3acd319cc6..6e6cb7ff1e2ac74dc54f14d8811c9a5d3662bbd2 100644 (file)
@@ -99,11 +99,11 @@ class MultiOutputFusionTest : public HloTestBase {
           nullptr);
     }
 
-    Literal arg1;
-    arg1.PopulateWithValue<float>(2.5f, {size, size});
+    Literal arg1(ShapeUtil::MakeShape(F32, {size, size}));
+    arg1.PopulateWithValue<float>(2.5f);
 
-    Literal expect;
-    expect.PopulateWithValue<float>(size * 1.5f * 3.5f, {size, size});
+    Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
+    expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
     auto actual = ExecuteAndTransfer(
         std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
     LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
@@ -159,11 +159,12 @@ class MultiOutputFusionTest : public HloTestBase {
                nullptr);
     }
 
-    Literal input0, input1;
-    input0.PopulateWithValue<float>(2.5f, {size});
-    input1.PopulateWithValue<double>(1, {size});
+    Literal input0(ShapeUtil::MakeShape(F32, {size}));
+    input0.PopulateWithValue(2.5f);
+    Literal input1(ShapeUtil::MakeShape(F64, {size}));
+    input1.PopulateWithValue(1.);
 
-    Literal expect = *Literal::CreateR1<float>({size * 1.5f * 3.5f});
+    Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
     auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
     LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
   }
index 98becbbfd7e336a6c0d709531b792cd2970ecd4a..bb7e800df84121f2045141bc366c34b94ba694ea 100644 (file)
@@ -462,10 +462,8 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
 // Verifies that passing a 2x2 with {0, 1} layout returns the same value back
 // when (transferred to the server and) passed through a parameter.
 XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
-  std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
-      {1, 2}, {3, 4},
-  });
-  *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+  std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+      {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
   ComputationBuilder builder(client_, TestName());
   builder.Parameter(0, literal->shape(), "input");
 
@@ -476,10 +474,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
 
 // As above, but for {1, 0} layout.
 XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
-  std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
-      {1, 3}, {2, 4},
-  });
-  *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+  std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+      {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
   ComputationBuilder builder(client_, TestName());
   builder.Parameter(0, literal->shape(), "input");
 
@@ -500,7 +496,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
         original.layout().minor_to_major().begin(),
         original.layout().minor_to_major().end());
     std::reverse(original_layout.begin(), original_layout.end());
-    *literal->mutable_shape()->mutable_layout() =
+    *literal->mutable_shape_do_not_use()->mutable_layout() =
         LayoutUtil::MakeLayout(original_layout);
     ASSERT_EQ(2, literal->Get<float>({0, 1}));
   }
index 19857f1e4d1887c477baecf85176086b7f215cc8..6489eee9f34c6c4426d52e166f7b401d5948742f 100644 (file)
@@ -152,10 +152,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
                        computation,
                        /*arguments=*/{param0_data.get()}, &execution_options));
 
-  EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size());
-  for (int i = 0; i < param0_literal->f32s_size(); ++i) {
-    EXPECT_GE(actual->f32s(i), param0_literal->f32s(i));
-    EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f);
+  EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
+            ShapeUtil::ElementsIn(param0_literal->shape()));
+  for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
+    EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
+    EXPECT_LT(actual->data<float>()[i],
+              param0_literal->data<float>()[i] + 1.0f);
   }
 }
 
index ddd50d7a5864d73de7916ce736bb7cd40c1c4dc9..f7b04debd4f5c40a904e32c832b6fc384a03c33b 100644 (file)
@@ -566,14 +566,15 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
 XLA_TEST_P(ReshapeTest, ToScalar) {
   for (int rank = 0; rank < 8; ++rank) {
     ComputationBuilder b(client_, TestName());
-    auto input_literal = Literal::CreateR1<float>({83.0f});
     std::vector<int64> ones(rank, 1);  // this is {1, ..., 1}.
     std::vector<int64> dimensions(rank);
     std::iota(dimensions.begin(), dimensions.end(), 0);
-    *input_literal->mutable_shape() = ShapeUtil::MakeShape(F32, ones);
+    Literal input_literal(ShapeUtil::MakeShape(F32, ones));
+    std::vector<int64> zeros(rank, 0);  // this is {0, ..., 0}.
+    input_literal.Set<float>(zeros, 83.0f);
 
     ComputationDataHandle parameter;
-    auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+    auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                    &b, &parameter);
     b.Reshape(parameter, dimensions, {});
 
@@ -818,11 +819,9 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
   // data.
   if (use_bfloat16()) {
     auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal);
-    EXPECT_EQ(tensorflow::gtl::ArraySlice<bfloat16>(expected->bf16s()),
-              tensorflow::gtl::ArraySlice<bfloat16>(output_literal->bf16s()));
+    EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
   } else {
-    EXPECT_EQ(tensorflow::gtl::ArraySlice<float>(input_literal->f32s()),
-              tensorflow::gtl::ArraySlice<float>(output_literal->f32s()));
+    EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
   }
 }
 
index ed556fafb17cb2d243141783f822400d3c54b459..268ba338f2e6740a1d1a046d5a85494f3cf2e9f8 100644 (file)
@@ -119,7 +119,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) {
                           transfer_manager_->TransferLiteralFromDevice(
                               stream_executor_, *device_buffer));
 
-  EXPECT_EQ(result->u8s_string(), test_string);
+  EXPECT_EQ(result->GetR1U8AsString(), test_string);
 }
 
 XLA_TEST_F(TransferManagerTest, TransferR2F32) {
index 4d060895d357493327ec50b38016478c65fef94d..6fa4c48e11d1102367b21bc21d4734466495ef0e 100644 (file)
@@ -102,9 +102,9 @@ StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
         ShapeUtil::HumanString(shape).c_str());
   }
 
-  auto result = MakeUnique<Literal>();
+  auto result = MakeUnique<Literal>(shape);
   const float fill = std::numeric_limits<float>::quiet_NaN();
-  result->PopulateWithValue<float>(fill, AsInt64Slice(shape.dimensions()));
+  result->PopulateWithValue<float>(fill);
   std::vector<tensorflow::StringPiece> pieces;
   std::vector<tensorflow::StringPiece> coordinates;
   std::vector<int64> coordinate_values;
index 3caa465769067cdf15d4a3c7a23ad3302f2e2708..2fc369dc0e69ac74e93470cc135b964b8d86306f 100644 (file)
@@ -1324,7 +1324,7 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
         PrimitiveType_Name(literal->shape().element_type())));
   }
 
-  literal->GetMutableArraySlice<LiteralNativeT>().at(linear_index) =
+  literal->data<LiteralNativeT>().at(linear_index) =
       static_cast<LiteralNativeT>(value);
   return true;
 }
index a7dc5862057047f7c56faeb211cc0b13992caec7..22e02de5e2f7dd6886acef8181e418aa376aa6fc 100644 (file)
@@ -82,9 +82,10 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
     arguments = MakeFakeArgumentsOrDie(computation, client);
   } else {  // use recorded data if available
     for (const auto& proto : module.arguments()) {
-      Literal literal(proto);
+      TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
+                          Literal::CreateFromProto(proto));
       TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
-                          client->TransferToServer(literal));
+                          client->TransferToServer(*literal));
       arguments.push_back(std::move(data));
     }
   }
@@ -162,9 +163,11 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
               ShapeUtil::HumanString(result->shape()).c_str(),
               result->ToString().c_str());
       if (module.has_result()) {
+        std::unique_ptr<Literal> literal =
+            Literal::CreateFromProto(module.result()).ConsumeValueOrDie();
         fprintf(stdout, "was %s:%s\n",
                 ShapeUtil::HumanString(module.result().shape()).c_str(),
-                Literal(module.result()).ToString().c_str());
+                literal->ToString().c_str());
       }
     }
   }
index b50cb5e28eac14ed99af566939f8bd64e393ff64..fe8e72ba32bb4493b2751cfdfeb977f271092f9c 100644 (file)
@@ -40,7 +40,8 @@ int main(int argc, char **argv) {
   xla::LiteralProto literal_proto;
   TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
                                           &literal_proto));
-  xla::Literal literal(literal_proto);
+  std::unique_ptr<xla::Literal> literal =
+      xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
   LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
-  fprintf(stderr, "%s\n", literal.ToString().c_str());
+  fprintf(stderr, "%s\n", literal->ToString().c_str());
 }
index bbe9902aa17a585c4bad5b732330305dfdd45302..8525873e913185554d18df8c8c3584bfcdcdcabe 100644 (file)
@@ -39,13 +39,13 @@ int main(int argc, char **argv) {
   std::unique_ptr<xla::Literal> literal =
       xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
 
-  LOG(INFO) << "literal: " << literal->ShortDebugString();
+  LOG(INFO) << "literal: " << *literal;
   fprintf(stderr, "%s\n", literal->ToString().c_str());
   if (literal->shape().element_type() == xla::F32) {
-    float min =
-        *std::min_element(literal->f32s().begin(), literal->f32s().end());
-    float max =
-        *std::max_element(literal->f32s().begin(), literal->f32s().end());
+    float min = *std::min_element(literal->data<float>().begin(),
+                                  literal->data<float>().end());
+    float max = *std::max_element(literal->data<float>().begin(),
+                                  literal->data<float>().end());
     fprintf(stderr, "min: %a=%f\n", min, min);
     fprintf(stderr, "max: %a=%f\n", max, max);
   }