xla::Literal lax;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax));
std::vector<bool> revdims(x_shape.dims());
- std::copy(lax.preds().begin(), lax.preds().end(), revdims.begin());
+ std::copy(lax.data<bool>().begin(), lax.data<bool>().end(),
+ revdims.begin());
std::vector<int64> dimensions;
for (int d = 0; d < x_shape.dims(); ++d) {
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
- int dim = literal.s32s(0);
+ int dim = literal.data<int32>()[0];
OP_REQUIRES(ctx,
(dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
std::vector<int32> perm(dims);
- std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin());
+ std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
+ perm.begin());
std::vector<int64> transposed_order;
// Check whether permutation is a permutation of integers of [0 .. dims).
namespace tensorflow {
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
- literal->Clear();
+ xla::Shape literal_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
- host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape()));
+ host_tensor.dtype(), host_tensor.shape(), &literal_shape));
- literal->Reserve(host_tensor.NumElements());
+ *literal = xla::Literal(literal_shape);
// memcpy over the payload ...
// TODO(phawkins): handle string types.
size_t total_bytes = host_tensor.TotalBytes();
if (total_bytes > 0) {
- void* dst_ptr = literal->MutableInternalData();
+ void* dst_ptr = literal->untyped_data();
const void* src_ptr = DMAHelper::base(&host_tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
*host_tensor = Tensor(target_type, shape);
size_t total_bytes = host_tensor->TotalBytes();
if (total_bytes > 0) {
- const void* src_ptr = literal.InternalData();
+ const void* src_ptr = literal.untyped_data();
void* dst_ptr = DMAHelper::base(host_tensor);
memcpy(dst_ptr, src_ptr, total_bytes);
}
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
switch (type) {
case xla::U8:
- literal = *xla::Literal::CreateR0<uint8>(value);
+ literal = std::move(*xla::Literal::CreateR0<uint8>(value));
break;
case xla::U32:
- literal = *xla::Literal::CreateR0<uint32>(value);
+ literal = std::move(*xla::Literal::CreateR0<uint32>(value));
break;
case xla::U64:
- literal = *xla::Literal::CreateR0<uint64>(value);
+ literal = std::move(*xla::Literal::CreateR0<uint64>(value));
break;
case xla::S8:
- literal = *xla::Literal::CreateR0<int8>(value);
+ literal = std::move(*xla::Literal::CreateR0<int8>(value));
break;
case xla::S32:
- literal = *xla::Literal::CreateR0<int32>(value);
+ literal = std::move(*xla::Literal::CreateR0<int32>(value));
break;
case xla::S64:
- literal = *xla::Literal::CreateR0<int64>(value);
+ literal = std::move(*xla::Literal::CreateR0<int64>(value));
break;
case xla::F32:
- literal = *xla::Literal::CreateR0<float>(value);
+ literal = std::move(*xla::Literal::CreateR0<float>(value));
break;
case xla::F64:
- literal = *xla::Literal::CreateR0<double>(value);
+ literal = std::move(*xla::Literal::CreateR0<double>(value));
break;
case xla::C64:
- literal = *xla::Literal::CreateR0<complex64>(value);
+ literal = std::move(*xla::Literal::CreateR0<complex64>(value));
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
- literal = *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value));
+ literal = std::move(
+ *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
break;
case xla::F16:
- literal =
- *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value));
+ literal = std::move(
+ *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
"elements.");
}
- *output = input;
- output->mutable_shape()->Swap(&shape);
+ *output = input.Clone();
+ output->mutable_shape_do_not_use()->Swap(&shape);
return Status::OK();
}
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
switch (literal.shape().element_type()) {
- case xla::S32:
- out->Clear();
- *out->mutable_shape() = literal.shape();
- out->mutable_shape()->set_element_type(xla::S64);
- for (int32 x : literal.s32s()) {
- out->add_s64s(x);
+ case xla::S32: {
+ *out = xla::Literal(
+ xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
+ auto src_data = literal.data<int32>();
+ for (int64 i = 0; i < src_data.size(); ++i) {
+ out->data<int64>()[i] = src_data[i];
}
return Status::OK();
-
+ }
case xla::S64:
*out = std::move(literal);
return Status::OK();
":array2d",
":array3d",
":array4d",
+ ":shape_tree",
":shape_util",
":status_macros",
":types",
"server provided response without a literal in "
"TransferToClient request");
}
- return MakeUnique<Literal>(response.literal());
+ return Literal::CreateFromProto(*response.mutable_literal());
}
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
"TransferToClient request");
}
- return MakeUnique<Literal>(response.literal());
+ return Literal::CreateFromProto(response.literal());
}
Status Client::ResetDevice() {
return true;
}
-ComputationDataHandle ComputationBuilder::ConstantOp(
- const PopulateLiteral& populate) {
+ComputationDataHandle ComputationBuilder::ConstantLiteral(
+ const Literal& literal) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
ConstantRequest request;
- Literal literal;
- populate(&literal);
*request.mutable_literal() = literal.ToProto();
VLOG(3) << "created constant: " << request.literal().ShortDebugString();
OpRequest op_request;
return ParseOpResponse(s, &response);
}
-ComputationDataHandle ComputationBuilder::ConstantLiteral(
- const Literal& literal) {
- return ConstantOp(
- [literal](Literal* mutable_literal) { *mutable_literal = literal; });
-}
-
ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
const Shape& shape,
const string& name) {
"no computed literal in the provided response in ComputeConstant "
"request");
}
- return MakeUnique<Literal>(response.literal());
+ return Literal::CreateFromProto(response.literal());
}
ComputationDataHandle ComputationBuilder::Map(
Status first_error() const { return first_error_; }
private:
- using PopulateLiteral = std::function<void(Literal*)>;
-
// Limited checking of convolution parameters. Returns false on
// error.
bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
Window* window);
- // Internal helper method that makes a request for a constant operation -- the
- // provided function is used to populate the literal before sending the
- // request.
- ComputationDataHandle ConstantOp(const PopulateLiteral& populate);
-
// Internal helper method that does the building for an arbitrary unary op.
ComputationDataHandle UnaryOp(UnaryOperation binop,
const ComputationDataHandle& operand);
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
- return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); });
+ return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateR1(values); });
+ return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
NativeT value) {
- return ConstantOp([length, value](Literal* literal) {
- literal->PopulateWithValue(value, {length});
- });
+ Literal literal(ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
+ literal.PopulateWithValue(value);
+ return ConstantLiteral(literal);
}
inline ComputationDataHandle ComputationBuilder::ConstantR1(
const tensorflow::core::Bitmap& values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateR1(values); });
+ return ConstantLiteral(*Literal::CreateR1(values));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateR2(values); });
+ return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- return ConstantOp([&values, &layout](Literal* literal) {
- literal->PopulateFromArrayWithLayout(values, layout);
- });
+ return ConstantLiteral(
+ *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArray(
const Array<NativeT>& values) {
- return ConstantOp(
- [&values](Literal* literal) { literal->PopulateFromArray(values); });
+ return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
- return ConstantFromArrayWithLayout(values, layout);
+ return ConstantLiteral(
+ *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
const Array2D<NativeT>& values) {
- return ConstantFromArray(values);
+ return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
- return ConstantFromArrayWithLayout(values, layout);
+ return ConstantLiteral(
+ *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+
+using tensorflow::strings::Printf;
+using tensorflow::strings::StrCat;
+
+namespace xla {
+
namespace {
-using tensorflow::int64;
constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
std::swap(bytes[i], bytes[i + 1]);
}
}
-} // namespace
-namespace xla {
+} // namespace
std::ostream& operator<<(std::ostream& out, const Literal& literal) {
out << literal.ToString();
}
}
+Literal::Literal(const Shape& shape)
+ : Literal(shape, /*allocate_arrays=*/true) {}
+
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+ : shape_(shape), pieces_(shape), owns_buffers_(true) {
+ CHECK(LayoutUtil::HasLayout(shape));
+ for (auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+
+ piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+ if (ShapeUtil::IsArray(piece.subshape())) {
+ if (allocate_arrays) {
+ piece.set_buffer(new char[piece.size_bytes()]);
+ } else {
+ piece.set_buffer(nullptr);
+ }
+ }
+ }
+}
+
+Literal::~Literal() { DeallocateBuffers(); }
+
+void Literal::DeallocateBuffers() {
+ if (owns_buffers_) {
+ for (auto& pair : pieces_) {
+ Piece& piece = pair.second;
+ if (piece.buffer() != nullptr) {
+ delete[] piece.buffer();
+ }
+ }
+ }
+}
+
+Literal::Literal(Literal&& other) {
+ shape_ = std::move(other.shape_);
+ pieces_ = std::move(other.pieces_);
+ // We need to iterate through the pieces to set the subshape pointer
+ // properly. It must refer to subshapes within shape_.
+ for (auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+ piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+ }
+ owns_buffers_ = other.owns_buffers_;
+
+ other.shape_ = ShapeUtil::MakeNil();
+ other.pieces_ = ShapeTree<Piece>(other.shape_);
+ other.piece({}).set_subshape(&other.shape_);
+}
+
+Literal& Literal::operator=(Literal&& other) {
+ DeallocateBuffers();
+ shape_ = std::move(other.shape_);
+ pieces_ = std::move(other.pieces_);
+ // We need to iterate through the pieces to set the subshape pointer
+ // properly. It must refer to subshapes within shape_.
+ for (auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+ piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+ }
+ owns_buffers_ = other.owns_buffers_;
+
+ other.shape_ = ShapeUtil::MakeNil();
+ other.pieces_ = ShapeTree<Piece>(other.shape_);
+ other.piece({}).set_subshape(&other.shape_);
+ return *this;
+}
+
std::unique_ptr<Literal> Literal::CreateFromShape(const Shape& shape) {
- auto literal = MakeUnique<Literal>();
- *literal->mutable_shape() = shape;
- if (ShapeUtil::IsTuple(shape)) {
- int64 num_elements = ShapeUtil::TupleElementCount(shape);
- literal->tuple_literals_.resize(num_elements);
- for (int i = 0; i < num_elements; ++i) {
- std::unique_ptr<Literal> elem =
- CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i));
- literal->tuple_literals_[i] = std::move(*elem);
+ auto literal = MakeUnique<Literal>(shape);
+ for (auto& pair : literal->pieces_) {
+ Piece& piece = pair.second;
+ if (ShapeUtil::IsArray(piece.subshape())) {
+ memset(piece.untyped_data(), 0, piece.size_bytes());
}
- } else {
- literal->Reserve(ShapeUtil::ElementsIn(literal->shape()));
}
return literal;
}
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
}
-template <typename T>
-Status Literal::CopyRange(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
- const Shape& src_shape = src_literal.shape();
- const Shape& dest_shape = shape();
- tensorflow::gtl::ArraySlice<T> src_data = src_literal.GetArraySlice<T>();
- tensorflow::gtl::MutableArraySlice<T> dest_data = GetMutableArraySlice<T>();
-
- TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size());
- TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size());
+template <typename NativeT>
+Status Literal::CopySliceFromInternal(
+ const Literal& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
+ TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
+
+ auto linear_index = [](const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
+ };
- if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) {
+ if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
+ ShapeUtil::Rank(shape()) == 0) {
// If any of the two shapes are scalars, we can just call the StridedCopy()
// directly, and we know we will be copying only one value.
TF_RET_CHECK(copy_size.empty());
- StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data,
- src_literal.LinearIndex(src_base), 0, 1);
- } else if (!ShapeUtil::HasZeroElements(dest_shape) &&
- !ShapeUtil::HasZeroElements(src_shape)) {
- // Perform copy if neither src literal nor dest literal has dimensions with
- // zero element, otherwise it's a no-op.
+ StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
+ src_literal.data<NativeT>(),
+ linear_index(src_literal.shape(), src_base), 0, 1);
+ } else if (!ShapeUtil::HasZeroElements(shape()) &&
+ !ShapeUtil::HasZeroElements(src_literal.shape())) {
+ // Perform copy if neither src nor dest has dimensions with zero element,
+ // otherwise it's a no-op.
TF_RET_CHECK(src_base.size() == dest_base.size());
TF_RET_CHECK(src_base.size() == copy_size.size());
// proper stride size at the matching dimension.
DimensionVector src_indexes(src_base.size(), 0);
DimensionVector dest_indexes(dest_base.size(), 0);
- StrideConfig stride_config(src_shape, dest_shape, copy_size);
+ Literal::StrideConfig stride_config(src_literal.shape(), shape(),
+ copy_size);
auto copy_proc = [&](const std::vector<int64>& indexes) {
// Map from multi-dimensional index, to source index.
std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
dest_indexes.begin(), std::plus<int64>());
- int64 src_index = src_literal.LinearIndex(src_indexes);
- int64 dest_index = LinearIndex(dest_indexes);
+ int64 src_index = linear_index(src_literal.shape(), src_indexes);
+ int64 dest_index = linear_index(shape(), dest_indexes);
- StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data,
- src_index, stride_config.source_stride,
- stride_config.minor_loop_size);
+ StridedCopy(data<NativeT>(), dest_index, stride_config.dest_stride,
+ src_literal.data<NativeT>(), src_index,
+ stride_config.source_stride, stride_config.minor_loop_size);
return true;
};
- ShapeUtil::ForEachIndex(src_shape, stride_config.base,
+ ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
stride_config.dimensions, stride_config.step,
copy_proc);
}
return Status::OK();
}
-Status Literal::Copy(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
+std::vector<Literal> Literal::DecomposeTuple() {
+ CHECK(ShapeUtil::IsTuple(shape()));
+ std::vector<Literal> elements;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+ elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
+ /*allocate_arrays=*/false));
+ Literal& element = elements.back();
+ for (auto& pair : element.pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& dest_piece = pair.second;
+ ShapeIndex src_index = {i};
+ for (int64 j : index) {
+ src_index.push_back(j);
+ }
+ Piece& src_piece = piece(src_index);
+
+ // Move the respective buffer over to the element Literal.
+ dest_piece.set_buffer(src_piece.buffer());
+ src_piece.set_buffer(nullptr);
+ }
+ }
+ // Set this literal to be nil-shaped.
+ *this = Literal();
+ return elements;
+}
+
+/* static */ Literal Literal::MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements) {
+ std::vector<Shape> element_shapes;
+ for (const Literal& element : elements) {
+ element_shapes.push_back(element.shape());
+ }
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
+ /*allocate_arrays=*/false);
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
+ }
+ return literal;
+}
+
+namespace {
+
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
+// the array slices are indicated by dest_shape and src_shape respectively.
+template <typename NativeT>
+void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+ tensorflow::gtl::ArraySlice<NativeT> src,
+ const Shape& dest_shape, const Shape& src_shape) {
+ CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
+ if (ShapeUtil::HasZeroElements(dest_shape)) {
+ return;
+ }
+ std::vector<int64> index(ShapeUtil::Rank(dest_shape));
+ do {
+ dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
+ src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
+ } while (IndexUtil::BumpIndices(dest_shape, &index));
+}
+
+} // namespace
+
+Status Literal::Piece::CopyFrom(const Literal::Piece& src) {
+ if (ShapeUtil::Equal(subshape(), src.subshape())) {
+ // If the layouts are equal it's faster just to memcpy.
+ memcpy(buffer(), src.buffer(), src.size_bytes());
+ } else {
+ TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
+ std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
+ switch (subshape().element_type()) {
+#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
+ case (XLA_T): \
+ CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
+ subshape(), src.subshape()); \
+ break;
+ COPY_ELEMENTS(U8, uint8);
+ COPY_ELEMENTS(U16, uint16);
+ COPY_ELEMENTS(U32, uint32);
+ COPY_ELEMENTS(U64, uint64);
+ COPY_ELEMENTS(S8, int8);
+ COPY_ELEMENTS(S16, int16);
+ COPY_ELEMENTS(S32, int32);
+ COPY_ELEMENTS(S64, int64);
+ COPY_ELEMENTS(F16, half);
+ COPY_ELEMENTS(BF16, bfloat16);
+ COPY_ELEMENTS(F32, float);
+ COPY_ELEMENTS(F64, double);
+ COPY_ELEMENTS(C64, complex64);
+ COPY_ELEMENTS(PRED, bool);
+#undef COPY_ELEMENTS
+ default:
+ return Unimplemented(
+ "Unhandled primitive type %s",
+ PrimitiveType_Name(subshape().element_type()).c_str());
+ }
+ }
+ return Status::OK();
+}
+
+Status Literal::CopyFrom(const Literal& src_literal,
+ const ShapeIndex& dest_shape_index,
+ const ShapeIndex& src_shape_index) {
+ const Shape& dest_subshape =
+ ShapeUtil::GetSubshape(shape(), dest_shape_index);
+ const Shape& src_subshape =
+ ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
+ if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
+ return InvalidArgument(
+ "Destination subshape incompatible with source subshape: %s vs %s",
+ ShapeUtil::HumanString(dest_subshape).c_str(),
+ ShapeUtil::HumanString(src_subshape).c_str());
+ }
+
+ for (auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ continue;
+ }
+
+ // Determine if this index is in the part of this literal that we want to
+ // copy over from src_literal.
+ bool in_subtree_to_copy = true;
+ for (int i = 0; i < dest_shape_index.size(); ++i) {
+ if (index[i] != dest_shape_index[i]) {
+ in_subtree_to_copy = false;
+ break;
+ }
+ }
+ if (!in_subtree_to_copy) {
+ continue;
+ }
+
+ // Construct the index of the corresponding piece in the source literal.
+ ShapeIndex src_piece_index = src_shape_index;
+ for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+ src_piece_index.push_back(index[i]);
+ }
+
+ TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index)));
+ }
+ return Status::OK();
+}
+
+Status Literal::MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index) {
+ const Shape& dest_subshape =
+ ShapeUtil::GetSubshape(shape(), dest_shape_index);
+ if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
+ return InvalidArgument(
+ "Destination subshape not equal to source shape: %s vs %s",
+ ShapeUtil::HumanString(dest_subshape).c_str(),
+ ShapeUtil::HumanString(src_literal.shape()).c_str());
+ }
+
+ if (!(owns_buffers_ && src_literal.owns_buffers_)) {
+ return InvalidArgument(
+ "Source and destination literals must both own their buffers (ie, not "
+ "be views)");
+ }
+
+ for (auto& pair : src_literal.pieces_) {
+ const ShapeIndex& src_index = pair.first;
+ Piece& src_piece = pair.second;
+ if (!ShapeUtil::IsArray(src_piece.subshape())) {
+ continue;
+ }
+
+ ShapeIndex dest_index = dest_shape_index;
+ for (int64 i : src_index) {
+ dest_index.push_back(i);
+ }
+ Piece& dest_piece = piece(dest_index);
+ delete[] dest_piece.buffer();
+ dest_piece.set_buffer(src_piece.buffer());
+ }
+
+ src_literal.shape_ = ShapeUtil::MakeNil();
+ src_literal.pieces_ = ShapeTree<Piece>(src_literal.shape_);
+ src_literal.piece({}).set_subshape(&src_literal.shape_);
+ return Status::OK();
+}
+
+Status Literal::CopySliceFrom(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
+ TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
+ << ShapeUtil::HumanString(src_literal.shape());
TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
- switch (src_literal.shape().element_type()) {
+
+ switch (shape().element_type()) {
case U8:
- return CopyRange<uint8>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
+ copy_size);
case U16:
- return CopyRange<uint16>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
+ copy_size);
case U32:
- return CopyRange<uint32>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
+ copy_size);
case U64:
- return CopyRange<uint64>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
+ copy_size);
case S8:
- return CopyRange<int8>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
+ copy_size);
case S16:
- return CopyRange<int16>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
+ copy_size);
case S32:
- return CopyRange<int32>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
+ copy_size);
case S64:
- return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
+ copy_size);
case F16:
- return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
+ copy_size);
case BF16:
- return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
+ copy_size);
case F32:
- return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
+ copy_size);
case F64:
- return CopyRange<double>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
+ copy_size);
case C64:
- return CopyRange<complex64>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
+ copy_size);
case PRED:
- return CopyRange<bool>(src_literal, src_base, dest_base, copy_size);
+ return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
+ copy_size);
default:
break;
}
- return Unimplemented("Unhandled primitive type %d",
- src_literal.shape().element_type());
+ return Unimplemented("Unhandled primitive type %d", shape().element_type());
}
/* static */ Literal Literal::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return *Literal::CreateR0<uint8>(0);
+ return std::move(*Literal::CreateR0<uint8>(0));
case U32:
- return *Literal::CreateR0<uint32>(0);
+ return std::move(*Literal::CreateR0<uint32>(0));
case U64:
- return *Literal::CreateR0<uint64>(0);
+ return std::move(*Literal::CreateR0<uint64>(0));
case S8:
- return *Literal::CreateR0<int8>(0);
+ return std::move(*Literal::CreateR0<int8>(0));
case S32:
- return *Literal::CreateR0<int32>(0);
+ return std::move(*Literal::CreateR0<int32>(0));
case S64:
- return *Literal::CreateR0<int64>(0);
+ return std::move(*Literal::CreateR0<int64>(0));
case F16:
- return *Literal::CreateR0<half>(static_cast<half>(0.0f));
+ return std::move(*Literal::CreateR0<half>(static_cast<half>(0.0f)));
case BF16:
- return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
+ return std::move(
+ *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
case F32:
- return *Literal::CreateR0<float>(0);
+ return std::move(*Literal::CreateR0<float>(0));
case F64:
- return *Literal::CreateR0<double>(0);
+ return std::move(*Literal::CreateR0<double>(0));
case C64:
- return *Literal::CreateR0<complex64>(0);
+ return std::move(*Literal::CreateR0<complex64>(0));
case PRED:
- return *Literal::CreateR0<bool>(false);
+ return std::move(*Literal::CreateR0<bool>(false));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
/* static */ Literal Literal::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return *Literal::CreateR0<uint8>(1);
+ return std::move(*Literal::CreateR0<uint8>(1));
case U32:
- return *Literal::CreateR0<uint32>(1);
+ return std::move(*Literal::CreateR0<uint32>(1));
case U64:
- return *Literal::CreateR0<uint64>(1);
+ return std::move(*Literal::CreateR0<uint64>(1));
case S8:
- return *Literal::CreateR0<int8>(1);
+ return std::move(*Literal::CreateR0<int8>(1));
case S32:
- return *Literal::CreateR0<int32>(1);
+ return std::move(*Literal::CreateR0<int32>(1));
case S64:
- return *Literal::CreateR0<int64>(1);
+ return std::move(*Literal::CreateR0<int64>(1));
case F16:
- return *Literal::CreateR0<half>(static_cast<half>(1.0f));
+ return std::move(*Literal::CreateR0<half>(static_cast<half>(1.0f)));
case BF16:
- return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
+ return std::move(
+ *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
case F32:
- return *Literal::CreateR0<float>(1);
+ return std::move(*Literal::CreateR0<float>(1));
case F64:
- return *Literal::CreateR0<double>(1);
+ return std::move(*Literal::CreateR0<double>(1));
case C64:
- return *Literal::CreateR0<complex64>(1);
+ return std::move(*Literal::CreateR0<complex64>(1));
case PRED:
- return *Literal::CreateR0<bool>(true);
+ return std::move(*Literal::CreateR0<bool>(true));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min());
+ return std::move(
+ *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
case U32:
- return *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min());
+ return std::move(
+ *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
case U64:
- return *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min());
+ return std::move(
+ *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
case S8:
- return *Literal::CreateR0<int8>(std::numeric_limits<int8>::min());
+ return std::move(
+ *Literal::CreateR0<int8>(std::numeric_limits<int8>::min()));
case S32:
- return *Literal::CreateR0<int32>(std::numeric_limits<int32>::min());
+ return std::move(
+ *Literal::CreateR0<int32>(std::numeric_limits<int32>::min()));
case S64:
- return *Literal::CreateR0<int64>(std::numeric_limits<int64>::min());
+ return std::move(
+ *Literal::CreateR0<int64>(std::numeric_limits<int64>::min()));
case F32:
- return *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity());
+ return std::move(
+ *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity()));
case F64:
- return *Literal::CreateR0<double>(
- -std::numeric_limits<double>::infinity());
+ return std::move(
+ *Literal::CreateR0<double>(-std::numeric_limits<double>::infinity()));
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return *Literal::CreateR0<bool>(false);
+ return std::move(*Literal::CreateR0<bool>(false));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return *Literal::CreateR0<half>(
- static_cast<half>(-std::numeric_limits<float>::infinity()));
+ return std::move(*Literal::CreateR0<half>(
+ static_cast<half>(-std::numeric_limits<float>::infinity())));
case BF16:
- return *Literal::CreateR0<bfloat16>(
- static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
+ return std::move(*Literal::CreateR0<bfloat16>(
+ static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
case OPAQUE:
/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max());
+ return std::move(
+ *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
case U32:
- return *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max());
+ return std::move(
+ *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
case U64:
- return *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max());
+ return std::move(
+ *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
case S8:
- return *Literal::CreateR0<int8>(std::numeric_limits<int8>::max());
+ return std::move(
+ *Literal::CreateR0<int8>(std::numeric_limits<int8>::max()));
case S32:
- return *Literal::CreateR0<int32>(std::numeric_limits<int32>::max());
+ return std::move(
+ *Literal::CreateR0<int32>(std::numeric_limits<int32>::max()));
case S64:
- return *Literal::CreateR0<int64>(std::numeric_limits<int64>::max());
+ return std::move(
+ *Literal::CreateR0<int64>(std::numeric_limits<int64>::max()));
case F32:
- return *Literal::CreateR0<float>(std::numeric_limits<float>::infinity());
+ return std::move(
+ *Literal::CreateR0<float>(std::numeric_limits<float>::infinity()));
case F64:
- return *Literal::CreateR0<double>(
- std::numeric_limits<double>::infinity());
+ return std::move(
+ *Literal::CreateR0<double>(std::numeric_limits<double>::infinity()));
case PRED:
- return *Literal::CreateR0<bool>(true);
+ return std::move(*Literal::CreateR0<bool>(true));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return *Literal::CreateR0<half>(
- static_cast<half>(std::numeric_limits<float>::infinity()));
+ return std::move(*Literal::CreateR0<half>(
+ static_cast<half>(std::numeric_limits<float>::infinity())));
case BF16:
- return *Literal::CreateR0<bfloat16>(
- static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
+ return std::move(*Literal::CreateR0<bfloat16>(
+ static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
case OPAQUE:
/* static */ std::unique_ptr<Literal> Literal::CreateR1(
const tensorflow::core::Bitmap& values) {
- auto literal = MakeUnique<Literal>();
+ auto literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
literal->PopulateR1(values);
return literal;
}
+void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(element_count(), values.bits());
+ CHECK_EQ(shape().element_type(), PRED);
+ for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
+ Set({i}, values.get(i));
+ }
+}
+
/* static */ std::unique_ptr<Literal> Literal::CreateR1U8(
tensorflow::StringPiece value) {
- auto literal = MakeUnique<Literal>();
- *literal->mutable_shape() =
- ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
- literal->set_u8s(tensorflow::StringPiece(value.ToString()));
+ auto literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+ for (int i = 0; i < value.size(); ++i) {
+ literal->Set<uint8>({i}, value[i]);
+ }
return literal;
}
std::unique_ptr<Literal> Literal::Relayout(
const Layout& new_layout, const ShapeIndex& shape_index) const {
- std::unique_ptr<Literal> outer_result = CloneToUnique();
-
- const Literal* copy_from = this;
- Literal* copy_to = outer_result.get();
- for (int64 i = 0; i < shape_index.size(); i++) {
- *ShapeUtil::GetMutableSubshape(copy_to->mutable_shape(), {shape_index, i})
- ->mutable_layout() = new_layout;
- copy_from = ©_from->tuple_literals_[shape_index[i]];
- copy_to = ©_to->tuple_literals_[shape_index[i]];
- }
-
- DimensionVector base(ShapeUtil::Rank(copy_from->shape()), 0);
- DimensionVector copy_size(copy_from->shape().dimensions().begin(),
- copy_from->shape().dimensions().end());
-
- CHECK(ShapeUtil::IsArray(copy_from->shape()));
- CHECK(ShapeUtil::IsArray(copy_to->shape()));
- *copy_to->mutable_shape()->mutable_layout() = new_layout;
- TF_CHECK_OK(copy_to->Copy(*copy_from, base, base, copy_size));
- return outer_result;
+ // Create new shape with 'new_layout' set at the given shape index.
+ Shape new_shape = shape();
+ Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
+ TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
+ *subshape->mutable_layout() = new_layout;
+ auto result = MakeUnique<Literal>(new_shape);
+ TF_CHECK_OK(result->CopyFrom(*this));
+ return result;
}
std::unique_ptr<Literal> Literal::Relayout(
result->shape(),
[this, &result](const Shape& subshape, const ShapeIndex& index) {
if (ShapeUtil::IsArray(subshape)) {
- DimensionVector base(ShapeUtil::Rank(subshape), 0);
- DimensionVector copy_size(subshape.dimensions().begin(),
- subshape.dimensions().end());
- TF_CHECK_OK(result->GetSubliteral(index).Copy(GetSubliteral(index),
- base, base, copy_size));
+ TF_CHECK_OK(result->CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
}
});
return result;
StatusOr<std::unique_ptr<Literal>> Literal::Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const {
- if (ShapeUtil::IsTuple(shape())) {
+ if (!ShapeUtil::IsArray(shape())) {
return InvalidArgument("Reshape does not support tuples.");
}
std::unique_ptr<Literal> output;
}
// Because the layout is monotonic, we can simply reuse the same sequence of
// values without changing their order.
- *output->mutable_shape() =
- ShapeUtil::MakeShape(shape().element_type(), dimensions);
+ output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions);
int64 elements_before = ShapeUtil::ElementsIn(shape());
int64 elements_after = ShapeUtil::ElementsIn(output->shape());
std::unique_ptr<Literal> Literal::Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const {
- CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose";
+ CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
<< "Given permutation is not a permutation of dimension numbers";
// To transpose the array, we just permute the dimensions and layout, and
std::unique_ptr<Literal> new_literal = CreateFromShape(permuted_shape);
DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()),
ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->MutableInternalData(), InternalData(),
- ShapeUtil::ByteSizeOf(shape()));
+ std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(),
+ root_piece().size_bytes());
return new_literal;
}
std::unique_ptr<Literal> Literal::Slice(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) const {
- CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape";
+ CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
DimensionVector result_dimensions;
for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
- auto result_literal = MakeUnique<Literal>();
- *result_literal->mutable_shape() = result_shape;
- result_literal->Reserve(ShapeUtil::ElementsIn(result_shape));
+ auto result_literal = MakeUnique<Literal>(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
switch (result_shape.element_type()) {
}
}
+Literal Literal::Clone() const {
+ Literal result(shape());
+ TF_CHECK_OK(result.CopyFrom(*this));
+ return result;
+}
+
std::unique_ptr<Literal> Literal::CloneToUnique() const {
- auto unique = MakeUnique<Literal>();
- *unique = *this;
- return unique;
+ auto result = MakeUnique<Literal>(shape());
+ TF_CHECK_OK(result->CopyFrom(*this));
+ return result;
}
-string Literal::GetAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- switch (shape().element_type()) {
+string Literal::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
+ const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ switch (subshape.element_type()) {
case PRED:
- return Get<bool>(multi_index) ? "true" : "false";
+ return Get<bool>(multi_index, shape_index) ? "true" : "false";
case U8:
- return tensorflow::strings::StrCat(Get<uint8>(multi_index));
+ return StrCat(Get<uint8>(multi_index, shape_index));
case S32:
- return tensorflow::strings::StrCat(Get<int32>(multi_index));
+ return StrCat(Get<int32>(multi_index, shape_index));
case S64:
- return tensorflow::strings::StrCat(Get<int64>(multi_index));
+ return StrCat(Get<int64>(multi_index, shape_index));
case U32:
- return tensorflow::strings::StrCat(Get<uint32>(multi_index));
+ return StrCat(Get<uint32>(multi_index, shape_index));
case U64:
- return tensorflow::strings::StrCat(Get<uint64>(multi_index));
+ return StrCat(Get<uint64>(multi_index, shape_index));
case F32:
- return tensorflow::strings::StrCat(Get<float>(multi_index));
+ return StrCat(Get<float>(multi_index, shape_index));
case F64:
- return tensorflow::strings::StrCat(Get<double>(multi_index));
+ return StrCat(Get<double>(multi_index, shape_index));
case C64: {
- complex64 c = Get<complex64>(multi_index);
- return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")");
+ complex64 c = Get<complex64>(multi_index, shape_index);
+ return StrCat("(", c.real(), ", ", c.imag(), ")");
}
case F16:
- return tensorflow::strings::StrCat(Get<half>(multi_index));
+ return StrCat(Get<half>(multi_index, shape_index));
case BF16:
- return tensorflow::strings::StrCat(
- static_cast<float>(Get<bfloat16>(multi_index)));
+ return StrCat(
+ static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
default:
- return tensorflow::strings::StrCat(
- "[", PrimitiveType_Name(shape().element_type()), "]");
+ return StrCat("[", PrimitiveType_Name(subshape.element_type()), "]");
}
}
}
}
-int64 Literal::LinearIndex(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
-}
+namespace {
-string Literal::ToString(bool print_layout) const {
- std::vector<string> pieces;
+void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index,
+ bool print_layout, std::vector<string>* pieces) {
+ const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
auto shape_to_string = [print_layout](const Shape& shape) {
if (print_layout) {
};
auto element_to_string =
- [this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
- PrimitiveType element_type = shape().element_type();
+ [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ PrimitiveType element_type = subshape.element_type();
if (element_type == PRED) {
// We display predicates in a densely packed form.
- return Get<bool>(indices) ? "1" : "0";
+ return literal.Get<bool>(indices, shape_index) ? "1" : "0";
}
return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
- GetAsString(indices);
+ literal.GetAsString(indices, shape_index);
};
// TODO(b/32894291): refactor this code to reduce code duplication.
- if (ShapeUtil::IsTuple(shape())) {
- pieces.push_back(shape_to_string(shape()));
- pieces.push_back(" (\n");
- pieces.push_back(tensorflow::str_util::Join(
- tuple_literals(), ",\n", [](string* out, const Literal& element) {
- tensorflow::strings::StrAppend(out, element.ToString());
- }));
- pieces.push_back("\n)");
- } else if (ShapeUtil::Rank(shape()) == 0) {
- pieces.push_back(GetAsString({}));
- } else if (ShapeUtil::Rank(shape()) == 1) {
- pieces.push_back("{");
- for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
- pieces.push_back(element_to_string({i0}));
+ if (ShapeUtil::IsTuple(subshape)) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" (\n");
+ std::vector<string> tuple_pieces;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
+ ShapeIndex element_index = shape_index;
+ element_index.push_back(i);
+ std::vector<string> element_pieces;
+ ToStringHelper(literal, element_index, print_layout, &element_pieces);
+ tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
+ }
+ pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
+ pieces->push_back("\n)");
+ } else if (ShapeUtil::Rank(subshape) == 0) {
+ pieces->push_back(literal.GetAsString({}, shape_index));
+ } else if (ShapeUtil::Rank(subshape) == 1) {
+ pieces->push_back("{");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(element_to_string({i0}));
}
- pieces.push_back("}");
- } else if (ShapeUtil::Rank(shape()) == 2) {
- pieces.push_back(shape_to_string(shape()));
- pieces.push_back(" {\n");
- for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
- pieces.push_back(" { ");
- for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
- pieces.push_back(element_to_string({i0, i1}));
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 2) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(" { ");
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(element_to_string({i0, i1}));
}
- pieces.push_back(" ");
- pieces.push_back(i0 == shape().dimensions(0) - 1 ? "}\n" : "},\n");
+ pieces->push_back(" ");
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
}
- pieces.push_back("}");
- } else if (ShapeUtil::Rank(shape()) == 3) {
- pieces.push_back(shape_to_string(shape()));
- pieces.push_back(" {\n");
- for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
- pieces.push_back(i0 > 0 ? ",\n{" : "{");
- for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
- pieces.push_back(i1 > 0 ? ",\n { " : " { ");
- for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) {
- pieces.push_back(element_to_string({i0, i1, i2}));
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 3) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(i0 > 0 ? ",\n{" : "{");
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(i1 > 0 ? ",\n { " : " { ");
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(element_to_string({i0, i1, i2}));
}
- pieces.push_back(" }");
+ pieces->push_back(" }");
}
- pieces.push_back(" }");
+ pieces->push_back(" }");
}
- pieces.push_back("\n}");
- } else if (ShapeUtil::Rank(shape()) == 4) {
- pieces.push_back(shape_to_string(shape()));
- pieces.push_back(" {\n");
- for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
- pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
- for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
- pieces.push_back(
- tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1));
- for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) {
- pieces.push_back(" {");
- for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) {
- pieces.push_back(element_to_string({i0, i1, i2, i3}));
+ pieces->push_back("\n}");
+ } else if (ShapeUtil::Rank(subshape) == 4) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(" {");
+ for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+ pieces->push_back(element_to_string({i0, i1, i2, i3}));
}
- pieces.push_back(i2 == shape().dimensions(2) - 1 ? "}\n" : "},\n");
+ pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
}
- pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n"
- : " },\n");
+ pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
+ : " },\n");
}
- pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n");
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
}
- pieces.push_back("}");
- } else if (ShapeUtil::Rank(shape()) == 5) {
- pieces.push_back(shape_to_string(shape()));
- pieces.push_back(" {\n");
- for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
- pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0));
- for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) {
- pieces.push_back(
- tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1));
- for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) {
- pieces.push_back(
- tensorflow::strings::Printf(" { /*i2=%lld*/\n", i2));
- for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) {
- pieces.push_back(" {");
- for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) {
- pieces.push_back(element_to_string({i0, i1, i2, i3, i4}));
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 5) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
+ for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+ pieces->push_back(" {");
+ for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
+ pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
}
- pieces.push_back(i3 == shape().dimensions(3) - 1 ? "}\n" : "},\n");
+ pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
+ : "},\n");
}
- pieces.push_back(i2 == shape().dimensions(2) - 1 ? " }\n"
- : " },\n");
+ pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n"
+ : " },\n");
}
- pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n"
- : " },\n");
+ pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
+ : " },\n");
}
- pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n");
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
}
- pieces.push_back("}");
+ pieces->push_back("}");
} else {
- pieces.push_back(shape_to_string(shape()));
- pieces.push_back(" {");
- EachCellAsString(
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {");
+ literal.EachCellAsString(
[&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
- pieces.push_back(" ");
- pieces.push_back(value);
+ pieces->push_back(" ");
+ pieces->push_back(value);
});
- pieces.push_back("}");
+ pieces->push_back("}");
}
+}
+} // namespace
+
+string Literal::ToString(bool print_layout) const {
+ std::vector<string> pieces;
+ ToStringHelper(*this, {}, print_layout, &pieces);
return tensorflow::str_util::Join(pieces, "");
}
/* static */ std::unique_ptr<Literal> Literal::MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
- auto literal = MakeUnique<Literal>();
- std::vector<Shape> shape;
- for (const Literal* tuple_element : elements) {
- *literal->add_tuple_literals() = *tuple_element;
- shape.push_back(tuple_element->shape());
+ std::vector<Shape> element_shapes;
+ for (const Literal* element : elements) {
+ element_shapes.push_back(element->shape());
+ }
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
}
- *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape);
return literal;
}
/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
- auto literal = MakeUnique<Literal>();
- std::vector<Shape> shape;
- for (auto& tuple_element : elements) {
- shape.push_back(tuple_element->shape());
- *literal->add_tuple_literals() = std::move(*tuple_element);
- }
- *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape);
- return literal;
-}
-
-const void* Literal::InternalData() const {
- return const_cast<const void*>(
- const_cast<Literal*>(this)->MutableInternalData());
-}
-
-void* Literal::MutableInternalData() {
- // NOTE: We access the vectors directly to avoid the const reference
- // created by the accessor functions.
- switch (shape().element_type()) {
- case PRED:
- case U8:
- return reinterpret_cast<void*>(u8s_.data());
- case S32:
- return reinterpret_cast<void*>(s32s_.data());
- case S64:
- return reinterpret_cast<void*>(s64s_.data());
- case U32:
- return reinterpret_cast<void*>(u32s_.data());
- case U64:
- return reinterpret_cast<void*>(u64s_.data());
- case F32:
- return reinterpret_cast<void*>(f32s_.data());
- case F64:
- return reinterpret_cast<void*>(f64s_.data());
- case C64:
- return reinterpret_cast<void*>(c64s_.data());
- case F16:
- return reinterpret_cast<void*>(f16s_.data());
- case BF16:
- return reinterpret_cast<void*>(bf16s_.data());
- default:
- LOG(FATAL) << "primitive type not supported in literals: "
- << PrimitiveType_Name(shape().element_type());
+ std::vector<const Literal*> element_ptrs;
+ for (const auto& element : elements) {
+ element_ptrs.push_back(element.get());
}
-}
-
-void Literal::Reserve(int64 num_elements) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- switch (shape().element_type()) {
- case PRED:
- Resize<bool>(num_elements, false);
- break;
- case S8:
- Resize<int8>(num_elements, 0);
- break;
- case U8:
- Resize<uint8>(num_elements, 0);
- break;
- case S32:
- Resize<int32>(num_elements, 0);
- break;
- case S64:
- Resize<int64>(num_elements, 0);
- break;
- case U32:
- Resize<uint32>(num_elements, 0);
- break;
- case U64:
- Resize<uint64>(num_elements, 0);
- break;
- case F32:
- Resize<float>(num_elements, 0);
- break;
- case F64:
- Resize<double>(num_elements, 0);
- break;
- case C64:
- Resize<complex64>(num_elements, 0);
- break;
- case F16:
- Resize<half>(num_elements, static_cast<half>(0.0f));
- break;
- case BF16:
- Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
- break;
- default:
- LOG(FATAL) << "primitive type not supported in literals: "
- << PrimitiveType_Name(shape().element_type());
- }
-}
-
-tensorflow::Status Literal::ValidateLiteral() const {
- TF_CHECK_OK(ShapeUtil::ValidateShape(shape()));
- int64 expected = ShapeUtil::ElementsIn(shape());
- int64 actual = -1;
- switch (shape().element_type()) {
- case PRED:
- case U8:
- actual = u8s_size();
- break;
- case S32:
- actual = s32s_size();
- break;
- case U32:
- actual = u32s_size();
- break;
- case S64:
- actual = s64s_size();
- break;
- case U64:
- actual = u64s_size();
- break;
- case F32:
- actual = f32s_size();
- break;
- case F64:
- actual = f64s_size();
- break;
- case C64:
- actual = c64s_size();
- break;
- case F16:
- actual = f16s().size() / sizeof(half);
- break;
- case BF16:
- actual = bf16s().size();
- break;
- default:
- return tensorflow::errors::Unimplemented(
- "unhandled element type for literal validation: " +
- PrimitiveType_Name(shape().element_type()));
- }
-
- if (expected != actual) {
- return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf(
- "literal has bad number of elements for its shape %s: want %lld "
- "got %lld",
- ShapeUtil::HumanString(shape()).c_str(), expected, actual));
- }
-
- return tensorflow::Status::OK();
+ return MakeTuple(element_ptrs);
}
void Literal::EachCellAsString(
namespace {
template <typename NativeSrcT, typename NativeDestT>
std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
- auto result_literal = MakeUnique<Literal>();
- Shape* result_shape = result_literal->mutable_shape();
- *result_shape = src_literal.shape();
- result_shape->set_element_type(
- primitive_util::NativeToPrimitiveType<NativeDestT>());
- result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
- tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
- src_literal.GetArraySlice<NativeSrcT>();
- tensorflow::gtl::MutableArraySlice<NativeDestT> dest_data =
- result_literal->GetMutableArraySlice<NativeDestT>();
- int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
+ CHECK(ShapeUtil::IsArray(src_literal.shape()));
+ auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
+ src_literal.shape(),
+ primitive_util::NativeToPrimitiveType<NativeDestT>()));
+ auto src_data = src_literal.data<NativeSrcT>();
+ auto dest_data = result_literal->template data<NativeDestT>();
+ int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = static_cast<NativeDestT>(src_data[i]);
template <PrimitiveType primitive_src_type>
std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
- auto result_literal = MakeUnique<Literal>();
- Shape* result_shape = result_literal->mutable_shape();
- *result_shape = src_literal.shape();
- result_shape->set_element_type(C64);
- result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape));
+ CHECK(ShapeUtil::IsArray(src_literal.shape()));
+ auto result_literal = MakeUnique<Literal>(
+ ShapeUtil::ChangeElementType(src_literal.shape(), C64));
using NativeSrcT =
typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
- src_literal.GetArraySlice<NativeSrcT>();
+ src_literal.data<NativeSrcT>();
tensorflow::gtl::MutableArraySlice<complex64> dest_data =
- result_literal->GetMutableArraySlice<complex64>();
- int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape());
+ result_literal->data<complex64>();
+ int64 num_elements = src_literal.element_count();
for (int64 i = 0; i < num_elements; ++i) {
dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
}
PrimitiveType_Name(primitive_dest_type).c_str());
}
}
+
} // namespace
StatusOr<std::unique_ptr<Literal>> Literal::Convert(
PrimitiveType primitive_dest_type) const {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
case (type): \
}
}
-namespace {
-
-// Helper function which compares whether the elements of literal1 are equal to
-// the elements of literal2. Recursively iterates through the entire
-// multidimensional index space and compares the literal elements
-// one-by-one. literal1 and literal2 must be compatible (same dimensions and
-// type).
template <typename NativeT>
-bool EqualElements(const Literal& literal1, const Literal& literal2,
- int dimension, std::vector<int64>* multi_index) {
- if (dimension == ShapeUtil::Rank(literal1.shape())) {
- return (literal1.Get<NativeT>(*multi_index) ==
- literal2.Get<NativeT>(*multi_index));
+bool Literal::Piece::EqualElementsInternal(
+ const Literal::Piece& other, std::vector<int64>* multi_index) const {
+ if (multi_index->size() == ShapeUtil::Rank(subshape())) {
+ return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
}
- for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) {
- (*multi_index)[dimension] = i;
- if (!EqualElements<NativeT>(literal1, literal2, dimension + 1,
- multi_index)) {
+ for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
+ multi_index->push_back(i);
+ if (!EqualElementsInternal<NativeT>(other, multi_index)) {
return false;
}
+ multi_index->pop_back();
}
return true;
}
-} // namespace
+bool Literal::Piece::EqualElements(const Literal::Piece& other) const {
+ DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
+
+ std::vector<int64> multi_index;
+ switch (subshape().element_type()) {
+ case PRED:
+ return EqualElementsInternal<bool>(other, &multi_index);
+ case U8:
+ return EqualElementsInternal<uint8>(other, &multi_index);
+ case S32:
+ return EqualElementsInternal<int32>(other, &multi_index);
+ case S64:
+ return EqualElementsInternal<int64>(other, &multi_index);
+ case U32:
+ return EqualElementsInternal<uint32>(other, &multi_index);
+ case U64:
+ return EqualElementsInternal<uint64>(other, &multi_index);
+ case F32:
+ return EqualElementsInternal<float>(other, &multi_index);
+ case F64:
+ return EqualElementsInternal<double>(other, &multi_index);
+ case F16:
+ return EqualElementsInternal<half>(other, &multi_index);
+ case BF16:
+ return EqualElementsInternal<bfloat16>(other, &multi_index);
+ case C64:
+ return EqualElementsInternal<complex64>(other, &multi_index);
+ default:
+ LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type "
+ << PrimitiveType_Name(subshape().element_type());
+ }
+}
bool Literal::operator==(const Literal& other) const {
if (!ShapeUtil::Compatible(shape(), other.shape())) {
return false;
}
- if (ShapeUtil::IsTuple(shape())) {
- // Because the shapes are compatible, they must have the same number of
- // tuple elements.
- CHECK_EQ(tuple_literals_size(), other.tuple_literals_size());
- for (int i = 0; i < tuple_literals_size(); ++i) {
- if (tuple_literals(i) != other.tuple_literals(i)) {
- return false;
- }
+ for (const auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ const Piece& piece = pair.second;
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ continue;
}
- return true;
- } else {
- std::vector<int64> multi_index(ShapeUtil::Rank(shape()), 0);
- switch (shape().element_type()) {
- case PRED:
- return EqualElements<bool>(*this, other, 0, &multi_index);
- case U8:
- return EqualElements<uint8>(*this, other, 0, &multi_index);
- case S32:
- return EqualElements<int32>(*this, other, 0, &multi_index);
- case S64:
- return EqualElements<int64>(*this, other, 0, &multi_index);
- case U32:
- return EqualElements<uint32>(*this, other, 0, &multi_index);
- case U64:
- return EqualElements<uint64>(*this, other, 0, &multi_index);
- case F32:
- return EqualElements<float>(*this, other, 0, &multi_index);
- case F64:
- return EqualElements<double>(*this, other, 0, &multi_index);
- case F16:
- return EqualElements<half>(*this, other, 0, &multi_index);
- case BF16:
- return EqualElements<bfloat16>(*this, other, 0, &multi_index);
- case C64:
- return EqualElements<complex64>(*this, other, 0, &multi_index);
- default:
- LOG(FATAL) << "Unimplemented: Literal::Equal for type "
- << PrimitiveType_Name(shape().element_type());
+
+ const Piece& other_piece = other.piece(index);
+ if (!piece.EqualElements(other_piece)) {
+ return false;
}
}
+ return true;
}
-template <>
-tensorflow::gtl::MutableArraySlice<bool> Literal::GetMutableArraySlice() {
- auto values = mutable_preds();
- return tensorflow::gtl::MutableArraySlice<bool>(
- reinterpret_cast<bool*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int8> Literal::GetMutableArraySlice() {
- auto values = mutable_u8s();
- return tensorflow::gtl::MutableArraySlice<int8>(
- reinterpret_cast<int8*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint8> Literal::GetMutableArraySlice() {
- auto values = mutable_u8s();
- return tensorflow::gtl::MutableArraySlice<uint8>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int16> Literal::GetMutableArraySlice() {
- auto values = mutable_s16s();
- return tensorflow::gtl::MutableArraySlice<int16>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint16> Literal::GetMutableArraySlice() {
- auto values = mutable_u16s();
- return tensorflow::gtl::MutableArraySlice<uint16>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int32> Literal::GetMutableArraySlice() {
- auto values = mutable_s32s();
- return tensorflow::gtl::MutableArraySlice<int32>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint32> Literal::GetMutableArraySlice() {
- auto values = mutable_u32s();
- return tensorflow::gtl::MutableArraySlice<uint32>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<int64> Literal::GetMutableArraySlice() {
- static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) &&
- alignof(int64) == alignof(tensorflow::protobuf_int64),
- "The int64 and tensorflow::protobuf_int64 types are not "
- "compatible");
- auto values = mutable_s64s();
- // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t
- // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is
- // necessary from the raw data pointer returned by the mutable_data() API.
- return tensorflow::gtl::MutableArraySlice<int64>(
- reinterpret_cast<int64*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint64> Literal::GetMutableArraySlice() {
- static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) &&
- alignof(uint64) == alignof(tensorflow::protobuf_uint64),
- "The uint64 and tensorflow::protobuf_uint64 types are not "
- "compatible");
- auto values = mutable_u64s();
- // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t
- // while tensorflow::uint64 is defined as unsigned long long, a
- // reinterpret_cast<> is necessary from the raw data pointer returned by the
- // mutable_data() API.
- return tensorflow::gtl::MutableArraySlice<uint64>(
- reinterpret_cast<uint64*>(values->data()), values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<float> Literal::GetMutableArraySlice() {
- auto values = mutable_f32s();
- return tensorflow::gtl::MutableArraySlice<float>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice() {
- auto values = mutable_f64s();
- return tensorflow::gtl::MutableArraySlice<double>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
- auto values = mutable_c64s();
- return {values->data(), values->size()};
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
- auto values = mutable_f16s();
- return tensorflow::gtl::MutableArraySlice<half>(values->data(),
- values->size());
-}
-
-template <>
-tensorflow::gtl::MutableArraySlice<bfloat16>
-Literal::GetMutableArraySlice<bfloat16>() {
- auto values = mutable_bf16s();
- return {values->data(), values->size()};
-}
-
-template <>
-tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
- CHECK_EQ(shape().element_type(), PRED) << ShapeUtil::HumanString(shape());
- return tensorflow::gtl::ArraySlice<bool>(
- reinterpret_cast<const bool*>(preds().data()), preds().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint8> Literal::GetArraySlice<uint8>() const {
- CHECK_EQ(shape().element_type(), U8) << ShapeUtil::HumanString(shape());
- return tensorflow::gtl::ArraySlice<uint8>(
- reinterpret_cast<const uint8*>(u8s().data()), u8s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int8> Literal::GetArraySlice<int8>() const {
- CHECK_EQ(shape().element_type(), S8) << ShapeUtil::HumanString(shape());
- return tensorflow::gtl::ArraySlice<int8>(
- reinterpret_cast<const int8*>(u8s().data()), u8s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint16> Literal::GetArraySlice<uint16>() const {
- CHECK_EQ(shape().element_type(), U16) << ShapeUtil::HumanString(shape());
- return tensorflow::gtl::ArraySlice<uint16>(u16s().data(), u16s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int16> Literal::GetArraySlice<int16>() const {
- CHECK_EQ(shape().element_type(), S16) << ShapeUtil::HumanString(shape());
- return tensorflow::gtl::ArraySlice<int16>(s16s().data(), s16s().size());
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint32> Literal::GetArraySlice<uint32>() const {
- CHECK_EQ(shape().element_type(), U32) << ShapeUtil::HumanString(shape());
- return u32s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<uint64> Literal::GetArraySlice<uint64>() const {
- CHECK_EQ(shape().element_type(), U64) << ShapeUtil::HumanString(shape());
- return u64s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int32> Literal::GetArraySlice<int32>() const {
- CHECK_EQ(shape().element_type(), S32) << ShapeUtil::HumanString(shape());
- return s32s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<int64> Literal::GetArraySlice<int64>() const {
- CHECK_EQ(shape().element_type(), S64) << ShapeUtil::HumanString(shape());
- return s64s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const {
- CHECK_EQ(shape().element_type(), F64) << ShapeUtil::HumanString(shape());
- return f64s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
- CHECK_EQ(shape().element_type(), F16) << ShapeUtil::HumanString(shape());
- return tensorflow::gtl::ArraySlice<half>(f16s().data(),
- f16s().size() / sizeof(half));
-}
-
-template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
- CHECK_EQ(shape().element_type(), BF16) << ShapeUtil::HumanString(shape());
- return {bf16s().data(), bf16s().size()};
-}
-
-template <>
-tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
- const {
- CHECK_EQ(shape().element_type(), C64) << ShapeUtil::HumanString(shape());
- return c64s();
-}
+namespace {
template <typename NativeT>
-static bool AllElementsEqualValue(const Literal& literal, NativeT value) {
- for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
- auto multi_index =
- IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
- if (literal.Get<NativeT>(multi_index) != value) {
+static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
+ NativeT value) {
+ for (int64 i = 0; i < data.size(); ++i) {
+ if (data[i] != value) {
return false;
}
}
return true;
}
+} // namespace
+
bool Literal::IsAll(int8 value) const {
- switch (shape().element_type()) {
- case U8:
- if (value >= 0) {
- return AllElementsEqualValue<uint8>(*this, value);
- }
- return false;
- case U32:
- if (value >= 0) {
- return AllElementsEqualValue<uint32>(*this, value);
- }
- return false;
- case U64:
- if (value >= 0) {
- return AllElementsEqualValue<uint64>(*this, value);
- }
- return false;
- case S8:
- return AllElementsEqualValue<int8>(*this, value);
- case S32:
- return AllElementsEqualValue<int32>(*this, value);
- case S64:
- return AllElementsEqualValue<int64>(*this, value);
- case F32:
- return AllElementsEqualValue<float>(*this, value);
- case F64:
- return AllElementsEqualValue<double>(*this, value);
- case F16:
- return AllElementsEqualValue<half>(*this, static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(*this,
- static_cast<bfloat16>(value));
- case PRED:
- if (value == 0) {
- return AllElementsEqualValue<bool>(*this, false);
- }
- if (value == 1) {
- return AllElementsEqualValue<bool>(*this, true);
+ for (const auto& pair : pieces_) {
+ const Piece& piece = pair.second;
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ continue;
+ }
+
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case U8:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
+ }
+ return false;
+ case U32:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
+ }
+ return false;
+ case U64:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
+ }
+ return false;
+ case S8:
+ return AllElementsEqualValue<int8>(piece.data<int8>(), value);
+ case S32:
+ return AllElementsEqualValue<int32>(piece.data<int32>(), value);
+ case S64:
+ return AllElementsEqualValue<int64>(piece.data<int64>(), value);
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
+ static_cast<bfloat16>(value));
+ case PRED:
+ if (value == 0) {
+ return AllElementsEqualValue<bool>(piece.data<bool>(), false);
+ }
+ if (value == 1) {
+ return AllElementsEqualValue<bool>(piece.data<bool>(), true);
+ }
+ return false;
+ default:
+ return false;
}
return false;
- default:
+ };
+
+ if (!piece_is_all()) {
return false;
+ }
}
+ return true;
}
bool Literal::IsAllFloat(float value) const {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(*this, value);
- case F64:
- return AllElementsEqualValue<double>(*this, value);
- case F16:
- return AllElementsEqualValue<half>(*this, static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(*this,
- static_cast<bfloat16>(value));
- default:
+ for (const auto& pair : pieces_) {
+ const Piece& piece = pair.second;
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ continue;
+ }
+
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
+ static_cast<bfloat16>(value));
+ default:
+ return false;
+ }
+ };
+ if (!piece_is_all()) {
return false;
+ }
}
+ return true;
}
bool Literal::IsAllComplex(complex64 value) const {
switch (shape().element_type()) {
case C64:
- return AllElementsEqualValue<complex64>(*this, value);
+ return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
+ value);
default:
return false;
}
}
bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+ CHECK(ShapeUtil::IsArray(shape()));
switch (shape().element_type()) {
case U8:
return Get<uint8>(indices) == 0;
}
}
-template <>
-/* static */ void Literal::Resize<bool>(int64 num_elements, bool value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_preds()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<int8>(int64 num_elements, int8 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_u8s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<uint8>(int64 num_elements, uint8 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_u8s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<int32>(int64 num_elements, int32 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_s32s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<uint32>(int64 num_elements, uint32 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_u32s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<int64>(int64 num_elements, int64 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_s64s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<uint64>(int64 num_elements, uint64 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_u64s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<float>(int64 num_elements, float value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_f32s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<double>(int64 num_elements, double value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_f64s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<half>(int64 num_elements, half value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_f16s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_bf16s()->resize(num_elements, value);
-}
-
-template <>
-void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_c64s()->resize(num_elements, value);
-}
+namespace {
template <typename RepeatedFieldT, typename NativeT>
void CopyToRepeatedField(RepeatedFieldT* dest,
- const std::vector<NativeT>& src) {
+ const tensorflow::gtl::ArraySlice<NativeT> src) {
*dest = RepeatedFieldT(src.begin(), src.end());
}
-template <>
-void CopyToRepeatedField<tensorflow::protobuf::RepeatedField<float>, complex64>(
- tensorflow::protobuf::RepeatedField<float>* dest,
- const std::vector<complex64>& src) {
- *dest = tensorflow::protobuf::RepeatedField<float>(
- reinterpret_cast<const float*>(src.data()),
- reinterpret_cast<const float*>(src.data()) + src.size() * 2);
-}
+} // namespace
-LiteralProto Literal::ToProto() const {
- LiteralProto proto;
- proto.Clear();
- *proto.mutable_shape() = shape();
- switch (shape().element_type()) {
+void Literal::Piece::WriteToProto(LiteralProto* proto) const {
+ *proto->mutable_shape() = subshape();
+ switch (subshape().element_type()) {
case PRED:
- CopyToRepeatedField(proto.mutable_preds(), preds());
+ CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
case U8:
- *proto.mutable_u8s() = u8s_string();
- break;
- case S32:
- CopyToRepeatedField(proto.mutable_s32s(), s32s());
- break;
- case S64:
- CopyToRepeatedField(proto.mutable_s64s(), s64s());
+ proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
+ element_count());
break;
case U32:
- CopyToRepeatedField(proto.mutable_u32s(), u32s());
+ CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
break;
case U64:
- CopyToRepeatedField(proto.mutable_u64s(), u64s());
+ CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
+ break;
+ case S32:
+ CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
+ break;
+ case S64:
+ CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
break;
case F16:
- *proto.mutable_f16s() =
- string(reinterpret_cast<const char*>(f16s_.data()),
- f16s_.size() * sizeof(half));
+ *proto->mutable_f16s() = string(
+ reinterpret_cast<const char*>(data<half>().data()), size_bytes());
if (!kLittleEndian) {
- ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
- proto.f16s().size());
+ ConvertEndianShort(const_cast<char*>(proto->mutable_f16s()->data()),
+ proto->f16s().size());
}
break;
case BF16:
- *proto.mutable_bf16s() =
- string(reinterpret_cast<const char*>(bf16s_.data()),
- bf16s_.size() * sizeof(bfloat16));
+ *proto->mutable_bf16s() = string(
+ reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
if (!kLittleEndian) {
- ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
- proto.bf16s().size());
+ ConvertEndianShort(const_cast<char*>(proto->mutable_bf16s()->data()),
+ proto->bf16s().size());
}
break;
case F32:
- CopyToRepeatedField(proto.mutable_f32s(), f32s());
+ CopyToRepeatedField(proto->mutable_f32s(), data<float>());
break;
case F64:
- CopyToRepeatedField(proto.mutable_f64s(), f64s());
+ CopyToRepeatedField(proto->mutable_f64s(), data<double>());
break;
case C64:
- CopyToRepeatedField(proto.mutable_c64s(), c64s());
- break;
- case TUPLE:
- for (const auto& tuple : tuple_literals()) {
- *proto.add_tuple_literals() = tuple.ToProto();
+ for (complex64 value : data<complex64>()) {
+ proto->add_c64s(value.real());
+ proto->add_c64s(value.imag());
}
break;
+ case TUPLE:
+ // Nothing to do but assign the shape which is done above.
+ return;
default:
- LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
+ LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
}
-
- return proto;
}
-template <typename RepeatedFieldT, typename NativeT>
-void CopyFromRepeatedField(std::vector<NativeT>* dest,
- const RepeatedFieldT& src) {
- *dest = std::vector<NativeT>(src.begin(), src.end());
+const void* Literal::Piece::untyped_data() const {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ return buffer();
}
-template <>
-void CopyFromRepeatedField<tensorflow::protobuf::RepeatedField<float>,
- complex64>(
- std::vector<complex64>* dest,
- const tensorflow::protobuf::RepeatedField<float>& src) {
- *dest = std::vector<complex64>(
- reinterpret_cast<const complex64*>(src.data()),
- reinterpret_cast<const complex64*>(src.data()) + src.size() / 2);
+void* Literal::Piece::untyped_data() {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ return buffer();
}
-void Literal::CopyFromProto(const LiteralProto& literal_proto) {
- if (!literal_proto.has_shape()) {
- return;
+namespace {
+
+template <typename RepeatedFieldT, typename NativeT>
+Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+ const RepeatedFieldT& src) {
+ if (dest.size() != src.size()) {
+ return InvalidArgument(
+ "Expected %lu elements in LiteralProto repeated field, has %d",
+ dest.size(), src.size());
}
+ std::copy(src.begin(), src.end(), dest.begin());
+ return Status::OK();
+}
- *mutable_shape() = literal_proto.shape();
- switch (shape().element_type()) {
+} // namespace
+
+Status Literal::Piece::CopyFromProto(const LiteralProto& proto) {
+ // These conditions should have been checked in Literal::CreateFromProto.
+ TF_RET_CHECK(proto.has_shape());
+ TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
+ TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+
+ switch (subshape().element_type()) {
case PRED:
- CopyFromRepeatedField(mutable_preds(), literal_proto.preds());
- break;
- case U8:
- set_u8s(literal_proto.u8s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
break;
+ case U8: {
+ auto u8_data = data<uint8>();
+ TF_RET_CHECK(proto.u8s().size() == u8_data.size());
+ std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
+ } break;
case S32:
- CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
break;
case S64:
- CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
break;
case U32:
- CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
break;
case U64:
- CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
break;
case F16: {
- const string& s(literal_proto.f16s());
- CHECK_EQ(0, s.size() % sizeof(half));
- f16s_ = std::vector<half>(s.size() / sizeof(half));
- memcpy(f16s_.data(), s.data(), s.size());
-
+ const string& s(proto.f16s());
+ TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
+ memcpy(untyped_data(), s.data(), s.size());
if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
+ ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
}
- break;
- }
- case BF16: {
- const string& s(literal_proto.bf16s());
- CHECK_EQ(0, s.size() % sizeof(bfloat16));
- bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
- memcpy(bf16s_.data(), s.data(), s.size());
+ } break;
+ case BF16: {
+ const string& s(proto.bf16s());
+ TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
+ memcpy(untyped_data(), s.data(), s.size());
if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
+ ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
}
- break;
- }
+ } break;
case F32:
- CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
break;
case F64:
- CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s());
- break;
- case C64:
- CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s());
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
break;
- case TUPLE:
- for (const auto& proto : literal_proto.tuple_literals()) {
- mutable_tuple_literals()->push_back(Literal(proto));
+ case C64: {
+ auto complex_data = data<complex64>();
+ TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
+ for (int64 i = 0; i < complex_data.size(); ++i) {
+ complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
}
+ } break;
+ case TUPLE:
+ LOG(FATAL) << "Should not be called on tuple shapes: "
+ << ShapeUtil::HumanString(subshape());
break;
default:
- LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
+ LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ }
+ return Status::OK();
+}
+
+LiteralProto Literal::ToProto() const {
+ LiteralProto proto;
+ for (const auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ const Piece& piece = pair.second;
+
+ LiteralProto* proto_piece = &proto;
+ for (int64 i : index) {
+ while (proto_piece->tuple_literals_size() <= i) {
+ proto_piece->add_tuple_literals();
+ }
+ proto_piece = proto_piece->mutable_tuple_literals(i);
+ }
+ piece.WriteToProto(proto_piece);
+ }
+ return proto;
+}
+
+/* static */
+StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
+ const LiteralProto& proto) {
+ if (!proto.has_shape()) {
+ return InvalidArgument("LiteralProto has no shape");
}
+ if (!LayoutUtil::HasLayout(proto.shape())) {
+ return InvalidArgument("LiteralProto has no layout");
+ }
+
+ auto literal = MakeUnique<Literal>(proto.shape());
+
+ for (auto& pair : literal->pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+ const LiteralProto* proto_element = &proto;
+ for (int64 i : index) {
+ TF_RET_CHECK(i < proto_element->tuple_literals_size());
+ proto_element = &proto_element->tuple_literals(i);
+ }
+
+ if (ShapeUtil::IsTuple(piece.subshape())) {
+ if (proto_element->tuple_literals_size() !=
+ ShapeUtil::TupleElementCount(piece.subshape())) {
+ return InvalidArgument(
+ "Expected %lld tuple elements in LiteralProto, has %d",
+ ShapeUtil::TupleElementCount(piece.subshape()),
+ proto_element->tuple_literals_size());
+ }
+ continue;
+ }
+
+ TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape()));
+ TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element));
+ }
+ return std::move(literal);
+}
+
+const void* Literal::untyped_data(const ShapeIndex& shape_index) const {
+ return piece(shape_index).untyped_data();
+}
+
+void* Literal::untyped_data(const ShapeIndex& shape_index) {
+ return piece(shape_index).untyped_data();
}
-const Literal& Literal::GetSubliteral(const ShapeIndex& index) const {
- return const_cast<Literal*>(this)->GetSubliteral(index);
+int64 Literal::size_bytes(const ShapeIndex& shape_index) const {
+ return piece(shape_index).size_bytes();
+}
+
+string Literal::GetR1U8AsString() const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(shape().element_type(), U8);
+ return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
+ ShapeUtil::ElementsIn(shape()));
+}
+
+/* static */ const LiteralView LiteralView::Create(
+ const Literal& literal, const ShapeIndex& view_root) {
+ return LiteralView(literal, view_root);
+}
+
+LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) {
+ shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root);
+ pieces_ = ShapeTree<Piece>(shape_);
+ owns_buffers_ = false;
+ for (auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+
+ ShapeIndex src_index = view_root;
+ for (int64 i : index) {
+ src_index.push_back(i);
+ }
+ const Piece& src_piece = literal.piece(src_index);
+ piece.set_buffer(src_piece.buffer());
+ piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
+ }
+}
+
+LiteralView::~LiteralView() {}
+
+LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); }
+
+LiteralView& LiteralView::operator=(const LiteralView& other) {
+ CopyFrom(other);
+ return *this;
}
-Literal& Literal::GetSubliteral(const ShapeIndex& index) {
- Literal* subliteral = this;
- for (int64 i : index) {
- subliteral = &subliteral->tuple_literals_.at(i);
+void LiteralView::CopyFrom(const LiteralView& other) {
+ // We can't use the default copy-constructor/copy-assignment because
+ // Piece::subshape_ points to subshapes within the Shape of the owning
+ // Literal/LiteralView.
+ shape_ = other.shape();
+ pieces_ = other.pieces_;
+ for (auto& pair : pieces_) {
+ const ShapeIndex& index = pair.first;
+ Piece& piece = pair.second;
+ piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index));
}
- return *subliteral;
+ owns_buffers_ = false;
}
} // namespace xla
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
-// Utility class for dealing with XLA literal values. Most methods are
-// templated by native (host) type which corresponds to a unique XLA
-// PrimitiveType. See ComputationBuilder for details. Not all primitive types
-// defined in xla_data.proto have a corresponding native type or even have a
-// storage location in the Literal proto yet (for example, primitive type F16).
+// Class representing literal values in XLA.
+//
+// TODO(b/67651157): The methods in this class should be reduced to a minimal
+// set of methods which construct Literals and accessors methods. Other methods
+// which perform computation on Literals (Reshape, Slice, etc) should be moved
+// elsewhere, and perhaps combined with evaluator code which operates on
+// Literals.
class Literal {
public:
- Literal() {}
+ Literal() : Literal(ShapeUtil::MakeNil()) {}
- Literal(const Literal& other) = default;
- Literal(Literal&&) = default;
+ // Create a literal of the given shape. The literal is allocated sufficient
+ // memory to hold the shape. Memory is uninitialized.
+ explicit Literal(const Shape& shape);
+ virtual ~Literal();
- explicit Literal(const LiteralProto& other) { CopyFromProto(other); }
-
- Literal& operator=(const Literal& other) = default;
- Literal& operator=(Literal&&) = default;
+ // Literals are moveable, but not copyable. To copy a literal use
+ // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
+ // of literals which can be expensive.
+ Literal(const Literal& other) = delete;
+ Literal& operator=(const Literal& other) = delete;
+ Literal(Literal&& other);
+ Literal& operator=(Literal&& other);
// Literals are equal if they have compatible shapes and the same data
- // values. Layout is not checked.
+ // values. Layout is not compared.
bool operator==(const Literal& other) const;
bool operator!=(const Literal& other) const { return !(*this == other); }
+ // Serialize to and from a proto.
+ static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+ const LiteralProto& proto);
LiteralProto ToProto() const;
- bool has_shape() const {
- return shape_.element_type() != PRIMITIVE_TYPE_INVALID;
- }
-
- // Basic accessor functions. Names mirror the original protobuf
- // functions for convenience.
- string DebugString() const { return ToProto().DebugString(); }
- string ShortDebugString() const { return ToProto().ShortDebugString(); }
-
- // Return the nested literal at the given shape index.
- const Literal& GetSubliteral(const ShapeIndex& index) const;
- Literal& GetSubliteral(const ShapeIndex& index);
-
- void Clear() {
- shape_.Clear();
- u8s_.clear();
- s16s_.clear();
- s32s_.clear();
- s64s_.clear();
- u16s_.clear();
- u32s_.clear();
- u64s_.clear();
- f16s_.clear();
- f32s_.clear();
- f64s_.clear();
- c64s_.clear();
- tuple_literals_.clear();
- }
-
- int preds_size() const { return u8s().size(); }
- const std::vector<uint8>& preds() const {
- static_assert(sizeof(uint8) == sizeof(bool),
- "The uint8 and bool types should be the same size");
- return u8s_;
- }
- std::vector<uint8>* mutable_preds() {
- static_assert(sizeof(uint8) == sizeof(bool),
- "The uint8 and bool types should be the same size");
- return &u8s_;
- }
-
- int s16s_size() const { return s16s().size(); }
- int32 s16s(int i) const { return s16s_[i]; }
- const std::vector<int16>& s16s() const { return s16s_; }
- std::vector<int16>* mutable_s16s() { return &s16s_; }
-
- int s32s_size() const { return s32s().size(); }
- int32 s32s(int i) const { return s32s_[i]; }
- const std::vector<int32>& s32s() const { return s32s_; }
- std::vector<int32>* mutable_s32s() { return &s32s_; }
-
- int s64s_size() const { return s64s().size(); }
- void add_s64s(int64 value) { s64s_.push_back(value); }
- const std::vector<int64>& s64s() const { return s64s_; }
- std::vector<int64>* mutable_s64s() { return &s64s_; }
-
- int u16s_size() const { return u16s().size(); }
- uint32 u16s(int i) const { return u16s_[i]; }
- const std::vector<uint16>& u16s() const { return u16s_; }
- std::vector<uint16>* mutable_u16s() { return &u16s_; }
-
- int u32s_size() const { return u32s().size(); }
- uint32 u32s(int i) const { return u32s_[i]; }
- const std::vector<uint32>& u32s() const { return u32s_; }
- std::vector<uint32>* mutable_u32s() { return &u32s_; }
-
- int u64s_size() const { return u64s().size(); }
- const std::vector<uint64>& u64s() const { return u64s_; }
- std::vector<uint64>* mutable_u64s() { return &u64s_; }
-
- int f16s_size() const { return f16s().size(); }
- half f16s(int i) const { return f16s_[i]; }
- const std::vector<half>& f16s() const { return f16s_; }
- std::vector<half>* mutable_f16s() { return &f16s_; }
-
- int f32s_size() const { return f32s().size(); }
- float f32s(int i) const { return f32s_[i]; }
- void add_f32s(float value) { f32s_.push_back(value); }
- const std::vector<float>& f32s() const { return f32s_; }
- std::vector<float>& f32s() { return f32s_; }
- std::vector<float>* mutable_f32s() { return &f32s_; }
-
- int f64s_size() const { return f64s().size(); }
- const std::vector<double>& f64s() const { return f64s_; }
- std::vector<double>* mutable_f64s() { return &f64s_; }
-
- int c64s_size() const { return c64s().size(); }
- const std::vector<complex64>& c64s() const { return c64s_; }
- std::vector<complex64>* mutable_c64s() { return &c64s_; }
-
- int bf16s_size() const { return bf16s().size(); }
- bfloat16 bf16s(int i) const { return bf16s_[i]; }
- const std::vector<bfloat16>& bf16s() const { return bf16s_; }
- std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
-
- int tuple_literals_size() const { return tuple_literals().size(); }
- const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
- Literal* add_tuple_literals() {
- tuple_literals_.push_back(Literal());
- return &tuple_literals_.back();
- }
- std::vector<Literal>* mutable_tuple_literals() { return &tuple_literals_; }
- const std::vector<Literal>& tuple_literals() const { return tuple_literals_; }
-
- int u8s_size() const { return u8s().size(); }
- const std::vector<uint8>& u8s() const { return u8s_; }
- void set_u8s(const std::vector<uint8>& value) { u8s_ = value; }
- void set_u8s(tensorflow::StringPiece value) {
- u8s_ = std::vector<uint8>(value.size());
- u8s_.clear();
- append_u8s(value);
- }
-
- void append_u8s(tensorflow::StringPiece value) {
- u8s_.insert(u8s_.end(), value.begin(), value.end());
- }
+ // Return the shape of the literal.
+ const Shape& shape() const { return shape_; }
- string u8s_string() const { return string(u8s().begin(), u8s().end()); }
+ // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+ // mutate the shape as this can produce malformed Literals.
+ Shape* mutable_shape_do_not_use() { return &shape_; }
- std::vector<uint8>* mutable_u8s() { return &u8s_; }
+ // Returns a (Mutable)ArraySlice view of the array for this literal for the
+ // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
+ // given ShapeIndex is not array. See primitive_util.h for the mapping from
+ // XLA type to native type.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {}) const;
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {});
- const Shape& shape() const { return shape_; }
- Shape* mutable_shape() { return &shape_; }
+ // Returns a pointer to (or size of) the underlying buffer holding the array
+ // at the given shape index. CHECKs if the subshape of the literal at the
+ // given ShapeIndex is not array.
+ const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+ void* untyped_data(const ShapeIndex& shape_index = {});
+ int64 size_bytes(const ShapeIndex& shape_index = {}) const;
// Creates a new literal of a given rank. To minimize ambiguity (for users
// and the compiler) these CreateR[0-2] methods should explicitly specify the
values,
const Layout& layout);
+ // Returns this literal's data as a string. This literal must be a rank-1 U8
+ // array.
+ string GetR1U8AsString() const;
+
// Creates a new Literal object with the shape specified as parameter.
// The content of the literal values is the default value of the primitive
// type of literal itself (0 for numeric types, and false for predicates).
PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+ // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+ // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+ // rooted at 'src_shape_index', but need not be arrays.
+ Status CopyFrom(const Literal& src_literal,
+ const ShapeIndex& dest_shape_index = {},
+ const ShapeIndex& src_shape_index = {});
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
// Copies the values from src_literal, starting at src_base shape indexes,
// to this literal, starting at dest_base, where the copy size in each
// dimension is specified by copy_size.
// Note: if either src_literal or this literal contains dimensions with zero
// element, then copy_size must be 0 in these dimensions while the
// corresponding base indices being 0.
- Status Copy(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
+ // This literal and 'src_literal' must be arrays.
+ Status CopySliceFrom(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
+
+ // This operation is the inverse of DecomposeTuple. The given elements are
+ // moved into the tuple elements of a new tuple-shaped Literal which is
+ // returned. Upon return, each of the Literals in 'elements' is set to a nil
+ // shape (empty tuple).
+ static Literal MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements);
// Creates a new value that has the equivalent value as this literal, but
// conforms to new_layout; e.g. a literal matrix that was in {0, 1}
// Creates a new literal by reshaping this literal to have the given
// dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
+ // This literal must be an array.
StatusOr<std::unique_ptr<Literal>> Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const;
// in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
// For example, a transpose call on a literal of shape [3 x 8 x 4] and
// `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+ // This literal must be an array.
std::unique_ptr<Literal> Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const;
// same rank and layout as for the given literal. The number of indices in
// start_indices and limit_indices must be the rank of the literal, and the
// indices follow the order of the dimensions.
+ // This literal must be an array.
std::unique_ptr<Literal> Slice(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) const;
// Creates a literal with a prepended dimension with bound "times"; e.g. a
// f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
// literal replicated four times.
+ // This literal must be an array.
template <typename NativeT>
std::unique_ptr<Literal> Replicate(int64 times) const;
// Converts this literal to another primitive type. Returns an error if the
- // conversion is not possible.
+ // conversion is not possible. This literal must be array-shaped.
StatusOr<std::unique_ptr<Literal>> Convert(
PrimitiveType primitive_dest_type) const;
- // Creates a literal value zero of the given primitive type.
+ // Creates a scalar literal value zero of the given primitive type.
static Literal Zero(PrimitiveType primitive_type);
- // Creates a literal value one of the given primitive type.
+ // Creates a scalar literal value one of the given primitive type.
static Literal One(PrimitiveType primitive_type);
- // Creates a literal value containing the minimum value of the given
+ // Creates a scalar literal value containing the minimum value of the given
// primitive type. For floating-point types, returns -inf.
static Literal MinValue(PrimitiveType primitive_type);
- // Creates a literal value containing the maximum value of the given
+ // Creates a scalar literal value containing the maximum value of the given
// primitive type. For floating-point types, returns inf.
static Literal MaxValue(PrimitiveType primitive_type);
static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
- // Creates a new literal from an array. The variants not ending with
+ // Creates a new literal from an Array type. The variants not ending with
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z);
- // Clones this literal into an owned unique_ptr version.
+ // Clones this literal into a new Literal, or new std::unique_ptr<Literal>.
+ Literal Clone() const;
std::unique_ptr<Literal> CloneToUnique() const;
- // Returns the linear index of the given index within this literal's
- // element_type repeated field.
- int64 LinearIndex(tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
- // Gets or sets an element in the literal at the given index. The index is
- // CHECKed against the dimension sizes.
+ // Gets or sets an element in the literal at the given index. The multi_index
+ // is CHECKed against the dimension sizes.
template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const;
template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value);
- // Returns a (Mutable)ArraySlice view of the array for this literal for the
- // given NativeT (e.g., float). These functions map native type to XLA
- // PrimitiveType via template specialization. The unspecialized forms below
- // aborts to handle the error case where the given native type does not map to
- // an XLA primitive type.
+ // Overloads of Get and Set for array literals. CHECKs if the literal is not
+ // array-shaped.
template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> GetArraySlice() const {
- static_assert(!std::is_same<NativeT, NativeT>::value,
- "Cannot map native type to primitive type.");
- }
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> GetMutableArraySlice() {
- static_assert(!std::is_same<NativeT, NativeT>::value,
- "Cannot map native type to primitive type.");
- }
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
// Returns the element value at index (0, ..., 0), however many zeroes are
// required for that index.
// As Get(), but determines the correct type and converts the value
// into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+ string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index = {}) const;
// As Get(), but determines the correct type and converts the value into
- // int64.
+ // int64. This literal must be an array.
StatusOr<int64> GetIntegralAsS64(
tensorflow::gtl::ArraySlice<int64> multi_index) const;
template <typename NativeT>
static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
- // Returns a tuple literal composed of given literals.
+ // Returns a tuple literal composed of given literals. Data is copied from the
+ // given elements into the returned literal.
static std::unique_ptr<Literal> MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements);
static std::unique_ptr<Literal> MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements);
- // Validates that the data payload of the literal matches the literal shape;
- // if it does not, an appropriate status is returned.
- tensorflow::Status ValidateLiteral() const;
-
// Returns a string representation of the literal value.
string ToString(bool print_layout = false) const;
NativeT value)>
per_cell) const;
- // Templated methods which populate the given repeated field in this literal
- // with the given value(s). The Shape field of this literal is set
- // to match the array dimensions and type. Examples:
+ // Populate this literal with the given values. Examples:
//
// // Populate with floats.
// Array2D<float> float_values = ...
// literal.PopulateR2FromArray2D(values);
//
// // Populate with int32s.
- // literal.PopulateR2({{1, 2}, {3, 4}});
+ // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
//
- template <typename NativeT>
- void PopulateR0(NativeT values);
+ // The shape and element type of this literal must match given values. For
+ // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+ // array of S32.
template <typename NativeT>
void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
void PopulateR1(const tensorflow::core::Bitmap& values);
template <typename NativeT>
void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
- void PopulateR2WithLayout(
- std::initializer_list<std::initializer_list<NativeT>> values,
- const Layout& layout);
- template <typename NativeT>
void PopulateFromArray(const Array<NativeT>& values);
template <typename NativeT>
- void PopulateFromArrayWithLayout(const Array<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
void PopulateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
- void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
void PopulateR3FromArray3D(const Array3D<NativeT>& values);
template <typename NativeT>
- void PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
void PopulateR4FromArray4D(const Array4D<NativeT>& values);
- template <typename NativeT>
- void PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout);
// Populates literal values by calling the generator function for every cell
// in this literal object.
template <typename NativeT, typename FnType>
Status Populate(const FnType& generator);
- // Creates a Literal of the given dimensions with all elements set to the
- // given value.
+ // Fills this literal with the given value.
template <typename NativeT>
- void PopulateWithValue(NativeT value,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // Returns a pointer to the underlying vector corresponding to the Literal's
- // shape.
- const void* InternalData() const;
- void* MutableInternalData();
-
- // Allocates space in the underlying vector of this literal sufficient to hold
- // num_elements of this literal's primitive type. Values in the vector are set
- // to zero. num_elements must equal the number of elements in the literal's
- // shape.
- void Reserve(int64 num_elements);
-
- // Allocates space in the underlying vector of this literal sufficient to hold
- // num_elements of this literal's primitive type and sets each element in this
- // literal to the given value. num_elements must equal the number of elements
- // in this literal's shape.
- template <typename NativeT>
- void Resize(int64 num_elements, NativeT value);
+ void PopulateWithValue(NativeT value);
// Returns whether every element in this literal is equal to value.
//
//
// If value doesn't fit in this literal's type, returns false. Values of 1/0
// are considered equal to true/false; other values are not considered equal
- // to true.
+ // to true. Also if this literal is not array-shaped false is returned.
bool IsAll(int8 value) const;
// Like IsAll(const Literal&, int8), except we check whether the literal is
// This casts value to the type of literal, then compares using ==. The usual
// admonishments about floating-point equality checks apply. We expect you to
// use this to check for values that can be expressed precisely as a float,
- // e.g. -0.5.
+ // e.g. -0.5. Also if this literal is not array-shaped false is returned.
bool IsAllFloat(float value) const;
// Like IsAll(const Literal&, int8), except we check whether the literal is
// must be an array.
bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
- private:
- // Copy from a LiteralProto instance.
- void CopyFromProto(const LiteralProto& literal_proto);
+ // Return the count of the elements in the array at the given shape index in
+ // this literal.
+ int64 element_count(const ShapeIndex& index = {}) const {
+ return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ }
+
+ protected:
+ // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+ // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+ // to nullptr.
+ Literal(const Shape& shape, bool allocate_arrays);
- // Internal template helper for the Copy() API, matching its arguments one by
- // one.
- template <typename T>
- Status CopyRange(const Literal& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
+ // Internal template helper for the Literal::CopySliceFrom(), matching its
+ // arguments one by one.
+ template <typename NativeT>
+ Status CopySliceFromInternal(const Literal& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
// Utility structure which is used to create the optimal configuration for
// a ShapeUtil::ForEachIndex() scan across two literals.
int64 minor_loop_size = 1;
};
- Shape shape_;
- std::vector<uint8> u8s_;
- std::vector<int16> s16s_;
- std::vector<int32> s32s_;
- std::vector<int64> s64s_;
- std::vector<uint16> u16s_;
- std::vector<uint32> u32s_;
- std::vector<uint64> u64s_;
- std::vector<bfloat16> bf16s_;
- std::vector<half> f16s_;
- std::vector<float> f32s_;
- std::vector<double> f64s_;
- std::vector<complex64> c64s_;
- std::vector<Literal> tuple_literals_;
-};
-
-std::ostream& operator<<(std::ostream& out, const Literal& literal);
-
-// Declarations of template specializations for GetArraySlice and
-// GetMutableArraySlice. The specializations map native type to XLA primitive
-// type.
-template <>
-tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint8> Literal::GetArraySlice<uint8>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int8> Literal::GetArraySlice<int8>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint16> Literal::GetArraySlice<uint16>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int16> Literal::GetArraySlice<int16>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint32> Literal::GetArraySlice<uint32>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<uint64> Literal::GetArraySlice<uint64>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int32> Literal::GetArraySlice<int32>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<int64> Literal::GetArraySlice<int64>() const;
-
-template <>
-inline tensorflow::gtl::ArraySlice<float> Literal::GetArraySlice<float>()
- const {
- DCHECK(shape().element_type() == F32);
- return f32s();
-}
-
-template <>
-tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
-
-template <>
-tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
- const;
-
-template <>
-tensorflow::gtl::MutableArraySlice<bool> Literal::GetMutableArraySlice();
-
-template <>
-tensorflow::gtl::MutableArraySlice<int8> Literal::GetMutableArraySlice();
-
-template <>
-tensorflow::gtl::MutableArraySlice<uint8> Literal::GetMutableArraySlice();
-
-template <>
-tensorflow::gtl::MutableArraySlice<int16> Literal::GetMutableArraySlice();
+ // A data structure representing a subshape at a particular ShapeIndex within
+ // the literal. For array-shaped ShapeIndexes, this data structure holds the
+ // pointer to the memory allocated for the array data.
+ class Piece {
+ public:
+ // Return the buffer holding the array data for this piece as an array
+ // slice. This piece must be array-shaped.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data() const;
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+ // Return the buffer holding the array data for this piece as a void*. This
+ // piece must be array-shaped.
+ void* untyped_data();
+ const void* untyped_data() const;
+
+ // Gets or sets an element in the array at the given index. The multi_index
+ // is CHECKed against the dimension sizes of the array. This piece must be
+ // array-shaped.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+ // Gets/sets the buffer holding the array data.
+ char* buffer() const { return buffer_; }
+ void set_buffer(char* buffer) { buffer_ = buffer; }
+
+ // Gets or sets the subshape of this piece. This reference points to a
+ // subshape within the shape in the containing Literal (Literal::shape_).
+ const Shape& subshape() const { return *subshape_; }
+ void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+ // Returns the size in bytes of the buffer holding the array data.
+ int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+ // Returns the number of elements in this piece's array.
+ int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); }
+
+ // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+ // and src must be compatible.
+ Status CopyFrom(const Piece& src);
+
+ // Returns true if this piece and 'other' contain the same data. This piece
+ // and 'other' must be array-shaped and compatible.
+ bool EqualElements(const Piece& other) const;
+
+ // Writes the shape and data (if array-shaped) into the given proto.
+ void WriteToProto(LiteralProto* proto) const;
+
+ // Copies the data from the given proto into this piece. The shape of this
+ // piece must be equal (not just compatible) to the shape of the proto.
+ Status CopyFromProto(const LiteralProto& proto);
+
+ private:
+ // Recursive helper for EqualElements.
+ template <typename NativeT>
+ bool EqualElementsInternal(const Piece& other,
+ std::vector<int64>* multi_index) const;
+
+ // For array-shaped pieces, this is the buffer holding the literal data.
+ char* buffer_ = nullptr;
+
+ // The shape of piece. This points into the shape of the containing Literal
+ // (Literal::shape_).
+ const Shape* subshape_ = nullptr;
+ };
-template <>
-tensorflow::gtl::MutableArraySlice<uint16> Literal::GetMutableArraySlice();
+ // Returns the piece at the given ShapeIndex.
+ Piece& piece(const ShapeIndex& shape_index) {
+ return *pieces_.mutable_element(shape_index);
+ }
+ const Piece& piece(const ShapeIndex& shape_index) const {
+ return pieces_.element(shape_index);
+ }
-template <>
-tensorflow::gtl::MutableArraySlice<int32> Literal::GetMutableArraySlice();
+ // Returns the piece at the root of the shape (empty ShapeIndex).
+ Piece& root_piece() { return piece({}); }
+ const Piece& root_piece() const { return piece({}); }
-template <>
-tensorflow::gtl::MutableArraySlice<uint32> Literal::GetMutableArraySlice();
+ // Deallocate the buffers held by this literal (if the literal owns the
+ // buffer).
+ void DeallocateBuffers();
-template <>
-tensorflow::gtl::MutableArraySlice<int64> Literal::GetMutableArraySlice();
+ Shape shape_;
+ ShapeTree<Piece> pieces_;
-template <>
-tensorflow::gtl::MutableArraySlice<uint64> Literal::GetMutableArraySlice();
+ // Whether the buffers held in pieces_ are owned by this Literal.
+ bool owns_buffers_;
-template <>
-tensorflow::gtl::MutableArraySlice<float> Literal::GetMutableArraySlice();
+ // LiteralView must access and manipulate Pieces of other Literals.
+ friend class LiteralView;
+}; // namespace xla
-template <>
-tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
-template <>
-tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
+// A read-only view of a Literal. A LiteralView contains pointers to buffers
+// owned by the viewed Literal.
+//
+// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable
+// and mutable) similar to (Mutable)ArraySlice.
+class LiteralView : public Literal {
+ public:
+ // Create and return a view of the given literal rooted at the given shape
+ // index within the given literal. A factory is used rather than a public
+ // constructor because only const LiteralViews are supported. It's still
+ // possible to create non-const LiteralViews via the copy constructors, but
+ // the factory method makes it a bit less likely. Implementing literal slices
+ // will fix this undesirable situation (b/71550060).
+ static const LiteralView Create(const Literal& literal,
+ const ShapeIndex& view_root = {});
-template <>
-tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
+ LiteralView(const LiteralView& other);
+ LiteralView& operator=(const LiteralView& other);
-template <>
-tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
+ virtual ~LiteralView();
-template <>
-void Literal::Resize<bool>(int64 num_elements, bool value);
+ private:
+ LiteralView(const Literal& literal, const ShapeIndex& view_root);
-template <>
-void Literal::Resize<int8>(int64 num_elements, int8 value);
+ // Helper for the copy constructor and copy assignment operator.
+ void CopyFrom(const LiteralView& other);
+};
-template <>
-void Literal::Resize<uint8>(int64 num_elements, uint8 value);
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> Literal::Piece::data() const {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ CHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
+ << "Attempting to access "
+ << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+ << " type, but literal element type is "
+ << PrimitiveType_Name(subshape().element_type());
+ return tensorflow::gtl::ArraySlice<NativeT>(
+ reinterpret_cast<const NativeT*>(buffer()),
+ ShapeUtil::ElementsIn(subshape()));
+}
-template <>
-void Literal::Resize<int32>(int64 num_elements, int32 value);
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> Literal::Piece::data() {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ CHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
+ << "Attempting to access "
+ << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+ << " type, but literal element type is "
+ << PrimitiveType_Name(subshape().element_type());
+ return tensorflow::gtl::MutableArraySlice<NativeT>(
+ reinterpret_cast<NativeT*>(buffer()), ShapeUtil::ElementsIn(subshape()));
+}
-template <>
-void Literal::Resize<uint32>(int64 num_elements, uint32 value);
+template <typename NativeT>
+NativeT Literal::Piece::Get(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+ subshape(), multi_index)];
+}
-template <>
-void Literal::Resize<int64>(int64 num_elements, int64 value);
+template <typename NativeT>
+void Literal::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+ subshape(), multi_index)] = value;
+}
-template <>
-void Literal::Resize<uint64>(int64 num_elements, uint64 value);
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> Literal::data(
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).data<NativeT>();
+}
-template <>
-void Literal::Resize<float>(int64 num_elements, float value);
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
+ const ShapeIndex& shape_index) {
+ return piece(shape_index).data<NativeT>();
+}
-template <>
-void Literal::Resize<double>(int64 num_elements, double value);
+template <typename NativeT>
+inline NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).Get<NativeT>(multi_index);
+}
-template <>
-void Literal::Resize<half>(int64 num_elements, half value);
+template <typename NativeT>
+inline NativeT Literal::Get(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ return root_piece().Get<NativeT>(multi_index);
+}
-template <>
-void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value) {
+ return piece(shape_index).Set<NativeT>(multi_index, value);
+}
-template <>
-void Literal::Resize<complex64>(int64 num_elements, complex64 value);
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ return root_piece().Set<NativeT>(multi_index, value);
+}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
- auto literal = MakeUnique<Literal>();
- literal->PopulateR0<NativeT>(value);
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {}));
+ literal->Set({}, value);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
- auto literal = MakeUnique<Literal>();
+ auto literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
+ {static_cast<int64>(values.size())}));
literal->PopulateR1(values);
return literal;
}
/* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
- auto literal = MakeUnique<Literal>();
- literal->PopulateR2WithLayout(values, layout);
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(),
+ {static_cast<int64>(values.size()),
+ static_cast<int64>(values.begin()->size())},
+ AsInt64Slice(layout.minor_to_major())));
+ literal->PopulateR2(values);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
- auto literal = MakeUnique<Literal>();
- literal->PopulateFromArrayWithLayout(values, layout);
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
+ AsInt64Slice(layout.minor_to_major())));
+ literal->PopulateFromArray(values);
return literal;
}
return CreateFromArrayWithLayout(values, layout);
}
-template <typename NativeT>
-NativeT Literal::Get(tensorflow::gtl::ArraySlice<int64> multi_index) const {
- int64 linear_index = LinearIndex(multi_index);
- return GetArraySlice<NativeT>().at(linear_index);
-}
-
template <typename NativeT>
NativeT Literal::GetFirstElement() const {
- return GetArraySlice<NativeT>().at(0);
-}
-
-template <>
-inline uint8 Literal::Get<uint8>(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(shape().element_type() == U8);
- int64 linear_index = LinearIndex(multi_index);
- return u8s()[linear_index];
-}
-
-template <>
-inline int8 Literal::Get<int8>(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(shape().element_type() == S8);
- int64 linear_index = LinearIndex(multi_index);
- return u8s()[linear_index];
-}
-
-template <>
-inline half Literal::Get<half>(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(shape().element_type() == F16);
- int64 linear_index = LinearIndex(multi_index);
- return GetArraySlice<half>()[linear_index];
-}
-
-template <>
-inline bfloat16 Literal::Get<bfloat16>(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(shape().element_type() == BF16);
- int64 linear_index = LinearIndex(multi_index);
- return GetArraySlice<bfloat16>()[linear_index];
-}
-
-template <typename NativeT>
-void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
- int64 linear_index = LinearIndex(multi_index);
- GetMutableArraySlice<NativeT>().at(linear_index) = value;
-}
-
-template <>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- uint8 value) {
- int64 linear_index = LinearIndex(multi_index);
- (*mutable_u8s())[linear_index] = value;
-}
-
-template <>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- int8 value) {
- return Set<uint8>(multi_index, value);
-}
-
-template <>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value) {
- int64 linear_index = LinearIndex(multi_index);
- (*mutable_s64s())[linear_index] = value;
-}
-
-template <>
-/* static */ inline void Literal::Set(
- tensorflow::gtl::ArraySlice<int64> multi_index, uint64 value) {
- int64 linear_index = LinearIndex(multi_index);
- (*mutable_u64s())[linear_index] = value;
+ return data<NativeT>().at(0);
}
// Returns an identity matrix (rank 2) with the given row and column count.
} while (IndexUtil::BumpIndices(shape(), &indices));
}
-template <typename NativeT>
-inline void Literal::PopulateR0(NativeT value) {
- *mutable_shape() = ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), {});
- Resize<NativeT>(1, value);
-}
-
template <typename NativeT>
inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- *mutable_shape() =
- ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
- {static_cast<int64>(values.size())});
- Reserve(values.size());
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
for (int64 i = 0; i < values.size(); ++i) {
Set({i}, values[i]);
}
}
-inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
- *mutable_shape() =
- ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())});
- Reserve(values.bits());
- for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
- Set({i}, values.get(i));
- }
-}
-
template <typename NativeT>
-void Literal::PopulateR2WithLayout(
- std::initializer_list<std::initializer_list<NativeT>> values,
- const Layout& layout) {
- *mutable_shape() = ShapeUtil::MakeShapeWithLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(),
- {static_cast<int64>(values.size()),
- static_cast<int64>(values.begin()->size())},
- LayoutUtil::MinorToMajor(layout));
+void Literal::PopulateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 2);
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
const int64 dim0_size = values.size();
const int64 dim1_size = values.begin()->size();
CHECK_EQ(dim0_size, shape().dimensions(0));
CHECK_EQ(dim1_size, shape().dimensions(1));
- const int64 num_elements = dim1_size * dim0_size;
- Reserve(num_elements);
-
int64 dim0 = 0;
for (auto inner_list : values) {
int64 dim1 = 0;
}
template <typename NativeT>
-void Literal::PopulateR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
-}
-
-template <typename NativeT>
-void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
- const Layout& layout) {
- CHECK_EQ(layout.format(), DENSE);
- *mutable_shape() = ShapeUtil::MakeShapeWithLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
- LayoutUtil::MinorToMajor(layout));
- Reserve(values.num_elements());
+void Literal::PopulateFromArray(const Array<NativeT>& values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
+ for (int dim = 0; dim < values.num_dimensions(); ++dim) {
+ CHECK_EQ(values.dim(dim), shape().dimensions(dim));
+ }
values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
NativeT value) { this->Set(indices, value); });
}
-template <typename NativeT>
-void Literal::PopulateFromArray(const Array<NativeT>& values) {
- PopulateFromArrayWithLayout(
- values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
-}
-
-template <typename NativeT>
-void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout) {
- PopulateFromArrayWithLayout(values, layout);
-}
-
template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateFromArray(values);
}
-template <typename NativeT>
-void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout) {
- PopulateFromArrayWithLayout(values, layout);
-}
-
template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateFromArray(values);
}
-template <typename NativeT>
-void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout) {
- PopulateFromArrayWithLayout(values, layout);
-}
-
template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateFromArray(values);
Status Literal::Populate(const FnType& generator) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
+ TF_RET_CHECK(ShapeUtil::IsArray(this_shape));
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
- tensorflow::gtl::MutableArraySlice<NativeT> data =
- GetMutableArraySlice<NativeT>();
+ tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
if (rank > 0) {
StrideConfig stride_config(this_shape, this_shape,
AsInt64Slice(this_shape.dimensions()));
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](const std::vector<int64>& indexes) {
- const int64 index = LinearIndex(indexes);
+ const int64 index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
for (int64 i = 0; i < minor_dimension_size; ++i) {
minor_scan_indexes[stride_config.minor_dimension] = i;
- data.at(index + i) = generator(minor_scan_indexes);
+ literal_data.at(index + i) = generator(minor_scan_indexes);
}
return true;
};
init_function);
} else {
// For scalars.
- data.at(0) = generator({});
+ literal_data.at(0) = generator({});
}
return Status::OK();
}
template <typename NativeT>
-void Literal::PopulateWithValue(NativeT value,
- tensorflow::gtl::ArraySlice<int64> dimensions) {
- *mutable_shape() = ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
- Resize<NativeT>(ShapeUtil::ElementsIn(shape()), value);
+void Literal::PopulateWithValue(NativeT value) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ for (NativeT& element : data<NativeT>()) {
+ element = value;
+ }
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
- Shape this_shape = ShapeUtil::MakeShapeWithDescendingLayout(
- primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
- auto literal = MakeUnique<Literal>();
- *literal->mutable_shape() = this_shape;
- literal->Reserve(ShapeUtil::ElementsIn(this_shape));
- std::vector<int64> index(dimensions.size(), 0);
- do {
- literal->Set(index, value);
- } while (IndexUtil::BumpIndices(this_shape, &index));
+ auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+ literal->PopulateWithValue(value);
return literal;
}
for (int64 bound : shape().dimensions()) {
bounds.push_back(bound);
}
- auto literal = MakeUnique<Literal>();
- *literal->mutable_shape() =
- ShapeUtil::MakeShape(shape().element_type(), bounds);
+ auto literal =
+ MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
int64 elements = ShapeUtil::ElementsIn(literal->shape());
if (elements == 0) {
return literal;
}
- literal->Reserve(elements);
DimensionVector output_indices(bounds.size(), 0);
tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
namespace {
using ::testing::ElementsAre;
+using ::testing::HasSubstr;
class LiteralUtilTest : public ::testing::Test {
protected:
auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
auto scalar = Literal::CreateR0<float>(1.0);
+ Literal nil(ShapeUtil::MakeNil());
EXPECT_EQ(*matrix, *matrix);
EXPECT_EQ(*matrix, *matrix_clone);
EXPECT_NE(*matrix, *matrix_different);
EXPECT_NE(*matrix, *vector_literal);
EXPECT_NE(*matrix, *scalar);
+ EXPECT_NE(*matrix, nil);
+ EXPECT_EQ(nil, nil);
}
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
- auto colmajor = MakeUnique<Literal>();
- *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
- *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
- colmajor->Reserve(4);
+ auto colmajor =
+ MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
colmajor->Set<float>({0, 0}, 1.0);
colmajor->Set<float>({0, 1}, 2.0);
colmajor->Set<float>({1, 0}, 3.0);
colmajor->Set<float>({1, 1}, 4.0);
- auto rowmajor = MakeUnique<Literal>();
- *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
- *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
- rowmajor->Reserve(4);
+ auto rowmajor =
+ MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
rowmajor->Set<float>({0, 0}, 1.0);
rowmajor->Set<float>({0, 1}, 2.0);
rowmajor->Set<float>({1, 0}, 3.0);
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
- auto mat_dim0minor = Literal::CreateR2WithLayout<int>({{1, 2, 3}, {4, 5, 6}},
- layout_r2_dim0minor_);
- EXPECT_EQ(mat_dim0minor->s32s_size(), 6);
- EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6));
+ auto mat_dim0minor = Literal::CreateR2WithLayout<int32>(
+ {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
+ EXPECT_EQ(mat_dim0minor->element_count(), 6);
+ EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
// Test expected memory layout when using Relayout to row major.
auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
- EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6));
+ EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+ ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
- auto mat_dim0major = Literal::CreateR2WithLayout<int>({{1, 2, 3}, {4, 5, 6}},
- layout_r2_dim0major_);
- EXPECT_EQ(mat_dim0major->s32s_size(), 6);
- EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6));
+ auto mat_dim0major = Literal::CreateR2WithLayout<int32>(
+ {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
+ EXPECT_EQ(mat_dim0major->element_count(), 6);
+ EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout when using Relayout to column major.
auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
- EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6));
+ EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+ ElementsAre(1, 4, 2, 5, 3, 6));
}
TEST_F(LiteralUtilTest, TestR3LinearLayout) {
auto lit_dim0minor =
Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0minor_);
- EXPECT_EQ(lit_dim0minor->s32s_size(), 12);
+ EXPECT_EQ(lit_dim0minor->element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
- EXPECT_THAT(lit_dim0minor->s32s(),
+ EXPECT_THAT(lit_dim0minor->data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
// Test expected memory layout when using Relayout to row major.
auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
- EXPECT_THAT(relaid_lit_to_dim0major->s32s(),
+ EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
auto lit_dim0major =
Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0major_);
- EXPECT_EQ(lit_dim0major->s32s_size(), 12);
- EXPECT_THAT(lit_dim0major->s32s(),
+ EXPECT_EQ(lit_dim0major->element_count(), 12);
+ EXPECT_THAT(lit_dim0major->data<int32>(),
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout when using Relayout to column major.
auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
- EXPECT_THAT(relaid_lit_to_dim0minor->s32s(),
+ EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
testing::ElementsAreArray(expected_dim0minor));
}
}
TEST_F(LiteralUtilTest, PopulateR1S64) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
auto expected = Literal::CreateR1<int64>({77});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
auto expected = Literal::CreateR1<uint64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
auto expected = Literal::CreateR1<complex64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR2C64) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
}
TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(BF16, {}));
bfloat16 h(0.25f);
- output.PopulateWithValue<bfloat16>(h, {});
+ output.PopulateWithValue<bfloat16>(h);
auto expected = Literal::CreateR0<bfloat16>(h);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(BF16, {3}));
bfloat16 h(0.5f);
- output.PopulateWithValue<bfloat16>(h, {3});
+ output.PopulateWithValue<bfloat16>(h);
auto expected = Literal::CreateR1<bfloat16>({h, h, h});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
bfloat16 h(2.0f);
- output.PopulateWithValue<bfloat16>(h, {2, 2});
+ output.PopulateWithValue<bfloat16>(h);
auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
- Literal output;
- output.PopulateWithValue<float>(2.5f, {});
+ Literal output(ShapeUtil::MakeShape(F32, {}));
+ output.PopulateWithValue<float>(2.5f);
auto expected = Literal::CreateR0<float>(2.5f);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
- Literal output;
- output.PopulateWithValue<int64>(-7, {3});
+ Literal output(ShapeUtil::MakeShape(S64, {3}));
+ output.PopulateWithValue<int64>(-7);
auto expected = Literal::CreateR1<int64>({-7, -7, -7});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
- Literal output;
- output.PopulateWithValue<uint64>(42, {2, 2});
+ Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
+ output.PopulateWithValue<uint64>(42);
auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
- Literal output;
- output.PopulateWithValue<complex64>({4, 2}, {2, 2});
+ Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
+ output.PopulateWithValue<complex64>({4, 2});
auto expected =
Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(F16, {}));
half h(0.25f);
- output.PopulateWithValue<half>(h, {});
+ output.PopulateWithValue<half>(h);
auto expected = Literal::CreateR0<half>(h);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(F16, {3}));
half h(0.5f);
- output.PopulateWithValue<half>(h, {3});
+ output.PopulateWithValue<half>(h);
auto expected = Literal::CreateR1<half>({h, h, h});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
- Literal output;
+ Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
half h(2.0f);
- output.PopulateWithValue<half>(h, {2, 2});
+ output.PopulateWithValue<half>(h);
auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
EXPECT_EQ(*output, *expected);
}
-TEST_F(LiteralUtilTest, Copy) {
+TEST_F(LiteralUtilTest, CopySliceFrom) {
const int64 dimensions[] = {17, 15, 34, 21};
const int64 layouts[][4] = {
{3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
const int64 src_base[] = {3, 1, 5, 7};
const int64 dest_base[] = {6, 4, 12, 2};
const int64 copy_size[] = {7, 8, 11, 9};
- TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size));
+ TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
}
}
-TEST_F(LiteralUtilTest, CopyScalars) {
+TEST_F(LiteralUtilTest, CopyFromScalars) {
auto zero = Literal::CreateR0<uint32>(0);
auto nine = Literal::CreateR0<uint32>(9);
- TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {}));
+ TF_EXPECT_OK(zero->CopyFrom(*nine));
EXPECT_EQ(*zero, *nine);
auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
- TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {}));
+ TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
EXPECT_EQ(zero->Get<uint32>({}), 17);
- TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {}));
+ TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
EXPECT_EQ(vect->Get<uint32>({4}), 17);
}
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = Literal::CreateR1<float>({9});
- TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0}));
+ TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
EXPECT_EQ(*nine, *const_nine);
}
const auto empty = Literal::CreateFromShape(empty_r1_shape);
auto nine = Literal::CreateR1<float>({9});
- TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0}));
+ TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
EXPECT_EQ(*empty, *const_empty);
}
}
+TEST_F(LiteralUtilTest, CopyFromNilShape) {
+ Literal nil_literal0(ShapeUtil::MakeNil());
+ Literal nil_literal1(ShapeUtil::MakeNil());
+ // This doesn't actually do any copying, but it should succeed.
+ TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
+}
+
+TEST_F(LiteralUtilTest, CopyFromArrays) {
+ auto scalar_42 = Literal::CreateR0<float>(42.0);
+ auto scalar_123 = Literal::CreateR0<float>(123.0);
+ EXPECT_NE(*scalar_42, *scalar_123);
+ TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(*scalar_42, *scalar_123);
+ EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+
+ auto matrix_1234 = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_5678 = Literal::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
+ EXPECT_NE(*matrix_1234, *matrix_5678);
+ EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
+ TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
+ /*src_shape_index=*/{}));
+ EXPECT_EQ(*matrix_1234, *matrix_5678);
+ EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+}
+
+TEST_F(LiteralUtilTest, CopyFromTuples) {
+ auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal nil_literal(ShapeUtil::MakeNil());
+ auto nested_tuple = Literal::MakeTuple(
+ {matrix.get(),
+ Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
+ Literal::CreateR1<double>({23.0, 44.0}).get(),
+ &nil_literal})
+ .get()});
+ // Create a tuple the same shape as the inner tuple of nested_tuple but with
+ // different values..
+ auto tuple = Literal::MakeTuple({Literal::CreateR0<int32>(-5).get(),
+ Literal::CreateR1<double>({2.0, 4.0}).get(),
+ &nil_literal});
+
+ EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+ EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
+ EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
+ EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+
+ // Overwrite the inner tuple element of nested_tuple with the contents of
+ // 'tuple'.
+ TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{}));
+
+ // The matrix element should be unchanged.
+ EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0}));
+
+ // The tuple element should have been copied from 'tuple'.
+ EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
+ EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
+ EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+}
+TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
+ auto tuple = Literal::MakeTuple(
+ {Literal::CreateR0<int32>(-2).get(), Literal::CreateR0<int32>(4).get()});
+
+ EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+
+ // Copy from one element to the other.
+ TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
+ /*src_shape_index=*/{0}));
+
+ EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
+ EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+}
+
+TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
+ auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector = Literal::CreateR1<float>({5.0, 7.0});
+ Status status = matrix->CopyFrom(*vector);
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(),
+ HasSubstr("Destination subshape incompatible"));
+}
+
TEST_F(LiteralUtilTest, F16) {
// Verify that the internal data views are consistent and that they
// are in little endian format
// TODO - modify if we make the data format machine endianess dependent
auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
Literal* l1 = m1.get();
- const char* d1 = static_cast<const char*>(l1->InternalData());
+ const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
EXPECT_EQ(d1[0], 0);
EXPECT_EQ(d1[1], 0);
EXPECT_EQ(d1[2], 0);
EXPECT_EQ(d1[5], 0);
EXPECT_EQ(d1[6], 0);
EXPECT_EQ(d1[7], 0);
- EXPECT_EQ(l1->InternalData(), l1->MutableInternalData());
half h1(1.0f);
half h2(2.0f);
auto m2 = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
Literal* l2 = m2.get();
- const char* d2 = static_cast<const char*>(l2->InternalData());
+ const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
EXPECT_EQ(d2[0], 0);
EXPECT_EQ(d2[1], 0x3C);
EXPECT_EQ(d2[2], 0);
EXPECT_EQ(d2[5], 0x40);
EXPECT_EQ(d2[6], 0);
EXPECT_EQ(d2[7], 0x3C);
- EXPECT_EQ(l2->InternalData(), l2->MutableInternalData());
}
TEST_F(LiteralUtilTest, Populate) {
auto generator = [&](tensorflow::gtl::ArraySlice<int64> indexes) -> uint32 {
// Offsets from linear index just to avoid R0 literals to be initialized
// with zero.
- return literal->LinearIndex(indexes) + 17;
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ indexes) +
+ 17;
};
TF_EXPECT_OK(literal->Populate<uint32>(generator));
for (int len = 0; len < 25; ++len) {
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(len);
+ LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_preds();
for (int i = 0; i < len; ++i) {
p.add_preds((i % 2) == (len % 2));
}
- Literal literal(p);
- ASSERT_EQ(len, literal.preds_size());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
+ Literal::CreateFromProto(p));
+ ASSERT_EQ(len, literal->data<bool>().size());
int i = 0;
- for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) {
- EXPECT_EQ((i % 2) == (len % 2), *it);
+ for (bool value : literal->data<bool>()) {
+ EXPECT_EQ((i % 2) == (len % 2), value);
++i;
}
}
auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
Literal* l = m.get();
EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->f16s().size());
- EXPECT_EQ(4, l->f16s_size());
+ EXPECT_EQ(4, l->data<half>().size());
LiteralProto p = l->ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
p.mutable_shape()->set_element_type(F16);
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(4);
+ LayoutUtil::SetToDefaultLayout(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
-
- Literal literal(p);
- ASSERT_EQ(4, literal.f16s_size());
- ASSERT_EQ(h1, literal.f16s(0));
- ASSERT_EQ(h2, literal.f16s(1));
- ASSERT_EQ(h2, literal.f16s(2));
- ASSERT_EQ(h1, literal.f16s(3));
-
- const std::vector<half>& r = literal.f16s();
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
+ Literal::CreateFromProto(p));
+ auto r = literal->data<half>();
ASSERT_EQ(4, r.size());
ASSERT_EQ(h1, r[0]);
ASSERT_EQ(h2, r[1]);
ASSERT_EQ(h1, r[3]);
}
-TEST_F(LiteralUtilTest, Subliterals) {
+TEST_F(LiteralUtilTest, LiteralViewTest) {
+ auto scalar = Literal::CreateR0<float>(1.0);
+ auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ Literal nil(ShapeUtil::MakeNil());
+
+ EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar);
+ EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix);
+ EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple);
+ EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple);
+ EXPECT_EQ(LiteralView::Create(nil, {}), nil);
+
+ EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar);
+ EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix);
+
+ EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple);
+ EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar);
+ EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix);
+ EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar);
+}
+
+TEST_F(LiteralUtilTest, MutatingLiteralView) {
auto scalar = Literal::CreateR0<float>(1.0);
auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ // Verify that changing the underlying data beneath the view changes the
+ // data of the view itself.
+ const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+ EXPECT_EQ(
+ nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 1.0f);
+ EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
+ /*shape_index=*/{0, 0}),
+ 1.0f);
+ nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+ EXPECT_EQ(
+ nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+ 555.0f);
+ EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
+ /*shape_index=*/{0, 0}),
+ 555.0f);
+}
+
+TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) {
+ auto scalar = Literal::CreateR0<float>(1.0);
+ auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+
+ const auto nested_tuple_view = LiteralView::Create(*nested_tuple);
+ const auto tuple_view =
+ LiteralView::Create(nested_tuple_view, /*view_root=*/{0});
+ const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1});
+ EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+}
+
+TEST_F(LiteralUtilTest, LiteralMove) {
+ std::unique_ptr<Literal> matrix =
+ Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ Literal literal(std::move(*matrix));
+
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
+ EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
+ EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
+ EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
+ EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
+}
- EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get());
- EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get());
- EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get());
- EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get());
+TEST_F(LiteralUtilTest, DecomposeTuple) {
+ Literal nil_literal(ShapeUtil::MakeNil());
+ auto nested_tuple = Literal::MakeTuple(
+ {Literal::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
+ Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
+ Literal::CreateR1<double>({23.0, 44.0}).get(),
+ &nil_literal})
+ .get(),
+ &nil_literal});
+
+ EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
+ std::vector<Literal> elements = nested_tuple->DecomposeTuple();
+ EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+
+ ASSERT_EQ(elements.size(), 3);
+
+ EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
+ ShapeUtil::MakeShape(S32, {2, 2})));
+ EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
+ EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
+ EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
+ EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
+
+ EXPECT_TRUE(ShapeUtil::Compatible(
+ elements[1].shape(),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F64, {2}),
+ ShapeUtil::MakeNil()})));
+ EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
+ EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
+ EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
+
+ EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
+}
+
+TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
+ Literal nil_literal(ShapeUtil::MakeNil());
+ std::vector<Literal> elements = nil_literal.DecomposeTuple();
+ EXPECT_EQ(elements.size(), 0);
+}
+
+TEST_F(LiteralUtilTest, MoveIntoTuple) {
+ std::vector<Literal> elements;
+ elements.push_back(std::move(*Literal::CreateR0<float>(1.0)));
+ elements.push_back(std::move(*Literal::CreateR1<int32>({4, 8})));
+ elements.push_back(std::move(
+ *Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
+ Literal::CreateR1<double>({23.0, 44.0}).get()})
+
+ ));
+
+ Literal literal = Literal::MoveIntoTuple(&elements);
+ ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
+ ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
+
+ EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
+ EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
+ EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
+ EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
+ EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
+ EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
+
+ for (const Literal& element : elements) {
+ EXPECT_TRUE(ShapeUtil::IsNil(element.shape()));
+ }
+}
+
+TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
+ Literal literal = Literal::MoveIntoTuple({});
+ ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
+ ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
+}
+
+TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
+ Literal literal;
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
+
+ std::unique_ptr<Literal> matrix =
+ Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ literal = std::move(*matrix);
+
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
+ EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
+ EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
+ EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
+ EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
+}
+
+TEST_F(LiteralUtilTest, LiteralViewCopy) {
+ std::unique_ptr<Literal> matrix =
+ Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ const auto matrix_view = LiteralView::Create(*matrix);
+ LiteralView matrix_view_copy(matrix_view);
+
+ EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
+ EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
+ EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
+ EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
+}
+
+TEST_F(LiteralUtilTest, GetSetTuple) {
+ auto tuple = Literal::MakeTuple(
+ {Literal::CreateR0<float>(42.0).get(),
+ Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
+ EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+ tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+ EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+
+ EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ 3.0);
+ tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+ EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+ -4.0);
+}
+
+TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
+ // Literals constructed using CreateFromShape should be zero initialized.
+ std::unique_ptr<Literal> scalar_f32 =
+ Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+ EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
+ EXPECT_TRUE(scalar_f32->IsAll(0));
+
+ std::unique_ptr<Literal> vector_s32 =
+ Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+ EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
+ EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
+ EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
+ EXPECT_TRUE(vector_s32->IsAll(0));
+
+ std::unique_ptr<Literal> tuple =
+ Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+ ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+
+ EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
+ EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
+ EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
+ EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
+ EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
+ EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+}
+
+TEST_F(LiteralUtilTest, ProtoRoundTrip) {
+ // Test serializing then deserializing a Literal through a proto.
+ auto one_f32 = Literal::CreateR0<float>(1.0);
+ auto two_f32 = Literal::CreateR0<float>(2.0);
+ auto vector_int8 = Literal::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_c64 = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector_bfloat16 = Literal::CreateR1<bfloat16>(
+ {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
+ auto vector_half =
+ Literal::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
+ auto matrix_pred =
+ Literal::CreateR2<bool>({{true, false, true}, {false, false, true}});
+ auto tuple = Literal::MakeTuple(
+ {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+ Literal nil_literal(ShapeUtil::MakeNil());
+ auto nested_tuple = Literal::MakeTuple(
+ {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+
+ auto to_from_proto = [](const Literal& literal) -> Literal {
+ return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+ };
+
+ EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
+ EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
+ EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
+ EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
+ EXPECT_EQ(*tuple, to_from_proto(*tuple));
+ EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+ EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
+
+ EXPECT_NE(*one_f32, *two_f32);
+ EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
+ // Proto contains a shape, but no values.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(),
+ HasSubstr("Expected 3 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
+ // Proto contains values, but no shape.
+ LiteralProto proto;
+ proto.add_preds(false);
+ proto.add_preds(true);
+ proto.add_preds(false);
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
+}
- EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar);
- EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix);
+TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
+ // Proto contains values in wrong container.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
+ proto.add_preds(false);
+ proto.add_preds(true);
+ proto.add_preds(false);
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(),
+ HasSubstr("Expected 3 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
+ // Proto contains too few values.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2});
+ proto.add_f32s(1.0);
+ proto.add_f32s(2.0);
+ proto.add_f32s(3.0);
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(),
+ HasSubstr("Expected 84 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
+ // Proto contains too many values.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2});
+ proto.add_s32s(42);
+ proto.add_s32s(-10);
+ proto.add_s32s(100);
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(),
+ HasSubstr("Expected 2 elements in LiteralProto"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
+ // Proto shape missing layout.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2});
+ LayoutUtil::ClearLayout(proto.mutable_shape());
+ proto.add_preds(true);
+ proto.add_preds(false);
+ proto.add_preds(true);
+ proto.add_preds(false);
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
+}
+
+TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
+ // Proto has the too few tuple elements.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
+ LiteralProto* element0 = proto.add_tuple_literals();
+ *element0->mutable_shape() =
+ ShapeUtil::GetTupleElementShape(proto.shape(), 0);
+ element0->add_preds(false);
+ element0->add_preds(true);
+
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
+}
- EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple);
- EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar);
- EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix);
- EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar);
+TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
+ // Proto has the too many tuple elements.
+ LiteralProto proto;
+ *proto.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
+ LiteralProto* element0 = proto.add_tuple_literals();
+ *element0->mutable_shape() =
+ ShapeUtil::GetTupleElementShape(proto.shape(), 0);
+ element0->add_preds(false);
+ element0->add_preds(true);
+ LiteralProto* element1 = proto.add_tuple_literals();
+ *element1->mutable_shape() =
+ ShapeUtil::GetTupleElementShape(proto.shape(), 1);
+ element1->add_f32s(42.0);
+ LiteralProto* element2 = proto.add_tuple_literals();
+ *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {});
+ element2->add_f32s(123.0);
+
+ Status status = Literal::CreateFromProto(proto).status();
+ ASSERT_FALSE(status.ok());
+ ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
}
} // namespace
VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
<< " layout: "
<< (layout == nullptr ? "<none>" : layout->ShortDebugString());
- auto result = MakeUnique<Literal>();
- *result->mutable_shape() = shape;
+ Shape literal_shape = shape;
if (layout != nullptr) {
- TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*layout, shape));
- *result->mutable_shape()->mutable_layout() = *layout;
+ TF_RETURN_IF_ERROR(
+ LayoutUtil::ValidateLayoutForShape(*layout, literal_shape));
+ *literal_shape.mutable_layout() = *layout;
}
if (shape.element_type() != F32) {
PrimitiveType_Name(shape.element_type()).c_str());
}
+ auto result = MakeUnique<Literal>(literal_shape);
+ result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+
int64 elements = ShapeUtil::ElementsIn(shape);
- result->Resize(elements, std::numeric_limits<float>::quiet_NaN());
- std::vector<float>* field = result->mutable_f32s();
- char* data = tensorflow::bit_cast<char*>(field->data());
+ tensorflow::gtl::ArraySlice<float> field = result->data<float>();
+ char* data = tensorflow::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
tensorflow::StringPiece sp;
auto s = file_->Read(offset_, bytes, &sp, data);
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
- temps.push_back(*numpy::XlaLiteralFromPyObject(o));
+ temps.push_back(std::move(*numpy::XlaLiteralFromPyObject(o)));
Py_DECREF(o);
}
$1 = &temps;
PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
if (ShapeUtil::IsTuple(literal.shape())) {
- const std::vector<Literal>& tuple_literals = literal.tuple_literals();
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
PyObject* tuple = PyTuple_New(num_elements);
for (int i = 0; i < num_elements; i++) {
- PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i]));
+ PyTuple_SET_ITEM(
+ tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i})));
}
return tuple;
} else {
template <typename NativeT>
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
- auto dest = literal->GetMutableArraySlice<NativeT>();
+ auto dest = literal->data<NativeT>();
std::copy(source, source + PyArray_SIZE(py_array), dest.data());
}
template <typename NativeT>
void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
- auto source = literal.GetArraySlice<NativeT>();
+ auto source = literal.data<NativeT>();
std::copy(source.begin(), source.end(), dest);
}
if (ShapeUtil::IsTuple(literal.shape())) {
std::vector<HloInstruction*> elems;
elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
- for (const Literal& child : literal.tuple_literals()) {
- elems.push_back(BuildTupleConstant(computation, child));
+ for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
+ elems.push_back(
+ BuildTupleConstant(computation, LiteralView::Create(literal, {i})));
}
return computation->AddInstruction(HloInstruction::CreateTuple(elems));
} else {
return computation->AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(literal)));
+ HloInstruction::CreateConstant(literal.CloneToUnique()));
}
}
if (instruction->opcode() == HloOpcode::kConstant) {
// Copy the constant out of the ProtocolBuffer so that we can give it a
// higher alignment.
- const void* data = instruction->literal().InternalData();
+ const void* data = instruction->literal().untyped_data();
int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape());
auto iter = aligned_constants.emplace(
instruction, xla::MakeUnique<unsigned char[]>(size));
if (!ShapeUtil::IsTuple(shape)) {
int64 size = GetByteSizeRequirement(shape);
- return TransferBufferToInfeed(executor, size, literal.InternalData());
+ return TransferBufferToInfeed(executor, size, literal.untyped_data());
}
if (ShapeUtil::IsNestedTuple(shape)) {
// enqueue the resulting destination device addresses with the
// infeed manager.
std::vector<cpu::runtime::XfeedBuffer*> buffers;
- buffers.reserve(literal.tuple_literals_size());
+ buffers.reserve(ShapeUtil::TupleElementCount(shape));
auto cleanup = tensorflow::gtl::MakeCleanup([&buffers]() {
for (cpu::runtime::XfeedBuffer* b : buffers) {
b->Done(Cancelled("Failed to infeed buffer to device."));
}
});
- for (const auto& tuple_element : literal.tuple_literals()) {
- const Shape& tuple_element_shape = tuple_element.shape();
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& tuple_element_shape = ShapeUtil::GetSubshape(shape, {i});
int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape);
TF_ASSIGN_OR_RETURN(
cpu::runtime::XfeedBuffer * buffer,
TransferBufferToInfeedInternal(executor, tuple_element_size,
- tuple_element.InternalData()));
+ literal.untyped_data({i})));
buffers.push_back(buffer);
}
literal_shape.element_type(), dimensions));
TF_ASSIGN_OR_RETURN(Shape received_shape,
TransferArrayBufferFromOutfeed(
- executor, literal->MutableInternalData(), size));
+ executor, literal->untyped_data(), size));
TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape()))
<< "Shape received from outfeed "
<< ShapeUtil::HumanString(received_shape)
<< " did not match the shape that was requested for outfeed: "
<< ShapeUtil::HumanString(literal_shape);
TF_RET_CHECK(size == GetByteSizeRequirement(received_shape));
- *literal->mutable_shape() = received_shape;
+ *literal->mutable_shape_do_not_use() = received_shape;
return Status::OK();
}
auto empty = Literal::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
int64 size = GetByteSizeRequirement(tuple_element_shape);
- buffer_data.push_back({empty->MutableInternalData(), size});
+ buffer_data.push_back({empty->untyped_data(), size});
elements.push_back(std::move(empty));
}
GetByteSizeRequirement(received_shape));
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
- *elements[i]->mutable_shape() = received_shape.tuple_shapes(i);
+ *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
}
*literal = std::move(*Literal::MakeTupleOwned(std::move(elements)));
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
<< " bytes with alignment of " << alignment;
- std::memcpy(raw_pointer, literal.InternalData(), literal_size);
+ std::memcpy(raw_pointer, literal.untyped_data(), literal_size);
entries_.emplace(std::move(name), static_cast<uint8*>(raw_pointer));
}
/*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape),
/*destination=*/
- literal->GetSubliteral(index).MutableInternalData()));
+ literal->untyped_data(index)));
}
return Status::OK();
TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
device_memory.size());
// Element is array-shaped: transfer array data to device buffer.
- const Literal& subliteral = literal.GetSubliteral(index);
+ const auto subliteral = LiteralView::Create(literal, index);
std::unique_ptr<Literal> relayed_out_literal;
const void* source;
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
- source = subliteral.InternalData();
+ source = subliteral.untyped_data();
} else {
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
- source = relayed_out_literal->InternalData();
+ source = relayed_out_literal->untyped_data();
}
return TransferBufferToDevice(
executor,
if (!ShapeUtil::IsTuple(shape)) {
int64 size = GetByteSizeRequirement(shape);
- return TransferBufferToInfeed(executor, size, literal.InternalData());
+ return TransferBufferToInfeed(executor, size, literal.untyped_data());
}
if (ShapeUtil::IsNestedTuple(shape)) {
// enqueue the resulting destination device addresses with the
// infeed manager.
std::vector<gpu::InfeedBuffer*> buffers;
- buffers.reserve(literal.tuple_literals_size());
+ buffers.reserve(ShapeUtil::TupleElementCount(shape));
auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
for (gpu::InfeedBuffer* b : buffers) {
b->Done();
}
});
- for (const auto& tuple_element : literal.tuple_literals()) {
- const Shape& tuple_element_shape = tuple_element.shape();
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& tuple_element_shape =
+ ShapeUtil::GetTupleElementShape(shape, i);
int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape);
TF_ASSIGN_OR_RETURN(
gpu::InfeedBuffer * buffer,
TransferBufferToInfeedInternal(executor, tuple_element_size,
- tuple_element.InternalData()));
+ literal.untyped_data({i})));
buffers.push_back(buffer);
}
const HloInstruction* operand = inst->operand(0);
CHECK_EQ(HloOpcode::kConstant, operand->opcode());
return MakeUnique<HostToDeviceCopyThunk>(
- /*source_address=*/operand->literal().InternalData(),
+ /*source_address=*/operand->literal().untyped_data(),
/*destination_buffer=*/GetAllocationSlice(*inst),
/*mem_size=*/
llvm_ir::ByteSizeOf(operand->shape(),
parent_->evaluated_[broadcast] =
Literal::CreateFromShape(broadcast->shape());
auto output = parent_->evaluated_[broadcast].get();
- auto operand_to_broadcast =
+ const Literal& operand_to_broadcast =
parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
std::vector<int64> broadcast_indices(
ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
<< " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
- auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto result = Literal::CreateFromShape(result_shape);
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
return scalar;
}));
- auto evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0));
+ const Literal& evaluated_operand =
+ parent_->GetEvaluatedLiteralFor(pad->operand(0));
std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
0);
<< ShapeUtil::HumanString(inferred_return_shape);
const int64 rank = ShapeUtil::Rank(operand->shape());
- auto operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
auto func = [&](tensorflow::gtl::ArraySlice<int64> out_index) {
DimensionVector operand_index(rank);
for (int64 i = 0; i < rank; ++i) {
StatusOr<std::unique_ptr<Literal>> DynamicSlice(
const Literal& operand_literal, const Literal& start_indices_literal,
const Shape& result_shape) {
- const auto& start_indices_typed =
- start_indices_literal.GetArraySlice<IndexT>();
+ auto start_indices_typed = start_indices_literal.data<IndexT>();
std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
const Literal& operand_literal, const Literal& update_literal,
const Literal& start_indices_literal) {
- const auto& start_indices_typed =
- start_indices_literal.GetArraySlice<IndexT>();
+ auto start_indices_typed = start_indices_literal.data<IndexT>();
const std::vector<int64> start(start_indices_typed.begin(),
start_indices_typed.end());
- auto result = MakeUnique<Literal>(operand_literal);
+ auto result = operand_literal.CloneToUnique();
std::vector<int64> result_index(ShapeUtil::Rank(result->shape()), 0);
auto func = [&](const std::vector<int64>& update_index) {
TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
- return MakeUnique<Literal>(
- GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()));
+ return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
+ .CloneToUnique();
}
template <typename LiteralPtr>
}
TF_RETURN_IF_ERROR(computation.Accept(this));
- return MakeUnique<Literal>(
- GetEvaluatedLiteralFor(computation.root_instruction()));
+ return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
}
template <typename LiteralPtr>
<< input_literal->ToString();
TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
- evaluated_[operand] = MakeUnique<Literal>(*input_literal);
+ evaluated_[operand] = input_literal->CloneToUnique();
}
}
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
+ return GetEvaluatedLiteralFor(instruction).CloneToUnique();
}
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
TF_RETURN_IF_ERROR(Preprocess(instruction));
TF_RETURN_IF_ERROR(instruction->Visit(this));
TF_RETURN_IF_ERROR(Postprocess(instruction));
- return MakeUnique<Literal>(GetEvaluatedLiteralFor(instruction));
+ return GetEvaluatedLiteralFor(instruction).CloneToUnique();
}
std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
<< ", but input literal shape is: "
<< ShapeUtil::HumanString(input_literal->shape());
- evaluated_[parameter] = MakeUnique<Literal>(*input_literal);
+ evaluated_[parameter] = input_literal->CloneToUnique();
return Status::OK();
}
for (auto operand : operands) {
const Shape& operand_shape = operand->shape();
- TF_RETURN_IF_ERROR(result_literal->Copy(
+ TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
AsInt64Slice(operand_shape.dimensions())));
dest_indices[concat_dim] +=
const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
- evaluated_[get_tuple_element] =
- MakeUnique<Literal>(operand_tuple_literal.tuple_literals(index));
-
- return Status::OK();
+ evaluated_[get_tuple_element] = MakeUnique<Literal>(
+ ShapeUtil::GetTupleElementShape(operand->shape(), index));
+ return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
+ /*dest_shape_index=*/{},
+ /*src_shape_index=*/{index});
}
Status HloEvaluator::HandleCopy(HloInstruction* copy) {
TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
- auto result = MakeUnique<Literal>(GetEvaluatedLiteralFor(copy->operand(0)));
+ auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
evaluated_[copy] = std::move(result);
return Status::OK();
}
instruction->metadata_ = proto.metadata();
if (proto.has_literal()) {
- instruction->literal_ = MakeUnique<Literal>(proto.literal());
+ TF_ASSIGN_OR_RETURN(instruction->literal_,
+ Literal::CreateFromProto(proto.literal()));
}
instruction->parameter_number_ = proto.parameter_number();
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand);
- instruction->literal_.reset(new Literal);
- instruction->literal_->append_u8s(tag);
+ instruction->literal_ = Literal::CreateR1U8(tag);
return instruction;
}
string HloInstruction::TracingTag() const {
CHECK_EQ(HloOpcode::kTrace, opcode());
CHECK(literal_ != nullptr);
- return literal_->u8s_string();
+ return literal_->GetR1U8AsString();
}
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
TransferToServerResponse* result) {
- Literal literal = Literal(arg->literal());
- const Shape& shape = literal.shape();
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ Literal::CreateFromProto(arg->literal()));
+ const Shape& shape = literal->shape();
std::vector<se::StreamExecutor*> replicas;
if (arg->has_device_handle()) {
if (executor->device_ordinal() == master_device_ordinal) {
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- executor, literal, *shaped_buffer));
+ executor, *literal, *shaped_buffer));
} else {
// The replica is not the master. Create an cloned shaped buffer with
// the replica's device ordinal. This is required because
CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal());
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- executor, literal, *clone));
+ executor, *literal, *clone));
}
}
TF_ASSIGN_OR_RETURN(
executor = replicas[arg->replica_id()];
}
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+ Literal::CreateFromProto(arg->literal()));
return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
- executor, Literal(arg->literal()));
+ executor, *literal);
}
tensorflow::Status Service::TransferFromOutfeed(
/*include_unreachable_instructions=*/
false));
- std::vector<Literal> parameters(arg->parameters_size());
+ std::vector<std::unique_ptr<Literal>> parameters(arg->parameters_size());
for (int64 i = 0; i < arg->parameters_size(); ++i) {
- parameters[i] = Literal(arg->parameters(i));
+ TF_ASSIGN_OR_RETURN(parameters[i],
+ Literal::CreateFromProto(arg->parameters(i)));
}
- std::vector<const Literal*> parameter_ptrs;
- std::transform(parameters.begin(), parameters.end(),
- std::back_inserter(parameter_ptrs),
- [](const Literal& literal) { return &literal; });
-
HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<const Literal*>(
- *module, parameter_ptrs));
+ TF_ASSIGN_OR_RETURN(
+ auto result_literal,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(*module, parameters));
// Since the shape_with_output_layout option in ExecutionOption is
// non-effective to the Evaluator results, explicit relayout here.
const ConstantRequest& constant_request =
request.request().constant_request();
hlo_instruction = add_instruction(HloInstruction::CreateConstant(
- Literal(constant_request.literal()).CloneToUnique()));
+ Literal::CreateFromProto(constant_request.literal())
+ .ConsumeValueOrDie()));
break;
}
VLOG(2) << "Couldn't evaluate while cond: " << result.status();
return nullopt;
}
- return result.ValueOrDie()->GetArraySlice<bool>() ==
+ return result.ValueOrDie()->data<bool>() ==
tensorflow::gtl::ArraySlice<bool>{true};
};
return out;
}
+std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
+ out << shape_index.ToString();
+ return out;
+}
+
namespace {
// Recursive helper for comparing the equality of two shapes. Returns true if
}
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
- CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank";
+ CHECK(!ShapeUtil::IsTuple(shape))
+ << "Tuples do not have a rank, shape: " << shape;
return shape.dimensions_size();
}
ShapeIndexView index) {
const Shape* return_shape = &shape;
for (auto i : index) {
- CHECK(IsTuple(*return_shape));
+ CHECK(IsTuple(*return_shape))
+ << "Invalid index " << index << " for shape " << shape;
return_shape = &return_shape->tuple_shapes(i);
}
return *return_shape;
return shape;
}
+std::ostream& operator<<(std::ostream& out, const Shape& shape) {
+ out << ShapeUtil::HumanString(shape);
+ return out;
+}
+
} // namespace xla
};
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
+std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index);
// Namespaced collection of (static) shape utilities.
//
TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
};
+std::ostream& operator<<(std::ostream& out, const Shape& shape);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
std::iota(r1.begin(), r1.end(), 1.0);
ComputationBuilder builder(client_, TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4D(r4);
- *a_literal->mutable_shape()->mutable_layout() =
- LayoutUtil::MakeLayout({0, 1, 2, 3});
+ std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
auto a = builder.ConstantLiteral(*a_literal);
auto b = builder.ConstantR1<float>(r1);
builder.Add(a, b, {1});
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = *Literal::CreateR4FromArray4D(input_array_);
+ input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
auto tuple = builder.BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
{{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
.get(),
Literal::CreateR1<float>({4, 5}).get(),
Literal::CreateR1<float>({5, 5}).get()});
- ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
auto tuple = builder.BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
{{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
.get(),
Literal::CreateR1<float>({4, 5}).get(),
Literal::CreateR1<float>({5, 5}).get()});
- ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
auto tuple = builder.BatchNormTraining(h0, h1, h2,
/*epsilon=*/1, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
.get(),
Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
- ComputeAndCompareTuple(&builder, expected,
+ ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
auto tuple = builder.BatchNormTraining(h0, h1, h2,
/*epsilon=*/-100, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
.get(),
Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
- ComputeAndCompareTuple(&builder, expected,
+ ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
ErrorSpec(0.1));
}
builder.BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
{{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
.get(),
Literal::CreateR1<float>({0, 0}).get(),
Literal::CreateR1<float>({16, 20}).get()});
- ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
struct BatchNormTestParam {
auto offset_activations =
builder.Parameter(2, offset_literal->shape(), "scale");
- auto expected = *Literal::MakeTuple({expected_normalized.get(),
- Literal::CreateR1<float>(mean).get(),
- Literal::CreateR1<float>(var).get()});
+ auto expected = Literal::MakeTuple({expected_normalized.get(),
+ Literal::CreateR1<float>(mean).get(),
+ Literal::CreateR1<float>(var).get()});
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
ComputeAndCompareTuple(
- &builder, expected,
+ &builder, *expected,
{input_data.get(), scale_data.get(), offset_data.get()},
ErrorSpec(0.01, 1));
}
grad_output_parameter, epsilon, feature_index);
auto expected =
- *Literal::MakeTuple({expected_grad_activation.get(),
- Literal::CreateR1<float>(grad_scale).get(),
- Literal::CreateR1<float>(grad_offset).get()});
+ Literal::MakeTuple({expected_grad_activation.get(),
+ Literal::CreateR1<float>(grad_scale).get(),
+ Literal::CreateR1<float>(grad_offset).get()});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
// testcase.
execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
- ComputeAndCompareTuple(&builder, expected,
+ ComputeAndCompareTuple(&builder, *expected,
{input_data.get(), scale_data.get(), mean_data.get(),
var_data.get(), grad_output_data.get()},
ErrorSpec(0.01, 1));
auto tuple = builder.BatchNormTraining(operand, scale, offset,
/*epsilon=*/0.001, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.7f)}, {static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.65f)}}},
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
- ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
}
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
builder.BatchNormGrad(operand, scale, mean, var, grad_output,
/*epsilon=*/0.0, kFeatureIndex);
- auto expected = *Literal::MakeTuple(
+ auto expected = Literal::MakeTuple(
{Literal::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
.get()});
- ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
}
} // namespace
LiteralTestUtil::ExpectNear(
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- result->tuple_literals(0), error_spec_);
+ LiteralView::Create(*result, {0}), error_spec_);
LiteralTestUtil::ExpectNear(
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- result->tuple_literals(1), error_spec_);
+ LiteralView::Create(*result, {1}), error_spec_);
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
VLOG(1) << "expected: " << expected_literal->ToString();
VLOG(1) << "actual: " << actual->ToString();
- EXPECT_EQ(expected, actual->u8s_string());
+ EXPECT_EQ(expected, actual->GetR1U8AsString());
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
auto result,
client_->ExecuteAndTransfer(computation, {}, &execution_options));
LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
- result->tuple_literals(0));
+ LiteralView::Create(*result, {0}));
LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
- result->tuple_literals(1));
+ LiteralView::Create(*result, {1}));
EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
auto computation = b.Add(param, b.ConstantR0<float>(1.5f));
std::vector<Literal> arguments;
- arguments.emplace_back(*Literal::CreateR0(42.5f));
+ arguments.push_back(std::move(*Literal::CreateR0(42.5f)));
EXPECT_TRUE(IsConstant(computation, &b, arguments.size()));
auto value =
{5.0f, 4.4f}, // p2
});
input_array.FillWithPZ(pz);
- Literal input_literal = *Literal::CreateR4FromArray4D(input_array);
+ std::unique_ptr<Literal> input_literal =
+ Literal::CreateR4FromArray4D(input_array);
{
ComputationBuilder builder(client_, TestName());
- builder.ConstantLiteral(input_literal);
+ builder.ConstantLiteral(*input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
std::unique_ptr<Literal> result = ExecuteAndTransferOrDie(&builder, {});
- LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
- result->tuple_literals(0), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, result->tuple_literals(1),
- error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>(
+ {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>(
+ {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_);
}
} // namespace
}));
ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
}
}));
// clang-format on
ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
}
}));
// clang-format on
ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
}
}));
// clang-format on
ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
+ {std::move(*Literal::CreateFromArray(input_data)),
+ std::move(*Literal::CreateFromArray(filter_data))},
error_spec_);
}
Array2D<float> expected_result(29, 10);
expected_result.Fill(0);
- ComputeAndCompare(
- &builder, conv,
- {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)},
- error_spec_);
+ ComputeAndCompare(&builder, conv,
+ {std::move(*Literal::CreateFromArray(param0)),
+ std::move(*Literal::CreateFromArray(param1))},
+ error_spec_);
}
INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
void TestCopyOp(const Literal& literal) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(MakeUnique<Literal>(literal)));
+ HloInstruction::CreateConstant(literal.CloneToUnique()));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopy, constant));
auto computation = builder.Build();
std::unique_ptr<Literal> literal =
Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
- Layout* literal_layout = literal->mutable_shape()->mutable_layout();
+ Layout* literal_layout =
+ literal->mutable_shape_do_not_use()->mutable_layout();
ASSERT_EQ(2, literal_layout->minor_to_major_size());
literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
}
+namespace {
+
+// Return a literal with all arrays of type FromNativeT converted to type
+// ToNativeT in the given literal.
+template <typename FromNativeT, typename ToNativeT>
+std::unique_ptr<Literal> ConvertType(const Literal& literal) {
+ // First construct shape of the result.
+ Shape result_shape(literal.shape());
+ ShapeUtil::ForEachMutableSubshape(
+ &result_shape, [](Shape* subshape, const ShapeIndex&) {
+ if (subshape->element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ subshape->set_element_type(
+ primitive_util::NativeToPrimitiveType<ToNativeT>());
+ }
+ });
+ auto result = MakeUnique<Literal>(result_shape);
+
+ // Then copy over the data from 'literal' converting FromNativeT values to
+ // ToNativeT values as necessary.
+ ShapeUtil::ForEachSubshape(
+ literal.shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ if (subshape.element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ auto src = literal.data<FromNativeT>(shape_index);
+ auto dest = result->data<ToNativeT>(shape_index);
+ for (int64 i = 0; i < src.size(); ++i) {
+ dest[i] = static_cast<ToNativeT>(src[i]);
+ }
+ } else {
+ TF_CHECK_OK(result->CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
+ }
+ }
+ });
+ return result;
+}
+
+} // namespace
+
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
const Literal& literal) {
- if (ShapeUtil::IsTuple(literal.shape())) {
- std::vector<std::unique_ptr<Literal>> converted_elements;
- for (const auto& element : literal.tuple_literals()) {
- converted_elements.push_back(ConvertBF16ToF32(element));
- }
- return Literal::MakeTupleOwned(std::move(converted_elements));
- }
-
- if (literal.shape().element_type() != BF16) {
- return MakeUnique<Literal>(literal);
- }
- Shape converted_shape = literal.shape();
- converted_shape.set_element_type(F32);
- auto converted = Literal::CreateFromShape(converted_shape);
- if (!ShapeUtil::HasZeroElements(converted_shape)) {
- std::vector<int64> index(converted_shape.dimensions_size(), 0);
- do {
- converted->Set<float>(index,
- static_cast<float>(literal.Get<bfloat16>(index)));
- } while (IndexUtil::BumpIndices(converted_shape, &index));
- }
- return converted;
+ return ConvertType<bfloat16, float>(literal);
}
/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
const Literal& literal) {
- if (ShapeUtil::IsTuple(literal.shape())) {
- std::vector<std::unique_ptr<Literal>> converted_elements;
- for (const auto& element : literal.tuple_literals()) {
- converted_elements.push_back(ConvertF32ToBF16(element));
- }
- return Literal::MakeTupleOwned(std::move(converted_elements));
- }
-
- if (literal.shape().element_type() != F32) {
- return MakeUnique<Literal>(literal);
- }
- Shape converted_shape = literal.shape();
- converted_shape.set_element_type(BF16);
- auto converted = Literal::CreateFromShape(converted_shape);
- if (!ShapeUtil::HasZeroElements(converted_shape)) {
- std::vector<int64> index(converted_shape.dimensions_size(), 0);
- do {
- converted->Set<bfloat16>(
- index, static_cast<bfloat16>(literal.Get<float>(index)));
- } while (IndexUtil::BumpIndices(converted_shape, &index));
- }
- return converted;
+ return ConvertType<float, bfloat16>(literal);
}
namespace {
break;
case TUPLE: {
bool tuple_match = true;
- for (int i = 0; i < actual.tuple_literals_size(); ++i) {
- auto result =
- Equal(expected.tuple_literals(i), actual.tuple_literals(i));
+ for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ // Create LiteralViews of the expected and actual elements.
+ auto result = Equal(LiteralView::Create(expected, {i}),
+ LiteralView::Create(actual, {i}));
tuple_match = tuple_match ? !!result : false;
}
match = tuple_match;
AssertEqualShapes(expected.shape(), actual.shape());
::testing::AssertionResult err = ::testing::AssertionSuccess();
- for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
SCOPED_TRACE(tensorflow::strings::StrCat(
"Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
- const auto& expected_element = expected.tuple_literals(i);
- const auto& actual_element = actual.tuple_literals(i);
+ const auto expected_element = LiteralView::Create(expected, {i});
+ const auto actual_element = LiteralView::Create(actual, {i});
::testing::AssertionResult res = [&] {
if (ShapeUtil::IsTuple(expected_element.shape())) {
abs_expected_miscompare_sum_ = 0.0;
max_rel_err_ = 0.0;
max_abs_err_ = 0.0;
- *miscompares_.mutable_shape() =
- ShapeUtil::ChangeElementType(actual.shape(), PRED);
- miscompares_.mutable_preds()->resize(
- ShapeUtil::ElementsIn(miscompares_.shape()), false);
+ miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED));
multi_index_.resize(expected.shape().dimensions_size(), 0);
switch (expected.shape().element_type()) {
AssertEqualShapes(expected.shape(), actual.shape());
::testing::AssertionResult err = ::testing::AssertionSuccess();
- for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
SCOPED_TRACE(tensorflow::strings::StrCat(
"Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape())));
- const auto& expected_element = expected.tuple_literals(i);
- const auto& actual_element = actual.tuple_literals(i);
+ const auto expected_element = LiteralView::Create(expected, {i});
+ const auto actual_element = LiteralView::Create(actual, {i});
::testing::AssertionResult res = [&] {
if (ShapeUtil::IsTuple(expected_element.shape())) {
}
CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
- auto new_literal = MakeUnique<Literal>();
- *new_literal->mutable_shape() =
- ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions);
+ auto new_literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
// Create a new shape with the given minor-to-major layout. This shape is used
// solely for converting linear address to multi-dimensional addresses when
Shape shape_with_layout = new_literal->shape();
*shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
- // Allocate space in the new literal.
- new_literal->Reserve(ShapeUtil::ElementsIn(literal.shape()));
-
// Copy data into new literal, element-by-element.
for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
std::vector<int64> from_multi_index =
LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
&literal_proto));
- Literal literal(literal_proto);
+ std::unique_ptr<Literal> literal =
+ Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
if (result.find("expected") != string::npos) {
- EXPECT_EQ("2", literal.ToString());
+ EXPECT_EQ("2", literal->ToString());
} else if (result.find("actual") != string::npos) {
- EXPECT_EQ("4", literal.ToString());
+ EXPECT_EQ("4", literal->ToString());
} else if (result.find("miscompares") != string::npos) {
- EXPECT_EQ("true", literal.ToString());
+ EXPECT_EQ("true", literal->ToString());
} else {
FAIL() << "unknown file in temporary directory: " << result;
}
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- result_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- result_literal->tuple_literals(1));
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- result_literal->tuple_literals(2));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{10.0f, 20.0f}, {30.0f, 40.0f}},
+ LiteralView::Create(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- result_literal->tuple_literals(1));
- const Literal& inner_tuple_literal = result_literal->tuple_literals(0);
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- inner_tuple_literal.tuple_literals(0));
- LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
- inner_tuple_literal.tuple_literals(1));
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- inner_tuple_literal.tuple_literals(2));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralView::Create(*result_literal, {0, 0}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{10.0f, 20.0f}, {30.0f, 40.0f}},
+ LiteralView::Create(*result_literal, {0, 1}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralView::Create(*result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
DefaultExecutableRunOptions());
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- result_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
- result_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
- LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
- result_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
- result_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{56.0f, 46.0f}, {36.0f, 26.0f}},
+ LiteralView::Create(*result_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>(
+ {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
ExecuteLocallyOrDie(computation, {arg_buffer.get()});
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
- LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
- result_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
- result_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>(
+ {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
std::unique_ptr<ScopedShapedBuffer> result_0 =
ExecuteLocallyOrDie(computation, {arg_buffer.get()});
std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(*result_0);
- LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
- result_0_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
- result_0_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{-1.0, -2.0}, {-3.0, -4.0}},
+ LiteralView::Create(*result_0_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1}));
std::unique_ptr<ScopedShapedBuffer> result_1 =
ExecuteLocallyOrDie(computation, {result_0.get()});
std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(*result_1);
- LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
- result_1_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
- result_1_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>(
+ {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, result_literal->tuple_literals(i), error_spec_);
+ {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}),
+ error_spec_);
}
}
for (int i = 0; i < kFanout; ++i) {
for (int j = 0; j < kFanout; ++j) {
LiteralTestUtil::ExpectR0Near<float>(
- i + j + i * kFanout + j,
- result_literal->tuple_literals(i).tuple_literals(j), error_spec_);
+ i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}),
+ error_spec_);
}
}
}
ExecuteLocallyOrDie(computation, {arg_buffer.get()});
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(*result);
- const Literal* result_element = result_literal.get();
+ ShapeIndex index;
for (int i = 0; i < kTupleDepth; ++i) {
- result_element = &result_element->tuple_literals(0);
+ index.push_back(0);
}
- LiteralTestUtil::ExpectR0Equal<float>(165.0, *result_element);
+ LiteralTestUtil::ExpectR0Equal<float>(
+ 165.0, LiteralView::Create(*result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
std::unique_ptr<ScopedShapedBuffer> result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(*result);
- LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
- tuple_literal->tuple_literals(0));
- LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
- tuple_literal->tuple_literals(1));
+ LiteralTestUtil::ExpectR1Equal<float>(
+ {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>(
+ {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
nullptr);
}
- Literal arg1;
- arg1.PopulateWithValue<float>(2.5f, {size, size});
+ Literal arg1(ShapeUtil::MakeShape(F32, {size, size}));
+ arg1.PopulateWithValue<float>(2.5f);
- Literal expect;
- expect.PopulateWithValue<float>(size * 1.5f * 3.5f, {size, size});
+ Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
+ expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual = ExecuteAndTransfer(
std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
nullptr);
}
- Literal input0, input1;
- input0.PopulateWithValue<float>(2.5f, {size});
- input1.PopulateWithValue<double>(1, {size});
+ Literal input0(ShapeUtil::MakeShape(F32, {size}));
+ input0.PopulateWithValue(2.5f);
+ Literal input1(ShapeUtil::MakeShape(F64, {size}));
+ input1.PopulateWithValue(1.);
- Literal expect = *Literal::CreateR1<float>({size * 1.5f * 3.5f});
+ Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
}
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
- {1, 2}, {3, 4},
- });
- *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
ComputationBuilder builder(client_, TestName());
builder.Parameter(0, literal->shape(), "input");
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
- {1, 3}, {2, 4},
- });
- *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
ComputationBuilder builder(client_, TestName());
builder.Parameter(0, literal->shape(), "input");
original.layout().minor_to_major().begin(),
original.layout().minor_to_major().end());
std::reverse(original_layout.begin(), original_layout.end());
- *literal->mutable_shape()->mutable_layout() =
+ *literal->mutable_shape_do_not_use()->mutable_layout() =
LayoutUtil::MakeLayout(original_layout);
ASSERT_EQ(2, literal->Get<float>({0, 1}));
}
computation,
/*arguments=*/{param0_data.get()}, &execution_options));
- EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size());
- for (int i = 0; i < param0_literal->f32s_size(); ++i) {
- EXPECT_GE(actual->f32s(i), param0_literal->f32s(i));
- EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f);
+ EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
+ ShapeUtil::ElementsIn(param0_literal->shape()));
+ for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
+ EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
+ EXPECT_LT(actual->data<float>()[i],
+ param0_literal->data<float>()[i] + 1.0f);
}
}
XLA_TEST_P(ReshapeTest, ToScalar) {
for (int rank = 0; rank < 8; ++rank) {
ComputationBuilder b(client_, TestName());
- auto input_literal = Literal::CreateR1<float>({83.0f});
std::vector<int64> ones(rank, 1); // this is {1, ..., 1}.
std::vector<int64> dimensions(rank);
std::iota(dimensions.begin(), dimensions.end(), 0);
- *input_literal->mutable_shape() = ShapeUtil::MakeShape(F32, ones);
+ Literal input_literal(ShapeUtil::MakeShape(F32, ones));
+ std::vector<int64> zeros(rank, 0); // this is {0, ..., 0}.
+ input_literal.Set<float>(zeros, 83.0f);
ComputationDataHandle parameter;
- auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+ auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&b, ¶meter);
b.Reshape(parameter, dimensions, {});
// data.
if (use_bfloat16()) {
auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal);
- EXPECT_EQ(tensorflow::gtl::ArraySlice<bfloat16>(expected->bf16s()),
- tensorflow::gtl::ArraySlice<bfloat16>(output_literal->bf16s()));
+ EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
} else {
- EXPECT_EQ(tensorflow::gtl::ArraySlice<float>(input_literal->f32s()),
- tensorflow::gtl::ArraySlice<float>(output_literal->f32s()));
+ EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
}
}
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, *device_buffer));
- EXPECT_EQ(result->u8s_string(), test_string);
+ EXPECT_EQ(result->GetR1U8AsString(), test_string);
}
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
ShapeUtil::HumanString(shape).c_str());
}
- auto result = MakeUnique<Literal>();
+ auto result = MakeUnique<Literal>(shape);
const float fill = std::numeric_limits<float>::quiet_NaN();
- result->PopulateWithValue<float>(fill, AsInt64Slice(shape.dimensions()));
+ result->PopulateWithValue<float>(fill);
std::vector<tensorflow::StringPiece> pieces;
std::vector<tensorflow::StringPiece> coordinates;
std::vector<int64> coordinate_values;
PrimitiveType_Name(literal->shape().element_type())));
}
- literal->GetMutableArraySlice<LiteralNativeT>().at(linear_index) =
+ literal->data<LiteralNativeT>().at(linear_index) =
static_cast<LiteralNativeT>(value);
return true;
}
arguments = MakeFakeArgumentsOrDie(computation, client);
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
- Literal literal(proto);
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
+ Literal::CreateFromProto(proto));
TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
- client->TransferToServer(literal));
+ client->TransferToServer(*literal));
arguments.push_back(std::move(data));
}
}
ShapeUtil::HumanString(result->shape()).c_str(),
result->ToString().c_str());
if (module.has_result()) {
+ std::unique_ptr<Literal> literal =
+ Literal::CreateFromProto(module.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(module.result().shape()).c_str(),
- Literal(module.result()).ToString().c_str());
+ literal->ToString().c_str());
}
}
}
xla::LiteralProto literal_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
&literal_proto));
- xla::Literal literal(literal_proto);
+ std::unique_ptr<xla::Literal> literal =
+ xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
- fprintf(stderr, "%s\n", literal.ToString().c_str());
+ fprintf(stderr, "%s\n", literal->ToString().c_str());
}
std::unique_ptr<xla::Literal> literal =
xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
- LOG(INFO) << "literal: " << literal->ShortDebugString();
+ LOG(INFO) << "literal: " << *literal;
fprintf(stderr, "%s\n", literal->ToString().c_str());
if (literal->shape().element_type() == xla::F32) {
- float min =
- *std::min_element(literal->f32s().begin(), literal->f32s().end());
- float max =
- *std::max_element(literal->f32s().begin(), literal->f32s().end());
+ float min = *std::min_element(literal->data<float>().begin(),
+ literal->data<float>().end());
+ float max = *std::max_element(literal->data<float>().begin(),
+ literal->data<float>().end());
fprintf(stderr, "min: %a=%f\n", min, min);
fprintf(stderr, "max: %a=%f\n", max, max);
}