From 7e9113ab912caff9ad15195b15771ff20bde6080 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 2 Apr 2018 17:08:27 -0700 Subject: [PATCH] [XLA] Redesign: implement ExecuteGraphParallel. PiperOrigin-RevId: 191371793 --- tensorflow/compiler/xla/client/client.cc | 34 +++- tensorflow/compiler/xla/service/service.cc | 229 +++++++++++++++++++++++---- tensorflow/compiler/xla/service/service.h | 22 +++ tensorflow/compiler/xla/tests/BUILD | 2 + tensorflow/compiler/xla/tests/client_test.cc | 10 +- 5 files changed, 259 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index c4c8894..3f45167 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -324,8 +324,38 @@ StatusOr>> Client::ExecuteParallel( StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { - return Unimplemented( - "ExecuteParallel is not yet implemented for XlaComputation."); + ExecuteGraphParallelRequest request; + + for (const XlaComputationInstance& computation : computations) { + ExecuteGraphRequest single_request; + *single_request.mutable_computation() = computation.computation.proto(); + for (GlobalData* argument : computation.arguments) { + *single_request.add_arguments() = argument->handle(); + } + *single_request.mutable_execution_options() = computation.execution_options; + *request.add_requests() = single_request; + } + + ExecuteParallelResponse response; + VLOG(1) << "making execute-graph-parallel request: " + << request.ShortDebugString(); + tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector> outputs; + for (size_t i = 0; i < computations.size(); ++i) { + outputs.push_back( + MakeUnique(stub_, response.responses(i).output())); + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = response.responses(i).profile(); + } + } + + return std::move(outputs); } StatusOr> Client::GetDeviceHandles( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index ca8071b..ec883a6 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -409,6 +409,37 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } +StatusOr>> Service::BuildExecutables( + const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf("BuildExecutable on service %p", this); + + VLOG(1) << "Computations:"; + for (const HloModuleProto* proto : module_protos) { + VLOG(1) << proto->name(); + } + + CHECK_EQ(module_protos.size(), module_configs.size()); + std::vector> modules; + for (int64 i = 0; i < module_protos.size(); ++i) { + const HloModuleProto* proto = module_protos[i]; + const HloModuleConfig& config = *module_configs[i]; + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(*proto, config)); + modules.push_back(std::move(module)); + } + + TF_ASSIGN_OR_RETURN( + std::vector> executables, + backend->compiler()->Compile(std::move(modules), std::move(executors), + device_allocator)); + + return std::move(executables); +} + StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, @@ -703,6 +734,47 @@ tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, return computation->SetReturnValue(arg->operand()); } +StatusOr> +Service::GetExecutors(const ExecutionOptions& execution_options, + int64 requests_size, int64 request_index) const { + if (execution_options.device_handles().empty()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); + } + if (requests_size > 1 && execution_options.device_handles_size() > 1) { + return InvalidArgument( + "Parallel requests with multiple device handles is not supported. " + "Found %lld parallel requests, with request %lld containing %d device " + "handles.", + requests_size, request_index, execution_options.device_handles_size()); + } + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } + return executors; +} + +StatusOr>> Service::GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments) { + // Resolve the allocations for the arguments of the computation, and create + // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, execution_options.device_handles(0))); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arguments, replicas)); + return replicated_arguments; +} + tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); @@ -731,26 +803,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // is one of the executors to run the replicated computation. const ExecutionOptions& execution_options = arg->requests(i).execution_options(); - if (execution_options.device_handles().empty()) { - return FailedPrecondition( - "device handles must be given to execute parallel computations"); - } - if (arg->requests_size() > 1 && - execution_options.device_handles_size() > 1) { - return InvalidArgument( - "Parallel requests with multiple device handles is not supported. " - "Found %d parallel requests, with request %lld containing %d device " - "handles.", - arg->requests_size(), i, execution_options.device_handles_size()); - } - std::vector executors; - for (const auto& device_handle : execution_options.device_handles()) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, device_handle)); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); - executors.push_back(executor); - } + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -767,16 +823,9 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); - // Resolve the allocations for the arguments of the computation, and create - // a vector of device memory offsets for the arguments from the allocations. - // In the case of partitioned computations, assume all arguments go on the - // zeroth core. - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, execution_options.device_handles(0))); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(request.arguments(), replicas)); + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. Here, we care only about the @@ -839,7 +888,103 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, tensorflow::Status Service::ExecuteGraphParallel( const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { - return Unimplemented("execute-graph-parallel is not yet implemented"); + VLOG(1) << "running execute-graph-parallel request"; + + std::vector>> all_arguments; + std::vector> all_executors; + std::vector module_protos; + std::vector> module_configs; + std::vector computation_names; + std::vector device_handles; + + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteGraphRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > + execute_backend_->device_count()) { + return FailedPrecondition( + "there are not enough stream executors to execute %d computations", + num_requested_devices); + } + + for (int64 i = 0; i < arg->requests_size(); ++i) { + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + const ExecuteGraphRequest& request = arg->requests(i); + TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; + TF_RET_CHECK(request.computation().has_program_shape()) + << "programe shape may not be empty"; + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); + + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); + + // Create an HloModuleConfig object for the computation, given the shape of + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(request.computation().program_shape(), + replicated_arguments.front(), + request.execution_options(), + /*user_computation=*/nullptr)); + VLOG(3) + << "ExecuteGraphParallel created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + // Adds to the vectors to build and execute the computations after the loop. + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); + module_protos.push_back(&request.computation()); + module_configs.push_back(std::move(module_config)); + computation_names.insert(computation_names.end(), executors.size(), + request.computation().name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); + } + + // Build the HloModules and compile to generate the executables. + // + // TODO(jlebar): There's currently no way to pass a device allocator to + // ExecuteGraphParallel, so we have to pass a null device_allocator below. + TF_ASSIGN_OR_RETURN(std::vector> executables, + BuildExecutables(module_protos, std::move(module_configs), + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); + std::vector executable_ptrs; + executable_ptrs.reserve(executables.size()); + for (const auto& executable : executables) { + executable_ptrs.push_back(executable.get()); + } + + // Execute the generated executables in parallel and return the device + // handles for each computation's output. + ExecutionProfile profile; + TF_ASSIGN_OR_RETURN( + std::vector outputs, + ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, + execute_backend_.get(), device_handles, + computation_names, &profile)); + for (const GlobalDataHandle& output : outputs) { + ExecuteResponse response; + *response.mutable_output() = output; + *response.mutable_profile() = profile; + *result->add_responses() = response; + } + + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; + return tensorflow::Status::OK(); } tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, @@ -872,6 +1017,20 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + ExecuteGraphParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::PickParallelResponse( + const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean // up the current asymmetries and support more functionalities. @@ -999,8 +1158,14 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("programe shape may not be empty"); + } - // TODO(b/74197823): Handle partitioning. + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + return ExecuteOneToN(arg, result); + } TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index ebe4a2e..e09d58b 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -278,6 +278,20 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, const UserComputation* user_computation = nullptr); + // Picks a parallel response and fills the result. + Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, + ExecuteResponse* result); + + // Prepare the executors for executing parallel. + StatusOr> GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const; + + // Prepare the arguments for executing parallel. + StatusOr>> GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments); + protected: friend class LocalExecutable; @@ -334,6 +348,12 @@ class Service : public ServiceInterface { Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); + StatusOr>> BuildExecutables( + const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and @@ -378,6 +398,8 @@ class Service : public ServiceInterface { // will be the result of this computation. tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); + tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 283efbb..9cead12 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1566,6 +1566,8 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 045148c..32e2f2c 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -109,14 +111,14 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XLA_TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { - Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr const_arg, client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Add(b.Parameter(0, shape, "param_0"), b.ConstantR2({{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); @@ -124,14 +126,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::ComputationInstance( + computation_instances.push_back(Client::XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, -- 2.7.4