return Status::OK();
}
-Status CopyLiteralToHostTensor(const xla::Literal& literal,
+Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor) {
TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
xla::ShapeUtil::ElementsIn(literal.shape()) ==
return Status::OK();
}
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
- Tensor* host_tensor) {
+Status LiteralToHostTensor(const xla::LiteralSlice& literal,
+ DataType target_type, Tensor* host_tensor) {
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
*host_tensor = Tensor(target_type, shape);
// derivable from the type of <literal>, because multiple tensorflow types map
// to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in
// XLA).
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
- Tensor* host_tensor);
+Status LiteralToHostTensor(const xla::LiteralSlice& literal,
+ DataType target_type, Tensor* host_tensor);
// Copies the contents of 'literal' to a previously allocated tensor
// 'host_tensor'. The tensor and the literal must have the same number of
// elements and the same type.
-Status CopyLiteralToHostTensor(const xla::Literal& literal,
+Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor);
} // namespace tensorflow
}
ComputationDataHandle ComputationBuilder::ConstantLiteral(
- const Literal& literal) {
+ const LiteralSlice& literal) {
OpRequest op_request;
ConstantRequest* request = op_request.mutable_constant_request();
*request->mutable_literal() = literal.ToProto();
// Enqueues a constant with the value of the given literal onto the
// computation.
- ComputationDataHandle ConstantLiteral(const Literal& literal);
+ ComputationDataHandle ConstantLiteral(const LiteralSlice& literal);
// Enqueues a constant onto the computation. Methods are templated on the
// native host type (NativeT) which corresponds to a specific XLA
return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
}
-XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
+XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape();
// Enqueues a constant with the value of the given literal onto the
// computation.
- XlaOp ConstantLiteral(const Literal& literal);
+ XlaOp ConstantLiteral(const LiteralSlice& literal);
// Enqueues a constant onto the computation. Methods are templated on the
// native host type (NativeT) which corresponds to a specific XLA
} // namespace
+LiteralBase::~LiteralBase() {}
+
std::ostream& operator<<(std::ostream& out, const Literal& literal) {
out << literal.ToString();
return out;
Literal::Literal(const Shape& shape)
: Literal(shape, /*allocate_arrays=*/true) {}
-Literal::Literal(const Shape& shape, bool allocate_arrays)
- : shape_(shape), pieces_(shape), owns_buffers_(true) {
- CHECK(LayoutUtil::HasLayout(shape));
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
-
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- const Shape& subshape = piece.subshape();
- if (ShapeUtil::IsArray(subshape)) {
- if (allocate_arrays) {
- if (LayoutUtil::IsSparseArray(subshape)) {
- // For sparse arrays, the buffer must be of the size of the maximum
- // number of sparse elements possible.
- const int64 max_sparse_elements =
- LayoutUtil::MaxSparseElements(subshape.layout());
- piece.set_buffer(
- new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType(
- subshape.element_type())]);
- piece.set_sparse_indices(new SparseIndexArray(
- max_sparse_elements, ShapeUtil::Rank(subshape)));
- } else {
- piece.set_buffer(new char[piece.size_bytes()]);
- }
+void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ SetPiece(subshape, &child_piece, allocate_arrays);
+
+ piece->emplace_back(std::move(child_piece));
+ }
+ } else {
+ CHECK(ShapeUtil::IsArray(shape));
+ if (allocate_arrays) {
+ if (LayoutUtil::IsSparseArray(shape)) {
+ // For sparse arrays, the buffer must be of the size of the maximum
+ // number of sparse elements possible.
+ const int64 max_sparse_elements =
+ LayoutUtil::MaxSparseElements(shape.layout());
+ piece->set_buffer(
+ new char[max_sparse_elements *
+ ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
+ piece->set_sparse_indices(
+ new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
} else {
- piece.set_buffer(nullptr);
+ piece->set_buffer(new char[piece->size_bytes()]);
}
}
}
}
-Literal::~Literal() { DeallocateBuffers(); }
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(LayoutUtil::HasLayout(*shape_));
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+ CHECK(&root_piece_->subshape() == shape_.get());
-void Literal::DeallocateBuffers() {
- if (owns_buffers_) {
- for (auto& pair : pieces_) {
- Piece& piece = pair.second;
- if (piece.buffer() != nullptr) {
- delete[] piece.buffer();
- delete piece.sparse_indices();
- }
- }
- }
+ SetPiece(*shape_, root_piece_, allocate_arrays);
}
-Literal::Literal(Literal&& other) {
- shape_ = std::move(other.shape_);
- pieces_ = std::move(other.pieces_);
- // We need to iterate through the pieces to set the subshape pointer
- // properly. It must refer to subshapes within shape_.
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+Literal::~Literal() {
+ if (root_piece_ != nullptr) {
+ DeallocateBuffers();
+ delete root_piece_;
}
- owns_buffers_ = other.owns_buffers_;
+}
- other.shape_ = ShapeUtil::MakeNil();
- other.pieces_ = ShapeTree<Piece>(other.shape_);
- other.piece({}).set_subshape(&other.shape_);
+void Literal::DeallocateBuffers() {
+ root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (piece->buffer() != nullptr) {
+ delete[] piece->buffer();
+ delete piece->sparse_indices();
+ }
+ });
}
+Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+
Literal& Literal::operator=(Literal&& other) {
- DeallocateBuffers();
- shape_ = std::move(other.shape_);
- pieces_ = std::move(other.pieces_);
- // We need to iterate through the pieces to set the subshape pointer
- // properly. It must refer to subshapes within shape_.
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
- owns_buffers_ = other.owns_buffers_;
-
- other.shape_ = ShapeUtil::MakeNil();
- other.pieces_ = ShapeTree<Piece>(other.shape_);
- other.piece({}).set_subshape(&other.shape_);
+ CHECK(&other.root_piece_->subshape() == other.shape_.get());
+
+ using std::swap;
+ swap(shape_, other.shape_);
+ swap(root_piece_, other.root_piece_);
+ CHECK(&root_piece_->subshape() == shape_.get());
+
return *this;
}
-std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
+std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
auto literal = MakeUnique<Literal>(shape);
- for (auto& pair : literal->pieces_) {
- Piece& piece = pair.second;
- if (ShapeUtil::IsArray(piece.subshape())) {
- memset(piece.untyped_data(), 0, piece.size_bytes());
- }
- }
+ literal->root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (ShapeUtil::IsArray(piece->subshape())) {
+ memset(piece->untyped_data(), 0, piece->size_bytes());
+ }
+ });
return literal;
}
-const SparseIndexArray* Literal::sparse_indices(
+const SparseIndexArray* LiteralBase::sparse_indices(
const ShapeIndex& shape_index) const {
return piece(shape_index).sparse_indices();
}
template <typename NativeT>
Status Literal::CopySliceFromInternal(
- const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
ShapeUtil::Rank(shape()) == 0) {
- // If any of the two shapes are scalars, we can just call the StridedCopy()
- // directly, and we know we will be copying only one value.
+ // If any of the two shapes are scalars, we can just call the
+ // StridedCopy() directly, and we know we will be copying only one value.
TF_RET_CHECK(copy_size.empty());
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
return Status::OK();
}
-Status Literal::CopyElementFrom(const Literal& src_literal,
+Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_index,
tensorflow::gtl::ArraySlice<int64> dest_index) {
DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
/*allocate_arrays=*/false));
Literal& element = elements.back();
- for (auto& pair : element.pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& dest_piece = pair.second;
- ShapeIndex src_index = {i};
- for (int64 j : index) {
- src_index.push_back(j);
- }
- Piece& src_piece = piece(src_index);
-
- // Move the respective buffer and sparse indices over to the element
- // Literal.
- dest_piece.set_buffer(src_piece.buffer());
- src_piece.set_buffer(nullptr);
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- src_piece.set_sparse_indices(nullptr);
- }
+ element.root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* dest_piece) {
+ ShapeIndex src_index = {i};
+ for (int64 j : index) {
+ src_index.push_back(j);
+ }
+ Piece& src_piece = piece(src_index);
+
+ // Move the respective buffer and sparse indices over to the element
+ // Literal.
+ dest_piece->set_buffer(src_piece.buffer());
+ src_piece.set_buffer(nullptr);
+ dest_piece->set_sparse_indices(src_piece.sparse_indices());
+ src_piece.set_sparse_indices(nullptr);
+ });
}
// Set this literal to be nil-shaped.
*this = Literal();
}
namespace {
-
-// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
-// the array slices are indicated by dest_shape and src_shape respectively.
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data
+// in the array slices are indicated by dest_shape and src_shape
+// respectively.
template <typename NativeT>
void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
tensorflow::gtl::ArraySlice<NativeT> src,
} // namespace
-Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
+Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
if (ShapeUtil::Equal(subshape(), src.subshape())) {
// If the layouts are equal it's faster just to memcpy.
memcpy(buffer(), src.buffer(), src.size_bytes());
#undef COPY_ELEMENTS
default:
return Unimplemented(
- "Copying a Literal object with element type %s is not implemented.",
+ "Copying a Literal object with element type %s is not "
+ "implemented.",
PrimitiveType_Name(subshape().element_type()).c_str());
}
}
return Status::OK();
}
-Status Literal::CopyFrom(const Literal& src_literal,
+Status Literal::CopyFrom(const LiteralSlice& src_literal,
const ShapeIndex& dest_shape_index,
const ShapeIndex& src_shape_index) {
const Shape& dest_subshape =
ShapeUtil::HumanString(src_subshape).c_str());
}
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
-
- // Determine if this index is in the part of this literal that we want to
- // copy over from src_literal.
- bool in_subtree_to_copy = true;
- for (int i = 0; i < dest_shape_index.size(); ++i) {
- if (index[i] != dest_shape_index[i]) {
- in_subtree_to_copy = false;
- break;
- }
- }
- if (!in_subtree_to_copy) {
- continue;
- }
-
- // Construct the index of the corresponding piece in the source literal.
- ShapeIndex src_piece_index = src_shape_index;
- for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
- src_piece_index.push_back(index[i]);
- }
+ return root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (!ShapeUtil::IsArray(piece->subshape())) {
+ return Status::OK();
+ }
- TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
- }
- return Status::OK();
-}
+ // Determine if this index is in the part of this literal that we want
+ // to copy over from src_literal.
+ bool in_subtree_to_copy = true;
+ for (int i = 0; i < dest_shape_index.size(); ++i) {
+ if (index[i] != dest_shape_index[i]) {
+ in_subtree_to_copy = false;
+ break;
+ }
+ }
+ if (!in_subtree_to_copy) {
+ return Status::OK();
+ }
+ // Construct the index of the corresponding piece in the source literal.
+ ShapeIndex src_piece_index = src_shape_index;
+ for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+ src_piece_index.push_back(index[i]);
+ }
+ TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
+ return Status::OK();
+ });
+} // namespace xla
Status Literal::MoveFrom(Literal&& src_literal,
const ShapeIndex& dest_shape_index) {
ShapeUtil::HumanString(src_literal.shape()).c_str());
}
- if (!(owns_buffers_ && src_literal.owns_buffers_)) {
- return InvalidArgument(
- "Source and destination literals must both own their buffers (ie, not "
- "be views)");
- }
+ src_literal.root_piece_->ForEachSubpiece(
+ [&](const ShapeIndex& src_index, const Piece& src_piece) {
+ if (!ShapeUtil::IsArray(src_piece.subshape())) {
+ return;
+ }
- for (auto& pair : src_literal.pieces_) {
- const ShapeIndex& src_index = pair.first;
- Piece& src_piece = pair.second;
- if (!ShapeUtil::IsArray(src_piece.subshape())) {
- continue;
- }
+ ShapeIndex dest_index = dest_shape_index;
+ for (int64 i : src_index) {
+ dest_index.push_back(i);
+ }
+ Piece& dest_piece = piece(dest_index);
+ delete[] dest_piece.buffer();
+ dest_piece.set_buffer(src_piece.buffer());
+ delete dest_piece.sparse_indices();
+ dest_piece.set_sparse_indices(src_piece.sparse_indices());
+ });
- ShapeIndex dest_index = dest_shape_index;
- for (int64 i : src_index) {
- dest_index.push_back(i);
- }
- Piece& dest_piece = piece(dest_index);
- delete[] dest_piece.buffer();
- dest_piece.set_buffer(src_piece.buffer());
- delete dest_piece.sparse_indices();
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- }
+ src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ delete src_literal.root_piece_;
+ src_literal.root_piece_ = new LiteralBase::Piece();
+ src_literal.root_piece_->set_subshape(src_literal.shape_.get());
- src_literal.shape_ = ShapeUtil::MakeNil();
- src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
- src_literal.piece({}).set_subshape(&src_literal.shape_);
return Status::OK();
}
-Status Literal::CopySliceFrom(const Literal& src_literal,
+Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size) {
return CreateR2FromArray2D(*value);
}
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
const Layout& new_layout, const ShapeIndex& shape_index) const {
// Create new shape with 'new_layout' set at the given shape index.
Shape new_shape = shape();
return result;
}
-std::unique_ptr<Literal> Literal::Relayout(
+std::unique_ptr<Literal> LiteralBase::Relayout(
const Shape& shape_with_layout) const {
CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
<< "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
return result;
}
-StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
+ *output->mutable_shape_do_not_use() =
+ ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
return std::move(output);
}
-std::unique_ptr<Literal> Literal::Transpose(
+std::unique_ptr<Literal> LiteralBase::Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
// representation intact.
// For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
// The shape with affine layout resulting from that operation will be
- // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
- // most minor.
+ // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized),
+ // the most minor.
//
// Essentially, given MinMaj(Di) the position of the Di dimension within the
// minor to major vector, and given T(Di) the index that the original Di
std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
- root_piece().size_bytes());
+ std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
return new_literal;
}
-std::unique_ptr<Literal> Literal::Slice(
+std::unique_ptr<Literal> LiteralBase::Slice(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) const {
CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
}
}
-Literal Literal::Clone() const {
+Literal LiteralBase::Clone() const {
Literal result(shape());
TF_CHECK_OK(result.CopyFrom(*this));
return result;
}
-std::unique_ptr<Literal> Literal::CloneToUnique() const {
+std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
auto result = MakeUnique<Literal>(shape());
TF_CHECK_OK(result->CopyFrom(*this));
return result;
}
-string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
+string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsDenseArray(subshape));
switch (subshape.element_type()) {
}
}
-string Literal::GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
+string LiteralBase::GetSparseElementAsString(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
CHECK(LayoutUtil::IsSparseArray(subshape));
switch (subshape.element_type()) {
}
}
-StatusOr<int64> Literal::GetIntegralAsS64(
+StatusOr<int64> LiteralBase::GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(shape()));
switch (shape().element_type()) {
return Status::OK();
}
-tensorflow::gtl::ArraySlice<int64> Literal::GetSparseIndex(
+tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
int64 sparse_element_number, const ShapeIndex& shape_index) const {
const Piece& p = piece(shape_index);
CHECK_GE(sparse_element_number, 0);
piece(shape_index).SortSparseElements();
}
-Literal Literal::GetFirstScalarLiteral() const {
- CHECK(ShapeUtil::IsArray(shape_));
- CHECK_GT(ShapeUtil::ElementsIn(shape_), 0);
- switch (shape_.element_type()) {
+Literal LiteralBase::GetFirstScalarLiteral() const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_GT(ShapeUtil::ElementsIn(shape()), 0);
+ switch (shape().element_type()) {
case PRED:
return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>()));
// 8 bit types.
case U64:
return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>()));
default:
- LOG(FATAL) << "Unhandled primitive type " << shape_.element_type();
+ LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
}
}
-void Literal::Piece::SortSparseElements() {
+void LiteralBase::Piece::SortSparseElements() {
switch (subshape().element_type()) {
case PRED:
SortSparseElementsInternal<bool>();
}
template <typename NativeT>
-void Literal::Piece::SortSparseElementsInternal() {
+void LiteralBase::Piece::SortSparseElementsInternal() {
CHECK(LayoutUtil::IsSparseArray(subshape()));
int64 num_elements = sparse_indices()->index_count();
auto values = data<NativeT>();
}
namespace {
-
-void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
+void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
bool print_layout, std::vector<string>* pieces) {
const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
+ CHECK(LayoutUtil::HasLayout(literal.shape()));
+ CHECK(LayoutUtil::HasLayout(subshape));
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
} // namespace
-int64 Literal::sparse_element_count() const {
+int64 LiteralBase::sparse_element_count() const {
CHECK(LayoutUtil::IsSparseArray(shape()));
return sparse_indices()->index_count();
}
-string Literal::ToString(bool print_layout) const {
+string LiteralBase::ToString(bool print_layout) const {
std::vector<string> pieces;
+ CHECK(LayoutUtil::HasLayout(this->shape()));
ToStringHelper(*this, {}, print_layout, &pieces);
return tensorflow::str_util::Join(pieces, "");
}
/* static */ std::unique_ptr<Literal> Literal::MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
std::vector<Shape> element_shapes;
- for (const Literal* element : elements) {
+ for (const auto* element : elements) {
element_shapes.push_back(element->shape());
}
auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
return literal;
}
+/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices(
+ tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
+ std::vector<Shape> element_shapes;
+ for (const auto& element : elements) {
+ element_shapes.push_back(element.shape());
+ }
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+ }
+ return literal;
+}
+
/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
std::vector<Shape> element_shapes;
return literal;
}
-void Literal::EachCellAsString(
+void LiteralBase::EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const {
if (ShapeUtil::HasZeroElements(shape())) {
namespace {
template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const Literal& src_literal, const ConverterType& converter) {
+ const LiteralBase& src_literal, const ConverterType& converter) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
src_literal.shape(),
}
template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(
+ const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
return tensorflow::bit_cast<NativeDestT>(src);
};
src_literal, converter);
}
-// This template specialization is here to make the compiler happy. bit_cast has
-// a static check that the types are the same size. This specialization should
-// never be used because the source and destination types are checked for
-// identical sizes higher up.
+// This template specialization is here to make the compiler happy. bit_cast
+// has a static check that the types are the same size. This specialization
+// should never be used because the source and destination types are checked
+// for identical sizes higher up.
template <typename NativeSrcT, typename NativeDestT>
typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const Literal& src_literal) {
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
LOG(FATAL) << "Invalid bitcast between types of different sizes.";
}
template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
auto result_literal = MakeUnique<Literal>(
ShapeUtil::ChangeElementType(src_literal.shape(), C64));
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
+std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
if (bitcast) {
template <PrimitiveType primitive_src_type>
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const Literal& src_literal, PrimitiveType primitive_dest_type,
+ const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
bool bitcast) {
switch (primitive_dest_type) {
#define CONVERT_IF_TYPES_MATCH(type) \
}
StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) {
+ const LiteralBase& literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
if (literal.shape().element_type() == primitive_dest_type) {
return literal.CloneToUnique();
} // namespace
-StatusOr<std::unique_ptr<Literal>> Literal::Convert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
PrimitiveType primitive_dest_type) const {
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
}
-StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
PrimitiveType primitive_dest_type) const {
if (primitive_util::BitWidth(shape().element_type()) !=
primitive_util::BitWidth(primitive_dest_type)) {
return InvalidArgument(
- "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+ "Cannot bitcast convert from %s to %s, bit widths are different: %d "
+ "!= "
"%d",
PrimitiveType_Name(shape().element_type()).c_str(),
PrimitiveType_Name(primitive_dest_type).c_str(),
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
+StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
const Shape& dest_shape, bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
if (round_f32_to_bf16 && shape().element_type() == F32 &&
}
std::vector<Literal> elements;
for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- auto element = LiteralView::Create(*this, {i});
+ auto element = LiteralSlice(*this, {i});
TF_ASSIGN_OR_RETURN(
auto new_element,
element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
}
template <typename NativeT>
-bool Literal::Piece::EqualElementsInternal(
- const Literal::Piece& other, std::vector<int64>* multi_index) const {
+bool LiteralBase::Piece::EqualElementsInternal(
+ const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
if (multi_index->size() == ShapeUtil::Rank(subshape())) {
return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
return true;
}
-bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
+bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
std::vector<int64> multi_index;
case C64:
return EqualElementsInternal<complex64>(other, &multi_index);
default:
- LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
+ LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
<< PrimitiveType_Name(subshape().element_type());
}
}
-bool Literal::operator==(const Literal& other) const {
+bool LiteralBase::operator==(const LiteralBase& other) const {
if (!ShapeUtil::Compatible(shape(), other.shape())) {
return false;
}
- for (const auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
- const Piece& other_piece = other.piece(index);
- if (!piece.EqualElements(other_piece)) {
- return false;
- }
- }
- return true;
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ const Piece& other_piece = other.piece(index);
+ if (!piece.EqualElements(other_piece)) {
+ return false;
+ }
+ return true;
+ });
}
namespace {
-
template <typename NativeT>
static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
NativeT value) {
} // namespace
-bool Literal::IsAll(int8 value) const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
+bool LiteralBase::IsAll(int8 value) const {
+ return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
+ const Piece& piece) {
if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
+ return true;
}
auto piece_is_all = [&]() {
if (!piece_is_all()) {
return false;
}
- }
- return true;
-}
+ return true;
+ });
+} // namespace xla
-bool Literal::IsAllFloat(float value) const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
+bool LiteralBase::IsAllFloat(float value) const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
- static_cast<bfloat16>(value));
- default:
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(
+ piece.data<bfloat16>(), static_cast<bfloat16>(value));
+ default:
+ return false;
+ }
+ };
+ if (!piece_is_all()) {
return false;
- }
- };
- if (!piece_is_all()) {
- return false;
- }
- }
- return true;
+ }
+ return true;
+ });
}
-bool Literal::IsAllComplex(complex64 value) const {
+bool LiteralBase::IsAllComplex(complex64 value) const {
switch (shape().element_type()) {
case C64:
return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
}
}
-bool Literal::IsAllFirst() const {
- for (const auto& pair : pieces_) {
- const Piece& piece = pair.second;
- if (!ShapeUtil::IsArray(piece.subshape())) {
- continue;
- }
-
- // Empty shapes are not all the first element since there is no first
- // element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
- return false;
- }
- auto piece_is_all = [&]() {
- switch (piece.subshape().element_type()) {
- case PRED: {
- auto data = piece.data<bool>();
- return AllElementsEqualValue<bool>(data, data[0]);
+bool LiteralBase::IsAllFirst() const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
}
- // 8 bit types
- case S8: {
- auto data = piece.data<int8>();
- return AllElementsEqualValue<int8>(data, data[0]);
- }
- case U8: {
- auto data = piece.data<uint8>();
- return AllElementsEqualValue<uint8>(data, data[0]);
- }
- // 16 bit types
- case BF16: {
- auto data = piece.data<bfloat16>();
- return AllElementsEqualValue<bfloat16>(data, data[0]);
- }
- case F16: {
- auto data = piece.data<half>();
- return AllElementsEqualValue<half>(data, data[0]);
- }
- case S16: {
- auto data = piece.data<int16>();
- return AllElementsEqualValue<int16>(data, data[0]);
- }
- case U16: {
- auto data = piece.data<uint16>();
- return AllElementsEqualValue<uint16>(data, data[0]);
- }
- // 32 bit types
- case F32: {
- auto data = piece.data<float>();
- return AllElementsEqualValue<float>(data, data[0]);
- }
- case U32: {
- auto data = piece.data<uint32>();
- return AllElementsEqualValue<uint32>(data, data[0]);
- }
- case S32: {
- auto data = piece.data<int32>();
- return AllElementsEqualValue<int32>(data, data[0]);
- }
- // 64 bit types
- case C64: {
- auto data = piece.data<complex64>();
- return AllElementsEqualValue<complex64>(data, data[0]);
- }
- case F64: {
- auto data = piece.data<double>();
- return AllElementsEqualValue<double>(data, data[0]);
- }
- case S64: {
- auto data = piece.data<int64>();
- return AllElementsEqualValue<int64>(data, data[0]);
- }
- case U64: {
- auto data = piece.data<uint64>();
- return AllElementsEqualValue<uint64>(data, data[0]);
- }
- default:
+
+ // Empty shapes are not all the first element since there is no first
+ // element.
+ if (ShapeUtil::HasZeroElements(piece.subshape())) {
return false;
- }
- };
+ }
+ auto piece_is_all = [&]() {
+ switch (piece.subshape().element_type()) {
+ case PRED: {
+ auto data = piece.data<bool>();
+ return AllElementsEqualValue<bool>(data, data[0]);
+ }
+ // 8 bit types
+ case S8: {
+ auto data = piece.data<int8>();
+ return AllElementsEqualValue<int8>(data, data[0]);
+ }
+ case U8: {
+ auto data = piece.data<uint8>();
+ return AllElementsEqualValue<uint8>(data, data[0]);
+ }
+ // 16 bit types
+ case BF16: {
+ auto data = piece.data<bfloat16>();
+ return AllElementsEqualValue<bfloat16>(data, data[0]);
+ }
+ case F16: {
+ auto data = piece.data<half>();
+ return AllElementsEqualValue<half>(data, data[0]);
+ }
+ case S16: {
+ auto data = piece.data<int16>();
+ return AllElementsEqualValue<int16>(data, data[0]);
+ }
+ case U16: {
+ auto data = piece.data<uint16>();
+ return AllElementsEqualValue<uint16>(data, data[0]);
+ }
+ // 32 bit types
+ case F32: {
+ auto data = piece.data<float>();
+ return AllElementsEqualValue<float>(data, data[0]);
+ }
+ case U32: {
+ auto data = piece.data<uint32>();
+ return AllElementsEqualValue<uint32>(data, data[0]);
+ }
+ case S32: {
+ auto data = piece.data<int32>();
+ return AllElementsEqualValue<int32>(data, data[0]);
+ }
+ // 64 bit types
+ case C64: {
+ auto data = piece.data<complex64>();
+ return AllElementsEqualValue<complex64>(data, data[0]);
+ }
+ case F64: {
+ auto data = piece.data<double>();
+ return AllElementsEqualValue<double>(data, data[0]);
+ }
+ case S64: {
+ auto data = piece.data<int64>();
+ return AllElementsEqualValue<int64>(data, data[0]);
+ }
+ case U64: {
+ auto data = piece.data<uint64>();
+ return AllElementsEqualValue<uint64>(data, data[0]);
+ }
+ default:
+ return false;
+ }
+ };
- if (!piece_is_all()) {
- return false;
- }
- }
- return true;
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
}
-bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
case U8:
}
namespace {
-
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
const tensorflow::gtl::ArraySlice<NativeT> src) {
} // namespace
-void Literal::Piece::WriteToProto(LiteralProto* proto) const {
+void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
*proto->mutable_shape() = subshape();
switch (subshape().element_type()) {
case PRED:
}
}
-const void* Literal::Piece::untyped_data() const {
+const void* LiteralBase::Piece::untyped_data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
return buffer();
}
-void* Literal::Piece::untyped_data() {
+void* LiteralBase::Piece::untyped_data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
return buffer();
}
namespace {
-
template <typename RepeatedFieldT, typename NativeT>
Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
const RepeatedFieldT& src) {
} // namespace
-Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
+Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in Literal::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
return Status::OK();
}
-LiteralProto Literal::ToProto() const {
+LiteralProto LiteralBase::ToProto() const {
LiteralProto proto;
- for (const auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- const Piece& piece = pair.second;
-
- LiteralProto* proto_piece = &proto;
- for (int64 i : index) {
- while (proto_piece->tuple_literals_size() <= i) {
- proto_piece->add_tuple_literals();
- }
- proto_piece = proto_piece->mutable_tuple_literals(i);
- }
- piece.WriteToProto(proto_piece);
- }
+ root_piece().ForEachSubpiece(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ LiteralProto* proto_piece = &proto;
+ for (int64 i : index) {
+ while (proto_piece->tuple_literals_size() <= i) {
+ proto_piece->add_tuple_literals();
+ }
+ proto_piece = proto_piece->mutable_tuple_literals(i);
+ }
+ piece.WriteToProto(proto_piece);
+ });
if (LayoutUtil::IsSparseArray(shape())) {
CopyToRepeatedField(proto.mutable_sparse_indices(),
auto literal = MakeUnique<Literal>(proto.shape());
- for (auto& pair : literal->pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- const LiteralProto* proto_element = &proto;
- for (int64 i : index) {
- TF_RET_CHECK(i < proto_element->tuple_literals_size());
- proto_element = &proto_element->tuple_literals(i);
- }
+ TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ const LiteralProto* proto_element = &proto;
+ for (int64 i : index) {
+ CHECK(i < proto_element->tuple_literals_size());
+ proto_element = &proto_element->tuple_literals(i);
+ }
- if (ShapeUtil::IsTuple(piece.subshape())) {
- if (proto_element->tuple_literals_size() !=
- ShapeUtil::TupleElementCount(piece.subshape())) {
- return InvalidArgument(
- "Expected %lld tuple elements in LiteralProto, has %d",
- ShapeUtil::TupleElementCount(piece.subshape()),
- proto_element->tuple_literals_size());
- }
- continue;
- }
+ if (ShapeUtil::IsTuple(piece->subshape())) {
+ if (proto_element->tuple_literals_size() !=
+ ShapeUtil::TupleElementCount(piece->subshape())) {
+ return InvalidArgument(
+ "Expected %lld tuple elements in LiteralProto, has %d",
+ ShapeUtil::TupleElementCount(piece->subshape()),
+ proto_element->tuple_literals_size());
+ }
+ return Status::OK();
+ }
- TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
- TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
- }
+ CHECK(ShapeUtil::IsArray(piece->subshape()));
+ TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
+
+ return Status::OK();
+ }));
return std::move(literal);
}
-const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
+const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
return piece(shape_index).untyped_data();
}
-int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
+int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
return piece(shape_index).size_bytes();
}
-string Literal::GetR1U8AsString() const {
+string LiteralBase::GetR1U8AsString() const {
CHECK(ShapeUtil::IsArray(shape()));
CHECK_EQ(ShapeUtil::Rank(shape()), 1);
CHECK_EQ(shape().element_type(), U8);
ShapeUtil::ElementsIn(shape()));
}
-/* static */ const LiteralView LiteralView::Create(
- const Literal& literal, const ShapeIndex& view_root) {
- return LiteralView(literal, view_root);
-}
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+ : LiteralBase(), root_piece_(&literal.root_piece()) {}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+ const ShapeIndex& view_root)
+ : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
-size_t Literal::Hash() const {
+size_t LiteralBase::Hash() const {
using tensorflow::Hash64;
using tensorflow::Hash64Combine;
return hash_value;
}
-LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
- shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
- pieces_ = ShapeTree<Piece>(shape_);
- owns_buffers_ = false;
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
-
- ShapeIndex src_index = view_root;
- for (int64 i : index) {
- src_index.push_back(i);
- }
- const Piece& src_piece = literal.piece(src_index);
- piece.set_buffer(src_piece.buffer());
- piece.set_sparse_indices(src_piece.sparse_indices());
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
-}
-
-LiteralView::~LiteralView() {}
-
-LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
-
-LiteralView& LiteralView::operator=(const LiteralView& other) {
- CopyFrom(other);
- return *this;
-}
-
-void LiteralView::CopyFrom(const LiteralView& other) {
- // We can't use the default copy-constructor/copy-assignment because
- // Piece::subshape_ points to subshapes within the Shape of the owning
- // Literal/LiteralView.
- shape_ = other.shape();
- pieces_ = other.pieces_;
- for (auto& pair : pieces_) {
- const ShapeIndex& index = pair.first;
- Piece& piece = pair.second;
- piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
- }
- owns_buffers_ = false;
-}
-
} // namespace xla
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/sparse_index_array.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace xla {
+// Forward declare Literal and LiteralSlice class to be used by the creation
+// methods in the base class.
+class Literal;
+class LiteralSlice;
+
+// Abstract base class for literals.
+class LiteralBase {
+ public:
+ virtual ~LiteralBase() = 0;
+
+ // Literals are equal if they have compatible shapes and the same data
+ // values. Layout is not compared.
+ bool operator==(const LiteralBase& other) const;
+ bool operator!=(const LiteralBase& other) const { return !(*this == other); }
+
+ // Returns the shape of the literal.
+ const Shape& shape() const { return root_piece().subshape(); }
+
+ // Serialize to proto.
+ LiteralProto ToProto() const;
+
+ // Returns an ArraySlice of the array for this literal for the given NativeT
+ // (e.g., float). CHECKs if the subshape of the literal at the given
+ // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
+ // to native type.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to the sparse index array. Returns nullptr if the
+ // literal is not a sparse array.
+ const SparseIndexArray* sparse_indices(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to (or size of) the underlying buffer holding the
+ // array at the given shape index. CHECKs if the subshape of the literal at
+ // the given ShapeIndex is not array.
+ const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+ int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+
+ // Returns this literal's data as a string. This literal must be a rank-1 U8
+ // array.
+ string GetR1U8AsString() const;
+
+ // Returns a string representation of the literal value.
+ // Warning: this function can take minutes for multi-million element Literals.
+ string ToString(bool print_layout = false) const;
+
+ // Gets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const;
+ // Overloads of Get for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the element value at index (0, ..., 0), however many zeroes are
+ // required for that index.
+ template <typename NativeT>
+ NativeT GetFirstElement() const;
+
+ // As Get(), but determines the correct type and converts the value
+ // into text.
+ string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index = {}) const;
+ // As GetSparseElement(), but determines the correct type and converts the
+ // value into text.
+ string GetSparseElementAsString(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+ // As Get(), but determines the correct type and converts the value into
+ // int64. This literal must be an array.
+ StatusOr<int64> GetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the multi-index of the element in a sparse literal at the given
+ // sparse element number. The sparse element number is the position with in
+ // the sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
+
+ // Returns the value of the element in a sparse literal at the given sparse
+ // element number. The sparse element number is the position with in the
+ // sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ template <typename NativeT>
+ NativeT GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+
+ // Invokes the "per cell" callback for each element in the provided
+ // literal with the element's indices and a string representation of
+ // the element's value.
+ //
+ // This function is useful if you want a polymorphic representation
+ // of the tensor's elements (turning it to a string for something
+ // like representation in a protobuf).
+ //
+ // This literal must have a dense layout.
+ void EachCellAsString(
+ const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>& per_cell) const;
+ template <typename NativeT>
+ void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) const;
+
+ // Returns whether every element in this literal is equal to value.
+ //
+ // value is an int8 because we expect this to be called with small
+ // compile-time constants (0, -1, etc.) and so that whatever value you pass
+ // can be represented exactly by floating-point types as small as 16 bits.
+ //
+ // If value doesn't fit in this literal's type, returns false. Values of 1/0
+ // are considered equal to true/false; other values are not considered equal
+ // to true. Also if this literal is not array-shaped false is returned.
+ bool IsAll(int8 value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular floating-point number.
+ //
+ // If the literal is not a floating-point value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for values that can be expressed precisely as a float,
+ // e.g. -0.5. Also if this literal is not array-shaped false is returned.
+ bool IsAllFloat(float value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular complex number.
+ //
+ // If the literal is not a complex value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for complex values that can be expressed precisely as
+ // float pairs e.g. (-0.5, 1.0).
+ //
+ // This literal must have a dense layout.
+ bool IsAllComplex(complex64 value) const;
+
+ // Literal consists entirely of the first element of the literal.
+ bool IsAllFirst() const;
+
+ // Returns whether this literal is zero at the specified index. This literal
+ // must be an array with a dense layout.
+ bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+
+ // Returns the count of the elements in the array at the given shape index in
+ // this literal.
+ int64 element_count(const ShapeIndex& index = {}) const {
+ return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ }
+
+ // Return the count of the elements in the sparse array at the given shape
+ // index in this literal, which will be no larger than
+ // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
+ int64 sparse_element_count() const;
+
+ // Compute a hash for this literal. This literal must not be a sparse tensor
+ // or a tuple containing a sparse tensor.
+ size_t Hash() const;
+
+ // Converts this literal to the given shape. Returns an error is the
+ // conversion is not possible.
+ //
+ // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
+ // instead of truncation; otherwise, truncation is used.
+ //
+ // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
+ // the default behavior.
+ StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+ const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+
+ // Converts this literal to another primitive type using a bitcast
+ // conversion. The to and from primitive types must have the same bit
+ // width. Returns an error if the conversion is not possible. This literal
+ // must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> BitcastConvert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Converts this literal to another primitive type. Returns an error if the
+ // conversion is not possible. This literal must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> Convert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Returns a literal scalar representing the first element.
+ Literal GetFirstScalarLiteral() const;
+
+ // Clones the underlying buffers into a new Literal, or new
+ // std::unique_ptr<Literal>.
+ Literal Clone() const;
+ std::unique_ptr<Literal> CloneToUnique() const;
+
+ // TODO(b/67651157): The methods below which perform computation on Literals
+ // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
+ // evaluator code which operates on Literals.
+ //
+ // Creates a new value that has the equivalent value as this
+ // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
+ // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
+ // minor-to-major dimension layout and the value in the cell at any given
+ // logical index (i0, i1) will be the same.
+ //
+ // For tuple shaped literals, shape_index should be used to select the inner
+ // array that the new layout applies to.
+ //
+ // Note: this is useful when the client wants to ensure that a value placed in
+ // the XLA allocation tracker has a particular layout; for efficiency
+ // purposes or avoiding unimplemented operation/layout combinations.
+ std::unique_ptr<Literal> Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
+
+ // An overload of Relayout which changes the layout of the entire shape rather
+ // than being limited to a single array within the shape.
+ std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+
+ // Creates a new literal by reshaping this literal to have the given
+ // dimensions. The total number of elements must not change; The
+ // implementation currently only supports monotonic dim0-major layouts.
+ // This literal must be an array.
+ StatusOr<std::unique_ptr<Literal>> Reshape(
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+ // Creates a new literal by reordering the dimensions of this literal.
+ // The given `permutation` must be a permutation of the dimension numbers
+ // in the original literal, and it specifies the order of the new dimensions
+ // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
+ // For example, a transpose call on a literal of shape [3 x 8 x 4] and
+ // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+ // This literal must be an array.
+ std::unique_ptr<Literal> Transpose(
+ tensorflow::gtl::ArraySlice<int64> permutation) const;
+
+ // Creates a sub-array from this literal by extracting the indices
+ // [start_index, limit_index) of each dimension. The result literal has the
+ // same rank and layout as for the given literal. The number of indices in
+ // start_indices and limit_indices must be the rank of the literal, and the
+ // indices follow the order of the dimensions.
+ // This literal must be an array.
+ std::unique_ptr<Literal> Slice(
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+
+ // Creates a literal with a prepended dimension with bound "times"; e.g. a
+ // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
+ // literal replicated four times.
+ // This literal must be an array.
+ template <typename NativeT>
+ std::unique_ptr<Literal> Replicate(int64 times) const;
+
+ // Creates a new Literal object with the shape specified as parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+
+ protected:
+ // A data structure representing a subshape at a particular ShapeIndex within
+ // the literal. For array-shaped ShapeIndexes, this data structure holds the
+ // pointer to the memory allocated for the array data.
+ class Piece {
+ public:
+ // Returns the buffer holding the array data for this piece as an array
+ // slice. This piece must be array-shaped.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data() const;
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+ // Returns the buffer holding the array data for this piece as a void*. This
+ // piece must be array-shaped.
+ void* untyped_data();
+ const void* untyped_data() const;
+
+ // Gets or sets an element in the array at the given index. The multi_index
+ // is CHECKed against the dimension sizes of the array. This piece must be
+ // array-shaped.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+ // Gets/sets the buffer holding the array data.
+ char* buffer() const { return buffer_; }
+ void set_buffer(char* buffer) { buffer_ = buffer; }
+
+ // The array of multi-indices that provide the locations of non-zero
+ // elements in a sparse array. Only used if
+ // LayoutUtil::IsSparseArray(shape()) is true.
+ SparseIndexArray* sparse_indices() const { return sparse_indices_; }
+ void set_sparse_indices(SparseIndexArray* sparse_indices) {
+ sparse_indices_ = sparse_indices;
+ }
+
+ // Gets or sets the subshape of this piece. This reference points to a
+ // subshape within the shape in the containing Literal (Literal::shape_).
+ const Shape& subshape() const { return *subshape_; }
+ void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+ // Returns the size in bytes of the buffer holding the array data.
+ int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+ // Returns the number of elements in this piece's array.
+ int64 element_count() const {
+ // If this is a sparse array, use the number of elements represented by
+ // the indices in the associated SparseIndexArray.
+ return LayoutUtil::IsSparseArray(subshape())
+ ? sparse_indices()->index_count()
+ : ShapeUtil::ElementsIn(subshape());
+ }
+
+ // Returns the child piece at 'index' of this piece.
+ Piece& child(int64 index) { return children_[index]; }
+
+ // Adds a child piece to this piece's children.
+ void emplace_back(Piece child_piece) {
+ children_.emplace_back(std::move(child_piece));
+ }
+
+ // Returns the size of children pieces of this piece.
+ int64 children_size() { return children_.size(); }
+
+ // Visitor functions that recursively traverses the piece and calls the
+ // given function at each child piece. The function has the type:
+ // void (const ShapeIndex& index, const Piece& piece)
+ template <typename Fn>
+ void ForEachSubpiece(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(
+ [&func](const ShapeIndex& index, const Piece& piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ *this, &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, const Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachSubpieceWithStatus(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Bool (const ShapeIndex& index, const Piece& piece)
+ // The first non-true return value is returned by the function.
+ template <typename Fn>
+ bool ForEachSubpieceWithBool(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelperBool(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Void (const ShapeIndex& index, Piece& piece)
+ template <typename Fn>
+ void ForEachMutableSubpiece(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ [&func](const ShapeIndex& index, Piece* piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ const_cast<xla::LiteralBase::Piece*>(this), &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachMutableSubpieceWithStatus(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ func, const_cast<xla::LiteralBase::Piece*>(this), &index);
+ }
+
+ // Returns true if this piece and 'other' contain the same data. This piece
+ // and 'other' must be array-shaped and compatible.
+ bool EqualElements(const Piece& other) const;
+
+ // Writes the shape and data (if array-shaped) into the given proto.
+ void WriteToProto(LiteralProto* proto) const;
+
+ // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+ // and src must be compatible.
+ Status CopyFrom(const Piece& src);
+
+ // Copies the data from the given proto into this piece. The shape of this
+ // piece must be equal (not just compatible) to the shape of the proto.
+ Status CopyFromProto(const LiteralProto& proto);
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements();
+
+ private:
+ // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
+ // The first non-OK (or non-true) value is returned by the function.
+ // The callable 'func' has the same signature as described above in
+ // ForEachSubpiece*.
+ template <typename Fn>
+ Status ForEachHelper(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+ template <typename Fn>
+ bool ForEachHelperBool(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ if (!func(*index, piece)) {
+ return false;
+ }
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ if (!ForEachHelperBool(func, piece.children_[i], index)) {
+ return false;
+ }
+ index->pop_back();
+ }
+ return true;
+ }
+ template <typename Fn>
+ Status ForEachMutableHelper(const Fn& func, Piece* piece,
+ ShapeIndex* index) {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece->children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(
+ ForEachMutableHelper(func, &piece->children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+
+ // Recursive helper for EqualElements.
+ template <typename NativeT>
+ bool EqualElementsInternal(const Piece& other,
+ std::vector<int64>* multi_index) const;
+
+ // Helper for SortSparseElements that has the element type as a template
+ // parameter.
+ template <typename NativeT>
+ void SortSparseElementsInternal();
+
+ // For array-shaped pieces, this is the buffer holding the literal data.
+ char* buffer_ = nullptr;
+
+ // For sparse arrays, this is the array of indices.
+ SparseIndexArray* sparse_indices_ = nullptr;
+
+ // The shape of piece. This points into the shape of the containing Literal
+ // (Literal::shape_).
+ const Shape* subshape_ = nullptr;
+
+ // Children pieces for tuple shaped pieces.
+ std::vector<Piece> children_ = {};
+ }; // class Piece
+
+ const Piece& piece(const ShapeIndex& shape_index) const {
+ Piece* piece = &const_cast<Piece&>(root_piece());
+ for (const auto i : shape_index) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, piece->children_size());
+ piece = &piece->child(i);
+ }
+ return *piece;
+ }
+
+ // Returns the piece at the root of the shape.
+ virtual const Piece& root_piece() const = 0;
+
+ // LiteralSlice and Literal must access Pieces of other Literals.
+ friend class LiteralSlice;
+ friend class Literal;
+};
+
// Class representing literal values in XLA.
//
-// TODO(b/67651157): The methods in this class should be reduced to a minimal
-// set of methods which construct Literals and accessors methods. Other methods
-// which perform computation on Literals (Reshape, Slice, etc) should be moved
-// elsewhere, and perhaps combined with evaluator code which operates on
-// Literals.
-class Literal {
+// The underlying buffer and shape is always owned by this class.
+class Literal : public LiteralBase {
public:
Literal() : Literal(ShapeUtil::MakeNil()) {}
Literal(const Shape& shape, bool allocate_arrays);
Literal& operator=(Literal&& other);
- // Literals are equal if they have compatible shapes and the same data
- // values. Layout is not compared.
- bool operator==(const Literal& other) const;
- bool operator!=(const Literal& other) const { return !(*this == other); }
+ // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+ // mutate the shape as this can produce malformed Literals.
+ Shape* mutable_shape_do_not_use() { return shape_.get(); }
- // Serialize to and from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
- LiteralProto ToProto() const;
+ // Returns a MutableArraySlice view of the array for this literal for the
+ // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
+ // given ShapeIndex is not array. See primitive_util.h for the mapping from
+ // XLA type to native type.
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::data;
+
+ // Returns a pointer to the sparse index array. Returns nullptr if the literal
+ // is not a sparse array.
+ SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+
+ // Returns a pointer to the underlying buffer holding the array at the given
+ // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
+ // is not array.
+ void* untyped_data(const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::untyped_data;
+
+ // Populates a literal with a sparse layout with the given indices and values.
+ // Each index in the indices array is CHECKed against the dimensions in the
+ // literal's shape. If sort is true, then the indices and values will be
+ // sorted. If sort is false, then the indices and values are assumed to
+ // already be in sorted order. See CreateSparse for an example of how data
+ // are populated.
+ template <typename NativeT>
+ void PopulateSparse(SparseIndexArray indices,
+ tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort = true);
+
+ // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+ // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+ // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+ // rooted at 'src_shape_index', but need not be arrays.
+ Status CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index = {},
+ const ShapeIndex& src_shape_index = {});
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
+ // Copies the values from src_literal, starting at src_base shape indexes,
+ // to this literal, starting at dest_base, where the copy size in each
+ // dimension is specified by copy_size.
+ // The src_literal and this literal must have the same primitive type,
+ // src_base+copy_size must fit the source literal dimensions, as well as
+ // dest_base+copy_size must fit the destination literal dimensions.
+ // Note: if either src_literal or this literal contains dimensions with zero
+ // element, then copy_size must be 0 in these dimensions while the
+ // corresponding base indices being 0.
+ // This literal and 'src_literal' must be arrays.
+ Status CopySliceFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Copies one element from src_literal[src_index] to (*this)[dest_index].
+ Status CopyElementFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index);
- // Return the shape of the literal.
- const Shape& shape() const { return shape_; }
+ // Sets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value);
+ // Overloads of Set for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+
+ // Appends the given element to the literal. If the elements are not appended
+ // in sorted order, then SortSparseElements should be called before calling
+ // other methods. This literal must have a sparse layout.
+ template <typename NativeT>
+ void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value, const ShapeIndex& shape_index = {});
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements(const ShapeIndex& shape_index = {});
+
+ // As Set(), but truncates `value` to the literal element type before storing.
+ // This literal must be an array.
+ Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value);
+
+ // Populate this literal with the given values. Examples:
+ //
+ // // Populate with floats.
+ // Array2D<float> float_values = ...
+ // literal.PopulateR2FromArray2D(values);
+ //
+ // // Populate with int32s.
+ // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
+ //
+ // The shape and element type of this literal must match given values. For
+ // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+ // array of S32.
+ template <typename NativeT>
+ void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ void PopulateR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ void PopulateFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR4FromArray4D(const Array4D<NativeT>& values);
+
+ // Populates literal values by calling the generator function for every cell
+ // in this literal object.
+ //
+ // generator must be a callable of the type
+ // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ //
+ // This literal must have a dense layout.
+ template <typename NativeT, typename FnType>
+ Status Populate(const FnType& generator);
- // TODO(b/67651157): Remove this accessor. Literal users should not be able to
- // mutate the shape as this can produce malformed Literals.
- Shape* mutable_shape_do_not_use() { return &shape_; }
+ // A parallel version of Populate(). This can be used if the generator is
+ // thread-safe and the values for the shape's different elements are
+ // independent.
+ template <typename NativeT, typename FnType>
+ Status PopulateParallel(const FnType& generator);
- // Returns a (Mutable)ArraySlice view of the array for this literal for the
- // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
- // given ShapeIndex is not array. See primitive_util.h for the mapping from
- // XLA type to native type.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {}) const;
+ // Fills this literal with the given value.
template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {});
+ void PopulateWithValue(NativeT value);
- // Returns a pointer to the sparse index array. Returns nullptr if the literal
- // is not a sparse array.
- const SparseIndexArray* sparse_indices(
- const ShapeIndex& shape_index = {}) const;
- SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+ // Factory methods below.
+ //
- // Returns a pointer to (or size of) the underlying buffer holding the array
- // at the given shape index. CHECKs if the subshape of the literal at the
- // given ShapeIndex is not array.
- const void* untyped_data(const ShapeIndex& shape_index = {}) const;
- void* untyped_data(const ShapeIndex& shape_index = {});
- int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+ // Serialize from a proto.
+ static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+ const LiteralProto& proto);
// Creates a new literal of a given rank. To minimize ambiguity (for users
// and the compiler) these CreateR[0-2] methods should explicitly specify the
values,
const Layout& layout);
- // Returns this literal's data as a string. This literal must be a rank-1 U8
- // array.
- string GetR1U8AsString() const;
-
// Creates a literal with a sparse layout and the given indices and values.
// The shape is initialized from the given dimensions. The minor dimension of
// the indices array must equal the rank of the shape (i.e. size of the
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values, bool sort = true);
- // Populates a literal with a sparse layout with the given indices and values.
- // Each index in the indices array is CHECKed against the dimensions in the
- // literal's shape. If sort is true, then the indices and values will be
- // sorted. If sort is false, then the indices and values are assumed to
- // already be in sorted order. See CreateSparse for an example of how data
- // are populated.
- template <typename NativeT>
- void PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort = true);
-
- // Creates a new Literal object with the shape specified as parameter.
- // The content of the literal values is the default value of the primitive
- // type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
-
- // Creates a new Literal object with its values havings the primitive_type
- // type, and with dimensions defined by the dimensions parameter.
- // The content of the literal values is the default value of the primitive
- // type of literal itself (0 for numeric types, and false for predicates).
- static std::unique_ptr<Literal> CreateFromDimensions(
- PrimitiveType primitive_type,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Copy values from 'src_literal' rooted at 'src_shape_index' into this
- // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
- // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
- // rooted at 'src_shape_index', but need not be arrays.
- Status CopyFrom(const Literal& src_literal,
- const ShapeIndex& dest_shape_index = {},
- const ShapeIndex& src_shape_index = {});
-
- // Similar to CopyFrom, but with move semantincs. The subshape of this literal
- // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
- // (layouts and shapes must match), but need not be arrays. The memory
- // allocated in this literal for the subshape at dest_shape_index is
- // deallocated, and the respective buffers are replaced with those in
- // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
- Status MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index = {});
-
- // Copies the values from src_literal, starting at src_base shape indexes,
- // to this literal, starting at dest_base, where the copy size in each
- // dimension is specified by copy_size.
- // The src_literal and this literal must have the same primitive type,
- // src_base+copy_size must fit the source literal dimensions, as well as
- // dest_base+copy_size must fit the destination literal dimensions.
- // Note: if either src_literal or this literal contains dimensions with zero
- // element, then copy_size must be 0 in these dimensions while the
- // corresponding base indices being 0.
- // This literal and 'src_literal' must be arrays.
- Status CopySliceFrom(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
-
- // Copies one element from src_literal[src_index] to (*this)[dest_index].
- Status CopyElementFrom(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index);
-
- // Returns a vector containing the tuple elements of this Literal as separate
- // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
- // elements are moved into the new Literals; no data is copied. Upon return
- // this Literal is set to a nil shape (empty tuple)
- std::vector<Literal> DecomposeTuple();
-
- // This operation is the inverse of DecomposeTuple. The given elements are
- // moved into the tuple elements of a new tuple-shaped Literal which is
- // returned. Upon return, each of the Literals in 'elements' is set to a nil
- // shape (empty tuple).
- static Literal MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements);
-
- // Creates a new value that has the equivalent value as this literal, but
- // conforms to new_layout; e.g. a literal matrix that was in {0, 1}
- // minor-to-major dimension layout can be re-layed-out as {1, 0}
- // minor-to-major dimension layout and the value in the cell at any given
- // logical index (i0, i1) will be the same.
- //
- // For tuple shaped literals, shape_index should be used to select the inner
- // array that the new layout applies to.
- //
- // Note: this is useful when the client wants to ensure that a value placed in
- // the XLA allocation tracker has a particular layout; for efficiency
- // purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
-
- // An overload of Relayout which changes the layout of the entire shape rather
- // than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
-
- // Creates a new literal by reshaping this literal to have the given
- // dimensions. The total number of elements must not change; The
- // implementation currently only supports monotonic dim0-major layouts.
- // This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
- // Creates a new literal by reordering the dimensions of this literal.
- // The given `permutation` must be a permutation of the dimension numbers
- // in the original literal, and it specifies the order of the new dimensions
- // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
- // For example, a transpose call on a literal of shape [3 x 8 x 4] and
- // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
- // This literal must be an array.
- std::unique_ptr<Literal> Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const;
-
- // Creates a sub-array from this literal by extracting the indices
- // [start_index, limit_index) of each dimension. The result literal has the
- // same rank and layout as for the given literal. The number of indices in
- // start_indices and limit_indices must be the rank of the literal, and the
- // indices follow the order of the dimensions.
- // This literal must be an array.
- std::unique_ptr<Literal> Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const;
-
- // Creates a literal with a prepended dimension with bound "times"; e.g. a
- // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
- // literal replicated four times.
- // This literal must be an array.
- template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
-
- // Converts this literal to another primitive type using
- // static_cast<>. Returns an error if the conversion is not possible. This
- // literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
-
- // Converts this literal to another primitive type using a bitcast
- // conversion. The to and from primitive types must have the same bit
- // width. Returns an error if the conversion is not possible. This literal
- // must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
-
- // Converts this literal to the given shape. Returns an error is the
- // conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
-
// Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
-
// Creates a scalar literal value one of the given primitive type.
static Literal One(PrimitiveType primitive_type);
-
// Creates a scalar literal value containing the minimum value of the given
// primitive type. For floating-point types, returns -inf.
static Literal MinValue(PrimitiveType primitive_type);
-
// Creates a scalar literal value containing the maximum value of the given
// primitive type. For floating-point types, returns inf.
static Literal MaxValue(PrimitiveType primitive_type);
-
// Creates a literal of the given shape where each element is `value`.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
// the z dimension given by "projection".
template <typename NativeT>
static std::unique_ptr<Literal> CreateR3Projected(
- std::initializer_list<std::initializer_list<NativeT>> values,
- int64 projection);
-
- // Creates a literal that projects the (x, y) dimensions given in values into
- // the z and p dimensions given.
- template <typename NativeT>
- static std::unique_ptr<Literal> CreateR4Projected(
- std::initializer_list<std::initializer_list<NativeT>> values,
- int64 projection_p, int64 projection_z);
-
- // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
- Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
-
- // Gets or sets an element in the literal at the given index. The multi_index
- // is CHECKed against the dimension sizes.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value);
-
- // Overloads of Get and Set for array literals. CHECKs if the literal is not
- // array-shaped and dense.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
-
- // Returns the multi-index of the element in a sparse literal at the given
- // sparse element number. The sparse element number is the position with in
- // the sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
- int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
-
- // Returns the value of the element in a sparse literal at the given sparse
- // element number. The sparse element number is the position with in the
- // sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- template <typename NativeT>
- NativeT GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
-
- // Appends the given element to the literal. If the elements are not appended
- // in sorted order, then SortSparseElements should be called before calling
- // other methods. This literal must have a sparse layout.
- template <typename NativeT>
- void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value, const ShapeIndex& shape_index = {});
-
- // Sorts the elements in a sparse array.
- void SortSparseElements(const ShapeIndex& shape_index = {});
-
- // Returns the element value at index (0, ..., 0), however many zeroes are
- // required for that index.
- template <typename NativeT>
- NativeT GetFirstElement() const;
-
- // Returns a literal scalar representing the first element.
- Literal GetFirstScalarLiteral() const;
-
- // As Get(), but determines the correct type and converts the value
- // into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index = {}) const;
-
- // As GetSparseElement(), but determines the correct type and converts the
- // value into text.
- string GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
-
- // As Get(), but determines the correct type and converts the value into
- // int64. This literal must be an array.
- StatusOr<int64> GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ int64 projection);
- // As Set(), but truncates `value` to the literal element type before storing.
- // This literal must be an array.
- Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value);
+ // Creates a literal that projects the (x, y) dimensions given in values into
+ // the z and p dimensions given.
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR4Projected(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ int64 projection_p, int64 projection_z);
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
static std::unique_ptr<Literal> MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements);
+ static std::unique_ptr<Literal> MakeTupleFromSlices(
+ tensorflow::gtl::ArraySlice<LiteralSlice> elements);
+
// As above, but intended to be invoked with move semantics; i.e.
//
// std::vector<std::unique_ptr<Literal>> elements = ...;
return MakeTupleOwned(std::move(v));
}
- // Returns a string representation of the literal value.
- // Warning: this function can take minutes for multi-million element Literals.
- string ToString(bool print_layout = false) const;
-
- // Invokes the "per cell" callback for each element in the provided
- // literal with the element's indices and a string representation of
- // the element's value.
- //
- // This function is useful if you want a polymorphic representation
- // of the tensor's elements (turning it to a string for something
- // like representation in a protobuf).
- //
- // This literal must have a dense layout.
- void EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- const string& value)>& per_cell) const;
- template <typename NativeT>
- void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const;
-
- // Populate this literal with the given values. Examples:
- //
- // // Populate with floats.
- // Array2D<float> float_values = ...
- // literal.PopulateR2FromArray2D(values);
- //
- // // Populate with int32s.
- // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
- //
- // The shape and element type of this literal must match given values. For
- // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
- // array of S32.
- template <typename NativeT>
- void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
- void PopulateR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- void PopulateFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- void PopulateR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- void PopulateR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- void PopulateR4FromArray4D(const Array4D<NativeT>& values);
-
- // Populates literal values by calling the generator function for every cell
- // in this literal object.
- //
- // generator must be a callable of the type
- // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
- //
- // This literal must have a dense layout.
- template <typename NativeT, typename FnType>
- Status Populate(const FnType& generator);
-
- // A parallel version of Populate(). This can be used if the generator is
- // thread-safe and the values for the shape's different elements are
- // independent.
- template <typename NativeT, typename FnType>
- Status PopulateParallel(const FnType& generator);
-
- // Fills this literal with the given value.
- template <typename NativeT>
- void PopulateWithValue(NativeT value);
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
- // Returns whether every element in this literal is equal to value.
- //
- // value is an int8 because we expect this to be called with small
- // compile-time constants (0, -1, etc.) and so that whatever value you pass
- // can be represented exactly by floating-point types as small as 16 bits.
- //
- // If value doesn't fit in this literal's type, returns false. Values of 1/0
- // are considered equal to true/false; other values are not considered equal
- // to true. Also if this literal is not array-shaped false is returned.
- bool IsAll(int8 value) const;
+ // This operation is the inverse of DecomposeTuple. The given elements are
+ // moved into the tuple elements of a new tuple-shaped Literal which is
+ // returned. Upon return, each of the Literals in 'elements' is set to a nil
+ // shape (empty tuple).
+ static Literal MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements);
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular floating-point number.
- //
- // If the literal is not a floating-point value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for values that can be expressed precisely as a float,
- // e.g. -0.5. Also if this literal is not array-shaped false is returned.
- bool IsAllFloat(float value) const;
+ // Creates a new Literal object with its values havings the primitive_type
+ // type, and with dimensions defined by the dimensions parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ static std::unique_ptr<Literal> CreateFromDimensions(
+ PrimitiveType primitive_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular complex number.
- //
- // If the literal is not a complex value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for complex values that can be expressed precisely as
- // float pairs e.g. (-0.5, 1.0).
//
- // This literal must have a dense layout.
- bool IsAllComplex(complex64 value) const;
+ // End of factory methods.
- // Literal consists entirely of the first element of the literal.
- bool IsAllFirst() const;
-
- // Returns whether this literal is zero at the specified index. This literal
- // must be an array with a dense layout.
- bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+ protected:
+ // Recursively sets the subshapes and buffers of all subpieces rooted at
+ // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+ // the shape.
+ void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
- // Return the count of the elements in the array at the given shape index in
- // this literal.
- int64 element_count(const ShapeIndex& index = {}) const {
- return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ // Returns the piece at the given ShapeIndex.
+ Piece& piece(const ShapeIndex& shape_index) {
+ return const_cast<Piece&>(LiteralBase::piece(shape_index));
}
- // Return the count of the elements in the sparse array at the given shape
- // index in this literal, which will be no larger than
- // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
- int64 sparse_element_count() const;
-
- // Compute a hash for this literal. This literal must not be a sparse tensor
- // or a tuple containing a sparse tensor.
- size_t Hash() const;
+ Piece& root_piece() const override { return *root_piece_; };
- protected:
+ private:
// Internal template helper for the Literal::CopySliceFrom(), matching its
// arguments one by one.
template <typename NativeT>
- Status CopySliceFromInternal(const Literal& src_literal,
+ Status CopySliceFromInternal(const LiteralBase& src_literal,
tensorflow::gtl::ArraySlice<int64> src_base,
tensorflow::gtl::ArraySlice<int64> dest_base,
tensorflow::gtl::ArraySlice<int64> copy_size);
int64 minor_loop_size = 1;
};
- // A data structure representing a subshape at a particular ShapeIndex within
- // the literal. For array-shaped ShapeIndexes, this data structure holds the
- // pointer to the memory allocated for the array data.
- class Piece {
- public:
- // Return the buffer holding the array data for this piece as an array
- // slice. This piece must be array-shaped.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data() const;
- template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data();
-
- // Return the buffer holding the array data for this piece as a void*. This
- // piece must be array-shaped.
- void* untyped_data();
- const void* untyped_data() const;
-
- // Gets or sets an element in the array at the given index. The multi_index
- // is CHECKed against the dimension sizes of the array. This piece must be
- // array-shaped.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
-
- // Gets/sets the buffer holding the array data.
- char* buffer() const { return buffer_; }
- void set_buffer(char* buffer) { buffer_ = buffer; }
-
- // The array of multi-indices that provide the locations of non-zero
- // elements in a sparse array. Only used if
- // LayoutUtil::IsSparseArray(shape()) is true.
- SparseIndexArray* sparse_indices() const { return sparse_indices_; }
- void set_sparse_indices(SparseIndexArray* sparse_indices) {
- sparse_indices_ = sparse_indices;
- }
-
- // Gets or sets the subshape of this piece. This reference points to a
- // subshape within the shape in the containing Literal (Literal::shape_).
- const Shape& subshape() const { return *subshape_; }
- void set_subshape(const Shape* subshape) { subshape_ = subshape; }
-
- // Returns the size in bytes of the buffer holding the array data.
- int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
-
- // Returns the number of elements in this piece's array.
- int64 element_count() const {
- // If this is a sparse array, use the number of elements represented by
- // the indices in the associated SparseIndexArray.
- return LayoutUtil::IsSparseArray(subshape())
- ? sparse_indices()->index_count()
- : ShapeUtil::ElementsIn(subshape());
- }
-
- // Copy the data from 'src' into this piece's buffer. Shapes of this piece
- // and src must be compatible.
- Status CopyFrom(const Piece& src);
-
- // Returns true if this piece and 'other' contain the same data. This piece
- // and 'other' must be array-shaped and compatible.
- bool EqualElements(const Piece& other) const;
-
- // Writes the shape and data (if array-shaped) into the given proto.
- void WriteToProto(LiteralProto* proto) const;
-
- // Copies the data from the given proto into this piece. The shape of this
- // piece must be equal (not just compatible) to the shape of the proto.
- Status CopyFromProto(const LiteralProto& proto);
-
- // Sorts the elements in a sparse array.
- void SortSparseElements();
-
- private:
- // Recursive helper for EqualElements.
- template <typename NativeT>
- bool EqualElementsInternal(const Piece& other,
- std::vector<int64>* multi_index) const;
-
- // Helper for SortSparseElements that has the element type as a template
- // parameter.
- template <typename NativeT>
- void SortSparseElementsInternal();
-
- // For array-shaped pieces, this is the buffer holding the literal data.
- char* buffer_ = nullptr;
-
- // For sparse arrays, this is the array of indices.
- SparseIndexArray* sparse_indices_ = nullptr;
-
- // The shape of piece. This points into the shape of the containing Literal
- // (Literal::shape_).
- const Shape* subshape_ = nullptr;
- };
-
- // Returns the piece at the given ShapeIndex.
- Piece& piece(const ShapeIndex& shape_index) {
- return *pieces_.mutable_element(shape_index);
- }
- const Piece& piece(const ShapeIndex& shape_index) const {
- return pieces_.element(shape_index);
- }
-
- // Returns the piece at the root of the shape (empty ShapeIndex).
- Piece& root_piece() { return piece({}); }
- const Piece& root_piece() const { return piece({}); }
+ // Literal class always owns the shape. The parent class borrows this shape.
+ std::unique_ptr<Shape> shape_;
- // Deallocate the buffers held by this literal (if the literal owns the
- // buffer).
- void DeallocateBuffers();
+ Piece* root_piece_ = nullptr;
// Implementation details shared between Populate() and PopulateParallel()
template <typename NativeT, typename FnType>
Status PopulateInternal(const FnType& generator, bool parallel);
- Shape shape_;
- ShapeTree<Piece> pieces_;
-
- // Whether the buffers held in pieces_ are owned by this Literal.
- bool owns_buffers_;
+ // Deallocate the buffers held by this literal.
+ void DeallocateBuffers();
- // LiteralView must access and manipulate Pieces of other Literals.
- friend class LiteralView;
-}; // namespace xla
+ friend class LiteralBase;
+};
std::ostream& operator<<(std::ostream& out, const Literal& literal);
-// A read-only view of a Literal. A LiteralView contains pointers to buffers
-// owned by the viewed Literal.
-//
-// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
-// and mutable) similar to (Mutable)ArraySlice.
-class LiteralView : public Literal {
+// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
+// literal buffers always owned by others.
+class LiteralSlice : public LiteralBase {
public:
- // Create and return a view of the given literal rooted at the given shape
- // index within the given literal. A factory is used rather than a public
- // constructor because only const LiteralViews are supported. It's still
- // possible to create non-const LiteralViews via the copy constructors, but
- // the factory method makes it a bit less likely. Implementing literal slices
- // will fix this undesirable situation (b/71550060).
- static const LiteralView Create(const Literal& literal,
- const ShapeIndex& view_root = {});
-
- LiteralView(const LiteralView& other);
- LiteralView& operator=(const LiteralView& other);
-
- virtual ~LiteralView();
+ LiteralSlice() : LiteralBase() {}
+ // Implicit conversion constructor that can also accept Literal.
+ LiteralSlice(const LiteralBase& literal);
+ LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
private:
- LiteralView(const Literal& literal, const ShapeIndex& view_root);
+ const Piece& root_piece() const override { return *root_piece_; };
- // Helper for the copy constructor and copy assignment operator.
- void CopyFrom(const LiteralView& other);
+ const Piece* root_piece_; // Not owned.
};
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
}
template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
+tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
CHECK_EQ(subshape().element_type(),
primitive_util::NativeToPrimitiveType<NativeT>())
}
template <typename NativeT>
-NativeT Literal::Piece::Get(
+NativeT LiteralBase::Piece::Get(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
CHECK(LayoutUtil::IsDenseArray(subshape()));
return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
}
template <typename NativeT>
-void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
+void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
CHECK(LayoutUtil::IsDenseArray(subshape()));
data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
subshape(), multi_index)] = value;
}
template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> Literal::data(
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
const ShapeIndex& shape_index) const {
return piece(shape_index).data<NativeT>();
}
}
template <typename NativeT>
-inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
+inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
return piece(shape_index).Get<NativeT>(multi_index);
}
template <typename NativeT>
-inline NativeT Literal::Get(
+inline NativeT LiteralBase::Get(
tensorflow::gtl::ArraySlice<int64> multi_index) const {
return root_piece().Get<NativeT>(multi_index);
}
}
template <typename NativeT>
-NativeT Literal::GetFirstElement() const {
+NativeT LiteralBase::GetFirstElement() const {
return data<NativeT>().at(0);
}
template <typename NativeT>
-NativeT Literal::GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
+NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index) const {
CHECK(
LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
return data<NativeT>(shape_index)[sparse_element_number];
}
template <typename NativeT>
-void Literal::EachCell(
+void LiteralBase::EachCell(
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
NativeT value)>
per_cell) const {
}
template <typename NativeT>
-std::unique_ptr<Literal> Literal::Replicate(int64 times) const {
+std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
DimensionVector bounds = {times};
bounds.reserve(shape().dimensions_size() + 1);
for (int64 bound : shape().dimensions()) {
Literal::CreateR1<double>({2.0, 4.0}).get(),
&nil_literal});
- EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+ EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
/*src_shape_index=*/{}));
// The matrix element should be unchanged.
- EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+ EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
// The tuple element should have been copied from 'tuple'.
EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
ASSERT_EQ(h1, r[3]);
}
-TEST_F(LiteralUtilTest, LiteralViewTest) {
+TEST_F(LiteralUtilTest, LiteralSliceTest) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
Literal nil(ShapeUtil::MakeNil());
- EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar);
- EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix);
- EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple);
- EXPECT_EQ(LiteralView::Create(nil, {}), nil);
+ EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
+ EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
+ EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralSlice(nil, {}), nil);
- EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar);
- EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix);
+ EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
+ EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix);
- EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
+ EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
}
-TEST_F(LiteralUtilTest, MutatingLiteralView) {
+TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
- const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+ const auto nested_tuple_view = LiteralSlice(*nested_tuple);
EXPECT_EQ(
nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1.0f);
555.0f);
}
-TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) {
+TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
- const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
- const auto tuple_view =
- LiteralView::Create(nested_tuple_view, /*view_root=*/{0});
- const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1});
+ const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+ const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
+ const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
}
-TEST_F(LiteralUtilTest, LiteralViewCopy) {
+TEST_F(LiteralUtilTest, LiteralSliceCopy) {
std::unique_ptr<Literal> matrix =
Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- const auto matrix_view = LiteralView::Create(*matrix);
- LiteralView matrix_view_copy(matrix_view);
+ const auto matrix_view = LiteralSlice(*matrix);
+ LiteralSlice matrix_view_copy(matrix_view);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
return result;
}
-PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
+PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
if (ShapeUtil::IsTuple(literal.shape())) {
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
PyObject* tuple = PyTuple_New(num_elements);
for (int i = 0; i < num_elements; i++) {
- PyTuple_SET_ITEM(
- tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i})));
+ PyTuple_SET_ITEM(tuple, i,
+ PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
}
return tuple;
} else {
return Status::OK();
}
-void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array) {
switch (np_type) {
case NPY_BOOL:
// array data.
//
// The return value is a new reference.
-PyObject* PyObjectFromXlaLiteral(const Literal& literal);
+PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
// Converts a Numpy ndarray or a nested Python tuple thereof to a
// corresponding XLA literal.
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal);
-void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array);
template <typename NativeT>
}
template <typename NativeT>
-void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
+void CopyLiteralToNumpyArray(const LiteralSlice& literal,
+ PyArrayObject* py_array) {
NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
auto source = literal.data<NativeT>();
std::copy(source.begin(), source.end(), dest);
}
static HloInstruction* BuildTupleConstant(HloComputation* computation,
- const Literal& literal) {
+ const LiteralSlice& literal) {
if (ShapeUtil::IsTuple(literal.shape())) {
std::vector<HloInstruction*> elems;
elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
elems.push_back(
- BuildTupleConstant(computation, LiteralView::Create(literal, {i})));
+ BuildTupleConstant(computation, LiteralSlice(literal, {i})));
}
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
: GenericTransferManager(se::host::kHostPlatformId,
/*pointer_size=*/sizeof(void*)) {}
-Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) {
+Status CpuTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const LiteralSlice& literal) {
const Shape& shape = literal.shape();
VLOG(2) << "Transferring literal to infeed with shape: "
<< ShapeUtil::HumanString(shape);
~CpuTransferManager() override {}
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) override;
+ const LiteralSlice& literal) override;
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
namespace xla {
namespace cpu {
-void ExternalConstantPool::Insert(string name, const Literal& literal,
+void ExternalConstantPool::Insert(string name, const LiteralSlice& literal,
int64 alignment) {
CHECK(!ShapeUtil::IsTuple(literal.shape()));
CHECK(alignment > 0 && IsPowerOfTwo(static_cast<uint64>(alignment)));
CHECK(entries_.find(name) == entries_.end());
- int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
+ const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
void* raw_pointer = tensorflow::port::AlignedMalloc(
literal_size, std::max<size_t>(alignment, sizeof(void*)));
CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
// The constant pool copies out the contents of `literal` into a buffer it
// owns -- it does not keep pointers to `literal`, or to memory owned by
// `literal`.
- void Insert(string name, const Literal& literal, int64 alignment);
+ void Insert(string name, const LiteralSlice& literal, int64 alignment);
// Find the constant with name `name` in this constant pool. If there isn't
// such constant, return nullptr.
TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
- const auto subliteral = LiteralView::Create(literal, index);
+ const auto subliteral = LiteralSlice(literal, index);
std::unique_ptr<Literal> relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
}
Status GenericTransferManager::TransferLiteralToInfeed(
- se::StreamExecutor* executor, const Literal& literal) {
+ se::StreamExecutor* executor, const LiteralSlice& literal) {
return Unimplemented("Generic transfer to Infeed");
}
const ShapedBuffer& device_buffer) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) override;
+ const LiteralSlice& literal) override;
Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
const Shape& literal_shape,
Literal* literal) override;
/*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout)
.getPointerSize(0 /* default address space */)) {}
-Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) {
+Status GpuTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const LiteralSlice& literal) {
const Shape& shape = literal.shape();
VLOG(2) << "Transferring literal to infeed with shape: "
<< ShapeUtil::HumanString(shape);
~GpuTransferManager() override {}
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) override;
+ const LiteralSlice& literal) override;
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source) override;
template <typename OperandT>
StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
- const Literal& lhs_literal,
- const Literal& rhs_literal) {
+ LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(OperandT, OperandT)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
template <>
StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
- const Shape& shape, HloOpcode opcode, const Literal& lhs_literal,
- const Literal& rhs_literal) {
+ const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
+ LiteralSlice rhs_literal) {
std::function<bool(complex64, complex64)> compare_op;
switch (opcode) {
case HloOpcode::kEq:
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
- const Literal& literal) = 0;
+ const LiteralSlice& literal) = 0;
// Transfers the given literal from the Outfeed interface of the device,
// using the given executor.
LiteralTestUtil::ExpectNear(
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralView::Create(*result, {0}), error_spec_);
+ LiteralSlice(*result, {0}), error_spec_);
LiteralTestUtil::ExpectNear(
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralView::Create(*result, {1}), error_spec_);
+ LiteralSlice(*result, {1}), error_spec_);
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- LiteralView::Create(*result, {0}));
+ LiteralSlice(*result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- LiteralView::Create(*result, {1}));
+ LiteralSlice(*result, {1}));
EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
LiteralTestUtil::ExpectR2Near<float>(
- {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_);
+ {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_);
LiteralTestUtil::ExpectR1Near<float>(
- {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_);
+ {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_);
}
} // namespace
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(const Literal& literal) {
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
// First construct shape of the result.
Shape result_shape(literal.shape());
ShapeUtil::ForEachMutableSubshape(
} // namespace
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
- const Literal& literal) {
+ LiteralSlice literal) {
return ConvertType<bfloat16, float>(literal);
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
- const Literal& literal) {
+ LiteralSlice literal) {
return ConvertType<float, bfloat16>(literal);
}
// actual literal and compares their values elementwise. Returns true if all
// elements are equal.
template <typename NativeT>
-bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
+bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
tensorflow::gtl::MutableArraySlice<int64> multi_index,
int64 dimension) {
if (dimension == expected.shape().dimensions_size()) {
} // namespace
-/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
- const Literal& actual,
+/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
+ LiteralSlice actual,
const string& message) {
EXPECT_TRUE(Equal(expected, actual))
<< "expected:\n"
<< (message.empty() ? "" : StrCat("\nmessage: ", message));
}
-/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
- const Literal& actual) {
+/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
+ LiteralSlice actual) {
EXPECT_FALSE(Equal(expected, actual));
}
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
- const Literal& expected, const Literal& actual) {
+ LiteralSlice expected, LiteralSlice actual) {
VLOG(1) << "expected:";
XLA_VLOG_LINES(1, expected.ToString());
VLOG(1) << "actual:";
SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
ShapeUtil::HumanString(expected.shape())));
- // Create LiteralViews of the expected and actual elements.
- auto result = Equal(LiteralView::Create(expected, {i}),
- LiteralView::Create(actual, {i}));
+ // Create LiteralSlices of the expected and actual elements.
+ auto result =
+ Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
tuple_match = tuple_match ? !!result : false;
}
match = tuple_match;
// 3 minutes. The utility of printing a literal with >1000 elements is
// questionable, especially when writing the Literal proto to disk is orders
// of magnitude faster.
-string TruncateHugeLiteral(const Literal& literal) {
+string TruncateHugeLiteral(LiteralSlice literal) {
return RecursiveElementCount(literal.shape()) < 1000
? literal.ToString()
: "[TRUNCATED, Literal with more than 1000 values]";
// result. The assertion result is successful if all actual and expected
// elements are within the given error bound. In case of error, the assertion
// result contains a detailed error message in case of failure.
- static ::testing::AssertionResult Compare(const Literal& expected,
- const Literal& actual,
+ static ::testing::AssertionResult Compare(LiteralSlice expected,
+ LiteralSlice actual,
ErrorSpec error,
bool detailed_message) {
NearComparator<NativeT> comparator(expected, actual, error,
}
};
- explicit NearComparator(const Literal& expected, const Literal& actual,
+ explicit NearComparator(LiteralSlice expected, LiteralSlice actual,
ErrorSpec error, bool detailed_message)
: expected_(expected),
actual_(actual),
}
// Writes the given literal to a file in the test temporary directory.
- void WriteLiteralToTempFile(const Literal& literal, const string& name) {
+ void WriteLiteralToTempFile(LiteralSlice literal, const string& name) {
int64 now_usec = tensorflow::Env::Default()->NowMicros();
string filename = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(),
}
// 'actual' and 'expected' literals being compared.
- const Literal& expected_;
- const Literal& actual_;
+ LiteralSlice expected_;
+ LiteralSlice actual_;
// The error bounds of the comparison.
ErrorSpec error_;
// Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
-::testing::AssertionResult NearHelper(const Literal& expected,
- const Literal& actual,
+::testing::AssertionResult NearHelper(LiteralSlice expected,
+ LiteralSlice actual,
const ErrorSpec& error,
bool detailed_message,
const ShapeIndex& shape_index) {
if (ShapeUtil::IsTuple(expected.shape())) {
for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- const auto expected_element = LiteralView::Create(expected, {i});
- const auto actual_element = LiteralView::Create(actual, {i});
+ const auto expected_element = LiteralSlice(expected, {i});
+ const auto actual_element = LiteralSlice(actual, {i});
ShapeIndex element_index = shape_index;
element_index.push_back(i);
::testing::AssertionResult res =
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
- const Literal& expected, const Literal& actual, const ErrorSpec& error,
+ LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
bool detailed_message) {
return NearHelper(expected, actual, error, detailed_message,
/*shape_index=*/{});
}
-/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
- const Literal& actual,
+/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected,
+ LiteralSlice actual,
const ErrorSpec& error,
const string& message) {
::testing::AssertionResult res =
}
/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
- const Literal& expected, const Literal& actual,
+ LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
if (error.has_value()) {
VLOG(1) << "Expects near";
}
/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
- const Literal& expected, const Literal& actual,
+ LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
EXPECT_TRUE(NearOrEqual(expected, actual, error));
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
+ tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
int64 new_num_elements = 1;
for (int64 i = 0; i < new_dimensions.size(); ++i) {
new_num_elements *= new_dimensions[i];
// If the given literal's data type is bfloat16, converts it to a float
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal);
+ static std::unique_ptr<Literal> ConvertBF16ToF32(LiteralSlice bf16_literal);
// If the given literal's data type is float, converts it to a bfloat16
// literal; otherwise, returns a copy of it. If the literal is a tuple,
// recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal);
+ static std::unique_ptr<Literal> ConvertF32ToBF16(LiteralSlice f32_literal);
// Asserts that the expected and actual literals are (bitwise) equal for all
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
// primitive type are equal.
static ::testing::AssertionResult Equal(
- const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
+ LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT;
// Expects that expected and actual are Equal.
- static void ExpectEqual(const Literal& expected, const Literal& actual,
+ static void ExpectEqual(LiteralSlice expected, LiteralSlice actual,
const string& message = "");
// Expects that expected and actual are Not Equal.
- static void ExpectNotEqual(const Literal& expected, const Literal& actual);
+ static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual);
// Asserts the given literal are (bitwise) equal to given expected values.
template <typename NativeT>
- static void ExpectR0Equal(NativeT expected, const Literal& actual);
+ static void ExpectR0Equal(NativeT expected, LiteralSlice actual);
template <typename NativeT>
static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
- const Literal& actual);
+ LiteralSlice actual);
template <typename NativeT>
static void ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual);
+ LiteralSlice actual);
template <typename NativeT>
static void ExpectR3Equal(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual);
+ LiteralSlice actual);
// Asserts the given literal are (bitwise) equal to given array.
template <typename NativeT>
static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
- const Literal& actual);
+ LiteralSlice actual);
template <typename NativeT>
static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
- const Literal& actual);
+ LiteralSlice actual);
template <typename NativeT>
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
- const Literal& actual);
+ LiteralSlice actual);
// Asserts that the expected and actual literals are within the given error
// bound for all elements. Also, asserts that the rank, dimensions sizes, and
// If detailed_message is true, then the error message in the assertion result
// will contain a more detailed breakdown of mismatches.
static ::testing::AssertionResult Near(
- const Literal& expected, const Literal& actual, const ErrorSpec& error,
+ LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
bool detailed_message = false) TF_MUST_USE_RESULT;
// Expects expected and actual to be Near with the given error.
- static void ExpectNear(const Literal& expected, const Literal& actual,
+ static void ExpectNear(LiteralSlice expected, LiteralSlice actual,
const ErrorSpec& error, const string& message = "");
// Asserts the given literal are within the given error bound of the given
// expected values. Only supported for floating point values.
template <typename NativeT>
- static void ExpectR0Near(NativeT expected, const Literal& actual,
+ static void ExpectR0Near(NativeT expected, LiteralSlice actual,
const ErrorSpec& error);
template <typename NativeT>
static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
- const Literal& actual, const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
template <typename NativeT>
static void ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual, const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
template <typename NativeT>
static void ExpectR3Near(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual, const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
template <typename NativeT>
static void ExpectR4Near(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
- const Literal& actual, const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
// Asserts the given literal are within the given error bound to the given
// array. Only supported for floating point values.
template <typename NativeT>
static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
- const Literal& actual,
- const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
template <typename NativeT>
static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
- const Literal& actual,
- const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
template <typename NativeT>
static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
- const Literal& actual,
- const ErrorSpec& error);
+ LiteralSlice actual, const ErrorSpec& error);
// If the error spec is given, returns whether the expected and the actual are
// within the error bound; otherwise, returns whether they are equal. Tuples
// will be compared recursively.
static ::testing::AssertionResult NearOrEqual(
- const Literal& expected, const Literal& actual,
+ LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
// If the error spec is given, expects the expected and the actual to be near;
// otherwise, expects them to be equal. Tuples will be compared recursively.
static void ExpectNearOrEqual(
- const Literal& expected, const Literal& actual,
+ LiteralSlice expected, LiteralSlice actual,
const tensorflow::gtl::optional<ErrorSpec>& error);
// Returns a multi-dimensional index as a string. For example: '{7, 8}' will
// layout order.
static std::unique_ptr<Literal> Reshape(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major,
- const Literal& literal);
+ tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal);
// Creates a literal with the supplied shape, and uses the provided value
// generator to populate the literal's values.
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
- const Literal& actual) {
+ LiteralSlice actual) {
ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
- tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual) {
+ tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual) {
ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual) {
+ LiteralSlice actual) {
ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual);
}
/* static */ void LiteralTestUtil::ExpectR3Equal(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual) {
+ LiteralSlice actual) {
ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
- const Array2D<NativeT>& expected, const Literal& actual) {
+ const Array2D<NativeT>& expected, LiteralSlice actual) {
ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
- const Array3D<NativeT>& expected, const Literal& actual) {
+ const Array3D<NativeT>& expected, LiteralSlice actual) {
ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
- const Array4D<NativeT>& expected, const Literal& actual) {
+ const Array4D<NativeT>& expected, LiteralSlice actual) {
ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
- const Literal& actual,
+ LiteralSlice actual,
const ErrorSpec& error) {
ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
- tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual,
+ tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual,
const ErrorSpec& error) {
ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
- const Literal& actual, const ErrorSpec& error) {
+ LiteralSlice actual, const ErrorSpec& error) {
ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error);
}
/* static */ void LiteralTestUtil::ExpectR3Near(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
- const Literal& actual, const ErrorSpec& error) {
+ LiteralSlice actual, const ErrorSpec& error) {
ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error);
}
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
- const Literal& actual, const ErrorSpec& error) {
+ LiteralSlice actual, const ErrorSpec& error) {
ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
- const Array2D<NativeT>& expected, const Literal& actual,
+ const Array2D<NativeT>& expected, LiteralSlice actual,
const ErrorSpec& error) {
ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
- const Array3D<NativeT>& expected, const Literal& actual,
+ const Array3D<NativeT>& expected, LiteralSlice actual,
const ErrorSpec& error) {
ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error);
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
- const Array4D<NativeT>& expected, const Literal& actual,
+ const Array4D<NativeT>& expected, LiteralSlice actual,
const ErrorSpec& error) {
ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error);
}
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
{{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralView::Create(*result_literal, {1}));
+ LiteralSlice(*result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
LiteralTestUtil::ExpectR2Equal<float>(
{{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralView::Create(*result_literal, {0, 0}));
+ LiteralSlice(*result_literal, {0, 0}));
LiteralTestUtil::ExpectR2Equal<float>(
{{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralView::Create(*result_literal, {0, 1}));
+ LiteralSlice(*result_literal, {0, 1}));
LiteralTestUtil::ExpectR2Equal<float>(
{{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralView::Create(*result_literal, {0, 2}));
+ LiteralSlice(*result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
{{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralView::Create(*result_literal, {0}));
+ LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>(
- {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1}));
+ {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR2Equal<float>(
- {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0}));
+ {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>(
- {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1}));
+ {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
LiteralTestUtil::ExpectR2Equal<float>(
{{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralView::Create(*result_0_literal, {0}));
+ LiteralSlice(*result_0_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1}));
+ {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0}));
+ {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0}));
LiteralTestUtil::ExpectR2Equal<float>(
- {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1}));
+ {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}),
+ {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}),
error_spec_);
}
}
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}),
+ i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
error_spec_);
}
}
index.push_back(0);
}
LiteralTestUtil::ExpectR0Equal<float>(
- 165.0, LiteralView::Create(*result_literal, index));
+ 165.0, LiteralSlice(*result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
LiteralTestUtil::ExpectR1Equal<float>(
- {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0}));
+ {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0}));
LiteralTestUtil::ExpectR1Equal<float>(
- {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1}));
+ {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {