[XLA] Redesign: implement local client and local service interface.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Mar 2018 23:13:13 +0000 (16:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:22:41 +0000 (04:22 -0700)
PiperOrigin-RevId: 190291400

tensorflow/compiler/xla/client/local_client.cc
tensorflow/compiler/xla/client/local_client.h
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/local_service.cc
tensorflow/compiler/xla/service/local_service.h
tensorflow/compiler/xla/service/service.cc
tensorflow/compiler/xla/service/service.h

index 91396f0..3059424 100644 (file)
@@ -265,6 +265,24 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
                                         updated_options));
 }
 
+StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
+    const XlaComputation& computation,
+    const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+    const ExecutableBuildOptions& options) {
+  ExecutableBuildOptions updated_options = options;
+  if (options.device_ordinal() == -1) {
+    updated_options.set_device_ordinal(default_device_ordinal());
+    VLOG(3) << "Set device ordinal to default value of: "
+            << updated_options.device_ordinal();
+  }
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                      local_service_->CompileExecutable(
+                          computation, argument_layouts, updated_options));
+  return WrapUnique(new LocalExecutable(std::move(executable),
+                                        local_service_->mutable_backend(),
+                                        updated_options));
+}
+
 StatusOr<std::unique_ptr<ScopedShapedBuffer>>
 LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal,
                                    DeviceMemoryAllocator* allocator) {
index 2e5d85b..98ee7c6 100644 (file)
@@ -123,7 +123,14 @@ class LocalClient : public Client {
       const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
       const ExecutableBuildOptions& options);
 
-  // TODO(b/74197823): Add a overload of Compile for XlaComputation.
+  // Build and return a LocalExecutable object. The executable is compiled using
+  // the given XlaComputation, argument layouts and options.
+  //
+  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+  StatusOr<std::unique_ptr<LocalExecutable>> Compile(
+      const XlaComputation& computation,
+      const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+      const ExecutableBuildOptions& options);
 
   // Copy the literal data to the device with the given ordinal and return as a
   // ScopedShapedBuffer. If non-null the given memory allocator is used for
index d4d6787..da16976 100644 (file)
@@ -623,6 +623,7 @@ cc_library(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:executable_build_options",
+        "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
     ],
index 1e2d8ee..499f280 100644 (file)
@@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options,
                            std::unique_ptr<Backend> execute_backend)
     : Service(options, std::move(execute_backend)) {}
 
+namespace {
+
+// Retrieves the parameter metadata for the given computation and parameter
+// number.
+//
+// If the parameter number is invalid for this computation, nullopt is
+// returned. When the return value has_value(), nullptr will never be
+// the held value.
+tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
+    const XlaComputation& computation, int parameter_number) {
+  for (const HloComputationProto& comp : computation.proto().computations()) {
+    if (comp.id() == computation.proto().entry_computation_id()) {
+      for (const HloInstructionProto& instr : comp.instructions()) {
+        if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
+            instr.parameter_number() == parameter_number) {
+          if (!instr.has_metadata()) {
+            return tensorflow::gtl::nullopt;
+          }
+          return &instr.metadata();
+        }
+      }
+    }
+  }
+  return tensorflow::gtl::nullopt;
+}
+
+ExecutionOptions CreateExecutionOptions(
+    const ExecutableBuildOptions& build_options,
+    const ProgramShape* program_shape) {
+  ExecutionOptions execution_options = CreateDefaultExecutionOptions();
+  if (build_options.hlo_profile().has_value()) {
+    execution_options.mutable_debug_options()->set_xla_hlo_profile(
+        *build_options.hlo_profile());
+  }
+  if (build_options.generate_hlo_graph().has_value()) {
+    execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
+        build_options.generate_hlo_graph().value());
+  }
+  if (build_options.dump_optimized_hlo_proto_to().has_value()) {
+    execution_options.mutable_debug_options()
+        ->set_xla_dump_optimized_hlo_proto_to(
+            build_options.dump_optimized_hlo_proto_to().value());
+  }
+  if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
+    execution_options.mutable_debug_options()
+        ->set_xla_dump_per_pass_hlo_proto_to(
+            build_options.dump_per_pass_hlo_proto_to().value());
+  }
+  if (build_options.result_layout() != nullptr) {
+    *execution_options.mutable_shape_with_output_layout() =
+        *build_options.result_layout();
+  } else {
+    *execution_options.mutable_shape_with_output_layout() =
+        program_shape->result();
+    LayoutUtil::SetToDefaultLayout(
+        execution_options.mutable_shape_with_output_layout());
+  }
+  return execution_options;
+}
+
+}  // namespace
+
 StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
     const ComputationHandle& computation,
     const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
@@ -118,34 +180,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
         *build_options.result_layout(), program_shape->result()));
   }
 
-  ExecutionOptions execution_options = CreateDefaultExecutionOptions();
-  if (build_options.hlo_profile().has_value()) {
-    execution_options.mutable_debug_options()->set_xla_hlo_profile(
-        *build_options.hlo_profile());
-  }
-  if (build_options.generate_hlo_graph().has_value()) {
-    execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
-        build_options.generate_hlo_graph().value());
-  }
-  if (build_options.dump_optimized_hlo_proto_to().has_value()) {
-    execution_options.mutable_debug_options()
-        ->set_xla_dump_optimized_hlo_proto_to(
-            build_options.dump_optimized_hlo_proto_to().value());
-  }
-  if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
-    execution_options.mutable_debug_options()
-        ->set_xla_dump_per_pass_hlo_proto_to(
-            build_options.dump_per_pass_hlo_proto_to().value());
-  }
-  if (build_options.result_layout() != nullptr) {
-    *execution_options.mutable_shape_with_output_layout() =
-        *build_options.result_layout();
-  } else {
-    *execution_options.mutable_shape_with_output_layout() =
-        program_shape->result();
-    LayoutUtil::SetToDefaultLayout(
-        execution_options.mutable_shape_with_output_layout());
-  }
+  ExecutionOptions execution_options =
+      CreateExecutionOptions(build_options, program_shape.get());
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
                       CreateModuleConfig(*program_shape, argument_layouts,
                                          &execution_options, user_computation));
@@ -159,6 +195,67 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
                          build_options.device_allocator());
 }
 
+StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
+    const XlaComputation& computation,
+    const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+    const ExecutableBuildOptions& build_options) {
+  const HloModuleProto& proto = computation.proto();
+  TF_RET_CHECK(proto.has_program_shape());
+  const ProgramShape& program_shape = proto.program_shape();
+
+  // Validate incoming layouts.
+  if (argument_layouts.size() != program_shape.parameters_size()) {
+    return InvalidArgument(
+        "Invalid number of arguments for computation: expected %d, got %zu.",
+        program_shape.parameters_size(), argument_layouts.size());
+  }
+
+  for (int i = 0; i < argument_layouts.size(); ++i) {
+    const Shape& argument_shape = *argument_layouts[i];
+    TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
+    if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
+      tensorflow::gtl::optional<const OpMetadata*> metadata =
+          ParameterMetadata(computation, /*parameter_number=*/i);
+      auto metadata_string = [&metadata]() -> string {
+        if (!metadata.has_value()) {
+          return "";
+        }
+        CHECK(metadata.value() != nullptr);
+        const OpMetadata& m = *metadata.value();
+        if (!m.source_file().empty()) {
+          return tensorflow::strings::Printf(
+              " (%s:%d)", m.source_file().c_str(), m.source_line());
+        }
+        return "";
+      };
+      return InvalidArgument(
+          "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
+          metadata_string().c_str(),
+          ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
+          ShapeUtil::HumanString(argument_shape).c_str());
+    }
+  }
+  if (build_options.result_layout() != nullptr) {
+    TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(
+        *build_options.result_layout(), program_shape.result()));
+  }
+
+  ExecutionOptions execution_options =
+      CreateExecutionOptions(build_options, &program_shape);
+
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<HloModuleConfig> module_config,
+      CreateModuleConfig(program_shape, argument_layouts, &execution_options));
+
+  TF_ASSIGN_OR_RETURN(
+      se::StreamExecutor * executor,
+      execute_backend_->stream_executor(build_options.device_ordinal()));
+
+  return BuildExecutable(proto, std::move(module_config),
+                         execute_backend_.get(), executor,
+                         build_options.device_allocator());
+}
+
 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
   return backend().computation_placer()->DeviceId(
       replica_number, /*computation=*/0, options_.number_of_replicas(),
index 15e1206..06567ca 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include <memory>
 
 #include "tensorflow/compiler/xla/client/executable_build_options.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/service/backend.h"
 #include "tensorflow/compiler/xla/service/compiler.h"
 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -50,6 +51,18 @@ class LocalService : public Service {
       const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
       const ExecutableBuildOptions& options);
 
+  // Builds an Executable with the given XlaComputation, argument layouts and
+  // options. If result_layout is non-null, then the executable is compiled to
+  // produce a result of the given layout.  If device_allocator is non-null,
+  // then the compiler may use it to allocate temp space on the device.  The
+  // compiler is responsible for freeing any memory it allocates this way.
+  //
+  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+  StatusOr<std::unique_ptr<Executable>> CompileExecutable(
+      const XlaComputation& computation,
+      const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+      const ExecutableBuildOptions& build_options);
+
   // Returns the device ordinal that corresponds to the given replica number.
   //
   // This returns an error if there is not a one-to-one correspondence of
index 4f6a823..1d379f0 100644 (file)
@@ -963,6 +963,30 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
   return tensorflow::Status::OK();
 }
 
+StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
+    const HloModuleProto& module_proto,
+    std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+    se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
+  VLOG(1) << Printf(
+      "BuildExecutable on service %p with serialized module proto: %s", this,
+      module_proto.name().c_str());
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+                      HloModule::CreateFromProto(module_proto, *module_config));
+
+  TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+
+  TF_ASSIGN_OR_RETURN(
+      module, backend->compiler()->RunHloPasses(std::move(module), executor,
+                                                device_allocator));
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
+                      backend->compiler()->RunBackend(
+                          std::move(module), executor, device_allocator));
+
+  return std::move(executable);
+}
+
 tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
                                          ExecuteResponse* result) {
   VLOG(1) << "running execute-graph request";
@@ -979,24 +1003,17 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
       ResolveAndValidateArguments(arg->arguments(), replicas));
 
-  TF_ASSIGN_OR_RETURN(const auto& config,
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
                       CreateModuleConfig(arg->computation().program_shape(),
                                          replicated_arguments.front(),
                                          arg->execution_options()));
 
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
-                      HloModule::CreateFromProto(arg->computation(), *config));
-  TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
-
-  TF_ASSIGN_OR_RETURN(module, execute_backend_->compiler()->RunHloPasses(
-                                  std::move(module),
-                                  execute_backend_->default_stream_executor(),
-                                  /*device_allocator=*/nullptr));
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<Executable> executable,
-      execute_backend_->compiler()->RunBackend(
-          std::move(module), execute_backend_->default_stream_executor(),
-          /*device_allocator=*/nullptr));
+      BuildExecutable(arg->computation(), std::move(module_config),
+                      execute_backend_.get(),
+                      execute_backend_->default_stream_executor(),
+                      /*device_allocator=*/nullptr));
 
   TF_ASSIGN_OR_RETURN(
       *result->mutable_output(),
index 3b79920..773f0a6 100644 (file)
@@ -115,6 +115,8 @@ class Service : public ServiceInterface {
   // Executes a computation with the provided global data passed as
   // immutable arguments. The request contains the whole computation graph.
   // Returns global data output and execution timing.
+  //
+  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg,
                                   ExecuteResponse* result) override;
 
@@ -299,6 +301,15 @@ class Service : public ServiceInterface {
       perftools::gputools::StreamExecutor* executor,
       DeviceMemoryAllocator* device_allocator = nullptr);
 
+  // Builds an Executable for the given HLO module proto.
+  //
+  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+  StatusOr<std::unique_ptr<Executable>> BuildExecutable(
+      const HloModuleProto& module_proto,
+      std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
+      perftools::gputools::StreamExecutor* executor,
+      DeviceMemoryAllocator* device_allocator = nullptr);
+
   // Same as BuildExecutable() above, but builds a list of Executables for the
   // given computations that may interact with each other.
   StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(