[XLA] Switch replay_computation to use LocalClient.
authorJustin Lebar <jlebar@google.com>
Thu, 31 May 2018 00:00:50 +0000 (17:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 00:03:41 +0000 (17:03 -0700)
This lets replay_computation build an executable once and run it
multiple times.  This is particularly important because in XLA:GPU, the
first run of an executable does some autotuning and therefore is
unrepresentative.

This change removes --xla_hlo_profile_last_run, because I don't see how
to support it in LocalClient -- LocalClient wants the do-profile bit to
be set when we *compile*.  (There may not be an easy fix for this; it
worked with regular Client because we were recompiling every time we
ran.)

PiperOrigin-RevId: 198643577

tensorflow/compiler/xla/client/local_client.cc
tensorflow/compiler/xla/client/local_client.h
tensorflow/compiler/xla/service/local_service.cc
tensorflow/compiler/xla/service/local_service.h
tensorflow/compiler/xla/tools/replay_computation.cc

index a7c55c6..f900337 100644 (file)
@@ -304,6 +304,11 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
                                                                  shaped_buffer);
 }
 
+StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
+    const GlobalDataHandle& data, int replica_number) {
+  return local_service_->GlobalDataToShapedBuffer(data, replica_number);
+}
+
 Status LocalClient::TransferToInfeedLocal(const Literal& literal,
                                           int device_ordinal) {
   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
index 3f23e52..5b408cc 100644 (file)
@@ -136,6 +136,11 @@ class LocalClient : public Client {
   StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
       const ShapedBuffer& shaped_buffer);
 
+  // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
+  // as long as the handle is valid.
+  StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
+      const GlobalDataHandle& data, int replica_number);
+
   // Transfer the given literal to the infeed queue of the given device.
   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
   // not inherit from Client and there is no possibility of confusion with
index 0fa4061..41aef39 100644 (file)
@@ -260,4 +260,15 @@ StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
       /*computation_count=*/1);
 }
 
+StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
+    const GlobalDataHandle& data, int replica_number) {
+  TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
+  if (replica_number >= buffers.size()) {
+    return InvalidArgument(
+        "replica_number %d out of range; must be less than num_replicas = %zu.",
+        replica_number, buffers.size());
+  }
+  return buffers[replica_number];
+}
+
 }  // namespace xla
index 06567ca..b55f119 100644 (file)
@@ -70,6 +70,11 @@ class LocalService : public Service {
   // the "easy" case where a single replica is a single device.
   StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
 
+  // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
+  // as long as the handle is valid.
+  StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
+      const GlobalDataHandle& data, int replica_number);
+
  private:
   explicit LocalService(const ServiceOptions& options,
                         std::unique_ptr<Backend> backend);
index fc7e800..be094b7 100644 (file)
@@ -68,7 +68,6 @@ struct Options {
   bool use_fake_data = false;
   bool print_result = true;
   int num_runs = 1;
-  bool xla_hlo_profile_last_run = false;
 };
 
 // Invokes the given computation passing arbitrary data for every (unbound)
@@ -80,21 +79,35 @@ struct Options {
 //
 // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
 // no infeed is performed.
-StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
-                                                     Client* client,
-                                                     const Options& opts) {
+StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
+                                    LocalClient* client, const Options& opts) {
   XlaComputation computation(module.hlo().hlo_module());
 
-  std::vector<std::unique_ptr<GlobalData>> arguments;
+  // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
+  // arguments.  This is a bit involved, because we may have to convert from
+  // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
+  // objects.
+  std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
+  std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
+  std::vector<const ShapedBuffer*> argument_ptrs;
   if (opts.use_fake_data) {
-    arguments = MakeFakeArgumentsOrDie(computation, client);
+    global_data_arguments = MakeFakeArgumentsOrDie(computation, client);
+    for (const auto& data : global_data_arguments) {
+      argument_ptrs.push_back(
+          client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
+              .ValueOrDie());
+    }
   } else {  // use recorded data if available
     for (const auto& proto : module.arguments()) {
       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
                           Literal::CreateFromProto(proto));
-      TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
-                          client->TransferToServer(*literal));
-      arguments.push_back(std::move(data));
+      TF_ASSIGN_OR_RETURN(
+          ScopedShapedBuffer data,
+          client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+      scoped_shaped_buffer_arguments.push_back(std::move(data));
+    }
+    for (const auto& argument : scoped_shaped_buffer_arguments) {
+      argument_ptrs.push_back(&argument);
     }
   }
 
@@ -149,43 +162,41 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
     });
   }
 
-  std::vector<GlobalData*> execute_arguments;
-  execute_arguments.reserve(arguments.size());
-  for (auto& argument : arguments) {
-    execute_arguments.push_back(argument.get());
+  std::vector<const Shape*> argument_layouts;
+  for (const auto& param : computation.proto().program_shape().parameters()) {
+    argument_layouts.push_back(&param);
   }
+  std::unique_ptr<LocalExecutable> executable =
+      client->Compile(computation, argument_layouts, ExecutableBuildOptions())
+          .ValueOrDie();
 
   // Run the computation num_runs times, and return the result from the last
   // execution.
-  std::unique_ptr<Literal> result;
+  StreamExecutorMemoryAllocator allocator(
+      client->platform(),
+      {client->platform()->ExecutorForDevice(0).ValueOrDie()});
+  tensorflow::gtl::optional<ScopedShapedBuffer> result;
   for (int i = 0; i < opts.num_runs; ++i) {
     ExecutionProfile profile;
-    ExecutionOptions execution_options = CreateDefaultExecutionOptions();
-    if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) {
-      execution_options.mutable_debug_options()->set_xla_hlo_profile(true);
-    }
+    ExecutableRunOptions run_options;
+    run_options.set_execution_profile(&profile);
+    run_options.set_allocator(&allocator);
 
-    if (opts.print_result) {
-      TF_ASSIGN_OR_RETURN(
-          result, client->ExecuteAndTransfer(computation, execute_arguments,
-                                             &execution_options, &profile));
-    } else {
-      // If we're not printing the result, execute the computation but don't
-      // bother retrieving the result.  This can be a significant speedup.
-      TF_RETURN_IF_ERROR(client
-                             ->Execute(computation, execute_arguments,
-                                       &execution_options, &profile)
-                             .status());
-    }
+    TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options));
     LOG(INFO) << "Execution took "
               << static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
   }
 
-  return std::move(result);
+  // Check that --num_runs > 0, otherwise *result below will fail with an
+  // unhelpful error (because the loop didn't run any iterations).
+  CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+                      client->ShapedBufferToLiteral(*result));
+  return std::move(*result_literal);
 }
 
 int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
-  Client* client = ClientLibrary::LocalClientOrDie();
+  LocalClient* client = ClientLibrary::LocalClientOrDie();
   tensorflow::Env* env = tensorflow::Env::Default();
   int exit_status = EXIT_SUCCESS;
   for (char* arg : args) {
@@ -202,8 +213,8 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
       CHECK(opts.use_fake_data)
           << "HloProto input must be handled with --use_fake_data";
     }
-    StatusOr<std::unique_ptr<Literal>> result_status =
-        ReplayComputation(snapshot, client, opts);
+
+    StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
     if (!result_status.ok()) {
       fprintf(stderr, "%s: error: %s\n", arg,
               result_status.status().ToString().c_str());
@@ -211,12 +222,12 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
       continue;
     }
 
-    std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
-    if (result != nullptr) {
+    if (opts.print_result) {
+      Literal result = std::move(result_status).ValueOrDie();
       fprintf(stdout, "%s: %s :: %s:%s\n", arg,
               snapshot.hlo().hlo_module().name().c_str(),
-              ShapeUtil::HumanString(result->shape()).c_str(),
-              result->ToString().c_str());
+              ShapeUtil::HumanString(result.shape()).c_str(),
+              result.ToString().c_str());
       if (snapshot.has_result()) {
         std::unique_ptr<Literal> literal =
             Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
@@ -249,9 +260,6 @@ int main(int argc, char** argv) {
       tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
                        "Whether a fake infeed shape should be generated "
                        "derived from the computation"),
-      tensorflow::Flag(
-          "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
-          "Pass --xla_hlo_profile the last time we run the computation."),
   };
   xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);