[XLA] First step in adding Literal slice classes, to improve interface safety
authorKay Zhu <kayzhu@google.com>
Wed, 9 May 2018 20:07:35 +0000 (13:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 20:47:51 +0000 (13:47 -0700)
and prepare for enabling more efficient interfacing from Tensor to Literal to
reduce host to device latency.

More specically:
* Introducing a new LiteralBase abstract base class that contains all immutable
methods of from the old Literal class.

* Introducing a subclass LiteralSlice to replace original LiteralView class.
LiteralSlice class is read-only and does not own Shape nor any buffer through
the Pieces. Change a number of callers to use LiteralSlice directly.

* Change Literal class to explicitly own the underlying Shape as well as owning
the underlying buffer via Piece.

* Conversion from Literal to LiteralSlice is now done via an implicit
conversion constructor instead of inheritance.

* Decouple ShapeTree from Literal classes.

* Use copy-and-swap for assignment constructors.

* Other minor cleanups.

PiperOrigin-RevId: 196016576

28 files changed:
tensorflow/compiler/tf2xla/literal_util.cc
tensorflow/compiler/tf2xla/literal_util.h
tensorflow/compiler/xla/client/computation_builder.cc
tensorflow/compiler/xla/client/computation_builder.h
tensorflow/compiler/xla/client/xla_client/xla_builder.cc
tensorflow/compiler/xla/client/xla_client/xla_builder.h
tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/literal_util_test.cc
tensorflow/compiler/xla/python/numpy_bridge.cc
tensorflow/compiler/xla/python/numpy_bridge.h
tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h
tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
tensorflow/compiler/xla/service/cpu/external_constant_pool.h
tensorflow/compiler/xla/service/generic_transfer_manager.cc
tensorflow/compiler/xla/service/generic_transfer_manager.h
tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
tensorflow/compiler/xla/service/hlo_evaluator.cc
tensorflow/compiler/xla/service/transfer_manager.h
tensorflow/compiler/xla/tests/broadcast_test.cc
tensorflow/compiler/xla/tests/client_test.cc
tensorflow/compiler/xla/tests/constants_test.cc
tensorflow/compiler/xla/tests/literal_test_util.cc
tensorflow/compiler/xla/tests/literal_test_util.h
tensorflow/compiler/xla/tests/local_client_execute_test.cc

index 2c3cd65..43e1c1e 100644 (file)
@@ -40,7 +40,7 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
   return Status::OK();
 }
 
-Status CopyLiteralToHostTensor(const xla::Literal& literal,
+Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
                                Tensor* host_tensor) {
   TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
                xla::ShapeUtil::ElementsIn(literal.shape()) ==
@@ -63,8 +63,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal,
   return Status::OK();
 }
 
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
-                           Tensor* host_tensor) {
+Status LiteralToHostTensor(const xla::LiteralSlice& literal,
+                           DataType target_type, Tensor* host_tensor) {
   TensorShape shape;
   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
   *host_tensor = Tensor(target_type, shape);
index f283b02..220bec1 100644 (file)
@@ -36,13 +36,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
 // derivable from the type of <literal>, because multiple tensorflow types map
 // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in
 // XLA).
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
-                           Tensor* host_tensor);
+Status LiteralToHostTensor(const xla::LiteralSlice& literal,
+                           DataType target_type, Tensor* host_tensor);
 
 // Copies the contents of 'literal' to a previously allocated tensor
 // 'host_tensor'. The tensor and the literal must have the same number of
 // elements and the same type.
-Status CopyLiteralToHostTensor(const xla::Literal& literal,
+Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
                                Tensor* host_tensor);
 
 }  // namespace tensorflow
index 83c7cb1..f9f9944 100644 (file)
@@ -185,7 +185,7 @@ bool ComputationBuilder::MakeWindow(
 }
 
 ComputationDataHandle ComputationBuilder::ConstantLiteral(
-    const Literal& literal) {
+    const LiteralSlice& literal) {
   OpRequest op_request;
   ConstantRequest* request = op_request.mutable_constant_request();
   *request->mutable_literal() = literal.ToProto();
index ac1eb91..176962b 100644 (file)
@@ -108,7 +108,7 @@ class ComputationBuilder {
 
   // Enqueues a constant with the value of the given literal onto the
   // computation.
-  ComputationDataHandle ConstantLiteral(const Literal& literal);
+  ComputationDataHandle ConstantLiteral(const LiteralSlice& literal);
 
   // Enqueues a constant onto the computation. Methods are templated on the
   // native host type (NativeT) which corresponds to a specific XLA
index 1899983..4c59d62 100644 (file)
@@ -437,7 +437,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
   return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
 }
 
-XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
+XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
   return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
     *instr.mutable_shape() = literal.shape();
index 4955f15..e1920d6 100644 (file)
@@ -139,7 +139,7 @@ class XlaBuilder {
 
   // Enqueues a constant with the value of the given literal onto the
   // computation.
-  XlaOp ConstantLiteral(const Literal& literal);
+  XlaOp ConstantLiteral(const LiteralSlice& literal);
 
   // Enqueues a constant onto the computation. Methods are templated on the
   // native host type (NativeT) which corresponds to a specific XLA
index b3b5e34..e9b0e11 100644 (file)
@@ -64,6 +64,8 @@ void ConvertEndianShort(char* bytes, int64 size) {
 
 }  // namespace
 
+LiteralBase::~LiteralBase() {}
+
 std::ostream& operator<<(std::ostream& out, const Literal& literal) {
   out << literal.ToString();
   return out;
@@ -95,99 +97,90 @@ Literal::StrideConfig::StrideConfig(
 Literal::Literal(const Shape& shape)
     : Literal(shape, /*allocate_arrays=*/true) {}
 
-Literal::Literal(const Shape& shape, bool allocate_arrays)
-    : shape_(shape), pieces_(shape), owns_buffers_(true) {
-  CHECK(LayoutUtil::HasLayout(shape));
-  for (auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-
-    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
-    const Shape& subshape = piece.subshape();
-    if (ShapeUtil::IsArray(subshape)) {
-      if (allocate_arrays) {
-        if (LayoutUtil::IsSparseArray(subshape)) {
-          // For sparse arrays, the buffer must be of the size of the maximum
-          // number of sparse elements possible.
-          const int64 max_sparse_elements =
-              LayoutUtil::MaxSparseElements(subshape.layout());
-          piece.set_buffer(
-              new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(
-                                                 subshape.element_type())]);
-          piece.set_sparse_indices(new SparseIndexArray(
-              max_sparse_elements, ShapeUtil::Rank(subshape)));
-        } else {
-          piece.set_buffer(new char[piece.size_bytes()]);
-        }
+void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
+  if (ShapeUtil::IsTuple(shape)) {
+    for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+      const Shape& subshape = shape.tuple_shapes(i);
+
+      auto child_piece = Piece();
+      child_piece.set_subshape(&subshape);
+
+      SetPiece(subshape, &child_piece, allocate_arrays);
+
+      piece->emplace_back(std::move(child_piece));
+    }
+  } else {
+    CHECK(ShapeUtil::IsArray(shape));
+    if (allocate_arrays) {
+      if (LayoutUtil::IsSparseArray(shape)) {
+        // For sparse arrays, the buffer must be of the size of the maximum
+        // number of sparse elements possible.
+        const int64 max_sparse_elements =
+            LayoutUtil::MaxSparseElements(shape.layout());
+        piece->set_buffer(
+            new char[max_sparse_elements *
+                     ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
+        piece->set_sparse_indices(
+            new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
       } else {
-        piece.set_buffer(nullptr);
+        piece->set_buffer(new char[piece->size_bytes()]);
       }
     }
   }
 }
 
-Literal::~Literal() { DeallocateBuffers(); }
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+    : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+  CHECK(LayoutUtil::HasLayout(*shape_));
+  root_piece_ = new Piece();
+  root_piece_->set_subshape(shape_.get());
+  CHECK(&root_piece_->subshape() == shape_.get());
 
-void Literal::DeallocateBuffers() {
-  if (owns_buffers_) {
-    for (auto& pair : pieces_) {
-      Piece& piece = pair.second;
-      if (piece.buffer() != nullptr) {
-        delete[] piece.buffer();
-        delete piece.sparse_indices();
-      }
-    }
-  }
+  SetPiece(*shape_, root_piece_, allocate_arrays);
 }
 
-Literal::Literal(Literal&& other) {
-  shape_ = std::move(other.shape_);
-  pieces_ = std::move(other.pieces_);
-  // We need to iterate through the pieces to set the subshape pointer
-  // properly. It must refer to subshapes within shape_.
-  for (auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+Literal::~Literal() {
+  if (root_piece_ != nullptr) {
+    DeallocateBuffers();
+    delete root_piece_;
   }
-  owns_buffers_ = other.owns_buffers_;
+}
 
-  other.shape_ = ShapeUtil::MakeNil();
-  other.pieces_ = ShapeTree<Piece>(other.shape_);
-  other.piece({}).set_subshape(&other.shape_);
+void Literal::DeallocateBuffers() {
+  root_piece_->ForEachMutableSubpiece(
+      [&](const ShapeIndex& index, Piece* piece) {
+        if (piece->buffer() != nullptr) {
+          delete[] piece->buffer();
+          delete piece->sparse_indices();
+        }
+      });
 }
 
+Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+
 Literal& Literal::operator=(Literal&& other) {
-  DeallocateBuffers();
-  shape_ = std::move(other.shape_);
-  pieces_ = std::move(other.pieces_);
-  // We need to iterate through the pieces to set the subshape pointer
-  // properly. It must refer to subshapes within shape_.
-  for (auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
-  }
-  owns_buffers_ = other.owns_buffers_;
-
-  other.shape_ = ShapeUtil::MakeNil();
-  other.pieces_ = ShapeTree<Piece>(other.shape_);
-  other.piece({}).set_subshape(&other.shape_);
+  CHECK(&other.root_piece_->subshape() == other.shape_.get());
+
+  using std::swap;
+  swap(shape_, other.shape_);
+  swap(root_piece_, other.root_piece_);
+  CHECK(&root_piece_->subshape() == shape_.get());
+
   return *this;
 }
 
-std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
+std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
   auto literal = MakeUnique<Literal>(shape);
-  for (auto& pair : literal->pieces_) {
-    Piece& piece = pair.second;
-    if (ShapeUtil::IsArray(piece.subshape())) {
-      memset(piece.untyped_data(), 0, piece.size_bytes());
-    }
-  }
+  literal->root_piece_->ForEachMutableSubpiece(
+      [&](const ShapeIndex& index, Piece* piece) {
+        if (ShapeUtil::IsArray(piece->subshape())) {
+          memset(piece->untyped_data(), 0, piece->size_bytes());
+        }
+      });
   return literal;
 }
 
-const SparseIndexArray* Literal::sparse_indices(
+const SparseIndexArray* LiteralBase::sparse_indices(
     const ShapeIndex& shape_index) const {
   return piece(shape_index).sparse_indices();
 }
@@ -204,7 +197,7 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
 
 template <typename NativeT>
 Status Literal::CopySliceFromInternal(
-    const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+    const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
     tensorflow::gtl::ArraySlice<int64> dest_base,
     tensorflow::gtl::ArraySlice<int64> copy_size) {
   TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
@@ -217,8 +210,8 @@ Status Literal::CopySliceFromInternal(
 
   if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
       ShapeUtil::Rank(shape()) == 0) {
-    // If any of the two shapes are scalars, we can just call the StridedCopy()
-    // directly, and we know we will be copying only one value.
+    // If any of the two shapes are scalars, we can just call the
+    // StridedCopy() directly, and we know we will be copying only one value.
     TF_RET_CHECK(copy_size.empty());
     StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
                 src_literal.data<NativeT>(),
@@ -264,7 +257,7 @@ Status Literal::CopySliceFromInternal(
   return Status::OK();
 }
 
-Status Literal::CopyElementFrom(const Literal& src_literal,
+Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
                                 tensorflow::gtl::ArraySlice<int64> src_index,
                                 tensorflow::gtl::ArraySlice<int64> dest_index) {
   DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
@@ -293,22 +286,21 @@ std::vector<Literal> Literal::DecomposeTuple() {
     elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
                                /*allocate_arrays=*/false));
     Literal& element = elements.back();
-    for (auto& pair : element.pieces_) {
-      const ShapeIndex& index = pair.first;
-      Piece& dest_piece = pair.second;
-      ShapeIndex src_index = {i};
-      for (int64 j : index) {
-        src_index.push_back(j);
-      }
-      Piece& src_piece = piece(src_index);
-
-      // Move the respective buffer and sparse indices over to the element
-      // Literal.
-      dest_piece.set_buffer(src_piece.buffer());
-      src_piece.set_buffer(nullptr);
-      dest_piece.set_sparse_indices(src_piece.sparse_indices());
-      src_piece.set_sparse_indices(nullptr);
-    }
+    element.root_piece_->ForEachMutableSubpiece(
+        [&](const ShapeIndex& index, Piece* dest_piece) {
+          ShapeIndex src_index = {i};
+          for (int64 j : index) {
+            src_index.push_back(j);
+          }
+          Piece& src_piece = piece(src_index);
+
+          // Move the respective buffer and sparse indices over to the element
+          // Literal.
+          dest_piece->set_buffer(src_piece.buffer());
+          src_piece.set_buffer(nullptr);
+          dest_piece->set_sparse_indices(src_piece.sparse_indices());
+          src_piece.set_sparse_indices(nullptr);
+        });
   }
   // Set this literal to be nil-shaped.
   *this = Literal();
@@ -331,9 +323,9 @@ std::vector<Literal> Literal::DecomposeTuple() {
 }
 
 namespace {
-
-// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
-// the array slices are indicated by dest_shape and src_shape respectively.
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data
+// in the array slices are indicated by dest_shape and src_shape
+// respectively.
 template <typename NativeT>
 void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
                          tensorflow::gtl::ArraySlice<NativeT> src,
@@ -351,7 +343,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
 
 }  // namespace
 
-Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
+Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
   if (ShapeUtil::Equal(subshape(), src.subshape())) {
     // If the layouts are equal it's faster just to memcpy.
     memcpy(buffer(), src.buffer(), src.size_bytes());
@@ -381,14 +373,15 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
 #undef COPY_ELEMENTS
       default:
         return Unimplemented(
-            "Copying a Literal object with element type %s is not implemented.",
+            "Copying a Literal object with element type %s is not "
+            "implemented.",
             PrimitiveType_Name(subshape().element_type()).c_str());
     }
   }
   return Status::OK();
 }
 
-Status Literal::CopyFrom(const Literal& src_literal,
+Status Literal::CopyFrom(const LiteralSlice& src_literal,
                          const ShapeIndex& dest_shape_index,
                          const ShapeIndex& src_shape_index) {
   const Shape& dest_subshape =
@@ -402,36 +395,33 @@ Status Literal::CopyFrom(const Literal& src_literal,
         ShapeUtil::HumanString(src_subshape).c_str());
   }
 
-  for (auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-    if (!ShapeUtil::IsArray(piece.subshape())) {
-      continue;
-    }
-
-    // Determine if this index is in the part of this literal that we want to
-    // copy over from src_literal.
-    bool in_subtree_to_copy = true;
-    for (int i = 0; i < dest_shape_index.size(); ++i) {
-      if (index[i] != dest_shape_index[i]) {
-        in_subtree_to_copy = false;
-        break;
-      }
-    }
-    if (!in_subtree_to_copy) {
-      continue;
-    }
-
-    // Construct the index of the corresponding piece in the source literal.
-    ShapeIndex src_piece_index = src_shape_index;
-    for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
-      src_piece_index.push_back(index[i]);
-    }
+  return root_piece_->ForEachMutableSubpieceWithStatus(
+      [&](const ShapeIndex& index, Piece* piece) {
+        if (!ShapeUtil::IsArray(piece->subshape())) {
+          return Status::OK();
+        }
 
-    TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
-  }
-  return Status::OK();
-}
+        // Determine if this index is in the part of this literal that we want
+        // to copy over from src_literal.
+        bool in_subtree_to_copy = true;
+        for (int i = 0; i < dest_shape_index.size(); ++i) {
+          if (index[i] != dest_shape_index[i]) {
+            in_subtree_to_copy = false;
+            break;
+          }
+        }
+        if (!in_subtree_to_copy) {
+          return Status::OK();
+        }
+        // Construct the index of the corresponding piece in the source literal.
+        ShapeIndex src_piece_index = src_shape_index;
+        for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+          src_piece_index.push_back(index[i]);
+        }
+        TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
+        return Status::OK();
+      });
+}  // namespace xla
 
 Status Literal::MoveFrom(Literal&& src_literal,
                          const ShapeIndex& dest_shape_index) {
@@ -444,37 +434,32 @@ Status Literal::MoveFrom(Literal&& src_literal,
         ShapeUtil::HumanString(src_literal.shape()).c_str());
   }
 
-  if (!(owns_buffers_ && src_literal.owns_buffers_)) {
-    return InvalidArgument(
-        "Source and destination literals must both own their buffers (ie, not "
-        "be views)");
-  }
+  src_literal.root_piece_->ForEachSubpiece(
+      [&](const ShapeIndex& src_index, const Piece& src_piece) {
+        if (!ShapeUtil::IsArray(src_piece.subshape())) {
+          return;
+        }
 
-  for (auto& pair : src_literal.pieces_) {
-    const ShapeIndex& src_index = pair.first;
-    Piece& src_piece = pair.second;
-    if (!ShapeUtil::IsArray(src_piece.subshape())) {
-      continue;
-    }
+        ShapeIndex dest_index = dest_shape_index;
+        for (int64 i : src_index) {
+          dest_index.push_back(i);
+        }
+        Piece& dest_piece = piece(dest_index);
+        delete[] dest_piece.buffer();
+        dest_piece.set_buffer(src_piece.buffer());
+        delete dest_piece.sparse_indices();
+        dest_piece.set_sparse_indices(src_piece.sparse_indices());
+      });
 
-    ShapeIndex dest_index = dest_shape_index;
-    for (int64 i : src_index) {
-      dest_index.push_back(i);
-    }
-    Piece& dest_piece = piece(dest_index);
-    delete[] dest_piece.buffer();
-    dest_piece.set_buffer(src_piece.buffer());
-    delete dest_piece.sparse_indices();
-    dest_piece.set_sparse_indices(src_piece.sparse_indices());
-  }
+  src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+  delete src_literal.root_piece_;
+  src_literal.root_piece_ = new LiteralBase::Piece();
+  src_literal.root_piece_->set_subshape(src_literal.shape_.get());
 
-  src_literal.shape_ = ShapeUtil::MakeNil();
-  src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
-  src_literal.piece({}).set_subshape(&src_literal.shape_);
   return Status::OK();
 }
 
-Status Literal::CopySliceFrom(const Literal& src_literal,
+Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
                               tensorflow::gtl::ArraySlice<int64> src_base,
                               tensorflow::gtl::ArraySlice<int64> dest_base,
                               tensorflow::gtl::ArraySlice<int64> copy_size) {
@@ -743,7 +728,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
   return CreateR2FromArray2D(*value);
 }
 
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
     const Layout& new_layout, const ShapeIndex& shape_index) const {
   // Create new shape with 'new_layout' set at the given shape index.
   Shape new_shape = shape();
@@ -755,7 +740,7 @@ std::unique_ptr<Literal> Literal::Relayout(
   return result;
 }
 
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
     const Shape& shape_with_layout) const {
   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
@@ -774,7 +759,7 @@ std::unique_ptr<Literal> Literal::Relayout(
   return result;
 }
 
-StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
     tensorflow::gtl::ArraySlice<int64> dimensions) const {
   if (!ShapeUtil::IsArray(shape())) {
     return InvalidArgument("Reshape does not support tuples.");
@@ -788,7 +773,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
   }
   // Because the layout is monotonic, we can simply reuse the same sequence of
   // values without changing their order.
-  output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
+  *output->mutable_shape_do_not_use() =
+      ShapeUtil::MakeShape(shape().element_type(), dimensions);
 
   int64 elements_before = ShapeUtil::ElementsIn(shape());
   int64 elements_after = ShapeUtil::ElementsIn(output->shape());
@@ -802,7 +788,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
   return std::move(output);
 }
 
-std::unique_ptr<Literal> Literal::Transpose(
+std::unique_ptr<Literal> LiteralBase::Transpose(
     tensorflow::gtl::ArraySlice<int64> permutation) const {
   CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
   CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
@@ -819,8 +805,8 @@ std::unique_ptr<Literal> Literal::Transpose(
   // representation intact.
   // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
   // The shape with affine layout resulting from that operation will be
-  // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
-  // most minor.
+  // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized),
+  // the most minor.
   //
   // Essentially, given MinMaj(Di) the position of the Di dimension within the
   // minor to major vector, and given T(Di) the index that the original Di
@@ -836,12 +822,11 @@ std::unique_ptr<Literal> Literal::Transpose(
   std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
   DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
             ShapeUtil::ByteSizeOf(shape()));
-  std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
-              root_piece().size_bytes());
+  std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
   return new_literal;
 }
 
-std::unique_ptr<Literal> Literal::Slice(
+std::unique_ptr<Literal> LiteralBase::Slice(
     tensorflow::gtl::ArraySlice<int64> start_indices,
     tensorflow::gtl::ArraySlice<int64> limit_indices) const {
   CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
@@ -909,20 +894,20 @@ std::unique_ptr<Literal> Literal::Slice(
   }
 }
 
-Literal Literal::Clone() const {
+Literal LiteralBase::Clone() const {
   Literal result(shape());
   TF_CHECK_OK(result.CopyFrom(*this));
   return result;
 }
 
-std::unique_ptr<Literal> Literal::CloneToUnique() const {
+std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
   auto result = MakeUnique<Literal>(shape());
   TF_CHECK_OK(result->CopyFrom(*this));
   return result;
 }
 
-string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
-                            const ShapeIndex& shape_index) const {
+string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+                                const ShapeIndex& shape_index) const {
   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
   CHECK(LayoutUtil::IsDenseArray(subshape));
   switch (subshape.element_type()) {
@@ -962,8 +947,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
   }
 }
 
-string Literal::GetSparseElementAsString(int64 sparse_element_number,
-                                         const ShapeIndex& shape_index) const {
+string LiteralBase::GetSparseElementAsString(
+    int64 sparse_element_number, const ShapeIndex& shape_index) const {
   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
   CHECK(LayoutUtil::IsSparseArray(subshape));
   switch (subshape.element_type()) {
@@ -1017,7 +1002,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number,
   }
 }
 
-StatusOr<int64> Literal::GetIntegralAsS64(
+StatusOr<int64> LiteralBase::GetIntegralAsS64(
     tensorflow::gtl::ArraySlice<int64> multi_index) const {
   CHECK(LayoutUtil::IsDenseArray(shape()));
   switch (shape().element_type()) {
@@ -1070,7 +1055,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
   return Status::OK();
 }
 
-tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
+tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
     int64 sparse_element_number, const ShapeIndex& shape_index) const {
   const Piece& p = piece(shape_index);
   CHECK_GE(sparse_element_number, 0);
@@ -1082,10 +1067,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) {
   piece(shape_index).SortSparseElements();
 }
 
-Literal Literal::GetFirstScalarLiteral() const {
-  CHECK(ShapeUtil::IsArray(shape_));
-  CHECK_GT(ShapeUtil::ElementsIn(shape_), 0);
-  switch (shape_.element_type()) {
+Literal LiteralBase::GetFirstScalarLiteral() const {
+  CHECK(ShapeUtil::IsArray(shape()));
+  CHECK_GT(ShapeUtil::ElementsIn(shape()), 0);
+  switch (shape().element_type()) {
     case PRED:
       return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>()));
     // 8 bit types.
@@ -1121,11 +1106,11 @@ Literal Literal::GetFirstScalarLiteral() const {
     case U64:
       return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>()));
     default:
-      LOG(FATAL) << "Unhandled primitive type " << shape_.element_type();
+      LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
   }
 }
 
-void Literal::Piece::SortSparseElements() {
+void LiteralBase::Piece::SortSparseElements() {
   switch (subshape().element_type()) {
     case PRED:
       SortSparseElementsInternal<bool>();
@@ -1176,7 +1161,7 @@ void Literal::Piece::SortSparseElements() {
 }
 
 template <typename NativeT>
-void Literal::Piece::SortSparseElementsInternal() {
+void LiteralBase::Piece::SortSparseElementsInternal() {
   CHECK(LayoutUtil::IsSparseArray(subshape()));
   int64 num_elements = sparse_indices()->index_count();
   auto values = data<NativeT>();
@@ -1186,10 +1171,11 @@ void Literal::Piece::SortSparseElementsInternal() {
 }
 
 namespace {
-
-void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
+void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
                     bool print_layout, std::vector<string>* pieces) {
   const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
+  CHECK(LayoutUtil::HasLayout(literal.shape()));
+  CHECK(LayoutUtil::HasLayout(subshape));
 
   auto shape_to_string = [print_layout](const Shape& shape) {
     if (print_layout) {
@@ -1348,13 +1334,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
 
 }  // namespace
 
-int64 Literal::sparse_element_count() const {
+int64 LiteralBase::sparse_element_count() const {
   CHECK(LayoutUtil::IsSparseArray(shape()));
   return sparse_indices()->index_count();
 }
 
-string Literal::ToString(bool print_layout) const {
+string LiteralBase::ToString(bool print_layout) const {
   std::vector<string> pieces;
+  CHECK(LayoutUtil::HasLayout(this->shape()));
   ToStringHelper(*this, {}, print_layout, &pieces);
   return tensorflow::str_util::Join(pieces, "");
 }
@@ -1362,7 +1349,7 @@ string Literal::ToString(bool print_layout) const {
 /* static */ std::unique_ptr<Literal> Literal::MakeTuple(
     tensorflow::gtl::ArraySlice<const Literal*> elements) {
   std::vector<Shape> element_shapes;
-  for (const Literal* element : elements) {
+  for (const auto* element : elements) {
     element_shapes.push_back(element->shape());
   }
   auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
@@ -1372,6 +1359,19 @@ string Literal::ToString(bool print_layout) const {
   return literal;
 }
 
+/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices(
+    tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
+  std::vector<Shape> element_shapes;
+  for (const auto& element : elements) {
+    element_shapes.push_back(element.shape());
+  }
+  auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+  for (int i = 0; i < elements.size(); ++i) {
+    TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+  }
+  return literal;
+}
+
 /* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
     std::vector<std::unique_ptr<Literal>> elements) {
   std::vector<Shape> element_shapes;
@@ -1387,7 +1387,7 @@ string Literal::ToString(bool print_layout) const {
   return literal;
 }
 
-void Literal::EachCellAsString(
+void LiteralBase::EachCellAsString(
     const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
                              const string& value)>& per_cell) const {
   if (ShapeUtil::HasZeroElements(shape())) {
@@ -1403,7 +1403,7 @@ void Literal::EachCellAsString(
 namespace {
 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
 std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
-    const Literal& src_literal, const ConverterType& converter) {
+    const LiteralBase& src_literal, const ConverterType& converter) {
   CHECK(ShapeUtil::IsArray(src_literal.shape()));
   auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
       src_literal.shape(),
@@ -1419,7 +1419,8 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
 }
 
 template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(
+    const LiteralBase& src_literal) {
   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
       src_literal, converter);
@@ -1428,7 +1429,7 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
 template <typename NativeSrcT, typename NativeDestT>
 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
                         std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
   auto converter = [](NativeSrcT src) {
     return tensorflow::bit_cast<NativeDestT>(src);
   };
@@ -1436,19 +1437,19 @@ BitcastBetweenNativeTypes(const Literal& src_literal) {
       src_literal, converter);
 }
 
-// This template specialization is here to make the compiler happy. bit_cast has
-// a static check that the types are the same size. This specialization should
-// never be used because the source and destination types are checked for
-// identical sizes higher up.
+// This template specialization is here to make the compiler happy. bit_cast
+// has a static check that the types are the same size. This specialization
+// should never be used because the source and destination types are checked
+// for identical sizes higher up.
 template <typename NativeSrcT, typename NativeDestT>
 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
                         std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
 }
 
 template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
   CHECK(ShapeUtil::IsArray(src_literal.shape()));
   auto result_literal = MakeUnique<Literal>(
       ShapeUtil::ChangeElementType(src_literal.shape(), C64));
@@ -1466,7 +1467,7 @@ std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
 }
 
 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
+std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
                                              bool bitcast) {
   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
   if (bitcast) {
@@ -1486,7 +1487,7 @@ std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
 
 template <PrimitiveType primitive_src_type>
 StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
-    const Literal& src_literal, PrimitiveType primitive_dest_type,
+    const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
     bool bitcast) {
   switch (primitive_dest_type) {
 #define CONVERT_IF_TYPES_MATCH(type)                                    \
@@ -1521,7 +1522,8 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
 }
 
 StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
-    const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) {
+    const LiteralBase& literal, PrimitiveType primitive_dest_type,
+    bool bitcast) {
   TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
   if (literal.shape().element_type() == primitive_dest_type) {
     return literal.CloneToUnique();
@@ -1555,17 +1557,18 @@ StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
 
 }  // namespace
 
-StatusOr<std::unique_ptr<Literal>> Literal::Convert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
     PrimitiveType primitive_dest_type) const {
   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
 }
 
-StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
     PrimitiveType primitive_dest_type) const {
   if (primitive_util::BitWidth(shape().element_type()) !=
       primitive_util::BitWidth(primitive_dest_type)) {
     return InvalidArgument(
-        "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+        "Cannot bitcast convert from %s to %s, bit widths are different: %d "
+        "!= "
         "%d",
         PrimitiveType_Name(shape().element_type()).c_str(),
         PrimitiveType_Name(primitive_dest_type).c_str(),
@@ -1575,7 +1578,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
 }
 
-StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
     const Shape& dest_shape, bool round_f32_to_bf16) const {
   if (!ShapeUtil::IsTuple(dest_shape)) {
     if (round_f32_to_bf16 && shape().element_type() == F32 &&
@@ -1590,7 +1593,7 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
   }
   std::vector<Literal> elements;
   for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
-    auto element = LiteralView::Create(*this, {i});
+    auto element = LiteralSlice(*this, {i});
     TF_ASSIGN_OR_RETURN(
         auto new_element,
         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
@@ -1602,8 +1605,8 @@ StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
 }
 
 template <typename NativeT>
-bool Literal::Piece::EqualElementsInternal(
-    const Literal::Piece& other, std::vector<int64>* multi_index) const {
+bool LiteralBase::Piece::EqualElementsInternal(
+    const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
   if (multi_index->size() == ShapeUtil::Rank(subshape())) {
     return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
   }
@@ -1617,7 +1620,7 @@ bool Literal::Piece::EqualElementsInternal(
   return true;
 }
 
-bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
+bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
   DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
 
   std::vector<int64> multi_index;
@@ -1645,32 +1648,31 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
     case C64:
       return EqualElementsInternal<complex64>(other, &multi_index);
     default:
-      LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
+      LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
                  << PrimitiveType_Name(subshape().element_type());
   }
 }
 
-bool Literal::operator==(const Literal& other) const {
+bool LiteralBase::operator==(const LiteralBase& other) const {
   if (!ShapeUtil::Compatible(shape(), other.shape())) {
     return false;
   }
-  for (const auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    const Piece& piece = pair.second;
-    if (!ShapeUtil::IsArray(piece.subshape())) {
-      continue;
-    }
 
-    const Piece& other_piece = other.piece(index);
-    if (!piece.EqualElements(other_piece)) {
-      return false;
-    }
-  }
-  return true;
+  return root_piece().ForEachSubpieceWithBool(
+      [&](const ShapeIndex& index, const Piece& piece) {
+        if (!ShapeUtil::IsArray(piece.subshape())) {
+          return true;
+        }
+
+        const Piece& other_piece = other.piece(index);
+        if (!piece.EqualElements(other_piece)) {
+          return false;
+        }
+        return true;
+      });
 }
 
 namespace {
-
 template <typename NativeT>
 static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
                                   NativeT value) {
@@ -1684,11 +1686,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
 
 }  // namespace
 
-bool Literal::IsAll(int8 value) const {
-  for (const auto& pair : pieces_) {
-    const Piece& piece = pair.second;
+bool LiteralBase::IsAll(int8 value) const {
+  return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
+                                                  const Piece& piece) {
     if (!ShapeUtil::IsArray(piece.subshape())) {
-      continue;
+      return true;
     }
 
     auto piece_is_all = [&]() {
@@ -1741,41 +1743,41 @@ bool Literal::IsAll(int8 value) const {
     if (!piece_is_all()) {
       return false;
     }
-  }
-  return true;
-}
+    return true;
+  });
+}  // namespace xla
 
-bool Literal::IsAllFloat(float value) const {
-  for (const auto& pair : pieces_) {
-    const Piece& piece = pair.second;
-    if (!ShapeUtil::IsArray(piece.subshape())) {
-      continue;
-    }
+bool LiteralBase::IsAllFloat(float value) const {
+  return root_piece().ForEachSubpieceWithBool(
+      [&](const ShapeIndex& index, const Piece& piece) {
+        if (!ShapeUtil::IsArray(piece.subshape())) {
+          return true;
+        }
 
-    auto piece_is_all = [&]() {
-      switch (shape().element_type()) {
-        case F32:
-          return AllElementsEqualValue<float>(piece.data<float>(), value);
-        case F64:
-          return AllElementsEqualValue<double>(piece.data<double>(), value);
-        case F16:
-          return AllElementsEqualValue<half>(piece.data<half>(),
-                                             static_cast<half>(value));
-        case BF16:
-          return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
-                                                 static_cast<bfloat16>(value));
-        default:
+        auto piece_is_all = [&]() {
+          switch (shape().element_type()) {
+            case F32:
+              return AllElementsEqualValue<float>(piece.data<float>(), value);
+            case F64:
+              return AllElementsEqualValue<double>(piece.data<double>(), value);
+            case F16:
+              return AllElementsEqualValue<half>(piece.data<half>(),
+                                                 static_cast<half>(value));
+            case BF16:
+              return AllElementsEqualValue<bfloat16>(
+                  piece.data<bfloat16>(), static_cast<bfloat16>(value));
+            default:
+              return false;
+          }
+        };
+        if (!piece_is_all()) {
           return false;
-      }
-    };
-    if (!piece_is_all()) {
-      return false;
-    }
-  }
-  return true;
+        }
+        return true;
+      });
 }
 
-bool Literal::IsAllComplex(complex64 value) const {
+bool LiteralBase::IsAllComplex(complex64 value) const {
   switch (shape().element_type()) {
     case C64:
       return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
@@ -1785,93 +1787,93 @@ bool Literal::IsAllComplex(complex64 value) const {
   }
 }
 
-bool Literal::IsAllFirst() const {
-  for (const auto& pair : pieces_) {
-    const Piece& piece = pair.second;
-    if (!ShapeUtil::IsArray(piece.subshape())) {
-      continue;
-    }
-
-    // Empty shapes are not all the first element since there is no first
-    // element.
-    if (ShapeUtil::HasZeroElements(piece.subshape())) {
-      return false;
-    }
-    auto piece_is_all = [&]() {
-      switch (piece.subshape().element_type()) {
-        case PRED: {
-          auto data = piece.data<bool>();
-          return AllElementsEqualValue<bool>(data, data[0]);
+bool LiteralBase::IsAllFirst() const {
+  return root_piece().ForEachSubpieceWithBool(
+      [&](const ShapeIndex& index, const Piece& piece) {
+        if (!ShapeUtil::IsArray(piece.subshape())) {
+          return true;
         }
-        // 8 bit types
-        case S8: {
-          auto data = piece.data<int8>();
-          return AllElementsEqualValue<int8>(data, data[0]);
-        }
-        case U8: {
-          auto data = piece.data<uint8>();
-          return AllElementsEqualValue<uint8>(data, data[0]);
-        }
-        // 16 bit types
-        case BF16: {
-          auto data = piece.data<bfloat16>();
-          return AllElementsEqualValue<bfloat16>(data, data[0]);
-        }
-        case F16: {
-          auto data = piece.data<half>();
-          return AllElementsEqualValue<half>(data, data[0]);
-        }
-        case S16: {
-          auto data = piece.data<int16>();
-          return AllElementsEqualValue<int16>(data, data[0]);
-        }
-        case U16: {
-          auto data = piece.data<uint16>();
-          return AllElementsEqualValue<uint16>(data, data[0]);
-        }
-        // 32 bit types
-        case F32: {
-          auto data = piece.data<float>();
-          return AllElementsEqualValue<float>(data, data[0]);
-        }
-        case U32: {
-          auto data = piece.data<uint32>();
-          return AllElementsEqualValue<uint32>(data, data[0]);
-        }
-        case S32: {
-          auto data = piece.data<int32>();
-          return AllElementsEqualValue<int32>(data, data[0]);
-        }
-        // 64 bit types
-        case C64: {
-          auto data = piece.data<complex64>();
-          return AllElementsEqualValue<complex64>(data, data[0]);
-        }
-        case F64: {
-          auto data = piece.data<double>();
-          return AllElementsEqualValue<double>(data, data[0]);
-        }
-        case S64: {
-          auto data = piece.data<int64>();
-          return AllElementsEqualValue<int64>(data, data[0]);
-        }
-        case U64: {
-          auto data = piece.data<uint64>();
-          return AllElementsEqualValue<uint64>(data, data[0]);
-        }
-        default:
+
+        // Empty shapes are not all the first element since there is no first
+        // element.
+        if (ShapeUtil::HasZeroElements(piece.subshape())) {
           return false;
-      }
-    };
+        }
+        auto piece_is_all = [&]() {
+          switch (piece.subshape().element_type()) {
+            case PRED: {
+              auto data = piece.data<bool>();
+              return AllElementsEqualValue<bool>(data, data[0]);
+            }
+            // 8 bit types
+            case S8: {
+              auto data = piece.data<int8>();
+              return AllElementsEqualValue<int8>(data, data[0]);
+            }
+            case U8: {
+              auto data = piece.data<uint8>();
+              return AllElementsEqualValue<uint8>(data, data[0]);
+            }
+            // 16 bit types
+            case BF16: {
+              auto data = piece.data<bfloat16>();
+              return AllElementsEqualValue<bfloat16>(data, data[0]);
+            }
+            case F16: {
+              auto data = piece.data<half>();
+              return AllElementsEqualValue<half>(data, data[0]);
+            }
+            case S16: {
+              auto data = piece.data<int16>();
+              return AllElementsEqualValue<int16>(data, data[0]);
+            }
+            case U16: {
+              auto data = piece.data<uint16>();
+              return AllElementsEqualValue<uint16>(data, data[0]);
+            }
+            // 32 bit types
+            case F32: {
+              auto data = piece.data<float>();
+              return AllElementsEqualValue<float>(data, data[0]);
+            }
+            case U32: {
+              auto data = piece.data<uint32>();
+              return AllElementsEqualValue<uint32>(data, data[0]);
+            }
+            case S32: {
+              auto data = piece.data<int32>();
+              return AllElementsEqualValue<int32>(data, data[0]);
+            }
+            // 64 bit types
+            case C64: {
+              auto data = piece.data<complex64>();
+              return AllElementsEqualValue<complex64>(data, data[0]);
+            }
+            case F64: {
+              auto data = piece.data<double>();
+              return AllElementsEqualValue<double>(data, data[0]);
+            }
+            case S64: {
+              auto data = piece.data<int64>();
+              return AllElementsEqualValue<int64>(data, data[0]);
+            }
+            case U64: {
+              auto data = piece.data<uint64>();
+              return AllElementsEqualValue<uint64>(data, data[0]);
+            }
+            default:
+              return false;
+          }
+        };
 
-    if (!piece_is_all()) {
-      return false;
-    }
-  }
-  return true;
+        if (!piece_is_all()) {
+          return false;
+        }
+        return true;
+      });
 }
 
-bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
   CHECK(ShapeUtil::IsArray(shape()));
   switch (shape().element_type()) {
     case U8:
@@ -1904,7 +1906,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
 }
 
 namespace {
-
 template <typename RepeatedFieldT, typename NativeT>
 void CopyToRepeatedField(RepeatedFieldT* dest,
                          const tensorflow::gtl::ArraySlice<NativeT> src) {
@@ -1913,7 +1914,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest,
 
 }  // namespace
 
-void Literal::Piece::WriteToProto(LiteralProto* proto) const {
+void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
   *proto->mutable_shape() = subshape();
   switch (subshape().element_type()) {
     case PRED:
@@ -1969,18 +1970,17 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const {
   }
 }
 
-const void* Literal::Piece::untyped_data() const {
+const void* LiteralBase::Piece::untyped_data() const {
   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
   return buffer();
 }
 
-void* Literal::Piece::untyped_data() {
+void* LiteralBase::Piece::untyped_data() {
   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
   return buffer();
 }
 
 namespace {
-
 template <typename RepeatedFieldT, typename NativeT>
 Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
                              const RepeatedFieldT& src) {
@@ -1995,7 +1995,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
 
 }  // namespace
 
-Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
+Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
   // These conditions should have been checked in Literal::CreateFromProto.
   TF_RET_CHECK(proto.has_shape());
   TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
@@ -2062,21 +2062,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
   return Status::OK();
 }
 
-LiteralProto Literal::ToProto() const {
+LiteralProto LiteralBase::ToProto() const {
   LiteralProto proto;
-  for (const auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    const Piece& piece = pair.second;
-
-    LiteralProto* proto_piece = &proto;
-    for (int64 i : index) {
-      while (proto_piece->tuple_literals_size() <= i) {
-        proto_piece->add_tuple_literals();
-      }
-      proto_piece = proto_piece->mutable_tuple_literals(i);
-    }
-    piece.WriteToProto(proto_piece);
-  }
+  root_piece().ForEachSubpiece(
+      [&](const ShapeIndex& index, const Piece& piece) {
+        LiteralProto* proto_piece = &proto;
+        for (int64 i : index) {
+          while (proto_piece->tuple_literals_size() <= i) {
+            proto_piece->add_tuple_literals();
+          }
+          proto_piece = proto_piece->mutable_tuple_literals(i);
+        }
+        piece.WriteToProto(proto_piece);
+      });
 
   if (LayoutUtil::IsSparseArray(shape())) {
     CopyToRepeatedField(proto.mutable_sparse_indices(),
@@ -2098,33 +2096,34 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
 
   auto literal = MakeUnique<Literal>(proto.shape());
 
-  for (auto& pair : literal->pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-    const LiteralProto* proto_element = &proto;
-    for (int64 i : index) {
-      TF_RET_CHECK(i < proto_element->tuple_literals_size());
-      proto_element = &proto_element->tuple_literals(i);
-    }
+  TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+      [&](const ShapeIndex& index, Piece* piece) {
+        const LiteralProto* proto_element = &proto;
+        for (int64 i : index) {
+          CHECK(i < proto_element->tuple_literals_size());
+          proto_element = &proto_element->tuple_literals(i);
+        }
 
-    if (ShapeUtil::IsTuple(piece.subshape())) {
-      if (proto_element->tuple_literals_size() !=
-          ShapeUtil::TupleElementCount(piece.subshape())) {
-        return InvalidArgument(
-            "Expected %lld tuple elements in LiteralProto, has %d",
-            ShapeUtil::TupleElementCount(piece.subshape()),
-            proto_element->tuple_literals_size());
-      }
-      continue;
-    }
+        if (ShapeUtil::IsTuple(piece->subshape())) {
+          if (proto_element->tuple_literals_size() !=
+              ShapeUtil::TupleElementCount(piece->subshape())) {
+            return InvalidArgument(
+                "Expected %lld tuple elements in LiteralProto, has %d",
+                ShapeUtil::TupleElementCount(piece->subshape()),
+                proto_element->tuple_literals_size());
+          }
+          return Status::OK();
+        }
 
-    TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
-    TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
-  }
+        CHECK(ShapeUtil::IsArray(piece->subshape()));
+        TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
+
+        return Status::OK();
+      }));
   return std::move(literal);
 }
 
-const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
+const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
   return piece(shape_index).untyped_data();
 }
 
@@ -2132,11 +2131,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) {
   return piece(shape_index).untyped_data();
 }
 
-int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
+int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
   return piece(shape_index).size_bytes();
 }
 
-string Literal::GetR1U8AsString() const {
+string LiteralBase::GetR1U8AsString() const {
   CHECK(ShapeUtil::IsArray(shape()));
   CHECK_EQ(ShapeUtil::Rank(shape()), 1);
   CHECK_EQ(shape().element_type(), U8);
@@ -2144,12 +2143,14 @@ string Literal::GetR1U8AsString() const {
                 ShapeUtil::ElementsIn(shape()));
 }
 
-/* static */ const LiteralView LiteralView::Create(
-    const Literal& literal, const ShapeIndex& view_root) {
-  return LiteralView(literal, view_root);
-}
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+    : LiteralBase(), root_piece_(&literal.root_piece()) {}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+                           const ShapeIndex& view_root)
+    : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
 
-size_t Literal::Hash() const {
+size_t LiteralBase::Hash() const {
   using tensorflow::Hash64;
   using tensorflow::Hash64Combine;
 
@@ -2170,46 +2171,4 @@ size_t Literal::Hash() const {
   return hash_value;
 }
 
-LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
-  shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
-  pieces_ = ShapeTree<Piece>(shape_);
-  owns_buffers_ = false;
-  for (auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-
-    ShapeIndex src_index = view_root;
-    for (int64 i : index) {
-      src_index.push_back(i);
-    }
-    const Piece& src_piece = literal.piece(src_index);
-    piece.set_buffer(src_piece.buffer());
-    piece.set_sparse_indices(src_piece.sparse_indices());
-    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
-  }
-}
-
-LiteralView::~LiteralView() {}
-
-LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
-
-LiteralView& LiteralView::operator=(const LiteralView& other) {
-  CopyFrom(other);
-  return *this;
-}
-
-void LiteralView::CopyFrom(const LiteralView& other) {
-  // We can't use the default copy-constructor/copy-assignment because
-  // Piece::subshape_ points to subshapes within the Shape of the owning
-  // Literal/LiteralView.
-  shape_ = other.shape();
-  pieces_ = other.pieces_;
-  for (auto& pair : pieces_) {
-    const ShapeIndex& index = pair.first;
-    Piece& piece = pair.second;
-    piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
-  }
-  owns_buffers_ = false;
-}
-
 }  // namespace xla
index 290f388..30442af 100644 (file)
@@ -34,7 +34,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/sparse_index_array.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -52,14 +51,491 @@ limitations under the License.
 
 namespace xla {
 
+// Forward declare Literal and LiteralSlice class to be used by the creation
+// methods in the base class.
+class Literal;
+class LiteralSlice;
+
+// Abstract base class for literals.
+class LiteralBase {
+ public:
+  virtual ~LiteralBase() = 0;
+
+  // Literals are equal if they have compatible shapes and the same data
+  // values. Layout is not compared.
+  bool operator==(const LiteralBase& other) const;
+  bool operator!=(const LiteralBase& other) const { return !(*this == other); }
+
+  // Returns the shape of the literal.
+  const Shape& shape() const { return root_piece().subshape(); }
+
+  // Serialize to proto.
+  LiteralProto ToProto() const;
+
+  // Returns an ArraySlice of the array for this literal for the given NativeT
+  // (e.g., float). CHECKs if the subshape of the literal at the given
+  // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
+  // to native type.
+  template <typename NativeT>
+  tensorflow::gtl::ArraySlice<NativeT> data(
+      const ShapeIndex& shape_index = {}) const;
+
+  // Returns a const pointer to the sparse index array. Returns nullptr if the
+  // literal is not a sparse array.
+  const SparseIndexArray* sparse_indices(
+      const ShapeIndex& shape_index = {}) const;
+
+  // Returns a const pointer to (or size of) the underlying buffer holding the
+  // array at the given shape index. CHECKs if the subshape of the literal at
+  // the given ShapeIndex is not array.
+  const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+  int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+
+  // Returns this literal's data as a string. This literal must be a rank-1 U8
+  // array.
+  string GetR1U8AsString() const;
+
+  // Returns a string representation of the literal value.
+  // Warning: this function can take minutes for multi-million element Literals.
+  string ToString(bool print_layout = false) const;
+
+  // Gets an element in the literal at the given index. The multi_index is
+  // CHECKed against the dimension sizes.
+  template <typename NativeT>
+  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+              const ShapeIndex& shape_index) const;
+  // Overloads of Get for array literals. CHECKs if the literal is not
+  // array-shaped and dense.
+  template <typename NativeT>
+  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+  // Returns the element value at index (0, ..., 0), however many zeroes are
+  // required for that index.
+  template <typename NativeT>
+  NativeT GetFirstElement() const;
+
+  // As Get(), but determines the correct type and converts the value
+  // into text.
+  string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+                     const ShapeIndex& shape_index = {}) const;
+  // As GetSparseElement(), but determines the correct type and converts the
+  // value into text.
+  string GetSparseElementAsString(int64 sparse_element_number,
+                                  const ShapeIndex& shape_index = {}) const;
+  // As Get(), but determines the correct type and converts the value into
+  // int64.  This literal must be an array.
+  StatusOr<int64> GetIntegralAsS64(
+      tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+  // Returns the multi-index of the element in a sparse literal at the given
+  // sparse element number.  The sparse element number is the position with in
+  // the sparse array's list of (index, value) pairs, and is checked against the
+  // total number of (index, value) pairs in the sparse array.
+  tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+      int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
+
+  // Returns the value of the element in a sparse literal at the given sparse
+  // element number.  The sparse element number is the position with in the
+  // sparse array's list of (index, value) pairs, and is checked against the
+  // total number of (index, value) pairs in the sparse array.
+  template <typename NativeT>
+  NativeT GetSparseElement(int64 sparse_element_number,
+                           const ShapeIndex& shape_index = {}) const;
+
+  // Invokes the "per cell" callback for each element in the provided
+  // literal with the element's indices and a string representation of
+  // the element's value.
+  //
+  // This function is useful if you want a polymorphic representation
+  // of the tensor's elements (turning it to a string for something
+  // like representation in a protobuf).
+  //
+  // This literal must have a dense layout.
+  void EachCellAsString(
+      const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+                               const string& value)>& per_cell) const;
+  template <typename NativeT>
+  void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+                                   NativeT value)>
+                    per_cell) const;
+
+  // Returns whether every element in this literal is equal to value.
+  //
+  // value is an int8 because we expect this to be called with small
+  // compile-time constants (0, -1, etc.) and so that whatever value you pass
+  // can be represented exactly by floating-point types as small as 16 bits.
+  //
+  // If value doesn't fit in this literal's type, returns false.  Values of 1/0
+  // are considered equal to true/false; other values are not considered equal
+  // to true. Also if this literal is not array-shaped false is returned.
+  bool IsAll(int8 value) const;
+
+  // Like IsAll(const Literal&, int8), except we check whether the literal is
+  // equal to a particular floating-point number.
+  //
+  // If the literal is not a floating-point value, this always returns false.
+  //
+  // This casts value to the type of literal, then compares using ==.  The usual
+  // admonishments about floating-point equality checks apply.  We expect you to
+  // use this to check for values that can be expressed precisely as a float,
+  // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
+  bool IsAllFloat(float value) const;
+
+  // Like IsAll(const Literal&, int8), except we check whether the literal is
+  // equal to a particular complex number.
+  //
+  // If the literal is not a complex value, this always returns false.
+  //
+  // This casts value to the type of literal, then compares using ==.  The usual
+  // admonishments about floating-point equality checks apply.  We expect you to
+  // use this to check for complex values that can be expressed precisely as
+  // float pairs e.g. (-0.5, 1.0).
+  //
+  // This literal must have a dense layout.
+  bool IsAllComplex(complex64 value) const;
+
+  // Literal consists entirely of the first element of the literal.
+  bool IsAllFirst() const;
+
+  // Returns whether this literal is zero at the specified index. This literal
+  // must be an array with a dense layout.
+  bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+
+  // Returns the count of the elements in the array at the given shape index in
+  // this literal.
+  int64 element_count(const ShapeIndex& index = {}) const {
+    return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+  }
+
+  // Return the count of the elements in the sparse array at the given shape
+  // index in this literal, which will be no larger than
+  // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
+  int64 sparse_element_count() const;
+
+  // Compute a hash for this literal.  This literal must not be a sparse tensor
+  // or a tuple containing a sparse tensor.
+  size_t Hash() const;
+
+  // Converts this literal to the given shape. Returns an error is the
+  // conversion is not possible.
+  //
+  // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
+  // instead of truncation; otherwise, truncation is used.
+  //
+  // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
+  // the default behavior.
+  StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+      const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+
+  // Converts this literal to another primitive type using a bitcast
+  // conversion. The to and from primitive types must have the same bit
+  // width. Returns an error if the conversion is not possible. This literal
+  // must be array-shaped.
+  StatusOr<std::unique_ptr<Literal>> BitcastConvert(
+      PrimitiveType primitive_dest_type) const;
+
+  // Converts this literal to another primitive type. Returns an error if the
+  // conversion is not possible. This literal must be array-shaped.
+  StatusOr<std::unique_ptr<Literal>> Convert(
+      PrimitiveType primitive_dest_type) const;
+
+  // Returns a literal scalar representing the first element.
+  Literal GetFirstScalarLiteral() const;
+
+  // Clones the underlying buffers into a new Literal, or new
+  // std::unique_ptr<Literal>.
+  Literal Clone() const;
+  std::unique_ptr<Literal> CloneToUnique() const;
+
+  // TODO(b/67651157): The methods below which perform computation on Literals
+  // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
+  // evaluator code which operates on Literals.
+  //
+  // Creates a new value that has the equivalent value as this
+  // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
+  // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
+  // minor-to-major dimension layout and the value in the cell at any given
+  // logical index (i0, i1) will be the same.
+  //
+  // For tuple shaped literals, shape_index should be used to select the inner
+  // array that the new layout applies to.
+  //
+  // Note: this is useful when the client wants to ensure that a value placed in
+  // the XLA allocation tracker has a particular layout; for efficiency
+  // purposes or avoiding unimplemented operation/layout combinations.
+  std::unique_ptr<Literal> Relayout(const Layout& new_layout,
+                                    const ShapeIndex& shape_index = {}) const;
+
+  // An overload of Relayout which changes the layout of the entire shape rather
+  // than being limited to a single array within the shape.
+  std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+
+  // Creates a new literal by reshaping this literal to have the given
+  // dimensions. The total number of elements must not change; The
+  // implementation currently only supports monotonic dim0-major layouts.
+  // This literal must be an array.
+  StatusOr<std::unique_ptr<Literal>> Reshape(
+      tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+  // Creates a new literal by reordering the dimensions of this literal.
+  // The given `permutation` must be a permutation of the dimension numbers
+  // in the original literal, and it specifies the order of the new dimensions
+  // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
+  // For example, a transpose call on a literal of shape [3 x 8 x 4] and
+  // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+  // This literal must be an array.
+  std::unique_ptr<Literal> Transpose(
+      tensorflow::gtl::ArraySlice<int64> permutation) const;
+
+  // Creates a sub-array from this literal by extracting the indices
+  // [start_index, limit_index) of each dimension. The result literal has the
+  // same rank and layout as for the given literal. The number of indices in
+  // start_indices and limit_indices must be the rank of the literal, and the
+  // indices follow the order of the dimensions.
+  // This literal must be an array.
+  std::unique_ptr<Literal> Slice(
+      tensorflow::gtl::ArraySlice<int64> start_indices,
+      tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+
+  // Creates a literal with a prepended dimension with bound "times"; e.g. a
+  // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
+  // literal replicated four times.
+  // This literal must be an array.
+  template <typename NativeT>
+  std::unique_ptr<Literal> Replicate(int64 times) const;
+
+  // Creates a new Literal object with the shape specified as parameter.
+  // The content of the literal values is the default value of the primitive
+  // type of literal itself (0 for numeric types, and false for predicates).
+  static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+
+ protected:
+  // A data structure representing a subshape at a particular ShapeIndex within
+  // the literal. For array-shaped ShapeIndexes, this data structure holds the
+  // pointer to the memory allocated for the array data.
+  class Piece {
+   public:
+    // Returns the buffer holding the array data for this piece as an array
+    // slice. This piece must be array-shaped.
+    template <typename NativeT>
+    tensorflow::gtl::ArraySlice<NativeT> data() const;
+    template <typename NativeT>
+    tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+    // Returns the buffer holding the array data for this piece as a void*. This
+    // piece must be array-shaped.
+    void* untyped_data();
+    const void* untyped_data() const;
+
+    // Gets or sets an element in the array at the given index. The multi_index
+    // is CHECKed against the dimension sizes of the array.  This piece must be
+    // array-shaped.
+    template <typename NativeT>
+    NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+    template <typename NativeT>
+    void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+    // Gets/sets the buffer holding the array data.
+    char* buffer() const { return buffer_; }
+    void set_buffer(char* buffer) { buffer_ = buffer; }
+
+    // The array of multi-indices that provide the locations of non-zero
+    // elements in a sparse array.  Only used if
+    // LayoutUtil::IsSparseArray(shape()) is true.
+    SparseIndexArray* sparse_indices() const { return sparse_indices_; }
+    void set_sparse_indices(SparseIndexArray* sparse_indices) {
+      sparse_indices_ = sparse_indices;
+    }
+
+    // Gets or sets the subshape of this piece. This reference points to a
+    // subshape within the shape in the containing Literal (Literal::shape_).
+    const Shape& subshape() const { return *subshape_; }
+    void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+    // Returns the size in bytes of the buffer holding the array data.
+    int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+    // Returns the number of elements in this piece's array.
+    int64 element_count() const {
+      // If this is a sparse array, use the number of elements represented by
+      // the indices in the associated SparseIndexArray.
+      return LayoutUtil::IsSparseArray(subshape())
+                 ? sparse_indices()->index_count()
+                 : ShapeUtil::ElementsIn(subshape());
+    }
+
+    // Returns the child piece at 'index' of this piece.
+    Piece& child(int64 index) { return children_[index]; }
+
+    // Adds a child piece to this piece's children.
+    void emplace_back(Piece child_piece) {
+      children_.emplace_back(std::move(child_piece));
+    }
+
+    // Returns the size of children pieces of this piece.
+    int64 children_size() { return children_.size(); }
+
+    // Visitor functions that recursively traverses the piece and calls the
+    // given function at each child piece. The function has the type:
+    //    void (const ShapeIndex& index, const Piece& piece)
+    template <typename Fn>
+    void ForEachSubpiece(const Fn& func) const {
+      ShapeIndex index;
+      return ForEachHelper(
+                 [&func](const ShapeIndex& index, const Piece& piece) {
+                   func(index, piece);
+                   return Status::OK();
+                 },
+                 *this, &index)
+          .IgnoreError();
+    }
+    // Same as above, but the function has the type:
+    //    Status (const ShapeIndex& index, const Piece& piece)
+    // The first non-OK return value is returned by the function.
+    template <typename Fn>
+    Status ForEachSubpieceWithStatus(const Fn& func) const {
+      ShapeIndex index;
+      return ForEachHelper(func, *this, &index);
+    }
+    // Same as above, but the function has the type:
+    //    Bool (const ShapeIndex& index, const Piece& piece)
+    // The first non-true return value is returned by the function.
+    template <typename Fn>
+    bool ForEachSubpieceWithBool(const Fn& func) const {
+      ShapeIndex index;
+      return ForEachHelperBool(func, *this, &index);
+    }
+    // Same as above, but the function has the type:
+    //    Void (const ShapeIndex& index, Piece& piece)
+    template <typename Fn>
+    void ForEachMutableSubpiece(const Fn& func) {
+      ShapeIndex index;
+      return ForEachMutableHelper(
+                 [&func](const ShapeIndex& index, Piece* piece) {
+                   func(index, piece);
+                   return Status::OK();
+                 },
+                 const_cast<xla::LiteralBase::Piece*>(this), &index)
+          .IgnoreError();
+    }
+    // Same as above, but the function has the type:
+    //    Status (const ShapeIndex& index, Piece& piece)
+    // The first non-OK return value is returned by the function.
+    template <typename Fn>
+    Status ForEachMutableSubpieceWithStatus(const Fn& func) {
+      ShapeIndex index;
+      return ForEachMutableHelper(
+          func, const_cast<xla::LiteralBase::Piece*>(this), &index);
+    }
+
+    // Returns true if this piece and 'other' contain the same data. This piece
+    // and 'other' must be array-shaped and compatible.
+    bool EqualElements(const Piece& other) const;
+
+    // Writes the shape and data (if array-shaped) into the given proto.
+    void WriteToProto(LiteralProto* proto) const;
+
+    // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+    // and src must be compatible.
+    Status CopyFrom(const Piece& src);
+
+    // Copies the data from the given proto into this piece. The shape of this
+    // piece must be equal (not just compatible) to the shape of the proto.
+    Status CopyFromProto(const LiteralProto& proto);
+
+    // Sorts the elements in a sparse array.
+    void SortSparseElements();
+
+   private:
+    // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
+    // The first non-OK (or non-true) value is returned by the function.
+    // The callable 'func' has the same signature as described above in
+    // ForEachSubpiece*.
+    template <typename Fn>
+    Status ForEachHelper(const Fn& func, const Piece& piece,
+                         ShapeIndex* index) const {
+      TF_RETURN_IF_ERROR(func(*index, piece));
+      for (int64 i = 0; i < piece.children_.size(); ++i) {
+        index->push_back(i);
+        TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
+        index->pop_back();
+      }
+      return Status::OK();
+    }
+    template <typename Fn>
+    bool ForEachHelperBool(const Fn& func, const Piece& piece,
+                           ShapeIndex* index) const {
+      if (!func(*index, piece)) {
+        return false;
+      }
+      for (int64 i = 0; i < piece.children_.size(); ++i) {
+        index->push_back(i);
+        if (!ForEachHelperBool(func, piece.children_[i], index)) {
+          return false;
+        }
+        index->pop_back();
+      }
+      return true;
+    }
+    template <typename Fn>
+    Status ForEachMutableHelper(const Fn& func, Piece* piece,
+                                ShapeIndex* index) {
+      TF_RETURN_IF_ERROR(func(*index, piece));
+      for (int64 i = 0; i < piece->children_.size(); ++i) {
+        index->push_back(i);
+        TF_RETURN_IF_ERROR(
+            ForEachMutableHelper(func, &piece->children_[i], index));
+        index->pop_back();
+      }
+      return Status::OK();
+    }
+
+    // Recursive helper for EqualElements.
+    template <typename NativeT>
+    bool EqualElementsInternal(const Piece& other,
+                               std::vector<int64>* multi_index) const;
+
+    // Helper for SortSparseElements that has the element type as a template
+    // parameter.
+    template <typename NativeT>
+    void SortSparseElementsInternal();
+
+    // For array-shaped pieces, this is the buffer holding the literal data.
+    char* buffer_ = nullptr;
+
+    // For sparse arrays, this is the array of indices.
+    SparseIndexArray* sparse_indices_ = nullptr;
+
+    // The shape of piece. This points into the shape of the containing Literal
+    // (Literal::shape_).
+    const Shape* subshape_ = nullptr;
+
+    // Children pieces for tuple shaped pieces.
+    std::vector<Piece> children_ = {};
+  };  // class Piece
+
+  const Piece& piece(const ShapeIndex& shape_index) const {
+    Piece* piece = &const_cast<Piece&>(root_piece());
+    for (const auto i : shape_index) {
+      DCHECK_GE(i, 0);
+      DCHECK_LT(i, piece->children_size());
+      piece = &piece->child(i);
+    }
+    return *piece;
+  }
+
+  // Returns the piece at the root of the shape.
+  virtual const Piece& root_piece() const = 0;
+
+  // LiteralSlice and Literal must access Pieces of other Literals.
+  friend class LiteralSlice;
+  friend class Literal;
+};
+
 // Class representing literal values in XLA.
 //
-// TODO(b/67651157): The methods in this class should be reduced to a minimal
-// set of methods which construct Literals and accessors methods. Other methods
-// which perform computation on Literals (Reshape, Slice, etc) should be moved
-// elsewhere, and perhaps combined with evaluator code which operates on
-// Literals.
-class Literal {
+// The underlying buffer and shape is always owned by this class.
+class Literal : public LiteralBase {
  public:
   Literal() : Literal(ShapeUtil::MakeNil()) {}
 
@@ -80,46 +556,156 @@ class Literal {
   Literal(const Shape& shape, bool allocate_arrays);
   Literal& operator=(Literal&& other);
 
-  // Literals are equal if they have compatible shapes and the same data
-  // values. Layout is not compared.
-  bool operator==(const Literal& other) const;
-  bool operator!=(const Literal& other) const { return !(*this == other); }
+  // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+  // mutate the shape as this can produce malformed Literals.
+  Shape* mutable_shape_do_not_use() { return shape_.get(); }
 
-  // Serialize to and from a proto.
-  static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
-      const LiteralProto& proto);
-  LiteralProto ToProto() const;
+  // Returns a MutableArraySlice view of the array for this literal for the
+  // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
+  // given ShapeIndex is not array. See primitive_util.h for the mapping from
+  // XLA type to native type.
+  template <typename NativeT>
+  tensorflow::gtl::MutableArraySlice<NativeT> data(
+      const ShapeIndex& shape_index = {});
+  // Unhide const method from parent class.
+  using LiteralBase::data;
+
+  // Returns a pointer to the sparse index array. Returns nullptr if the literal
+  // is not a sparse array.
+  SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+
+  // Returns a pointer to the underlying buffer holding the array at the given
+  // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
+  // is not array.
+  void* untyped_data(const ShapeIndex& shape_index = {});
+  // Unhide const method from parent class.
+  using LiteralBase::untyped_data;
+
+  // Populates a literal with a sparse layout with the given indices and values.
+  // Each index in the indices array is CHECKed against the dimensions in the
+  // literal's shape.  If sort is true, then the indices and values will be
+  // sorted.  If sort is false, then the indices and values are assumed to
+  // already be in sorted order.  See CreateSparse for an example of how data
+  // are populated.
+  template <typename NativeT>
+  void PopulateSparse(SparseIndexArray indices,
+                      tensorflow::gtl::ArraySlice<NativeT> values,
+                      bool sort = true);
+
+  // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+  // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+  // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+  // rooted at 'src_shape_index', but need not be arrays.
+  Status CopyFrom(const LiteralSlice& src_literal,
+                  const ShapeIndex& dest_shape_index = {},
+                  const ShapeIndex& src_shape_index = {});
+
+  // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+  // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+  // (layouts and shapes must match), but need not be arrays. The memory
+  // allocated in this literal for the subshape at dest_shape_index is
+  // deallocated, and the respective buffers are replaced with those in
+  // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+  Status MoveFrom(Literal&& src_literal,
+                  const ShapeIndex& dest_shape_index = {});
+
+  // Copies the values from src_literal, starting at src_base shape indexes,
+  // to this literal, starting at dest_base, where the copy size in each
+  // dimension is specified by copy_size.
+  // The src_literal and this literal must have the same primitive type,
+  // src_base+copy_size must fit the source literal dimensions, as well as
+  // dest_base+copy_size must fit the destination literal dimensions.
+  // Note: if either src_literal or this literal contains dimensions with zero
+  // element, then copy_size must be 0 in these dimensions while the
+  // corresponding base indices being 0.
+  // This literal and 'src_literal' must be arrays.
+  Status CopySliceFrom(const LiteralSlice& src_literal,
+                       tensorflow::gtl::ArraySlice<int64> src_base,
+                       tensorflow::gtl::ArraySlice<int64> dest_base,
+                       tensorflow::gtl::ArraySlice<int64> copy_size);
+
+  // Copies one element from src_literal[src_index] to (*this)[dest_index].
+  Status CopyElementFrom(const LiteralSlice& src_literal,
+                         tensorflow::gtl::ArraySlice<int64> src_index,
+                         tensorflow::gtl::ArraySlice<int64> dest_index);
 
-  // Return the shape of the literal.
-  const Shape& shape() const { return shape_; }
+  // Sets an element in the literal at the given index. The multi_index is
+  // CHECKed against the dimension sizes.
+  template <typename NativeT>
+  void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+           const ShapeIndex& shape_index, NativeT value);
+  // Overloads of Set for array literals. CHECKs if the literal is not
+  // array-shaped and dense.
+  template <typename NativeT>
+  void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+
+  // Appends the given element to the literal.  If the elements are not appended
+  // in sorted order, then SortSparseElements should be called before calling
+  // other methods.  This literal must have a sparse layout.
+  template <typename NativeT>
+  void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
+                           NativeT value, const ShapeIndex& shape_index = {});
+
+  // Sorts the elements in a sparse array.
+  void SortSparseElements(const ShapeIndex& shape_index = {});
+
+  // As Set(), but truncates `value` to the literal element type before storing.
+  // This literal must be an array.
+  Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+                          int64 value);
+
+  // Populate this literal with the given values. Examples:
+  //
+  //   // Populate with floats.
+  //   Array2D<float> float_values = ...
+  //   literal.PopulateR2FromArray2D(values);
+  //
+  //   // Populate with int32s.
+  //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
+  //
+  // The shape and element type of this literal must match given values. For
+  // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+  // array of S32.
+  template <typename NativeT>
+  void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+  void PopulateR1(const tensorflow::core::Bitmap& values);
+  template <typename NativeT>
+  void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
+  template <typename NativeT>
+  void PopulateFromArray(const Array<NativeT>& values);
+  template <typename NativeT>
+  void PopulateR2FromArray2D(const Array2D<NativeT>& values);
+  template <typename NativeT>
+  void PopulateR3FromArray3D(const Array3D<NativeT>& values);
+  template <typename NativeT>
+  void PopulateR4FromArray4D(const Array4D<NativeT>& values);
+
+  // Populates literal values by calling the generator function for every cell
+  // in this literal object.
+  //
+  // generator must be a callable of the type
+  // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+  //
+  // This literal must have a dense layout.
+  template <typename NativeT, typename FnType>
+  Status Populate(const FnType& generator);
 
-  // TODO(b/67651157): Remove this accessor. Literal users should not be able to
-  // mutate the shape as this can produce malformed Literals.
-  Shape* mutable_shape_do_not_use() { return &shape_; }
+  // A parallel version of Populate(). This can be used if the generator is
+  // thread-safe and the values for the shape's different elements are
+  // independent.
+  template <typename NativeT, typename FnType>
+  Status PopulateParallel(const FnType& generator);
 
-  // Returns a (Mutable)ArraySlice view of the array for this literal for the
-  // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
-  // given ShapeIndex is not array. See primitive_util.h for the mapping from
-  // XLA type to native type.
-  template <typename NativeT>
-  tensorflow::gtl::ArraySlice<NativeT> data(
-      const ShapeIndex& shape_index = {}) const;
+  // Fills this literal with the given value.
   template <typename NativeT>
-  tensorflow::gtl::MutableArraySlice<NativeT> data(
-      const ShapeIndex& shape_index = {});
+  void PopulateWithValue(NativeT value);
 
-  // Returns a pointer to the sparse index array. Returns nullptr if the literal
-  // is not a sparse array.
-  const SparseIndexArray* sparse_indices(
-      const ShapeIndex& shape_index = {}) const;
-  SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+  // Factory methods below.
+  //
 
-  // Returns a pointer to (or size of) the underlying buffer holding the array
-  // at the given shape index. CHECKs if the subshape of the literal at the
-  // given ShapeIndex is not array.
-  const void* untyped_data(const ShapeIndex& shape_index = {}) const;
-  void* untyped_data(const ShapeIndex& shape_index = {});
-  int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+  // Serialize from a proto.
+  static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+      const LiteralProto& proto);
 
   // Creates a new literal of a given rank. To minimize ambiguity (for users
   // and the compiler) these CreateR[0-2] methods should explicitly specify the
@@ -167,10 +753,6 @@ class Literal {
           values,
       const Layout& layout);
 
-  // Returns this literal's data as a string. This literal must be a rank-1 U8
-  // array.
-  string GetR1U8AsString() const;
-
   // Creates a literal with a sparse layout and the given indices and values.
   // The shape is initialized from the given dimensions.  The minor dimension of
   // the indices array must equal the rank of the shape (i.e. size of the
@@ -210,171 +792,16 @@ class Literal {
       tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
       tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
 
-  // Populates a literal with a sparse layout with the given indices and values.
-  // Each index in the indices array is CHECKed against the dimensions in the
-  // literal's shape.  If sort is true, then the indices and values will be
-  // sorted.  If sort is false, then the indices and values are assumed to
-  // already be in sorted order.  See CreateSparse for an example of how data
-  // are populated.
-  template <typename NativeT>
-  void PopulateSparse(SparseIndexArray indices,
-                      tensorflow::gtl::ArraySlice<NativeT> values,
-                      bool sort = true);
-
-  // Creates a new Literal object with the shape specified as parameter.
-  // The content of the literal values is the default value of the primitive
-  // type of literal itself (0 for numeric types, and false for predicates).
-  static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
-
-  // Creates a new Literal object with its values havings the primitive_type
-  // type, and with dimensions defined by the dimensions parameter.
-  // The content of the literal values is the default value of the primitive
-  // type of literal itself (0 for numeric types, and false for predicates).
-  static std::unique_ptr<Literal> CreateFromDimensions(
-      PrimitiveType primitive_type,
-      tensorflow::gtl::ArraySlice<int64> dimensions);
-
-  // Copy values from 'src_literal' rooted at 'src_shape_index' into this
-  // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
-  // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
-  // rooted at 'src_shape_index', but need not be arrays.
-  Status CopyFrom(const Literal& src_literal,
-                  const ShapeIndex& dest_shape_index = {},
-                  const ShapeIndex& src_shape_index = {});
-
-  // Similar to CopyFrom, but with move semantincs. The subshape of this literal
-  // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
-  // (layouts and shapes must match), but need not be arrays. The memory
-  // allocated in this literal for the subshape at dest_shape_index is
-  // deallocated, and the respective buffers are replaced with those in
-  // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
-  Status MoveFrom(Literal&& src_literal,
-                  const ShapeIndex& dest_shape_index = {});
-
-  // Copies the values from src_literal, starting at src_base shape indexes,
-  // to this literal, starting at dest_base, where the copy size in each
-  // dimension is specified by copy_size.
-  // The src_literal and this literal must have the same primitive type,
-  // src_base+copy_size must fit the source literal dimensions, as well as
-  // dest_base+copy_size must fit the destination literal dimensions.
-  // Note: if either src_literal or this literal contains dimensions with zero
-  // element, then copy_size must be 0 in these dimensions while the
-  // corresponding base indices being 0.
-  // This literal and 'src_literal' must be arrays.
-  Status CopySliceFrom(const Literal& src_literal,
-                       tensorflow::gtl::ArraySlice<int64> src_base,
-                       tensorflow::gtl::ArraySlice<int64> dest_base,
-                       tensorflow::gtl::ArraySlice<int64> copy_size);
-
-  // Copies one element from src_literal[src_index] to (*this)[dest_index].
-  Status CopyElementFrom(const Literal& src_literal,
-                         tensorflow::gtl::ArraySlice<int64> src_index,
-                         tensorflow::gtl::ArraySlice<int64> dest_index);
-
-  // Returns a vector containing the tuple elements of this Literal as separate
-  // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
-  // elements are moved into the new Literals; no data is copied. Upon return
-  // this Literal is set to a nil shape (empty tuple)
-  std::vector<Literal> DecomposeTuple();
-
-  // This operation is the inverse of DecomposeTuple. The given elements are
-  // moved into the tuple elements of a new tuple-shaped Literal which is
-  // returned. Upon return, each of the Literals in 'elements' is set to a nil
-  // shape (empty tuple).
-  static Literal MoveIntoTuple(
-      tensorflow::gtl::MutableArraySlice<Literal> elements);
-
-  // Creates a new value that has the equivalent value as this literal, but
-  // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
-  // minor-to-major dimension layout can be re-layed-out as {1, 0}
-  // minor-to-major dimension layout and the value in the cell at any given
-  // logical index (i0, i1) will be the same.
-  //
-  // For tuple shaped literals, shape_index should be used to select the inner
-  // array that the new layout applies to.
-  //
-  // Note: this is useful when the client wants to ensure that a value placed in
-  // the XLA allocation tracker has a particular layout; for efficiency
-  // purposes or avoiding unimplemented operation/layout combinations.
-  std::unique_ptr<Literal> Relayout(const Layout& new_layout,
-                                    const ShapeIndex& shape_index = {}) const;
-
-  // An overload of Relayout which changes the layout of the entire shape rather
-  // than being limited to a single array within the shape.
-  std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
-
-  // Creates a new literal by reshaping this literal to have the given
-  // dimensions. The total number of elements must not change; The
-  // implementation currently only supports monotonic dim0-major layouts.
-  // This literal must be an array.
-  StatusOr<std::unique_ptr<Literal>> Reshape(
-      tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
-  // Creates a new literal by reordering the dimensions of this literal.
-  // The given `permutation` must be a permutation of the dimension numbers
-  // in the original literal, and it specifies the order of the new dimensions
-  // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
-  // For example, a transpose call on a literal of shape [3 x 8 x 4] and
-  // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
-  // This literal must be an array.
-  std::unique_ptr<Literal> Transpose(
-      tensorflow::gtl::ArraySlice<int64> permutation) const;
-
-  // Creates a sub-array from this literal by extracting the indices
-  // [start_index, limit_index) of each dimension. The result literal has the
-  // same rank and layout as for the given literal. The number of indices in
-  // start_indices and limit_indices must be the rank of the literal, and the
-  // indices follow the order of the dimensions.
-  // This literal must be an array.
-  std::unique_ptr<Literal> Slice(
-      tensorflow::gtl::ArraySlice<int64> start_indices,
-      tensorflow::gtl::ArraySlice<int64> limit_indices) const;
-
-  // Creates a literal with a prepended dimension with bound "times"; e.g. a
-  // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
-  // literal replicated four times.
-  // This literal must be an array.
-  template <typename NativeT>
-  std::unique_ptr<Literal> Replicate(int64 times) const;
-
-  // Converts this literal to another primitive type using
-  // static_cast<>. Returns an error if the conversion is not possible. This
-  // literal must be array-shaped.
-  StatusOr<std::unique_ptr<Literal>> Convert(
-      PrimitiveType primitive_dest_type) const;
-
-  // Converts this literal to another primitive type using a bitcast
-  // conversion. The to and from primitive types must have the same bit
-  // width. Returns an error if the conversion is not possible. This literal
-  // must be array-shaped.
-  StatusOr<std::unique_ptr<Literal>> BitcastConvert(
-      PrimitiveType primitive_dest_type) const;
-
-  // Converts this literal to the given shape. Returns an error is the
-  // conversion is not possible.
-  //
-  // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
-  // instead of truncation; otherwise, truncation is used.
-  //
-  // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
-  // the default behavior.
-  StatusOr<std::unique_ptr<Literal>> ConvertToShape(
-      const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
-
   // Creates a scalar literal value zero of the given primitive type.
   static Literal Zero(PrimitiveType primitive_type);
-
   // Creates a scalar literal value one of the given primitive type.
   static Literal One(PrimitiveType primitive_type);
-
   // Creates a scalar literal value containing the minimum value of the given
   // primitive type. For floating-point types, returns -inf.
   static Literal MinValue(PrimitiveType primitive_type);
-
   // Creates a scalar literal value containing the maximum value of the given
   // primitive type. For floating-point types, returns inf.
   static Literal MaxValue(PrimitiveType primitive_type);
-
   // Creates a literal of the given shape where each element is `value`.
   template <typename NativeT>
   static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
@@ -419,88 +846,15 @@ class Literal {
   // the z dimension given by "projection".
   template <typename NativeT>
   static std::unique_ptr<Literal> CreateR3Projected(
-      std::initializer_list<std::initializer_list<NativeT>> values,
-      int64 projection);
-
-  // Creates a literal that projects the (x, y) dimensions given in values into
-  // the z and p dimensions given.
-  template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR4Projected(
-      std::initializer_list<std::initializer_list<NativeT>> values,
-      int64 projection_p, int64 projection_z);
-
-  // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
-  Literal Clone() const;
-  std::unique_ptr<Literal> CloneToUnique() const;
-
-  // Gets or sets an element in the literal at the given index. The multi_index
-  // is CHECKed against the dimension sizes.
-  template <typename NativeT>
-  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
-              const ShapeIndex& shape_index) const;
-  template <typename NativeT>
-  void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
-           const ShapeIndex& shape_index, NativeT value);
-
-  // Overloads of Get and Set for array literals. CHECKs if the literal is not
-  // array-shaped and dense.
-  template <typename NativeT>
-  NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
-  template <typename NativeT>
-  void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
-
-  // Returns the multi-index of the element in a sparse literal at the given
-  // sparse element number.  The sparse element number is the position with in
-  // the sparse array's list of (index, value) pairs, and is checked against the
-  // total number of (index, value) pairs in the sparse array.
-  tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
-      int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
-
-  // Returns the value of the element in a sparse literal at the given sparse
-  // element number.  The sparse element number is the position with in the
-  // sparse array's list of (index, value) pairs, and is checked against the
-  // total number of (index, value) pairs in the sparse array.
-  template <typename NativeT>
-  NativeT GetSparseElement(int64 sparse_element_number,
-                           const ShapeIndex& shape_index = {}) const;
-
-  // Appends the given element to the literal.  If the elements are not appended
-  // in sorted order, then SortSparseElements should be called before calling
-  // other methods.  This literal must have a sparse layout.
-  template <typename NativeT>
-  void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
-                           NativeT value, const ShapeIndex& shape_index = {});
-
-  // Sorts the elements in a sparse array.
-  void SortSparseElements(const ShapeIndex& shape_index = {});
-
-  // Returns the element value at index (0, ..., 0), however many zeroes are
-  // required for that index.
-  template <typename NativeT>
-  NativeT GetFirstElement() const;
-
-  // Returns a literal scalar representing the first element.
-  Literal GetFirstScalarLiteral() const;
-
-  // As Get(), but determines the correct type and converts the value
-  // into text.
-  string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
-                     const ShapeIndex& shape_index = {}) const;
-
-  // As GetSparseElement(), but determines the correct type and converts the
-  // value into text.
-  string GetSparseElementAsString(int64 sparse_element_number,
-                                  const ShapeIndex& shape_index = {}) const;
-
-  // As Get(), but determines the correct type and converts the value into
-  // int64.  This literal must be an array.
-  StatusOr<int64> GetIntegralAsS64(
-      tensorflow::gtl::ArraySlice<int64> multi_index) const;
+      std::initializer_list<std::initializer_list<NativeT>> values,
+      int64 projection);
 
-  // As Set(), but truncates `value` to the literal element type before storing.
-  // This literal must be an array.
-  Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
-                          int64 value);
+  // Creates a literal that projects the (x, y) dimensions given in values into
+  // the z and p dimensions given.
+  template <typename NativeT>
+  static std::unique_ptr<Literal> CreateR4Projected(
+      std::initializer_list<std::initializer_list<NativeT>> values,
+      int64 projection_p, int64 projection_z);
 
   // Returns an identity matrix (rank 2) with the given row and column count.
   template <typename NativeT>
@@ -511,6 +865,9 @@ class Literal {
   static std::unique_ptr<Literal> MakeTuple(
       tensorflow::gtl::ArraySlice<const Literal*> elements);
 
+  static std::unique_ptr<Literal> MakeTupleFromSlices(
+      tensorflow::gtl::ArraySlice<LiteralSlice> elements);
+
   // As above, but intended to be invoked with move semantics; i.e.
   //
   //  std::vector<std::unique_ptr<Literal>> elements = ...;
@@ -542,135 +899,48 @@ class Literal {
     return MakeTupleOwned(std::move(v));
   }
 
-  // Returns a string representation of the literal value.
-  // Warning: this function can take minutes for multi-million element Literals.
-  string ToString(bool print_layout = false) const;
-
-  // Invokes the "per cell" callback for each element in the provided
-  // literal with the element's indices and a string representation of
-  // the element's value.
-  //
-  // This function is useful if you want a polymorphic representation
-  // of the tensor's elements (turning it to a string for something
-  // like representation in a protobuf).
-  //
-  // This literal must have a dense layout.
-  void EachCellAsString(
-      const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
-                               const string& value)>& per_cell) const;
-  template <typename NativeT>
-  void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
-                                   NativeT value)>
-                    per_cell) const;
-
-  // Populate this literal with the given values. Examples:
-  //
-  //   // Populate with floats.
-  //   Array2D<float> float_values = ...
-  //   literal.PopulateR2FromArray2D(values);
-  //
-  //   // Populate with int32s.
-  //   literal.PopulateR2<int32>({{1, 2}, {3, 4}});
-  //
-  // The shape and element type of this literal must match given values. For
-  // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
-  // array of S32.
-  template <typename NativeT>
-  void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
-  void PopulateR1(const tensorflow::core::Bitmap& values);
-  template <typename NativeT>
-  void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
-  template <typename NativeT>
-  void PopulateFromArray(const Array<NativeT>& values);
-  template <typename NativeT>
-  void PopulateR2FromArray2D(const Array2D<NativeT>& values);
-  template <typename NativeT>
-  void PopulateR3FromArray3D(const Array3D<NativeT>& values);
-  template <typename NativeT>
-  void PopulateR4FromArray4D(const Array4D<NativeT>& values);
-
-  // Populates literal values by calling the generator function for every cell
-  // in this literal object.
-  //
-  // generator must be a callable of the type
-  // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
-  //
-  // This literal must have a dense layout.
-  template <typename NativeT, typename FnType>
-  Status Populate(const FnType& generator);
-
-  // A parallel version of Populate(). This can be used if the generator is
-  // thread-safe and the values for the shape's different elements are
-  // independent.
-  template <typename NativeT, typename FnType>
-  Status PopulateParallel(const FnType& generator);
-
-  // Fills this literal with the given value.
-  template <typename NativeT>
-  void PopulateWithValue(NativeT value);
+  // Returns a vector containing the tuple elements of this Literal as separate
+  // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+  // elements are moved into the new Literals; no data is copied. Upon return
+  // this Literal is set to a nil shape (empty tuple)
+  std::vector<Literal> DecomposeTuple();
 
-  // Returns whether every element in this literal is equal to value.
-  //
-  // value is an int8 because we expect this to be called with small
-  // compile-time constants (0, -1, etc.) and so that whatever value you pass
-  // can be represented exactly by floating-point types as small as 16 bits.
-  //
-  // If value doesn't fit in this literal's type, returns false.  Values of 1/0
-  // are considered equal to true/false; other values are not considered equal
-  // to true. Also if this literal is not array-shaped false is returned.
-  bool IsAll(int8 value) const;
+  // This operation is the inverse of DecomposeTuple. The given elements are
+  // moved into the tuple elements of a new tuple-shaped Literal which is
+  // returned. Upon return, each of the Literals in 'elements' is set to a nil
+  // shape (empty tuple).
+  static Literal MoveIntoTuple(
+      tensorflow::gtl::MutableArraySlice<Literal> elements);
 
-  // Like IsAll(const Literal&, int8), except we check whether the literal is
-  // equal to a particular floating-point number.
-  //
-  // If the literal is not a floating-point value, this always returns false.
-  //
-  // This casts value to the type of literal, then compares using ==.  The usual
-  // admonishments about floating-point equality checks apply.  We expect you to
-  // use this to check for values that can be expressed precisely as a float,
-  // e.g. -0.5.  Also if this literal is not array-shaped false is returned.
-  bool IsAllFloat(float value) const;
+  // Creates a new Literal object with its values havings the primitive_type
+  // type, and with dimensions defined by the dimensions parameter.
+  // The content of the literal values is the default value of the primitive
+  // type of literal itself (0 for numeric types, and false for predicates).
+  static std::unique_ptr<Literal> CreateFromDimensions(
+      PrimitiveType primitive_type,
+      tensorflow::gtl::ArraySlice<int64> dimensions);
 
-  // Like IsAll(const Literal&, int8), except we check whether the literal is
-  // equal to a particular complex number.
-  //
-  // If the literal is not a complex value, this always returns false.
-  //
-  // This casts value to the type of literal, then compares using ==.  The usual
-  // admonishments about floating-point equality checks apply.  We expect you to
-  // use this to check for complex values that can be expressed precisely as
-  // float pairs e.g. (-0.5, 1.0).
   //
-  // This literal must have a dense layout.
-  bool IsAllComplex(complex64 value) const;
+  // End of factory methods.
 
-  // Literal consists entirely of the first element of the literal.
-  bool IsAllFirst() const;
-
-  // Returns whether this literal is zero at the specified index. This literal
-  // must be an array with a dense layout.
-  bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+ protected:
+  // Recursively sets the subshapes and buffers of all subpieces rooted at
+  // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+  // the shape.
+  void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
 
-  // Return the count of the elements in the array at the given shape index in
-  // this literal.
-  int64 element_count(const ShapeIndex& index = {}) const {
-    return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+  // Returns the piece at the given ShapeIndex.
+  Piece& piece(const ShapeIndex& shape_index) {
+    return const_cast<Piece&>(LiteralBase::piece(shape_index));
   }
 
-  // Return the count of the elements in the sparse array at the given shape
-  // index in this literal, which will be no larger than
-  // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
-  int64 sparse_element_count() const;
-
-  // Compute a hash for this literal.  This literal must not be a sparse tensor
-  // or a tuple containing a sparse tensor.
-  size_t Hash() const;
+  Piece& root_piece() const override { return *root_piece_; };
 
- protected:
+ private:
   // Internal template helper for the Literal::CopySliceFrom(), matching its
   // arguments one by one.
   template <typename NativeT>
-  Status CopySliceFromInternal(const Literal& src_literal,
+  Status CopySliceFromInternal(const LiteralBase& src_literal,
                                tensorflow::gtl::ArraySlice<int64> src_base,
                                tensorflow::gtl::ArraySlice<int64> dest_base,
                                tensorflow::gtl::ArraySlice<int64> copy_size);
@@ -698,162 +968,40 @@ class Literal {
     int64 minor_loop_size = 1;
   };
 
-  // A data structure representing a subshape at a particular ShapeIndex within
-  // the literal. For array-shaped ShapeIndexes, this data structure holds the
-  // pointer to the memory allocated for the array data.
-  class Piece {
-   public:
-    // Return the buffer holding the array data for this piece as an array
-    // slice. This piece must be array-shaped.
-    template <typename NativeT>
-    tensorflow::gtl::ArraySlice<NativeT> data() const;
-    template <typename NativeT>
-    tensorflow::gtl::MutableArraySlice<NativeT> data();
-
-    // Return the buffer holding the array data for this piece as a void*. This
-    // piece must be array-shaped.
-    void* untyped_data();
-    const void* untyped_data() const;
-
-    // Gets or sets an element in the array at the given index. The multi_index
-    // is CHECKed against the dimension sizes of the array.  This piece must be
-    // array-shaped.
-    template <typename NativeT>
-    NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
-    template <typename NativeT>
-    void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
-
-    // Gets/sets the buffer holding the array data.
-    char* buffer() const { return buffer_; }
-    void set_buffer(char* buffer) { buffer_ = buffer; }
-
-    // The array of multi-indices that provide the locations of non-zero
-    // elements in a sparse array.  Only used if
-    // LayoutUtil::IsSparseArray(shape()) is true.
-    SparseIndexArray* sparse_indices() const { return sparse_indices_; }
-    void set_sparse_indices(SparseIndexArray* sparse_indices) {
-      sparse_indices_ = sparse_indices;
-    }
-
-    // Gets or sets the subshape of this piece. This reference points to a
-    // subshape within the shape in the containing Literal (Literal::shape_).
-    const Shape& subshape() const { return *subshape_; }
-    void set_subshape(const Shape* subshape) { subshape_ = subshape; }
-
-    // Returns the size in bytes of the buffer holding the array data.
-    int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
-
-    // Returns the number of elements in this piece's array.
-    int64 element_count() const {
-      // If this is a sparse array, use the number of elements represented by
-      // the indices in the associated SparseIndexArray.
-      return LayoutUtil::IsSparseArray(subshape())
-                 ? sparse_indices()->index_count()
-                 : ShapeUtil::ElementsIn(subshape());
-    }
-
-    // Copy the data from 'src' into this piece's buffer. Shapes of this piece
-    // and src must be compatible.
-    Status CopyFrom(const Piece& src);
-
-    // Returns true if this piece and 'other' contain the same data. This piece
-    // and 'other' must be array-shaped and compatible.
-    bool EqualElements(const Piece& other) const;
-
-    // Writes the shape and data (if array-shaped) into the given proto.
-    void WriteToProto(LiteralProto* proto) const;
-
-    // Copies the data from the given proto into this piece. The shape of this
-    // piece must be equal (not just compatible) to the shape of the proto.
-    Status CopyFromProto(const LiteralProto& proto);
-
-    // Sorts the elements in a sparse array.
-    void SortSparseElements();
-
-   private:
-    // Recursive helper for EqualElements.
-    template <typename NativeT>
-    bool EqualElementsInternal(const Piece& other,
-                               std::vector<int64>* multi_index) const;
-
-    // Helper for SortSparseElements that has the element type as a template
-    // parameter.
-    template <typename NativeT>
-    void SortSparseElementsInternal();
-
-    // For array-shaped pieces, this is the buffer holding the literal data.
-    char* buffer_ = nullptr;
-
-    // For sparse arrays, this is the array of indices.
-    SparseIndexArray* sparse_indices_ = nullptr;
-
-    // The shape of piece. This points into the shape of the containing Literal
-    // (Literal::shape_).
-    const Shape* subshape_ = nullptr;
-  };
-
-  // Returns the piece at the given ShapeIndex.
-  Piece& piece(const ShapeIndex& shape_index) {
-    return *pieces_.mutable_element(shape_index);
-  }
-  const Piece& piece(const ShapeIndex& shape_index) const {
-    return pieces_.element(shape_index);
-  }
-
-  // Returns the piece at the root of the shape (empty ShapeIndex).
-  Piece& root_piece() { return piece({}); }
-  const Piece& root_piece() const { return piece({}); }
+  // Literal class always owns the shape. The parent class borrows this shape.
+  std::unique_ptr<Shape> shape_;
 
-  // Deallocate the buffers held by this literal (if the literal owns the
-  // buffer).
-  void DeallocateBuffers();
+  Piece* root_piece_ = nullptr;
 
   // Implementation details shared between Populate() and PopulateParallel()
   template <typename NativeT, typename FnType>
   Status PopulateInternal(const FnType& generator, bool parallel);
 
-  Shape shape_;
-  ShapeTree<Piece> pieces_;
-
-  // Whether the buffers held in pieces_ are owned by this Literal.
-  bool owns_buffers_;
+  // Deallocate the buffers held by this literal.
+  void DeallocateBuffers();
 
-  // LiteralView must access and manipulate Pieces of other Literals.
-  friend class LiteralView;
-};  // namespace xla
+  friend class LiteralBase;
+};
 
 std::ostream& operator<<(std::ostream& out, const Literal& literal);
 
-// A read-only view of a Literal. A LiteralView contains pointers to buffers
-// owned by the viewed Literal.
-//
-// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
-// and mutable) similar to (Mutable)ArraySlice.
-class LiteralView : public Literal {
+// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
+// literal buffers always owned by others.
+class LiteralSlice : public LiteralBase {
  public:
-  // Create and return a view of the given literal rooted at the given shape
-  // index within the given literal. A factory is used rather than a public
-  // constructor because only const LiteralViews are supported. It's still
-  // possible to create non-const LiteralViews via the copy constructors, but
-  // the factory method makes it a bit less likely. Implementing literal slices
-  // will fix this undesirable situation (b/71550060).
-  static const LiteralView Create(const Literal& literal,
-                                  const ShapeIndex& view_root = {});
-
-  LiteralView(const LiteralView& other);
-  LiteralView& operator=(const LiteralView& other);
-
-  virtual ~LiteralView();
+  LiteralSlice() : LiteralBase() {}
+  // Implicit conversion constructor that can also accept Literal.
+  LiteralSlice(const LiteralBase& literal);
+  LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
 
  private:
-  LiteralView(const Literal& literal, const ShapeIndex& view_root);
+  const Piece& root_piece() const override { return *root_piece_; };
 
-  // Helper for the copy constructor and copy assignment operator.
-  void CopyFrom(const LiteralView& other);
+  const Piece* root_piece_;  // Not owned.
 };
 
 template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
   CHECK_EQ(subshape().element_type(),
            primitive_util::NativeToPrimitiveType<NativeT>())
@@ -866,7 +1014,7 @@ tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
 }
 
 template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
+tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
   CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
   CHECK_EQ(subshape().element_type(),
            primitive_util::NativeToPrimitiveType<NativeT>())
@@ -879,7 +1027,7 @@ tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
 }
 
 template <typename NativeT>
-NativeT Literal::Piece::Get(
+NativeT LiteralBase::Piece::Get(
     tensorflow::gtl::ArraySlice<int64> multi_index) const {
   CHECK(LayoutUtil::IsDenseArray(subshape()));
   return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
@@ -887,15 +1035,15 @@ NativeT Literal::Piece::Get(
 }
 
 template <typename NativeT>
-void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
-                         NativeT value) {
+void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+                             NativeT value) {
   CHECK(LayoutUtil::IsDenseArray(subshape()));
   data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
       subshape(), multi_index)] = value;
 }
 
 template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> Literal::data(
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
     const ShapeIndex& shape_index) const {
   return piece(shape_index).data<NativeT>();
 }
@@ -907,13 +1055,13 @@ tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
 }
 
 template <typename NativeT>
-inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
-                            const ShapeIndex& shape_index) const {
+inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+                                const ShapeIndex& shape_index) const {
   return piece(shape_index).Get<NativeT>(multi_index);
 }
 
 template <typename NativeT>
-inline NativeT Literal::Get(
+inline NativeT LiteralBase::Get(
     tensorflow::gtl::ArraySlice<int64> multi_index) const {
   return root_piece().Get<NativeT>(multi_index);
 }
@@ -1160,13 +1308,13 @@ template <typename NativeT>
 }
 
 template <typename NativeT>
-NativeT Literal::GetFirstElement() const {
+NativeT LiteralBase::GetFirstElement() const {
   return data<NativeT>().at(0);
 }
 
 template <typename NativeT>
-NativeT Literal::GetSparseElement(int64 sparse_element_number,
-                                  const ShapeIndex& shape_index) const {
+NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
+                                      const ShapeIndex& shape_index) const {
   CHECK(
       LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
   return data<NativeT>(shape_index)[sparse_element_number];
@@ -1199,7 +1347,7 @@ template <typename NativeT>
 }
 
 template <typename NativeT>
-void Literal::EachCell(
+void LiteralBase::EachCell(
     std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
                        NativeT value)>
         per_cell) const {
@@ -1375,7 +1523,7 @@ template <typename NativeT>
 }
 
 template <typename NativeT>
-std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
+std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
   DimensionVector bounds = {times};
   bounds.reserve(shape().dimensions_size() + 1);
   for (int64 bound : shape().dimensions()) {
index 6104678..087d509 100644 (file)
@@ -974,7 +974,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) {
                                    Literal::CreateR1<double>({2.0, 4.0}).get(),
                                    &nil_literal});
 
-  EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+  EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
   EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
   EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
   EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
@@ -985,7 +985,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) {
                                       /*src_shape_index=*/{}));
 
   // The matrix element should be unchanged.
-  EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+  EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
 
   // The tuple element should have been copied from 'tuple'.
   EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
@@ -1373,36 +1373,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
   ASSERT_EQ(h1, r[3]);
 }
 
-TEST_F(LiteralUtilTest, LiteralViewTest) {
+TEST_F(LiteralUtilTest, LiteralSliceTest) {
   auto scalar = Literal::CreateR0<float>(1.0);
   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
   Literal nil(ShapeUtil::MakeNil());
 
-  EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar);
-  EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix);
-  EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple);
-  EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple);
-  EXPECT_EQ(LiteralView::Create(nil, {}), nil);
+  EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
+  EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
+  EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
+  EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+  EXPECT_EQ(LiteralSlice(nil, {}), nil);
 
-  EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar);
-  EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix);
+  EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
+  EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
 
-  EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple);
-  EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar);
-  EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix);
-  EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar);
+  EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
+  EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
+  EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
+  EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
 }
 
-TEST_F(LiteralUtilTest, MutatingLiteralView) {
+TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
   auto scalar = Literal::CreateR0<float>(1.0);
   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
   // Verify that changing the underlying data beneath the view changes the
   // data of the view itself.
-  const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+  const auto nested_tuple_view = LiteralSlice(*nested_tuple);
   EXPECT_EQ(
       nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
       1.0f);
@@ -1418,16 +1418,15 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) {
             555.0f);
 }
 
-TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) {
+TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
   auto scalar = Literal::CreateR0<float>(1.0);
   auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
   auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
 
-  const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
-  const auto tuple_view =
-      LiteralView::Create(nested_tuple_view, /*view_root=*/{0});
-  const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1});
+  const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+  const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
+  const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
   EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
 }
 
@@ -1533,11 +1532,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
 }
 
-TEST_F(LiteralUtilTest, LiteralViewCopy) {
+TEST_F(LiteralUtilTest, LiteralSliceCopy) {
   std::unique_ptr<Literal> matrix =
       Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  const auto matrix_view = LiteralView::Create(*matrix);
-  LiteralView matrix_view_copy(matrix_view);
+  const auto matrix_view = LiteralSlice(*matrix);
+  LiteralSlice matrix_view_copy(matrix_view);
 
   EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
   EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
index dc6f5fe..68648a3 100644 (file)
@@ -340,13 +340,13 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
   return result;
 }
 
-PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
+PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
   if (ShapeUtil::IsTuple(literal.shape())) {
     int num_elements = ShapeUtil::TupleElementCount(literal.shape());
     PyObject* tuple = PyTuple_New(num_elements);
     for (int i = 0; i < num_elements; i++) {
-      PyTuple_SET_ITEM(
-          tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i})));
+      PyTuple_SET_ITEM(tuple, i,
+                       PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
     }
     return tuple;
   } else {
@@ -431,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
   return Status::OK();
 }
 
-void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
                              PyArrayObject* py_array) {
   switch (np_type) {
     case NPY_BOOL:
index 9656cb1..64f0aae 100644 (file)
@@ -74,7 +74,7 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
 // array data.
 //
 // The return value is a new reference.
-PyObject* PyObjectFromXlaLiteral(const Literal& literal);
+PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
 
 // Converts a Numpy ndarray or a nested Python tuple thereof to a
 // corresponding XLA literal.
@@ -90,7 +90,7 @@ StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
 Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
                                Literal* literal);
 
-void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
                              PyArrayObject* py_array);
 
 template <typename NativeT>
@@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
 }
 
 template <typename NativeT>
-void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
+void CopyLiteralToNumpyArray(const LiteralSlice& literal,
+                             PyArrayObject* py_array) {
   NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
   auto source = literal.data<NativeT>();
   std::copy(source.begin(), source.end(), dest);
index 4ec79a0..3ce80bb 100644 (file)
@@ -501,13 +501,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
 }
 
 static HloInstruction* BuildTupleConstant(HloComputation* computation,
-                                          const Literal& literal) {
+                                          const LiteralSlice& literal) {
   if (ShapeUtil::IsTuple(literal.shape())) {
     std::vector<HloInstruction*> elems;
     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
     for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
       elems.push_back(
-          BuildTupleConstant(computation, LiteralView::Create(literal, {i})));
+          BuildTupleConstant(computation, LiteralSlice(literal, {i})));
     }
     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
   } else {
index 9b39e7f..d97802e 100644 (file)
@@ -88,8 +88,8 @@ CpuTransferManager::CpuTransferManager()
     : GenericTransferManager(se::host::kHostPlatformId,
                              /*pointer_size=*/sizeof(void*)) {}
 
-Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
-                                                   const Literal& literal) {
+Status CpuTransferManager::TransferLiteralToInfeed(
+    se::StreamExecutor* executor, const LiteralSlice& literal) {
   const Shape& shape = literal.shape();
   VLOG(2) << "Transferring literal to infeed with shape: "
           << ShapeUtil::HumanString(shape);
index 3ecb0d2..6dfc666 100644 (file)
@@ -38,7 +38,7 @@ class CpuTransferManager : public GenericTransferManager {
   ~CpuTransferManager() override {}
 
   Status TransferLiteralToInfeed(se::StreamExecutor* executor,
-                                 const Literal& literal) override;
+                                 const LiteralSlice& literal) override;
   Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
                                 const void* source) override;
   Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
index 7dcc4ca..c562865 100644 (file)
@@ -26,13 +26,13 @@ limitations under the License.
 
 namespace xla {
 namespace cpu {
-void ExternalConstantPool::Insert(string name, const Literal& literal,
+void ExternalConstantPool::Insert(string name, const LiteralSlice& literal,
                                   int64 alignment) {
   CHECK(!ShapeUtil::IsTuple(literal.shape()));
   CHECK(alignment > 0 && IsPowerOfTwo(static_cast<uint64>(alignment)));
   CHECK(entries_.find(name) == entries_.end());
 
-  int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
+  const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
   void* raw_pointer = tensorflow::port::AlignedMalloc(
       literal_size, std::max<size_t>(alignment, sizeof(void*)));
   CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
index 8008a56..0677f5f 100644 (file)
@@ -43,7 +43,7 @@ class ExternalConstantPool {
   // The constant pool copies out the contents of `literal` into a buffer it
   // owns -- it does not keep pointers to `literal`, or to memory owned by
   // `literal`.
-  void Insert(string name, const Literal& literal, int64 alignment);
+  void Insert(string name, const LiteralSlice& literal, int64 alignment);
 
   // Find the constant with name `name` in this constant pool.  If there isn't
   // such constant, return nullptr.
index ddb6873..dbf1ab6 100644 (file)
@@ -115,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice(
           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
                        device_memory.size());
           // Element is array-shaped: transfer array data to device buffer.
-          const auto subliteral = LiteralView::Create(literal, index);
+          const auto subliteral = LiteralSlice(literal, index);
           std::unique_ptr<Literal> relayed_out_literal;
           const void* source;
           if (LayoutUtil::Equal(device_subshape.layout(),
@@ -137,7 +137,7 @@ Status GenericTransferManager::TransferLiteralToDevice(
 }
 
 Status GenericTransferManager::TransferLiteralToInfeed(
-    se::StreamExecutor* executor, const Literal& literal) {
+    se::StreamExecutor* executor, const LiteralSlice& literal) {
   return Unimplemented("Generic transfer to Infeed");
 }
 
index 0579099..3343eca 100644 (file)
@@ -49,7 +49,7 @@ class GenericTransferManager : public TransferManager {
                                  const ShapedBuffer& device_buffer) override;
 
   Status TransferLiteralToInfeed(se::StreamExecutor* executor,
-                                 const Literal& literal) override;
+                                 const LiteralSlice& literal) override;
   Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
                                     const Shape& literal_shape,
                                     Literal* literal) override;
index f13727c..7bb8df6 100644 (file)
@@ -44,8 +44,8 @@ GpuTransferManager::GpuTransferManager()
           /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout)
               .getPointerSize(0 /* default address space */)) {}
 
-Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
-                                                   const Literal& literal) {
+Status GpuTransferManager::TransferLiteralToInfeed(
+    se::StreamExecutor* executor, const LiteralSlice& literal) {
   const Shape& shape = literal.shape();
   VLOG(2) << "Transferring literal to infeed with shape: "
           << ShapeUtil::HumanString(shape);
index d040a99..09f8227 100644 (file)
@@ -37,7 +37,7 @@ class GpuTransferManager : public GenericTransferManager {
   ~GpuTransferManager() override {}
 
   Status TransferLiteralToInfeed(se::StreamExecutor* executor,
-                                 const Literal& literal) override;
+                                 const LiteralSlice& literal) override;
   Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
                                 const void* source) override;
 
index fffe192..63eaf6f 100644 (file)
@@ -56,8 +56,8 @@ using tensorflow::gtl::FlatSet;
 
 template <typename OperandT>
 StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
-                                           const Literal& lhs_literal,
-                                           const Literal& rhs_literal) {
+                                           LiteralSlice lhs_literal,
+                                           LiteralSlice rhs_literal) {
   std::function<bool(OperandT, OperandT)> compare_op;
   switch (opcode) {
     case HloOpcode::kEq:
@@ -106,8 +106,8 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
 
 template <>
 StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
-    const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
-    const Literal& rhs_literal) {
+    const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
+    LiteralSlice rhs_literal) {
   std::function<bool(complex64, complex64)> compare_op;
   switch (opcode) {
     case HloOpcode::kEq:
index d82b4f0..55c544f 100644 (file)
@@ -81,7 +81,7 @@ class TransferManager {
   // Transfers the given literal into the Infeed interface of the device,
   // using the given executor.
   virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
-                                         const Literal& literal) = 0;
+                                         const LiteralSlice& literal) = 0;
 
   // Transfers the given literal from the Outfeed interface of the device,
   // using the given executor.
index 6ebbf71..a180cdd 100644 (file)
@@ -87,11 +87,11 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
 
   LiteralTestUtil::ExpectNear(
       *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
-      LiteralView::Create(*result, {0}), error_spec_);
+      LiteralSlice(*result, {0}), error_spec_);
 
   LiteralTestUtil::ExpectNear(
       *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
-      LiteralView::Create(*result, {1}), error_spec_);
+      LiteralSlice(*result, {1}), error_spec_);
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
index 0b425b9..abf7312 100644 (file)
@@ -91,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
       auto result,
       client_->ExecuteAndTransfer(computation, {}, &execution_options));
   LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
-                                        LiteralView::Create(*result, {0}));
+                                        LiteralSlice(*result, {0}));
   LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
-                                        LiteralView::Create(*result, {1}));
+                                        LiteralSlice(*result, {1}));
 
   EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
index 4743673..d518e4a 100644 (file)
@@ -169,9 +169,9 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) {
       ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
 
   LiteralTestUtil::ExpectR2Near<float>(
-      {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_);
+      {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_);
   LiteralTestUtil::ExpectR1Near<float>(
-      {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_);
+      {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_);
 }
 
 }  // namespace
index c28f79a..868876c 100644 (file)
@@ -111,7 +111,7 @@ namespace {
 // Return a literal with all arrays of type FromNativeT converted to type
 // ToNativeT in the given literal.
 template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(const Literal& literal) {
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
   // First construct shape of the result.
   Shape result_shape(literal.shape());
   ShapeUtil::ForEachMutableSubshape(
@@ -150,12 +150,12 @@ std::unique_ptr<Literal> ConvertType(const Literal& literal) {
 }  // namespace
 
 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
-    const Literal& literal) {
+    LiteralSlice literal) {
   return ConvertType<bfloat16, float>(literal);
 }
 
 /* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
-    const Literal& literal) {
+    LiteralSlice literal) {
   return ConvertType<float, bfloat16>(literal);
 }
 
@@ -237,7 +237,7 @@ template <>
 // actual literal and compares their values elementwise. Returns true if all
 // elements are equal.
 template <typename NativeT>
-bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
+bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
                          tensorflow::gtl::MutableArraySlice<int64> multi_index,
                          int64 dimension) {
   if (dimension == expected.shape().dimensions_size()) {
@@ -259,8 +259,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
 
 }  // namespace
 
-/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
-                                               const Literal& actual,
+/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
+                                               LiteralSlice actual,
                                                const string& message) {
   EXPECT_TRUE(Equal(expected, actual))
       << "expected:\n"
@@ -269,13 +269,13 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
       << (message.empty() ? "" : StrCat("\nmessage: ", message));
 }
 
-/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
-                                                  const Literal& actual) {
+/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
+                                                  LiteralSlice actual) {
   EXPECT_FALSE(Equal(expected, actual));
 }
 
 /* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
-    const Literal& expected, const Literal& actual) {
+    LiteralSlice expected, LiteralSlice actual) {
   VLOG(1) << "expected:";
   XLA_VLOG_LINES(1, expected.ToString());
   VLOG(1) << "actual:";
@@ -324,9 +324,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
         SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
                             ShapeUtil::HumanString(expected.shape())));
 
-        // Create LiteralViews of the expected and actual elements.
-        auto result = Equal(LiteralView::Create(expected, {i}),
-                            LiteralView::Create(actual, {i}));
+        // Create LiteralSlices of the expected and actual elements.
+        auto result =
+            Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
         tuple_match = tuple_match ? !!result : false;
       }
       match = tuple_match;
@@ -368,7 +368,7 @@ int64 RecursiveElementCount(const Shape& shape) {
 // 3 minutes.  The utility of printing a literal with >1000 elements is
 // questionable, especially when writing the Literal proto to disk is orders
 // of magnitude faster.
-string TruncateHugeLiteral(const Literal& literal) {
+string TruncateHugeLiteral(LiteralSlice literal) {
   return RecursiveElementCount(literal.shape()) < 1000
              ? literal.ToString()
              : "[TRUNCATED, Literal with more than 1000 values]";
@@ -435,8 +435,8 @@ class NearComparator {
   // result. The assertion result is successful if all actual and expected
   // elements are within the given error bound. In case of error, the assertion
   // result contains a detailed error message in case of failure.
-  static ::testing::AssertionResult Compare(const Literal& expected,
-                                            const Literal& actual,
+  static ::testing::AssertionResult Compare(LiteralSlice expected,
+                                            LiteralSlice actual,
                                             ErrorSpec error,
                                             bool detailed_message) {
     NearComparator<NativeT> comparator(expected, actual, error,
@@ -472,7 +472,7 @@ class NearComparator {
     }
   };
 
-  explicit NearComparator(const Literal& expected, const Literal& actual,
+  explicit NearComparator(LiteralSlice expected, LiteralSlice actual,
                           ErrorSpec error, bool detailed_message)
       : expected_(expected),
         actual_(actual),
@@ -649,7 +649,7 @@ class NearComparator {
   }
 
   // Writes the given literal to a file in the test temporary directory.
-  void WriteLiteralToTempFile(const Literal& literal, const string& name) {
+  void WriteLiteralToTempFile(LiteralSlice literal, const string& name) {
     int64 now_usec = tensorflow::Env::Default()->NowMicros();
     string filename = tensorflow::io::JoinPath(
         tensorflow::testing::TmpDir(),
@@ -733,8 +733,8 @@ class NearComparator {
   }
 
   // 'actual' and 'expected' literals being compared.
-  const Literal& expected_;
-  const Literal& actual_;
+  LiteralSlice expected_;
+  LiteralSlice actual_;
 
   // The error bounds of the comparison.
   ErrorSpec error_;
@@ -794,8 +794,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 // Helper function for comparing two literals for nearness. Handles tuple-shapes
 // via recursion. shape_index is the ShapeIndex of expected (or actual)
 // currently being compared.
-::testing::AssertionResult NearHelper(const Literal& expected,
-                                      const Literal& actual,
+::testing::AssertionResult NearHelper(LiteralSlice expected,
+                                      LiteralSlice actual,
                                       const ErrorSpec& error,
                                       bool detailed_message,
                                       const ShapeIndex& shape_index) {
@@ -807,8 +807,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 
   if (ShapeUtil::IsTuple(expected.shape())) {
     for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
-      const auto expected_element = LiteralView::Create(expected, {i});
-      const auto actual_element = LiteralView::Create(actual, {i});
+      const auto expected_element = LiteralSlice(expected, {i});
+      const auto actual_element = LiteralSlice(actual, {i});
       ShapeIndex element_index = shape_index;
       element_index.push_back(i);
       ::testing::AssertionResult res =
@@ -874,14 +874,14 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 }  // namespace
 
 /* static */ ::testing::AssertionResult LiteralTestUtil::Near(
-    const Literal& expected, const Literal& actual, const ErrorSpec& error,
+    LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
     bool detailed_message) {
   return NearHelper(expected, actual, error, detailed_message,
                     /*shape_index=*/{});
 }
 
-/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
-                                              const Literal& actual,
+/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected,
+                                              LiteralSlice actual,
                                               const ErrorSpec& error,
                                               const string& message) {
   ::testing::AssertionResult res =
@@ -897,7 +897,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 }
 
 /*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
-    const Literal& expected, const Literal& actual,
+    LiteralSlice expected, LiteralSlice actual,
     const tensorflow::gtl::optional<ErrorSpec>& error) {
   if (error.has_value()) {
     VLOG(1) << "Expects near";
@@ -908,7 +908,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 }
 
 /*static*/ void LiteralTestUtil::ExpectNearOrEqual(
-    const Literal& expected, const Literal& actual,
+    LiteralSlice expected, LiteralSlice actual,
     const tensorflow::gtl::optional<ErrorSpec>& error) {
   EXPECT_TRUE(NearOrEqual(expected, actual, error));
 }
@@ -920,7 +920,7 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 
 /* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
     tensorflow::gtl::ArraySlice<int64> new_dimensions,
-    tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
+    tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
   int64 new_num_elements = 1;
   for (int64 i = 0; i < new_dimensions.size(); ++i) {
     new_num_elements *= new_dimensions[i];
index a755568..4983ddd 100644 (file)
@@ -69,53 +69,53 @@ class LiteralTestUtil {
   // If the given literal's data type is bfloat16, converts it to a float
   // literal; otherwise, returns a copy of it. If the literal is a tuple,
   // recursively converts its elements.
-  static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal);
+  static std::unique_ptr<Literal> ConvertBF16ToF32(LiteralSlice bf16_literal);
 
   // If the given literal's data type is float, converts it to a bfloat16
   // literal; otherwise, returns a copy of it. If the literal is a tuple,
   // recursively converts its elements.
-  static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal);
+  static std::unique_ptr<Literal> ConvertF32ToBF16(LiteralSlice f32_literal);
 
   // Asserts that the expected and actual literals are (bitwise) equal for all
   // elements in the literal. Also, asserts that the rank, dimensions sizes, and
   // primitive type are equal.
   static ::testing::AssertionResult Equal(
-      const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
+      LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT;
 
   // Expects that expected and actual are Equal.
-  static void ExpectEqual(const Literal& expected, const Literal& actual,
+  static void ExpectEqual(LiteralSlice expected, LiteralSlice actual,
                           const string& message = "");
 
   // Expects that expected and actual are Not Equal.
-  static void ExpectNotEqual(const Literal& expected, const Literal& actual);
+  static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual);
 
   // Asserts the given literal are (bitwise) equal to given expected values.
   template <typename NativeT>
-  static void ExpectR0Equal(NativeT expected, const Literal& actual);
+  static void ExpectR0Equal(NativeT expected, LiteralSlice actual);
   template <typename NativeT>
   static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
-                            const Literal& actual);
+                            LiteralSlice actual);
   template <typename NativeT>
   static void ExpectR2Equal(
       std::initializer_list<std::initializer_list<NativeT>> expected,
-      const Literal& actual);
+      LiteralSlice actual);
   template <typename NativeT>
   static void ExpectR3Equal(
       std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>
           expected,
-      const Literal& actual);
+      LiteralSlice actual);
 
   // Asserts the given literal are (bitwise) equal to given array.
   template <typename NativeT>
   static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
-                                   const Literal& actual);
+                                   LiteralSlice actual);
   template <typename NativeT>
   static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
-                                   const Literal& actual);
+                                   LiteralSlice actual);
   template <typename NativeT>
   static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
-                                   const Literal& actual);
+                                   LiteralSlice actual);
 
   // Asserts that the expected and actual literals are within the given error
   // bound for all elements. Also, asserts that the rank, dimensions sizes, and
@@ -133,64 +133,61 @@ class LiteralTestUtil {
   // If detailed_message is true, then the error message in the assertion result
   // will contain a more detailed breakdown of mismatches.
   static ::testing::AssertionResult Near(
-      const Literal& expected, const Literal& actual, const ErrorSpec& error,
+      LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
       bool detailed_message = false) TF_MUST_USE_RESULT;
 
   // Expects expected and actual to be Near with the given error.
-  static void ExpectNear(const Literal& expected, const Literal& actual,
+  static void ExpectNear(LiteralSlice expected, LiteralSlice actual,
                          const ErrorSpec& error, const string& message = "");
 
   // Asserts the given literal are within the given error bound of the given
   // expected values. Only supported for floating point values.
   template <typename NativeT>
-  static void ExpectR0Near(NativeT expected, const Literal& actual,
+  static void ExpectR0Near(NativeT expected, LiteralSlice actual,
                            const ErrorSpec& error);
   template <typename NativeT>
   static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
-                           const Literal& actual, const ErrorSpec& error);
+                           LiteralSlice actual, const ErrorSpec& error);
   template <typename NativeT>
   static void ExpectR2Near(
       std::initializer_list<std::initializer_list<NativeT>> expected,
-      const Literal& actual, const ErrorSpec& error);
+      LiteralSlice actual, const ErrorSpec& error);
   template <typename NativeT>
   static void ExpectR3Near(
       std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>
           expected,
-      const Literal& actual, const ErrorSpec& error);
+      LiteralSlice actual, const ErrorSpec& error);
   template <typename NativeT>
   static void ExpectR4Near(
       std::initializer_list<std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>>
           expected,
-      const Literal& actual, const ErrorSpec& error);
+      LiteralSlice actual, const ErrorSpec& error);
 
   // Asserts the given literal are within the given error bound to the given
   // array. Only supported for floating point values.
   template <typename NativeT>
   static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
-                                  const Literal& actual,
-                                  const ErrorSpec& error);
+                                  LiteralSlice actual, const ErrorSpec& error);
   template <typename NativeT>
   static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
-                                  const Literal& actual,
-                                  const ErrorSpec& error);
+                                  LiteralSlice actual, const ErrorSpec& error);
   template <typename NativeT>
   static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
-                                  const Literal& actual,
-                                  const ErrorSpec& error);
+                                  LiteralSlice actual, const ErrorSpec& error);
 
   // If the error spec is given, returns whether the expected and the actual are
   // within the error bound; otherwise, returns whether they are equal. Tuples
   // will be compared recursively.
   static ::testing::AssertionResult NearOrEqual(
-      const Literal& expected, const Literal& actual,
+      LiteralSlice expected, LiteralSlice actual,
       const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
 
   // If the error spec is given, expects the expected and the actual to be near;
   // otherwise, expects them to be equal. Tuples will be compared recursively.
   static void ExpectNearOrEqual(
-      const Literal& expected, const Literal& actual,
+      LiteralSlice expected, LiteralSlice actual,
       const tensorflow::gtl::optional<ErrorSpec>& error);
 
   // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
@@ -205,8 +202,7 @@ class LiteralTestUtil {
   // layout order.
   static std::unique_ptr<Literal> Reshape(
       tensorflow::gtl::ArraySlice<int64> new_dimensions,
-      tensorflow::gtl::ArraySlice<int64> minor_to_major,
-      const Literal& literal);
+      tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal);
 
   // Creates a literal with the supplied shape, and uses the provided value
   // generator to populate the literal's values.
@@ -244,20 +240,20 @@ class LiteralTestUtil {
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
-                                                 const Literal& actual) {
+                                                 LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR1Equal(
-    tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual) {
+    tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2Equal(
     std::initializer_list<std::initializer_list<NativeT>> expected,
-    const Literal& actual) {
+    LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual);
 }
 
@@ -265,38 +261,38 @@ template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3Equal(
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         expected,
-    const Literal& actual) {
+    LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
-    const Array2D<NativeT>& expected, const Literal& actual) {
+    const Array2D<NativeT>& expected, LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
-    const Array3D<NativeT>& expected, const Literal& actual) {
+    const Array3D<NativeT>& expected, LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
-    const Array4D<NativeT>& expected, const Literal& actual) {
+    const Array4D<NativeT>& expected, LiteralSlice actual) {
   ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
-                                                const Literal& actual,
+                                                LiteralSlice actual,
                                                 const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR1Near(
-    tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual,
+    tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual,
     const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error);
 }
@@ -304,7 +300,7 @@ template <typename NativeT>
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2Near(
     std::initializer_list<std::initializer_list<NativeT>> expected,
-    const Literal& actual, const ErrorSpec& error) {
+    LiteralSlice actual, const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error);
 }
 
@@ -312,7 +308,7 @@ template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3Near(
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         expected,
-    const Literal& actual, const ErrorSpec& error) {
+    LiteralSlice actual, const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error);
 }
 
@@ -321,27 +317,27 @@ template <typename NativeT>
     std::initializer_list<std::initializer_list<
         std::initializer_list<std::initializer_list<NativeT>>>>
         expected,
-    const Literal& actual, const ErrorSpec& error) {
+    LiteralSlice actual, const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
-    const Array2D<NativeT>& expected, const Literal& actual,
+    const Array2D<NativeT>& expected, LiteralSlice actual,
     const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
-    const Array3D<NativeT>& expected, const Literal& actual,
+    const Array3D<NativeT>& expected, LiteralSlice actual,
     const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error);
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
-    const Array4D<NativeT>& expected, const Literal& actual,
+    const Array4D<NativeT>& expected, LiteralSlice actual,
     const ErrorSpec& error) {
   ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error);
 }
index 44c6811..96858c0 100644 (file)
@@ -210,12 +210,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>(
       {{10.0f, 20.0f}, {30.0f, 40.0f}},
-      LiteralView::Create(*result_literal, {1}));
+      LiteralSlice(*result_literal, {1}));
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2}));
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -239,16 +239,16 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
   LiteralTestUtil::ExpectR2Equal<float>(
       {{1.0f, 2.0f}, {3.0f, 4.0f}},
-      LiteralView::Create(*result_literal, {0, 0}));
+      LiteralSlice(*result_literal, {0, 0}));
   LiteralTestUtil::ExpectR2Equal<float>(
       {{10.0f, 20.0f}, {30.0f, 40.0f}},
-      LiteralView::Create(*result_literal, {0, 1}));
+      LiteralSlice(*result_literal, {0, 1}));
   LiteralTestUtil::ExpectR2Equal<float>(
       {{1.0f, 2.0f}, {3.0f, 4.0f}},
-      LiteralView::Create(*result_literal, {0, 2}));
+      LiteralSlice(*result_literal, {0, 2}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -274,9 +274,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -321,9 +321,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>(
       {{56.0f, 46.0f}, {36.0f, 26.0f}},
-      LiteralView::Create(*result_literal, {0}));
+      LiteralSlice(*result_literal, {0}));
   LiteralTestUtil::ExpectR1Equal<float>(
-      {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1}));
+      {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -361,9 +361,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
 
   std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0}));
+      {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0}));
   LiteralTestUtil::ExpectR1Equal<float>(
-      {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1}));
+      {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -391,16 +391,16 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
   std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
   LiteralTestUtil::ExpectR2Equal<float>(
       {{-1.0, -2.0}, {-3.0, -4.0}},
-      LiteralView::Create(*result_0_literal, {0}));
+      LiteralSlice(*result_0_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1}));
+      {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1}));
 
   ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
   std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0}));
+      {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1}));
+      {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -447,7 +447,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
 
   for (int i = 0; i < kElementCount; ++i) {
     LiteralTestUtil::ExpectR1Near<float>(
-        {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}),
+        {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}),
         error_spec_);
   }
 }
@@ -502,7 +502,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
   for (int i = 0; i < kFanout; ++i) {
     for (int j = 0; j < kFanout; ++j) {
       LiteralTestUtil::ExpectR0Near<float>(
-          i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}),
+          i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
           error_spec_);
     }
   }
@@ -548,7 +548,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
     index.push_back(0);
   }
   LiteralTestUtil::ExpectR0Equal<float>(
-      165.0, LiteralView::Create(*result_literal, index));
+      165.0, LiteralSlice(*result_literal, index));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -754,9 +754,9 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
   std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR1Equal<float>(
-      {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0}));
+      {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0}));
   LiteralTestUtil::ExpectR1Equal<float>(
-      {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1}));
+      {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {