From 901d119b938d9ff4239f27fbede488ae3d05d598 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 6 Feb 2018 15:30:40 -0800 Subject: [PATCH] [XLA] Add and use new Literal::MakeTupleOwned overload. Previously MakeTupleOwned was cumbersome to use, because you had to explicitly materialize a vector>. With this new overload, you can pass unique_ptrs directly. PiperOrigin-RevId: 184751119 --- tensorflow/compiler/xla/literal_util.cc | 12 +++++++++--- tensorflow/compiler/xla/literal_util.h | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 89279b6..09db011 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1257,11 +1257,17 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { - std::vector element_ptrs; + std::vector element_shapes; + element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_ptrs.push_back(element.get()); + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int64 i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); } - return MakeTuple(element_ptrs); + return literal; } void Literal::EachCellAsString( diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2b68b8f..d996004 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -485,6 +485,27 @@ class Literal { static std::unique_ptr MakeTupleOwned( std::vector> elements); + // This overload lets you pass a braced list of unique_ptrs to + // MakeTupleOwned: + // + // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). + // + // Simply relying on the MakeTupleOwned(std::vector>) + // overload doesn't work because std::initializer_list's elements are always + // const. + // + // The arguments to this function must all be unique_ptr. + template + static std::unique_ptr MakeTupleOwned( + std::unique_ptr... elements) { + std::array, sizeof...(Ts)> arr{ + std::move(elements)...}; + std::vector> v; + v.insert(v.begin(), std::make_move_iterator(arr.begin()), + std::make_move_iterator(arr.end())); + return MakeTupleOwned(std::move(v)); + } + // Returns a string representation of the literal value. // Warning: this function can take minutes for multi-million element Literals. string ToString(bool print_layout = false) const; -- 2.7.4