[XLA] Support generating tuple shaped fake data in client testing
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Feb 2018 13:34:05 +0000 (05:34 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Feb 2018 13:38:27 +0000 (05:38 -0800)
The previous implementation failed over in case of a tuple shaped input
what broke the replay computation tool for the case where the input is a
tuple.

PiperOrigin-RevId: 185366228

tensorflow/compiler/xla/client/lib/testing.cc

index 5f2b557..b63a146 100644 (file)
@@ -31,14 +31,43 @@ limitations under the License.
 namespace xla {
 namespace {
 
+// Calculates the number of bytes required to store the data within the
+// specified shape. In case of a (nested) tuple shape this is the total byte
+// size of all sub-shapes within the tuple.
+int64 DataSizeOfShape(const Shape& shape) {
+  if (ShapeUtil::IsArray(shape)) {
+    return ShapeUtil::ByteSizeOf(shape);
+  }
+
+  int64 total_size = 0;
+  for (const Shape& s : shape.tuple_shapes()) {
+    total_size += DataSizeOfShape(s);
+  }
+  return total_size;
+}
+
+// Create a ComputationDataHandle for an op what generates fake data with the
+// given shape.
+ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape,
+                                              ComputationBuilder* builder) {
+  if (ShapeUtil::IsArray(shape)) {
+    return builder->Broadcast(
+        builder->ConstantLiteral(Literal::One(shape.element_type())),
+        AsInt64Slice(shape.dimensions()));
+  }
+  std::vector<ComputationDataHandle> parts;
+  for (const Shape& s : shape.tuple_shapes()) {
+    parts.push_back(BuildFakeDataOpOnDevice(s, builder));
+  }
+  return builder->Tuple(parts);
+}
+
 std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
                                                        Client* client) {
   ComputationBuilder b(
       client,
       tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
-  // TODO(b/26811613): Replace this when RNG is supported on all backends.
-  b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())),
-              AsInt64Slice(shape.dimensions()));
+  BuildFakeDataOpOnDevice(shape, &b);
   Computation computation = b.Build().ConsumeValueOrDie();
 
   auto execution_options = CreateDefaultExecutionOptions();
@@ -51,7 +80,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
 
 std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
                                               Client* client) {
-  if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) {
+  if (DataSizeOfShape(shape) < (1LL << 20)) {
     StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
     if (!literal_status.ok()) {
       // If we got an Unimplemented error, fall back to making the fake data via