[XLA] Redesign: add a method that creates fake data for XlaComputation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 19:35:32 +0000 (12:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 19:38:21 +0000 (12:38 -0700)
PiperOrigin-RevId: 192807851

tensorflow/compiler/xla/client/lib/BUILD
tensorflow/compiler/xla/client/lib/testing.cc
tensorflow/compiler/xla/client/lib/testing.h

index f4673a8..59c4a53 100644 (file)
@@ -46,6 +46,7 @@ cc_library(
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:global_data",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:lib",
     ],
index b63a146..311dc4b 100644 (file)
@@ -111,4 +111,20 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
   return fake_arguments;
 }
 
+std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
+    const XlaComputation& computation, Client* client) {
+  CHECK(computation.proto().has_program_shape())
+      << "Computation should have progran shape.";
+  auto program_shape = computation.proto().program_shape();
+
+  // For every (unbound) parameter that the computation wants, we manufacture
+  // some arbitrary data so that we can invoke the computation.
+  std::vector<std::unique_ptr<GlobalData>> fake_arguments;
+  for (const Shape& parameter : program_shape.parameters()) {
+    fake_arguments.push_back(MakeFakeDataOrDie(parameter, client));
+  }
+
+  return fake_arguments;
+}
+
 }  // namespace xla
index 7e640d1..1dc2622 100644 (file)
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
@@ -38,6 +39,12 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
 std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
     const Computation& computation, Client* client);
 
+// Returns vector of GlobalData handles of fake data (created using
+// MakeFakeDataOrDie) that are correctly shaped arguments for the given
+// xla computation.
+std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
+    const XlaComputation& computation, Client* client);
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_