From a2504ab91afcdeb917a272d2cc54f13dd67c04e7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 12 Feb 2018 05:34:05 -0800 Subject: [PATCH] [XLA] Support generating tuple shaped fake data in client testing The previous implementation failed over in case of a tuple shaped input what broke the replay computation tool for the case where the input is a tuple. PiperOrigin-RevId: 185366228 --- tensorflow/compiler/xla/client/lib/testing.cc | 37 ++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 5f2b557..b63a146 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -31,14 +31,43 @@ limitations under the License. namespace xla { namespace { +// Calculates the number of bytes required to store the data within the +// specified shape. In case of a (nested) tuple shape this is the total byte +// size of all sub-shapes within the tuple. +int64 DataSizeOfShape(const Shape& shape) { + if (ShapeUtil::IsArray(shape)) { + return ShapeUtil::ByteSizeOf(shape); + } + + int64 total_size = 0; + for (const Shape& s : shape.tuple_shapes()) { + total_size += DataSizeOfShape(s); + } + return total_size; +} + +// Create a ComputationDataHandle for an op what generates fake data with the +// given shape. +ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, + ComputationBuilder* builder) { + if (ShapeUtil::IsArray(shape)) { + return builder->Broadcast( + builder->ConstantLiteral(Literal::One(shape.element_type())), + AsInt64Slice(shape.dimensions())); + } + std::vector parts; + for (const Shape& s : shape.tuple_shapes()) { + parts.push_back(BuildFakeDataOpOnDevice(s, builder)); + } + return builder->Tuple(parts); +} + std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { ComputationBuilder b( client, tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); - // TODO(b/26811613): Replace this when RNG is supported on all backends. - b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())), - AsInt64Slice(shape.dimensions())); + BuildFakeDataOpOnDevice(shape, &b); Computation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); @@ -51,7 +80,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { - if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) { + if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr> literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via -- 2.7.4