From: Kay Zhu Date: Wed, 9 May 2018 20:07:35 +0000 (-0700) Subject: [XLA] First step in adding Literal slice classes, to improve interface safety X-Git-Tag: upstream/v1.9.0_rc1~116^2^2~192 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e1347ba769b98e260d36e895be2963af35c88d18;p=platform%2Fupstream%2Ftensorflow.git [XLA] First step in adding Literal slice classes, to improve interface safety and prepare for enabling more efficient interfacing from Tensor to Literal to reduce host to device latency. More specically: * Introducing a new LiteralBase abstract base class that contains all immutable methods of from the old Literal class. * Introducing a subclass LiteralSlice to replace original LiteralView class. LiteralSlice class is read-only and does not own Shape nor any buffer through the Pieces. Change a number of callers to use LiteralSlice directly. * Change Literal class to explicitly own the underlying Shape as well as owning the underlying buffer via Piece. * Conversion from Literal to LiteralSlice is now done via an implicit conversion constructor instead of inheritance. * Decouple ShapeTree from Literal classes. * Use copy-and-swap for assignment constructors. * Other minor cleanups. PiperOrigin-RevId: 196016576 --- diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2c3cd65..43e1c1e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,7 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && xla::ShapeUtil::ElementsIn(literal.shape()) == @@ -63,8 +63,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal, return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index f283b02..220bec1 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -36,13 +36,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor); +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 83c7cb1..f9f9944 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -185,7 +185,7 @@ bool ComputationBuilder::MakeWindow( } ComputationDataHandle ComputationBuilder::ConstantLiteral( - const Literal& literal) { + const LiteralSlice& literal) { OpRequest op_request; ConstantRequest* request = op_request.mutable_constant_request(); *request->mutable_literal() = literal.ToProto(); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index ac1eb91..176962b 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -108,7 +108,7 @@ class ComputationBuilder { // Enqueues a constant with the value of the given literal onto the // computation. - ComputationDataHandle ConstantLiteral(const Literal& literal); + ComputationDataHandle ConstantLiteral(const LiteralSlice& literal); // Enqueues a constant onto the computation. Methods are templated on the // native host type (NativeT) which corresponds to a specific XLA diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 1899983..4c59d62 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -437,7 +437,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { +XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 4955f15..e1920d6 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -139,7 +139,7 @@ class XlaBuilder { // Enqueues a constant with the value of the given literal onto the // computation. - XlaOp ConstantLiteral(const Literal& literal); + XlaOp ConstantLiteral(const LiteralSlice& literal); // Enqueues a constant onto the computation. Methods are templated on the // native host type (NativeT) which corresponds to a specific XLA diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index b3b5e34..e9b0e11 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -64,6 +64,8 @@ void ConvertEndianShort(char* bytes, int64 size) { } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -95,99 +97,90 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(subshape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(subshape.layout()); - piece.set_buffer( - new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( - subshape.element_type())]); - piece.set_sparse_indices(new SparseIndexArray( - max_sparse_elements, ShapeUtil::Rank(subshape))); - } else { - piece.set_buffer(new char[piece.size_bytes()]); - } +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); -void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + SetPiece(*shape_, root_piece_, allocate_arrays); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; } - owns_buffers_ = other.owns_buffers_; +} - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); + CHECK(&other.root_piece_->subshape() == other.shape_.get()); + + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + CHECK(&root_piece_->subshape() == shape_.get()); + return *this; } -std::unique_ptr Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -204,7 +197,7 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { template Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -217,8 +210,8 @@ Status Literal::CopySliceFromInternal( if (ShapeUtil::Rank(src_literal.shape()) == 0 || ShapeUtil::Rank(shape()) == 0) { - // If any of the two shapes are scalars, we can just call the StridedCopy() - // directly, and we know we will be copying only one value. + // If any of the two shapes are scalars, we can just call the + // StridedCopy() directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); StridedCopy(data(), linear_index(shape(), dest_base), 0, src_literal.data(), @@ -264,7 +257,7 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const Literal& src_literal, +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_index, tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); @@ -293,22 +286,21 @@ std::vector Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -331,9 +323,9 @@ std::vector Literal::DecomposeTuple() { } namespace { - -// Copies the elements in 'src' to 'dest'. The shape and layout of the data in -// the array slices are indicated by dest_shape and src_shape respectively. +// Copies the elements in 'src' to 'dest'. The shape and layout of the data +// in the array slices are indicated by dest_shape and src_shape +// respectively. template void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, tensorflow::gtl::ArraySlice src, @@ -351,7 +343,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -381,14 +373,15 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { #undef COPY_ELEMENTS default: return Unimplemented( - "Copying a Literal object with element type %s is not implemented.", + "Copying a Literal object with element type %s is not " + "implemented.", PrimitiveType_Name(subshape().element_type()).c_str()); } } return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -402,36 +395,33 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(src_subshape).c_str()); } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); -} + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); +} // namespace xla Status Literal::MoveFrom(Literal&& src_literal, const ShapeIndex& dest_shape_index) { @@ -444,37 +434,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -743,7 +728,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -755,7 +740,7 @@ std::unique_ptr Literal::Relayout( return result; } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -774,7 +759,7 @@ std::unique_ptr Literal::Relayout( return result; } -StatusOr> Literal::Reshape( +StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -788,7 +773,8 @@ StatusOr> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -802,7 +788,7 @@ StatusOr> Literal::Reshape( return std::move(output); } -std::unique_ptr Literal::Transpose( +std::unique_ptr LiteralBase::Transpose( tensorflow::gtl::ArraySlice permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -819,8 +805,8 @@ std::unique_ptr Literal::Transpose( // representation intact. // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the - // most minor. + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), + // the most minor. // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di @@ -836,12 +822,11 @@ std::unique_ptr Literal::Transpose( std::unique_ptr new_literal = CreateFromShape(permuted_shape); DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr Literal::Slice( +std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -909,20 +894,20 @@ std::unique_ptr Literal::Slice( } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr Literal::CloneToUnique() const { +std::unique_ptr LiteralBase::CloneToUnique() const { auto result = MakeUnique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -962,8 +947,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -1017,7 +1002,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr Literal::GetIntegralAsS64( +StatusOr LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -1070,7 +1055,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, return Status::OK(); } -tensorflow::gtl::ArraySlice Literal::GetSparseIndex( +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1082,10 +1067,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -Literal Literal::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); - switch (shape_.element_type()) { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { case PRED: return std::move(*Literal::CreateR0(GetFirstElement())); // 8 bit types. @@ -1121,11 +1106,11 @@ Literal Literal::GetFirstScalarLiteral() const { case U64: return std::move(*Literal::CreateR0(GetFirstElement())); default: - LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } } -void Literal::Piece::SortSparseElements() { +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal(); @@ -1176,7 +1161,7 @@ void Literal::Piece::SortSparseElements() { } template -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data(); @@ -1186,10 +1171,11 @@ void Literal::Piece::SortSparseElementsInternal() { } namespace { - -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1348,13 +1334,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1362,7 +1349,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1372,6 +1359,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; @@ -1387,7 +1387,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1403,7 +1403,7 @@ void Literal::EachCellAsString( namespace { template std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const Literal& src_literal, const ConverterType& converter) { + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1419,7 +1419,8 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1428,7 +1429,7 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); }; @@ -1436,19 +1437,19 @@ BitcastBetweenNativeTypes(const Literal& src_literal) { src_literal, converter); } -// This template specialization is here to make the compiler happy. bit_cast has -// a static check that the types are the same size. This specialization should -// never be used because the source and destination types are checked for -// identical sizes higher up. +// This template specialization is here to make the compiler happy. bit_cast +// has a static check that the types are the same size. This specialization +// should never be used because the source and destination types are checked +// for identical sizes higher up. template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const Literal& src_literal) { +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1466,7 +1467,7 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { @@ -1486,7 +1487,7 @@ std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type, + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ @@ -1521,7 +1522,8 @@ StatusOr> ConvertIfDestTypeMatches( } StatusOr> ConvertSwitch( - const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.CloneToUnique(); @@ -1555,17 +1557,18 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> Literal::Convert( +StatusOr> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> Literal::BitcastConvert( +StatusOr> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { return InvalidArgument( - "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "Cannot bitcast convert from %s to %s, bit widths are different: %d " + "!= " "%d", PrimitiveType_Name(shape().element_type()).c_str(), PrimitiveType_Name(primitive_dest_type).c_str(), @@ -1575,7 +1578,7 @@ StatusOr> Literal::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> Literal::ConvertToShape( +StatusOr> LiteralBase::ConvertToShape( const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && @@ -1590,7 +1593,7 @@ StatusOr> Literal::ConvertToShape( } std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralView::Create(*this, {i}); + auto element = LiteralSlice(*this, {i}); TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); @@ -1602,8 +1605,8 @@ StatusOr> Literal::ConvertToShape( } template -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get(*multi_index) == other.Get(*multi_index)); } @@ -1617,7 +1620,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector multi_index; @@ -1645,32 +1648,31 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { - template static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, NativeT value) { @@ -1684,11 +1686,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1741,41 +1743,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; -} + return true; + }); +} // namespace xla -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue(root_piece().data(), @@ -1785,93 +1787,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsAllFirst() const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { return false; - } - }; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; - if (!piece_is_all()) { - return false; - } - } - return true; + if (!piece_is_all()) { + return false; + } + return true; + }); } -bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1904,7 +1906,6 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { } namespace { - template void CopyToRepeatedField(RepeatedFieldT* dest, const tensorflow::gtl::ArraySlice src) { @@ -1913,7 +1914,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1969,18 +1970,17 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } namespace { - template Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, const RepeatedFieldT& src) { @@ -1995,7 +1995,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -2062,21 +2062,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -2098,33 +2096,34 @@ StatusOr> Literal::CreateFromProto( auto literal = MakeUnique(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -2132,11 +2131,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -2144,12 +2143,14 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} -size_t Literal::Hash() const { +size_t LiteralBase::Hash() const { using tensorflow::Hash64; using tensorflow::Hash64Combine; @@ -2170,46 +2171,4 @@ size_t Literal::Hash() const { return hash_value; } -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); - } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } -} - -LiteralView::~LiteralView() {} - -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } - -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; -} - -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = false; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 290f388..30442af 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -52,14 +51,491 @@ limitations under the License. namespace xla { +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Return the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class LiteralSlice; + friend class Literal; +}; + // Class representing literal values in XLA. // -// TODO(b/67651157): The methods in this class should be reduced to a minimal -// set of methods which construct Literals and accessors methods. Other methods -// which perform computation on Literals (Reshape, Slice, etc) should be moved -// elsewhere, and perhaps combined with evaluator code which operates on -// Literals. -class Literal { +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { public: Literal() : Literal(ShapeUtil::MakeNil()) {} @@ -80,46 +556,156 @@ class Literal { Literal(const Shape& shape, bool allocate_arrays); Literal& operator=(Literal&& other); - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const Literal& other) const; - bool operator!=(const Literal& other) const { return !(*this == other); } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } - // Serialize to and from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); - LiteralProto ToProto() const; + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); - // Return the shape of the literal. - const Shape& shape() const { return shape_; } + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return &shape_; } + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + // Fills this literal with the given value. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + void PopulateWithValue(NativeT value); - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Factory methods below. + // - // Returns a pointer to (or size of) the underlying buffer holding the array - // at the given shape index. CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - void* untyped_data(const ShapeIndex& shape_index = {}); - int64 size_bytes(const ShapeIndex& shape_index = {}) const; + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -167,10 +753,6 @@ class Literal { values, const Layout& layout); - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - // Creates a literal with a sparse layout and the given indices and values. // The shape is initialized from the given dimensions. The minor dimension of // the indices array must equal the rank of the shape (i.e. size of the @@ -210,171 +792,16 @@ class Literal { tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort = true); - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape); - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const Literal& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - - // Creates a new value that has the equivalent value as this literal, but - // conforms to new_layout; e.g. a literal matrix that was in {0, 1} - // minor-to-major dimension layout can be re-layed-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Converts this literal to another primitive type using - // static_cast<>. Returns an error if the conversion is not possible. This - // literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( @@ -419,88 +846,15 @@ class Literal { // the z dimension given by "projection". template static std::unique_ptr CreateR3Projected( - std::initializer_list> values, - int64 projection); - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z); - - // Clones this literal into a new Literal, or new std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // Gets or sets an element in the literal at the given index. The multi_index - // is CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - - // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; + std::initializer_list> values, + int64 projection); - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + // Creates a literal that projects the (x, y) dimensions given in values into + // the z and p dimensions given. + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template @@ -511,6 +865,9 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + static std::unique_ptr MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; @@ -542,135 +899,48 @@ class Literal { return MakeTupleOwned(std::move(v)); } - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); - - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; + // End of factory methods. - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; - - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + protected: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - // Return the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); } - // Return the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. - size_t Hash() const; + Piece& root_piece() const override { return *root_piece_; }; - protected: + private: // Internal template helper for the Literal::CopySliceFrom(), matching its // arguments one by one. template - Status CopySliceFromInternal(const Literal& src_literal, + Status CopySliceFromInternal(const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); @@ -698,162 +968,40 @@ class Literal { int64 minor_loop_size = 1; }; - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Return the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Return the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - }; - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return *pieces_.mutable_element(shape_index); - } - const Piece& piece(const ShapeIndex& shape_index) const { - return pieces_.element(shape_index); - } - - // Returns the piece at the root of the shape (empty ShapeIndex). - Piece& root_piece() { return piece({}); } - const Piece& root_piece() const { return piece({}); } + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; - // Deallocate the buffers held by this literal (if the literal owns the - // buffer). - void DeallocateBuffers(); + Piece* root_piece_ = nullptr; // Implementation details shared between Populate() and PopulateParallel() template Status PopulateInternal(const FnType& generator, bool parallel); - Shape shape_; - ShapeTree pieces_; - - // Whether the buffers held in pieces_ are owned by this Literal. - bool owns_buffers_; + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); - // LiteralView must access and manipulate Pieces of other Literals. - friend class LiteralView; -}; // namespace xla + friend class LiteralBase; +}; std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralView contains pointers to buffers -// owned by the viewed Literal. -// -// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable -// and mutable) similar to (Mutable)ArraySlice. -class LiteralView : public Literal { +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { public: - // Create and return a view of the given literal rooted at the given shape - // index within the given literal. A factory is used rather than a public - // constructor because only const LiteralViews are supported. It's still - // possible to create non-const LiteralViews via the copy constructors, but - // the factory method makes it a bit less likely. Implementing literal slices - // will fix this undesirable situation (b/71550060). - static const LiteralView Create(const Literal& literal, - const ShapeIndex& view_root = {}); - - LiteralView(const LiteralView& other); - LiteralView& operator=(const LiteralView& other); - - virtual ~LiteralView(); + LiteralSlice() : LiteralBase() {} + // Implicit conversion constructor that can also accept Literal. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); private: - LiteralView(const Literal& literal, const ShapeIndex& view_root); + const Piece& root_piece() const override { return *root_piece_; }; - // Helper for the copy constructor and copy assignment operator. - void CopyFrom(const LiteralView& other); + const Piece* root_piece_; // Not owned. }; template -tensorflow::gtl::ArraySlice Literal::Piece::data() const { +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -866,7 +1014,7 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { } template -tensorflow::gtl::MutableArraySlice Literal::Piece::data() { +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -879,7 +1027,7 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { } template -NativeT Literal::Piece::Get( +NativeT LiteralBase::Piece::Get( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -887,15 +1035,15 @@ NativeT Literal::Piece::Get( } template -void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } template -tensorflow::gtl::ArraySlice Literal::data( +tensorflow::gtl::ArraySlice LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } @@ -907,13 +1055,13 @@ tensorflow::gtl::MutableArraySlice Literal::data( } template -inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT Literal::Get( +inline NativeT LiteralBase::Get( tensorflow::gtl::ArraySlice multi_index) const { return root_piece().Get(multi_index); } @@ -1160,13 +1308,13 @@ template } template -NativeT Literal::GetFirstElement() const { +NativeT LiteralBase::GetFirstElement() const { return data().at(0); } template -NativeT Literal::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { CHECK( LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); return data(shape_index)[sparse_element_number]; @@ -1199,7 +1347,7 @@ template } template -void Literal::EachCell( +void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { @@ -1375,7 +1523,7 @@ template } template -std::unique_ptr Literal::Replicate(int64 times) const { +std::unique_ptr LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 6104678..087d509 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -974,7 +974,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { Literal::CreateR1({2.0, 4.0}).get(), &nil_literal}); - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); @@ -985,7 +985,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); @@ -1373,36 +1373,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, LiteralViewTest) { +TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); - EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); - EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); - EXPECT_EQ(LiteralView::Create(nil, {}), nil); + EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); + EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); } -TEST_F(LiteralUtilTest, MutatingLiteralView) { +TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); EXPECT_EQ( nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); @@ -1418,16 +1418,15 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) { 555.0f); } -TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { +TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); - const auto tuple_view = - LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); - const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } @@ -1533,11 +1532,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_EQ(literal.Get({1, 1}), 4.0); } -TEST_F(LiteralUtilTest, LiteralViewCopy) { +TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralView::Create(*matrix); - LiteralView matrix_view_copy(matrix_view); + const auto matrix_view = LiteralSlice(*matrix); + LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index dc6f5fe..68648a3 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -340,13 +340,13 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const Literal& literal) { +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM( - tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + PyTuple_SET_ITEM(tuple, i, + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } return tuple; } else { @@ -431,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 9656cb1..64f0aae 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -74,7 +74,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const Literal& literal); +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,7 +90,7 @@ StatusOr > XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array); template @@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { } template -void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { +void CopyLiteralToNumpyArray(const LiteralSlice& literal, + PyArrayObject* py_array) { NativeT* dest = static_cast(PyArray_DATA(py_array)); auto source = literal.data(); std::copy(source.begin(), source.end(), dest); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4ec79a0..3ce80bb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -501,13 +501,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const Literal& literal) { + const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { elems.push_back( - BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); + BuildTupleConstant(computation, LiteralSlice(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 9b39e7f..d97802e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -88,8 +88,8 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status CpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 3ecb0d2..6dfc666 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -38,7 +38,7 @@ class CpuTransferManager : public GenericTransferManager { ~CpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index 7dcc4ca..c562865 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -26,13 +26,13 @@ limitations under the License. namespace xla { namespace cpu { -void ExternalConstantPool::Insert(string name, const Literal& literal, +void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, int64 alignment) { CHECK(!ShapeUtil::IsTuple(literal.shape())); CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); CHECK(entries_.find(name) == entries_.end()); - int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); void* raw_pointer = tensorflow::port::AlignedMalloc( literal_size, std::max(alignment, sizeof(void*))); CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 8008a56..0677f5f 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -43,7 +43,7 @@ class ExternalConstantPool { // The constant pool copies out the contents of `literal` into a buffer it // owns -- it does not keep pointers to `literal`, or to memory owned by // `literal`. - void Insert(string name, const Literal& literal, int64 alignment); + void Insert(string name, const LiteralSlice& literal, int64 alignment); // Find the constant with name `name` in this constant pool. If there isn't // such constant, return nullptr. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index ddb6873..dbf1ab66 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -115,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const auto subliteral = LiteralView::Create(literal, index); + const auto subliteral = LiteralSlice(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), @@ -137,7 +137,7 @@ Status GenericTransferManager::TransferLiteralToDevice( } Status GenericTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { + se::StreamExecutor* executor, const LiteralSlice& literal) { return Unimplemented("Generic transfer to Infeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 0579099..3343eca 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -49,7 +49,7 @@ class GenericTransferManager : public TransferManager { const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f13727c..7bb8df6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,8 +44,8 @@ GpuTransferManager::GpuTransferManager() /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) .getPointerSize(0 /* default address space */)) {} -Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status GpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index d040a99..09f8227 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -37,7 +37,7 @@ class GpuTransferManager : public GenericTransferManager { ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index fffe192..63eaf6f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -56,8 +56,8 @@ using tensorflow::gtl::FlatSet; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, - const Literal& lhs_literal, - const Literal& rhs_literal) { + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -106,8 +106,8 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, template <> StatusOr> Compare( - const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, - const Literal& rhs_literal) { + const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index d82b4f0..55c544f 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -81,7 +81,7 @@ class TransferManager { // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) = 0; + const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 6ebbf71..a180cdd 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -87,11 +87,11 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { LiteralTestUtil::ExpectNear( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralView::Create(*result, {0}), error_spec_); + LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectNear( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralView::Create(*result, {1}), error_spec_); + LiteralSlice(*result, {1}), error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 0b425b9..abf7312 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -91,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralView::Create(*result, {0})); + LiteralSlice(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralView::Create(*result, {1})); + LiteralSlice(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 4743673..d518e4a 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -169,9 +169,9 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); + {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index c28f79a..868876c 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -111,7 +111,7 @@ namespace { // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template -std::unique_ptr ConvertType(const Literal& literal) { +std::unique_ptr ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -150,12 +150,12 @@ std::unique_ptr ConvertType(const Literal& literal) { } // namespace /* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { + LiteralSlice literal) { return ConvertType(literal); } /* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { + LiteralSlice literal) { return ConvertType(literal); } @@ -237,7 +237,7 @@ template <> // actual literal and compares their values elementwise. Returns true if all // elements are equal. template -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, +bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual, tensorflow::gtl::MutableArraySlice multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { @@ -259,8 +259,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, +/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected, + LiteralSlice actual, const string& message) { EXPECT_TRUE(Equal(expected, actual)) << "expected:\n" @@ -269,13 +269,13 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, << (message.empty() ? "" : StrCat("\nmessage: ", message)); } -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { +/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected, + LiteralSlice actual) { EXPECT_FALSE(Equal(expected, actual)); } /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { + LiteralSlice expected, LiteralSlice actual) { VLOG(1) << "expected:"; XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; @@ -324,9 +324,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, SCOPED_TRACE(StrCat("Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); + // Create LiteralSlices of the expected and actual elements. + auto result = + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})); tuple_match = tuple_match ? !!result : false; } match = tuple_match; @@ -368,7 +368,7 @@ int64 RecursiveElementCount(const Shape& shape) { // 3 minutes. The utility of printing a literal with >1000 elements is // questionable, especially when writing the Literal proto to disk is orders // of magnitude faster. -string TruncateHugeLiteral(const Literal& literal) { +string TruncateHugeLiteral(LiteralSlice literal) { return RecursiveElementCount(literal.shape()) < 1000 ? literal.ToString() : "[TRUNCATED, Literal with more than 1000 values]"; @@ -435,8 +435,8 @@ class NearComparator { // result. The assertion result is successful if all actual and expected // elements are within the given error bound. In case of error, the assertion // result contains a detailed error message in case of failure. - static ::testing::AssertionResult Compare(const Literal& expected, - const Literal& actual, + static ::testing::AssertionResult Compare(LiteralSlice expected, + LiteralSlice actual, ErrorSpec error, bool detailed_message) { NearComparator comparator(expected, actual, error, @@ -472,7 +472,7 @@ class NearComparator { } }; - explicit NearComparator(const Literal& expected, const Literal& actual, + explicit NearComparator(LiteralSlice expected, LiteralSlice actual, ErrorSpec error, bool detailed_message) : expected_(expected), actual_(actual), @@ -649,7 +649,7 @@ class NearComparator { } // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { + void WriteLiteralToTempFile(LiteralSlice literal, const string& name) { int64 now_usec = tensorflow::Env::Default()->NowMicros(); string filename = tensorflow::io::JoinPath( tensorflow::testing::TmpDir(), @@ -733,8 +733,8 @@ class NearComparator { } // 'actual' and 'expected' literals being compared. - const Literal& expected_; - const Literal& actual_; + LiteralSlice expected_; + LiteralSlice actual_; // The error bounds of the comparison. ErrorSpec error_; @@ -794,8 +794,8 @@ constexpr std::array NearComparator::kErrorBucketBounds; // Helper function for comparing two literals for nearness. Handles tuple-shapes // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. -::testing::AssertionResult NearHelper(const Literal& expected, - const Literal& actual, +::testing::AssertionResult NearHelper(LiteralSlice expected, + LiteralSlice actual, const ErrorSpec& error, bool detailed_message, const ShapeIndex& shape_index) { @@ -807,8 +807,8 @@ constexpr std::array NearComparator::kErrorBucketBounds; if (ShapeUtil::IsTuple(expected.shape())) { for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); ShapeIndex element_index = shape_index; element_index.push_back(i); ::testing::AssertionResult res = @@ -874,14 +874,14 @@ constexpr std::array NearComparator::kErrorBucketBounds; } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, + LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error, bool detailed_message) { return NearHelper(expected, actual, error, detailed_message, /*shape_index=*/{}); } -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, +/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected, + LiteralSlice actual, const ErrorSpec& error, const string& message) { ::testing::AssertionResult res = @@ -897,7 +897,7 @@ constexpr std::array NearComparator::kErrorBucketBounds; } /*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, + LiteralSlice expected, LiteralSlice actual, const tensorflow::gtl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; @@ -908,7 +908,7 @@ constexpr std::array NearComparator::kErrorBucketBounds; } /*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, + LiteralSlice expected, LiteralSlice actual, const tensorflow::gtl::optional& error) { EXPECT_TRUE(NearOrEqual(expected, actual, error)); } @@ -920,7 +920,7 @@ constexpr std::array NearComparator::kErrorBucketBounds; /* static */ std::unique_ptr LiteralTestUtil::Reshape( tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { + tensorflow::gtl::ArraySlice minor_to_major, LiteralSlice literal) { int64 new_num_elements = 1; for (int64 i = 0; i < new_dimensions.size(); ++i) { new_num_elements *= new_dimensions[i]; diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a755568..4983ddd 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -69,53 +69,53 @@ class LiteralTestUtil { // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); + static std::unique_ptr ConvertBF16ToF32(LiteralSlice bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); + static std::unique_ptr ConvertF32ToBF16(LiteralSlice f32_literal); // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. static ::testing::AssertionResult Equal( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT; // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual, + static void ExpectEqual(LiteralSlice expected, LiteralSlice actual, const string& message = ""); // Expects that expected and actual are Not Equal. - static void ExpectNotEqual(const Literal& expected, const Literal& actual); + static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual); // Asserts the given literal are (bitwise) equal to given expected values. template - static void ExpectR0Equal(NativeT expected, const Literal& actual); + static void ExpectR0Equal(NativeT expected, LiteralSlice actual); template static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, - const Literal& actual); + LiteralSlice actual); template static void ExpectR2Equal( std::initializer_list> expected, - const Literal& actual); + LiteralSlice actual); template static void ExpectR3Equal( std::initializer_list< std::initializer_list>> expected, - const Literal& actual); + LiteralSlice actual); // Asserts the given literal are (bitwise) equal to given array. template static void ExpectR2EqualArray2D(const Array2D& expected, - const Literal& actual); + LiteralSlice actual); template static void ExpectR3EqualArray3D(const Array3D& expected, - const Literal& actual); + LiteralSlice actual); template static void ExpectR4EqualArray4D(const Array4D& expected, - const Literal& actual); + LiteralSlice actual); // Asserts that the expected and actual literals are within the given error // bound for all elements. Also, asserts that the rank, dimensions sizes, and @@ -133,64 +133,61 @@ class LiteralTestUtil { // If detailed_message is true, then the error message in the assertion result // will contain a more detailed breakdown of mismatches. static ::testing::AssertionResult Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, + LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error, bool detailed_message = false) TF_MUST_USE_RESULT; // Expects expected and actual to be Near with the given error. - static void ExpectNear(const Literal& expected, const Literal& actual, + static void ExpectNear(LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error, const string& message = ""); // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. template - static void ExpectR0Near(NativeT expected, const Literal& actual, + static void ExpectR0Near(NativeT expected, LiteralSlice actual, const ErrorSpec& error); template static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, - const Literal& actual, const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); template static void ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); template static void ExpectR3Near( std::initializer_list< std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); template static void ExpectR4Near( std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. template static void ExpectR2NearArray2D(const Array2D& expected, - const Literal& actual, - const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); template static void ExpectR3NearArray3D(const Array3D& expected, - const Literal& actual, - const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); template static void ExpectR4NearArray4D(const Array4D& expected, - const Literal& actual, - const ErrorSpec& error); + LiteralSlice actual, const ErrorSpec& error); // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. static ::testing::AssertionResult NearOrEqual( - const Literal& expected, const Literal& actual, + LiteralSlice expected, LiteralSlice actual, const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; // If the error spec is given, expects the expected and the actual to be near; // otherwise, expects them to be equal. Tuples will be compared recursively. static void ExpectNearOrEqual( - const Literal& expected, const Literal& actual, + LiteralSlice expected, LiteralSlice actual, const tensorflow::gtl::optional& error); // Returns a multi-dimensional index as a string. For example: '{7, 8}' will @@ -205,8 +202,7 @@ class LiteralTestUtil { // layout order. static std::unique_ptr Reshape( tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const Literal& literal); + tensorflow::gtl::ArraySlice minor_to_major, LiteralSlice literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -244,20 +240,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, - const Literal& actual) { + LiteralSlice actual) { ExpectEqual(*Literal::CreateR0(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const Literal& actual) { + tensorflow::gtl::ArraySlice expected, LiteralSlice actual) { ExpectEqual(*Literal::CreateR1(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, - const Literal& actual) { + LiteralSlice actual) { ExpectEqual(*Literal::CreateR2(expected), actual); } @@ -265,38 +261,38 @@ template /* static */ void LiteralTestUtil::ExpectR3Equal( std::initializer_list>> expected, - const Literal& actual) { + LiteralSlice actual) { ExpectEqual(*Literal::CreateR3(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( - const Array2D& expected, const Literal& actual) { + const Array2D& expected, LiteralSlice actual) { ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( - const Array3D& expected, const Literal& actual) { + const Array3D& expected, LiteralSlice actual) { ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( - const Array4D& expected, const Literal& actual) { + const Array4D& expected, LiteralSlice actual) { ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, - const Literal& actual, + LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR0(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const Literal& actual, + tensorflow::gtl::ArraySlice expected, LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR1(expected), actual, error); } @@ -304,7 +300,7 @@ template template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error) { + LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR2(expected), actual, error); } @@ -312,7 +308,7 @@ template /* static */ void LiteralTestUtil::ExpectR3Near( std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error) { + LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR3(expected), actual, error); } @@ -321,27 +317,27 @@ template std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error) { + LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR4(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( - const Array2D& expected, const Literal& actual, + const Array2D& expected, LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( - const Array3D& expected, const Literal& actual, + const Array3D& expected, LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( - const Array4D& expected, const Literal& actual, + const Array4D& expected, LiteralSlice actual, const ErrorSpec& error) { ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); } diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 44c6811..96858c0 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -210,12 +210,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {1})); + LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -239,16 +239,16 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 0})); + LiteralSlice(*result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {0, 1})); + LiteralSlice(*result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 2})); + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -274,9 +274,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -321,9 +321,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralView::Create(*result_literal, {0})); + LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); + {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -361,9 +361,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); + {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -391,16 +391,16 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal( {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralView::Create(*result_0_literal, {0})); + LiteralSlice(*result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); + {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); + {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -447,7 +447,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -502,7 +502,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), error_spec_); } } @@ -548,7 +548,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralView::Create(*result_literal, index)); + 165.0, LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -754,9 +754,9 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); + {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {