From ca4bda919793cc2578e5c0f7440525261da16fdf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 22:03:16 -0700 Subject: [PATCH] [XLA] Redesign: delete the old service interface. - Computation - ComputeConstant - Execute - ExecuteAsync - ExecuteParallel - GetComputationStats - GetComputationShape - GetLocalShape - IsConstant - LoadComputationSnapshot - Op - SetReturnValue - SnapshotComputation PiperOrigin-RevId: 198669035 --- tensorflow/compiler/xla/client/client.h | 2 - tensorflow/compiler/xla/client/xla_client/BUILD | 1 - tensorflow/compiler/xla/rpc/grpc_service.cc | 88 --- tensorflow/compiler/xla/rpc/grpc_service.h | 47 -- tensorflow/compiler/xla/rpc/grpc_stub.cc | 93 --- tensorflow/compiler/xla/rpc/grpc_stub.h | 39 -- tensorflow/compiler/xla/rpc/xla_service.proto | 60 -- .../compiler/xla/service/compile_only_service.cc | 52 -- .../compiler/xla/service/compile_only_service.h | 33 - tensorflow/compiler/xla/service/local_service.cc | 64 -- tensorflow/compiler/xla/service/local_service.h | 12 - tensorflow/compiler/xla/service/service.cc | 704 --------------------- tensorflow/compiler/xla/service/service.h | 76 --- tensorflow/compiler/xla/service_interface.h | 41 -- 14 files changed, 1312 deletions(-) diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index cda8a71..68f0d0a 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -153,8 +153,6 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 0d6e207..507a2dc 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -37,7 +37,6 @@ cc_library( ], ) -# TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 5f4dc6b..4e1435f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -32,19 +32,6 @@ namespace xla { return tensorflow::ToGrpcStatus(s); } -::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Computation(arg, result); }); -} - -::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, - const OpRequest* arg, OpResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Op(arg, result); }); -} - ::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) { @@ -60,21 +47,6 @@ namespace xla { }); } -::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - return DelegateRPC([this, arg, results]() { - return service_->SetReturnValue(arg, results); - }); -} - -::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Execute(arg, result); }); -} - ::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, const ExecuteGraphRequest* arg, ExecuteResponse* result) { @@ -82,13 +54,6 @@ namespace xla { [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); } -::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); -} - ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { @@ -136,20 +101,6 @@ namespace xla { [this, arg, result]() { return service_->ResetDevice(arg, result); }); } -::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->IsConstant(arg, result); }); -} - -::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ComputeConstant(arg, result); }); -} - ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) { @@ -157,43 +108,4 @@ namespace xla { [this, arg, result]() { return service_->GetShape(arg, result); }); } -::grpc::Status GRPCService::GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationShape(arg, result); - }); -} - -::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->GetLocalShape(arg, result); }); -} - -::grpc::Status GRPCService::GetComputationStats( - ::grpc::ServerContext* context, const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationStats(arg, result); - }); -} - -::grpc::Status GRPCService::SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->SnapshotComputation(arg, result); - }); -} - -::grpc::Status GRPCService::LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->LoadComputationSnapshot(arg, result); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 50f0279..5cd5731 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service { static StatusOr> NewService( se::Platform* platform = nullptr); - ::grpc::Status Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) override; - - ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, - OpResponse* result) override; - ::grpc::Status Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) override; @@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - ::grpc::Status Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, const ExecuteGraphRequest* arg, ExecuteResponse* result) override; - ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; @@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service { const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - ::grpc::Status IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) override; - - ::grpc::Status ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - ::grpc::Status GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) override; - ::grpc::Status GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ::grpc::Status GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - ::grpc::Status GetComputationStats(::grpc::ServerContext* context, - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - ::grpc::Status SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - ::grpc::Status LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - private: std::unique_ptr<::xla::Service> service_; diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 620ac6c..7b8ab15 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,21 +62,6 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->LoadComputationSnapshot(context, *request, response); - }); -} - -Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Execute(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -84,13 +69,6 @@ Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, }); } -Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, - ExecuteParallelResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteParallel(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { @@ -99,13 +77,6 @@ Status GRPCStub::ExecuteGraphParallel( }); } -Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteAsync(context, *request, response); - }); -} - Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -120,13 +91,6 @@ Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, }); } -Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, - ComputationStatsResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationStats(context, *request, response); - }); -} - Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { @@ -135,13 +99,6 @@ Status GRPCStub::GetComputationGraphStats( }); } -Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationShape(context, *request, response); - }); -} - Status GRPCStub::GetShape(const GetShapeRequest* request, GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -163,48 +120,6 @@ Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, }); } -// Methods used by ComputationBuilder. -Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Computation(context, *request, response); - }); -} - -Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->CreateOp(context, *request, response); - }); -} - -Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetLocalShape(context, *request, response); - }); -} - -Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, - SetReturnValueResponse* responses) { - return MakeRPC([this, request, responses](::grpc::ClientContext* context) { - return grpc_stub_->SetReturnValue(context, *request, responses); - }); -} - -Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->IsConstant(context, *request, response); - }); -} - -Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, - ComputeConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ComputeConstant(context, *request, response); - }); -} - Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { @@ -213,14 +128,6 @@ Status GRPCStub::ComputeConstantGraph( }); } -// Methods used by Computation. -Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->SnapshotComputation(context, *request, response); - }); -} - // Methods used by GlobalData. Status GRPCStub::Unregister(const UnregisterRequest* request, UnregisterResponse* response) { diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 5906d45..8dfcb76 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,39 +43,21 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) override; - - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) override; - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) override; - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; @@ -85,30 +67,9 @@ class GRPCStub : public ServiceInterface { Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Methods used by ComputationBuilder. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - - Status Op(const OpRequest* arg, OpResponse* result) override; - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; - // Methods used by Computation. - Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; - // Methods used by GlobalData. Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index c47164e..92eb19e 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -75,19 +75,7 @@ service XlaService { rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { } - // Requests the program shape of the referenced computation. - rpc GetComputationShape(GetComputationShapeRequest) - returns (GetComputationShapeResponse) { - } - // Requests the statistics of the given computation. - rpc GetComputationStats(ComputationStatsRequest) - returns (ComputationStatsResponse) { - } - - // Requests the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc GetComputationGraphStats(ComputationGraphStatsRequest) returns (ComputationStatsResponse) { } @@ -121,15 +109,6 @@ service XlaService { rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { } - // Tests if an expression is a compile-time constant. - rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { - } - - // Computes the value of a constant expression. - rpc ComputeConstant(ComputeConstantRequest) - returns (ComputeConstantResponse) { - } - // Computes the value of a constant expression. The request contains the // computation graph for the constant expression. rpc ComputeConstantGraph(ComputeConstantGraphRequest) @@ -165,20 +144,6 @@ service XlaService { rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { } - // Computation creates a new computation with the given name. - // A unique ComputationHandle is returned. - rpc Computation(ComputationRequest) returns (ComputationResponse) { - } - - // Adds a new op to a computation. - rpc CreateOp(OpRequest) returns (OpResponse) { - } - - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - rpc Execute(ExecuteRequest) returns (ExecuteResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. @@ -188,38 +153,13 @@ service XlaService { // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and // execution timing. - rpc ExecuteParallel(ExecuteParallelRequest) - returns (ExecuteParallelResponse) { - } - - // Invokes the provided list of computations in parallel with the provided - // global data for each computation. Returns a list of global data output and - // execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) returns (ExecuteParallelResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { - } - // Waits until the given execution (aysnchronously launched) is complete, and // returns the global data output. rpc WaitForExecution(WaitForExecutionRequest) returns (WaitForExecutionResponse) { } - - // Serializes a computation to proto form, so it can be loaded via - // LoadComputationSnapshot. - rpc SnapshotComputation(SnapshotComputationRequest) - returns (SnapshotComputationResponse) { - } - - // Loads a computation from a captured snapshot. - rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) - returns (LoadComputationSnapshotResponse) { - } } diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d39fd73..c2e698a 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -104,56 +104,4 @@ CompileOnlyService::CompileAheadOfTime( return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); } -StatusOr>> -CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto state if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); - hlo_modules.push_back(std::move(hlo_module)); - } - - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 7f2ce0e..e6a66c2 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -38,24 +38,7 @@ class CompileOnlyService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); - // A description of a xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { HloModuleProto computation; std::vector argument_layouts; @@ -65,31 +48,15 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index f54b52b..968db7c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -136,70 +136,6 @@ ExecutionOptions CreateExecutionOptions( } // namespace StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { - return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); - } - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { - tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); - auto metadata_string = [&metadata]() -> string { - if (!metadata.has_value()) { - return ""; - } - CHECK(metadata.value() != nullptr); - const OpMetadata& m = *metadata.value(); - if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); - } - return ""; - }; - return InvalidArgument( - "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); - } - - ExecutionOptions execution_options = - CreateExecutionOptions(build_options, program_shape.get()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(build_options.device_ordinal())); - - return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); -} - -StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& build_options) { diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index b55f119..39d6734 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -41,23 +41,11 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5a813dc..79c098a 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -195,20 +195,6 @@ Service::Service(const ServiceOptions& options, } } -Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return Status::OK(); -} - Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); @@ -806,13 +792,6 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); -} - StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const { @@ -854,117 +833,6 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - - std::vector>> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; - std::vector> module_configs; - std::vector computation_names; - std::vector device_handles; - - int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { - return a + r.execution_options().device_handles_size(); - }); - if (num_requested_devices * options_.number_of_replicas() > - execute_backend_->device_count()) { - return FailedPrecondition( - "there are not enough stream executors to execute %d computations", - num_requested_devices); - } - - for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor for the i'th computation. This stream executor - // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - - // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); - - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); - - // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. Here, we care only about the - // shapes of the arguments, so, it is sufficient to use the arguments of - // replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(replicated_arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - versioned_handles.push_back(versioned_handle); - module_configs.push_back(std::move(module_config)); - computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); - all_executors.push_back(executors); - device_handles.insert(device_handles.end(), - execution_options.device_handles().begin(), - execution_options.device_handles().end()); - } - - // Build the user computations into HloModules and compile to generate the - // executables. - // - // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); - std::vector executable_ptrs; - executable_ptrs.reserve(executables.size()); - for (const auto& executable : executables) { - executable_ptrs.push_back(executable.get()); - } - - // Execute the generated executables in parallel and return the device - // handles for each computation's output. - ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; - } - - VLOG(1) << "successfully completed 'execute-parallel' request"; - return Status::OK(); -} - Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; @@ -1090,15 +958,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; @@ -1131,80 +990,6 @@ Status Service::PickParallelResponse( return Status::OK(); } -Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - // Since we care only about the shapes of the arguments, it is sufficient to - // use the arguments of replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } - - TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + user_computation->name(), result->mutable_profile())); - - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } - - VLOG(1) << "successfully completed 'execute' request"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -1310,86 +1095,6 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return Status::OK(); } -Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); - - // Set up streams. - std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - std::vector result_buffers; - for (size_t i = 0; i < streams.size(); ++i) { - const auto& stream = streams[i]; - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); - - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, - executable->ExecuteAsyncOnStream( - &service_options, replicated_arguments[i])); - - result_buffers.emplace_back(std::move(this_result_buffer)); - } - - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(result_buffers), "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return Status::OK(); -} - Status Service::WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, @@ -1556,117 +1261,6 @@ Status Service::ResetDevice(const ResetDeviceRequest* arg, return execute_backend_->ResetDevices(); } -Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return Status::OK(); -} - -Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } - return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); - } - - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - - TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); - if (arg->has_output_layout()) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - user_computation)); - - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); - - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. - // - // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); - } - *result->mutable_literal() = result_literal->ToProto(); - - return Status::OK(); -} - Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { if (!arg->has_computation()) { @@ -1716,60 +1310,6 @@ Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { return Status::OK(); } -Status Service::GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return Status::OK(); -} - -Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return Status::OK(); -} - -Status Service::GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - HloModuleConfig config; - config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); - - // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis( - execute_backend_->compiler()->ShapeSizeBytesFunction()); - - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); - - ComputationStats stats; - stats.set_flop_count(analysis.flop_count()); - stats.set_transcendental_count(analysis.transcendental_count()); - *result->mutable_stats() = stats; - return Status::OK(); -} - Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { @@ -1812,250 +1352,6 @@ Status Service::AddInstruction( return Status::OK(); } -Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - // Send does not return a value, but we need a handle to be able to - // set OpMetadata and OpSharding (device assignment). - handle_status = computation->AddSendInstruction(arg->send_request()); - break; - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return Status::OK(); -} - -Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return Status::OK(); -} - -Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return Status::OK(); -} - DeviceHandle Service::SingleComputationDeviceHandle() const { DeviceHandle device_handle; device_handle.set_handle(0); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 81fbd41..b3c0eac 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -83,11 +83,6 @@ class Service : public ServiceInterface { static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is @@ -100,35 +95,15 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - - // Executes one or more computations in parallel with the provided global data - // passed as immutable arguments. Returns global data output for each - // computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) override; @@ -143,16 +118,6 @@ class Service : public ServiceInterface { Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override; - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the @@ -190,13 +155,6 @@ class Service : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - // Tests if an expression is a compile-time constant. - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - // Computes the value of a constant expression. - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; @@ -205,43 +163,10 @@ class Service : public ServiceInterface { Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; - // Returns the program shape of the computation associated with the given - // handle. - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - // Retrieves the statistics of a computation. - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Retrieves the statistics of a computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) override; - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - Status SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - // Creates a unique channel handle that can be used for Send/Recv // instructions. Status CreateChannelHandle(const CreateChannelHandleRequest* arg, @@ -382,7 +307,6 @@ class Service : public ServiceInterface { // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 141347a..14c35e7 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,41 +47,22 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; - - virtual Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) = 0; - virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) = 0; - virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; - virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; - virtual Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; - virtual Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) = 0; - virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; - virtual Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) = 0; @@ -91,31 +72,9 @@ class ServiceInterface { virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - // Methods used by ComputationBuilder. - virtual Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) = 0; - - virtual Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; - - virtual Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) = 0; - virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) = 0; - // Methods used by Computation. - virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; - // Methods used by GlobalData. virtual Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) = 0; -- 2.7.4