From f36c93505fc4562b90703966ff67cc4edac2c002 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 May 2018 16:23:33 -0700 Subject: [PATCH] [XLA] Redesign: delete xla::Computation. PiperOrigin-RevId: 197069851 --- tensorflow/compiler/xla/client/BUILD | 19 --- tensorflow/compiler/xla/client/client.cc | 148 --------------------- tensorflow/compiler/xla/client/client.h | 71 ---------- .../compiler/xla/client/compile_only_client.cc | 18 --- .../compiler/xla/client/compile_only_client.h | 22 --- tensorflow/compiler/xla/client/computation.cc | 77 ----------- tensorflow/compiler/xla/client/computation.h | 82 ------------ tensorflow/compiler/xla/client/lib/testing.cc | 15 --- tensorflow/compiler/xla/client/lib/testing.h | 6 - tensorflow/compiler/xla/client/local_client.cc | 19 --- tensorflow/compiler/xla/client/local_client.h | 10 -- .../compiler/xla/client/xla_client/xla_builder.h | 10 -- .../xla/client/xla_client/xla_computation.h | 2 - tensorflow/compiler/xla/tools/BUILD | 6 - 14 files changed, 505 deletions(-) delete mode 100644 tensorflow/compiler/xla/client/computation.cc delete mode 100644 tensorflow/compiler/xla/client/computation.h diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 9d86827..aacb394 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -99,7 +98,6 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", @@ -126,7 +124,6 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -163,22 +160,6 @@ cc_library( ) cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - -cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], hdrs = ["sharding_builder.h"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 10a2d97..c9d275a 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -162,22 +162,6 @@ Status Client::ResetDevice() { } StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr data, - Execute(computation, arguments, execution_options, execution_profile)); - - const Shape* shape_with_output_layout = nullptr; - if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); - } - return Transfer(*data, shape_with_output_layout); -} - -StatusOr> Client::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options, @@ -227,46 +211,6 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { } StatusOr> Client::Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); - VLOG(1) << execution_stats; - } - } - - return MakeUnique(stub_, response.output()); -} - -StatusOr> Client::Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options, @@ -307,41 +251,6 @@ StatusOr> Client::Execute( } StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; - - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - Status s = stub_->ExecuteParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { - outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); -} - -StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteGraphParallelRequest request; @@ -436,24 +345,6 @@ StatusOr>> Client::DeconstructTuple( } StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); - *request.mutable_debug_options() = debug_options; - ComputationStatsResponse response; - - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - CHECK(response.has_stats()); - return response.stats(); -} - -StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { ComputationGraphStatsRequest request; @@ -475,23 +366,6 @@ StatusOr Client::GetComputationStats( } StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); -} - -StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); return MakeUnique(result); @@ -514,28 +388,6 @@ StatusOr Client::GetShape(const GlobalData& data) { } StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN( - auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); - int64 total_flops = - computation_stats.flop_count() + computation_stats.transcendental_count(); - if (profile.compute_time_ns() > 0) { - int64 nanoseconds = profile.compute_time_ns(); - int64 cycle_count = profile.compute_cycle_count(); - double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( - "[Execution Statistics] flop count: ", computation_stats.flop_count(), - ", transcendental count: ", computation_stats.transcendental_count(), - ", compute execution time: ", nanoseconds, " nsec", - ", compute cycles: ", cycle_count, ", performance: ", gflops, - "gflop/s"); - } - return string("[Execution Statistics] not available."); -} - -StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index d359e87..d57e253 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -53,21 +52,6 @@ class Client { // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. StatusOr> Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and returns the global - // data that was produced from the execution. - // * If execution_options is not nullptr, these options are passed to the - // service to affect how it compiles our computation. (The pointer does not - // need to live beyond this call.) - // * If execution_profile is not nullptr then the pointed-to ExecutionProfile - // will be filled with profile data from the execution. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - StatusOr> Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options = nullptr, @@ -78,34 +62,6 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; - - // Executes a list ComputationInstances and returns global data produced from - // each computation. - StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); - - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct XlaComputationInstance { const XlaComputation& computation; std::vector arguments; @@ -125,7 +81,6 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); @@ -178,17 +133,6 @@ class Client { // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and transfers the result - // to the client as a literal. Parameters are defined the same as for - // Execute() and Transfer(). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - StatusOr> ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions* execution_options = nullptr, @@ -224,12 +168,6 @@ class Client { // Retrieves the statistics of the given computation. StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; - - // Retrieves the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; @@ -240,13 +178,6 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). StatusOr> GetComputationShape( - const Computation& computation); - - // As above, but returns the shape of the provided computation (parameter - // types/names and return type). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - StatusOr> GetComputationShape( const XlaComputation& computation); // Creates a channel handle that can be used to transfer data between @@ -260,8 +191,6 @@ class Client { private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, - const ExecutionProfile& profile); StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 96e38bc..dc69d20 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -23,24 +23,6 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return compiler_service_->CompileAheadOfTime(service_instances, options); -} - -StatusOr>> -CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options) { std::vector service_instances; diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index c8725b8..f9a7c31 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -38,26 +37,7 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); - // A description of an xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. @@ -69,8 +49,6 @@ class CompileOnlyClient : public Client { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. The |options| parameter describes // the target for which the compiler should emit code. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bd..0000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index 9a1bcde..0000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -// -// TODO(b/74197823): Deprecated. Use XlaComputation instead. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 9cd87f7..3380af9 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -93,21 +93,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, } std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); - - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; -} - -std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { CHECK(computation.proto().has_program_shape()) << "Computation should have progran shape."; diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 9e06141..dc61309 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -34,12 +34,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); - -// Returns vector of GlobalData handles of fake data (created using -// MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 9d44d3a..a7c55c6 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -262,25 +262,6 @@ Backend* LocalClient::mutable_backend() { } StatusOr> LocalClient::Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options) { - ExecutableBuildOptions updated_options = options; - if (options.device_ordinal() == -1) { - updated_options.set_device_ordinal(default_device_ordinal()); - VLOG(3) << "Set device ordinal to default value of: " - << updated_options.device_ordinal(); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); -} - -StatusOr> LocalClient::Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 3195037..d63d4ec 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" @@ -109,16 +108,7 @@ class LocalClient : public Client { void operator=(const LocalClient&) = delete; // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. - StatusOr> Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - - // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index d802e43..2b3013a 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(b/74197823): Replace computation_builder.h with this file. -// -// This is NOT YET ready to use. - #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ @@ -48,8 +44,6 @@ class XlaBuilder; // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the // instruction as an operand. -// -// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: XlaOp() : handle_(0), builder_(nullptr) {} @@ -85,8 +79,6 @@ class XlaOp { // A convenient interface for building up computations. // // Thread-compatible. -// -// TODO(b/74197823): Replace xla::ComputationBuilder with this one. class XlaBuilder { public: // computation_name: name to use for the built computation. @@ -989,8 +981,6 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { // RAII-style object: sets the current sharding assignment in builder on // construction, and sets back to the previous assignment on destruction. -// -// TODO(b/74197823): This is a part of a NOT YET ready refactor. class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index b70b57e..0ffba20 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -25,8 +25,6 @@ limitations under the License. namespace xla { // The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: XlaComputation() : unique_id_(-1) {} diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 1874004..415cf9c 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,7 +36,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", @@ -63,7 +62,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", @@ -84,7 +82,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", @@ -164,7 +161,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:computation_tracker", @@ -183,7 +179,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", @@ -201,7 +196,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", -- 2.7.4