[XLA] Redesign: delete xla::Computation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 23:23:33 +0000 (16:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 23:26:35 +0000 (16:26 -0700)
PiperOrigin-RevId: 197069851

14 files changed:
tensorflow/compiler/xla/client/BUILD
tensorflow/compiler/xla/client/client.cc
tensorflow/compiler/xla/client/client.h
tensorflow/compiler/xla/client/compile_only_client.cc
tensorflow/compiler/xla/client/compile_only_client.h
tensorflow/compiler/xla/client/computation.cc [deleted file]
tensorflow/compiler/xla/client/computation.h [deleted file]
tensorflow/compiler/xla/client/lib/testing.cc
tensorflow/compiler/xla/client/lib/testing.h
tensorflow/compiler/xla/client/local_client.cc
tensorflow/compiler/xla/client/local_client.h
tensorflow/compiler/xla/client/xla_client/xla_builder.h
tensorflow/compiler/xla/client/xla_client/xla_computation.h
tensorflow/compiler/xla/tools/BUILD

index 9d86827..aacb394 100644 (file)
@@ -63,7 +63,6 @@ cc_library(
     srcs = ["client.cc"],
     hdrs = ["client.h"],
     deps = [
-        ":computation",
         ":global_data",
         "//tensorflow/compiler/xla:execution_options_util",
         "//tensorflow/compiler/xla:literal_util",
@@ -99,7 +98,6 @@ cc_library(
     hdrs = ["local_client.h"],
     deps = [
         ":client",
-        ":computation",
         ":executable_build_options",
         "//tensorflow/compiler/xla:executable_run_options",
         "//tensorflow/compiler/xla:status_macros",
@@ -126,7 +124,6 @@ cc_library(
     hdrs = ["compile_only_client.h"],
     deps = [
         ":client",
-        ":computation",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:util",
@@ -163,22 +160,6 @@ cc_library(
 )
 
 cc_library(
-    name = "computation",
-    srcs = ["computation.cc"],
-    hdrs = ["computation.h"],
-    deps = [
-        "//tensorflow/compiler/xla:service_interface",
-        "//tensorflow/compiler/xla:status_macros",
-        "//tensorflow/compiler/xla:statusor",
-        "//tensorflow/compiler/xla:util",
-        "//tensorflow/compiler/xla:xla_data_proto",
-        "//tensorflow/compiler/xla:xla_proto",
-        "//tensorflow/compiler/xla/service:session_proto",
-        "//tensorflow/core:lib",
-    ],
-)
-
-cc_library(
     name = "sharding_builder",
     srcs = ["sharding_builder.cc"],
     hdrs = ["sharding_builder.h"],
index 10a2d97..c9d275a 100644 (file)
@@ -162,22 +162,6 @@ Status Client::ResetDevice() {
 }
 
 StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
-    const Computation& computation,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-    const ExecutionOptions* execution_options,
-    ExecutionProfile* execution_profile) {
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<GlobalData> data,
-      Execute(computation, arguments, execution_options, execution_profile));
-
-  const Shape* shape_with_output_layout = nullptr;
-  if (execution_options && execution_options->has_shape_with_output_layout()) {
-    shape_with_output_layout = &execution_options->shape_with_output_layout();
-  }
-  return Transfer(*data, shape_with_output_layout);
-}
-
-StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
     const XlaComputation& computation,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const ExecutionOptions* execution_options,
@@ -227,46 +211,6 @@ StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
 }
 
 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
-    const Computation& computation,
-    tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-    const ExecutionOptions* execution_options,
-    ExecutionProfile* execution_profile) {
-  ExecuteRequest request;
-  *request.mutable_computation() = computation.handle();
-
-  if (execution_options == nullptr) {
-    *request.mutable_execution_options() = CreateDefaultExecutionOptions();
-  } else {
-    *request.mutable_execution_options() = *execution_options;
-  }
-  for (GlobalData* argument : arguments) {
-    CHECK(argument != nullptr) << "Argument pointers must not be null.";
-    *request.add_arguments() = argument->handle();
-  }
-
-  ExecuteResponse response;
-  VLOG(1) << "making execute request: " << request.ShortDebugString();
-  Status s = stub_->Execute(&request, &response);
-  VLOG(1) << "done with request";
-
-  if (!s.ok()) {
-    return s;
-  }
-
-  if (execution_profile != nullptr) {
-    *execution_profile = response.profile();
-    if (VLOG_IS_ON(1)) {
-      TF_ASSIGN_OR_RETURN(
-          auto execution_stats,
-          ExecutionStatsAsString(computation, response.profile()));
-      VLOG(1) << execution_stats;
-    }
-  }
-
-  return MakeUnique<GlobalData>(stub_, response.output());
-}
-
-StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
     const XlaComputation& computation,
     tensorflow::gtl::ArraySlice<GlobalData*> arguments,
     const ExecutionOptions* execution_options,
@@ -307,41 +251,6 @@ StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
 }
 
 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
-    tensorflow::gtl::ArraySlice<ComputationInstance> computations) {
-  ExecuteParallelRequest request;
-
-  for (const ComputationInstance& computation : computations) {
-    ExecuteRequest single_request;
-    *single_request.mutable_computation() = computation.computation.handle();
-    for (GlobalData* argument : computation.arguments) {
-      *single_request.add_arguments() = argument->handle();
-    }
-    *single_request.mutable_execution_options() = computation.execution_options;
-    *request.add_requests() = single_request;
-  }
-
-  ExecuteParallelResponse response;
-  VLOG(1) << "making execute-parallel request: " << request.ShortDebugString();
-  Status s = stub_->ExecuteParallel(&request, &response);
-  VLOG(1) << "done with request";
-
-  if (!s.ok()) {
-    return s;
-  }
-
-  std::vector<std::unique_ptr<GlobalData>> outputs;
-  for (size_t i = 0; i < computations.size(); ++i) {
-    outputs.push_back(
-        MakeUnique<GlobalData>(stub_, response.responses(i).output()));
-    if (computations[i].execution_profile != nullptr) {
-      *computations[i].execution_profile = response.responses(i).profile();
-    }
-  }
-
-  return std::move(outputs);
-}
-
-StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
     tensorflow::gtl::ArraySlice<XlaComputationInstance> computations) {
   ExecuteGraphParallelRequest request;
 
@@ -436,24 +345,6 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
 }
 
 StatusOr<ComputationStats> Client::GetComputationStats(
-    const Computation& computation, const DebugOptions& debug_options) const {
-  ComputationStatsRequest request;
-  *request.mutable_computation() = computation.handle();
-  *request.mutable_debug_options() = debug_options;
-  ComputationStatsResponse response;
-
-  VLOG(1) << "making computation stats request";
-  Status s = stub_->GetComputationStats(&request, &response);
-  VLOG(1) << "done with request";
-
-  if (!s.ok()) {
-    return s;
-  }
-  CHECK(response.has_stats());
-  return response.stats();
-}
-
-StatusOr<ComputationStats> Client::GetComputationStats(
     const XlaComputation& computation,
     const DebugOptions& debug_options) const {
   ComputationGraphStatsRequest request;
@@ -475,23 +366,6 @@ StatusOr<ComputationStats> Client::GetComputationStats(
 }
 
 StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
-    const Computation& computation) {
-  GetComputationShapeRequest request;
-  *request.mutable_computation() = computation.handle();
-  GetComputationShapeResponse response;
-
-  VLOG(1) << "making get-computation-shape request";
-  Status s = stub_->GetComputationShape(&request, &response);
-  VLOG(1) << "done with request";
-
-  if (!s.ok()) {
-    return s;
-  }
-
-  return WrapUnique(response.release_program_shape());
-}
-
-StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
     const XlaComputation& computation) {
   TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
   return MakeUnique<ProgramShape>(result);
@@ -514,28 +388,6 @@ StatusOr<Shape> Client::GetShape(const GlobalData& data) {
 }
 
 StatusOr<string> Client::ExecutionStatsAsString(
-    const Computation& computation, const ExecutionProfile& profile) {
-  TF_ASSIGN_OR_RETURN(
-      auto computation_stats,
-      GetComputationStats(computation,
-                          legacy_flags::GetDebugOptionsFromFlags()));
-  int64 total_flops =
-      computation_stats.flop_count() + computation_stats.transcendental_count();
-  if (profile.compute_time_ns() > 0) {
-    int64 nanoseconds = profile.compute_time_ns();
-    int64 cycle_count = profile.compute_cycle_count();
-    double gflops = total_flops / nanoseconds;
-    return tensorflow::strings::StrCat(
-        "[Execution Statistics] flop count: ", computation_stats.flop_count(),
-        ", transcendental count: ", computation_stats.transcendental_count(),
-        ", compute execution time: ", nanoseconds, " nsec",
-        ", compute cycles: ", cycle_count, ", performance: ", gflops,
-        "gflop/s");
-  }
-  return string("[Execution Statistics] not available.");
-}
-
-StatusOr<string> Client::ExecutionStatsAsString(
     const XlaComputation& computation, const ExecutionProfile& profile) {
   TF_ASSIGN_OR_RETURN(
       auto computation_stats,
index d359e87..d57e253 100644 (file)
@@ -19,7 +19,6 @@ limitations under the License.
 #include <memory>
 #include <vector>
 
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/literal_util.h"
@@ -53,21 +52,6 @@ class Client {
   // * If execution_profile is not nullptr then the pointed-to ExecutionProfile
   //   will be filled with profile data from the execution.
   StatusOr<std::unique_ptr<GlobalData>> Execute(
-      const Computation& computation,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-      const ExecutionOptions* execution_options = nullptr,
-      ExecutionProfile* execution_profile = nullptr);
-
-  // Executes the computation with the given arguments and returns the global
-  // data that was produced from the execution.
-  // * If execution_options is not nullptr, these options are passed to the
-  //   service to affect how it compiles our computation.  (The pointer does not
-  //   need to live beyond this call.)
-  // * If execution_profile is not nullptr then the pointed-to ExecutionProfile
-  //   will be filled with profile data from the execution.
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
-  StatusOr<std::unique_ptr<GlobalData>> Execute(
       const XlaComputation& computation,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
       const ExecutionOptions* execution_options = nullptr,
@@ -78,34 +62,6 @@ class Client {
   //   executed on the devices associated with the handles by partitioning the
   //   computation based on the attached sharding attributes. Otherwise, a
   //   device is chosen by the service.
-  struct ComputationInstance {
-    const Computation& computation;
-    std::vector<GlobalData*> arguments;
-    ExecutionOptions execution_options;
-    ExecutionProfile* execution_profile;
-
-    ComputationInstance(const Computation& computation,
-                        std::vector<GlobalData*> arguments,
-                        ExecutionOptions execution_options,
-                        ExecutionProfile* execution_profile)
-        : computation(computation),
-          arguments(std::move(arguments)),
-          execution_options(execution_options),
-          execution_profile(execution_profile) {}
-  };
-
-  // Executes a list ComputationInstances and returns global data produced from
-  // each computation.
-  StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
-      tensorflow::gtl::ArraySlice<ComputationInstance> computations);
-
-  // A struct to represent a computation instance to be executed.
-  // * If execution_options.device_handles is not empty, the computation is
-  //   executed on the devices associated with the handles by partitioning the
-  //   computation based on the attached sharding attributes. Otherwise, a
-  //   device is chosen by the service.
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   struct XlaComputationInstance {
     const XlaComputation& computation;
     std::vector<GlobalData*> arguments;
@@ -125,7 +81,6 @@ class Client {
   // Executes a list XlaComputationInstances and returns global data produced
   // from each computation.
   //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
       tensorflow::gtl::ArraySlice<XlaComputationInstance> computations);
 
@@ -178,17 +133,6 @@ class Client {
   // to the client as a literal. Parameters are defined the same as for
   // Execute() and Transfer().
   StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
-      const Computation& computation,
-      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
-      const ExecutionOptions* execution_options = nullptr,
-      ExecutionProfile* execution_profile = nullptr);
-
-  // Executes the computation with the given arguments and transfers the result
-  // to the client as a literal. Parameters are defined the same as for
-  // Execute() and Transfer().
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
-  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
       const XlaComputation& computation,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments,
       const ExecutionOptions* execution_options = nullptr,
@@ -224,12 +168,6 @@ class Client {
 
   // Retrieves the statistics of the given computation.
   StatusOr<ComputationStats> GetComputationStats(
-      const Computation& computation, const DebugOptions& debug_options) const;
-
-  // Retrieves the statistics of the given computation.
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
-  StatusOr<ComputationStats> GetComputationStats(
       const XlaComputation& computation,
       const DebugOptions& debug_options) const;
 
@@ -240,13 +178,6 @@ class Client {
   // As above, but returns the shape of the provided computation (parameter
   // types/names and return type).
   StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
-      const Computation& computation);
-
-  // As above, but returns the shape of the provided computation (parameter
-  // types/names and return type).
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
-  StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
       const XlaComputation& computation);
 
   // Creates a channel handle that can be used to transfer data between
@@ -260,8 +191,6 @@ class Client {
  private:
   // Returns the execution statistics (e.g., gflop/s) as a string from the
   // ExecutionProfile returned from an execution of the computation.
-  StatusOr<string> ExecutionStatsAsString(const Computation& computation,
-                                          const ExecutionProfile& profile);
   StatusOr<string> ExecutionStatsAsString(const XlaComputation& computation,
                                           const ExecutionProfile& profile);
 
index 96e38bc..dc69d20 100644 (file)
@@ -23,24 +23,6 @@ namespace xla {
 
 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
 CompileOnlyClient::CompileAheadOfTime(
-    const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
-    const AotCompilationOptions& options) {
-  std::vector<CompileOnlyService::AotComputationInstance> service_instances;
-  service_instances.reserve(computations.size());
-  for (const AotComputationInstance& instance : computations) {
-    service_instances.push_back({});
-    CompileOnlyService::AotComputationInstance& service_instance =
-        service_instances.back();
-    TF_RET_CHECK(instance.computation != nullptr);
-    service_instance.computation = instance.computation->handle();
-    service_instance.argument_layouts = instance.argument_layouts;
-    service_instance.result_layout = instance.result_layout;
-  }
-  return compiler_service_->CompileAheadOfTime(service_instances, options);
-}
-
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-CompileOnlyClient::CompileAheadOfTime(
     const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
     const AotCompilationOptions& options) {
   std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
index c8725b8..f9a7c31 100644 (file)
@@ -17,7 +17,6 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_
 
 #include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/service/compile_only_service.h"
 #include "tensorflow/compiler/xla/service/compiler.h"
@@ -38,26 +37,7 @@ class CompileOnlyClient : public Client {
   CompileOnlyClient(const CompileOnlyClient&) = delete;
   void operator=(const CompileOnlyClient&) = delete;
 
-  // A description of a computation to compile using CompileAheadOfTime.
-  struct AotComputationInstance {
-    const Computation* computation;
-    // Inform the compiler of the expected layout for arguments.
-    std::vector<const Shape*> argument_layouts;
-    // Specifies the expected result layout.
-    const Shape* result_layout;
-  };
-
-  // Compiles a list of computations for ahead-of-time execution.  This is
-  // intended for use in static compilation. The |options| parameter describes
-  // the target for which the compiler should emit code.
-  StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-  CompileAheadOfTime(
-      const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
-      const AotCompilationOptions& options);
-
   // A description of an xla computation to compile using CompileAheadOfTime.
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   struct AotXlaComputationInstance {
     const XlaComputation* computation;
     // Inform the compiler of the expected layout for arguments.
@@ -69,8 +49,6 @@ class CompileOnlyClient : public Client {
   // Compiles a list of xla computations for ahead-of-time execution.  This is
   // intended for use in static compilation. The |options| parameter describes
   // the target for which the compiler should emit code.
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
   CompileAheadOfTime(
       const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc
deleted file mode 100644 (file)
index e6c57bd..0000000
+++ /dev/null
@@ -1,77 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/client/computation.h"
-
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/lib/core/errors.h"
-
-namespace xla {
-
-Computation::Computation() : parent_(nullptr) {}
-
-Computation::Computation(ServiceInterface* parent,
-                         const ComputationHandle& handle)
-    : handle_(handle), parent_(parent) {}
-
-Computation::Computation(Computation&& computation)
-    : handle_(std::move(computation.handle_)), parent_(computation.parent_) {
-  computation.ResetWithoutFreeing();
-}
-
-void Computation::Reset() {
-  // TODO(b/34469253) deallocate any owned computation.
-  ResetWithoutFreeing();
-}
-
-StatusOr<std::unique_ptr<SessionModule>> Computation::Snapshot() const {
-  SnapshotComputationRequest request;
-  *request.mutable_computation() = handle_;
-  SnapshotComputationResponse response;
-
-  TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response));
-
-  return WrapUnique(response.release_module());
-}
-
-Computation::~Computation() { Reset(); }
-
-Computation& Computation::operator=(Computation&& computation) {
-  if (&computation != this) {
-    Reset();
-    handle_ = computation.handle_;
-    parent_ = computation.parent_;
-    computation.ResetWithoutFreeing();
-  }
-  return *this;
-}
-
-void Computation::ResetWithoutFreeing() {
-  handle_.Clear();
-  parent_ = nullptr;
-}
-
-StatusOr<ProgramShape> Computation::GetProgramShape() const {
-  GetComputationShapeRequest request;
-  *request.mutable_computation() = handle_;
-  GetComputationShapeResponse response;
-
-  TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response));
-
-  return std::move(*response.mutable_program_shape());
-}
-
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h
deleted file mode 100644 (file)
index 9a1bcde..0000000
+++ /dev/null
@@ -1,82 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_
-#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_
-
-#include <memory>
-
-#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service_interface.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/macros.h"
-
-namespace xla {
-
-// Wraps a ComputationHandle protobuf with a lifetime. Computation is
-// movable and not copyable to capture the same kind of unique
-// ownership that std::unique_ptr represents.
-//
-// TODO(b/74197823): Deprecated. Use XlaComputation instead.
-class Computation {
- public:
-  // Creates a null Computation.
-  Computation();
-
-  // parent: stub for the service on which we will deallocate the computation
-  //   when it is no longer needed.
-  // handle: the computation handle protobuf from the service.
-  Computation(ServiceInterface* parent, const ComputationHandle& handle);
-
-  Computation(Computation&& computation);
-
-  // Deallocates the computation.
-  ~Computation();
-
-  Computation& operator=(Computation&& computation);
-
-  // Returns the underlying handle.
-  const ComputationHandle& handle() const { return handle_; }
-
-  // Sets handle to a null state and clears any owned computation.
-  void Reset();
-
-  // Requests that we snapshot the computation into a serializable protocol
-  // buffer form.
-  StatusOr<std::unique_ptr<SessionModule>> Snapshot() const;
-
-  // Returns true if this object is a null Computation.
-  bool IsNull() const { return parent_ == nullptr; }
-
-  // Returns the "program shape" (parameter and return shapes) for this
-  // computation.
-  StatusOr<ProgramShape> GetProgramShape() const;
-
- private:
-  void ResetWithoutFreeing();
-
-  ComputationHandle handle_;  // Handle that is wrapped by this class.
-
-  // Stub that the handle is deallocated on when this object's lifetime ends.
-  ServiceInterface* parent_;
-
-  TF_DISALLOW_COPY_AND_ASSIGN(Computation);
-};
-
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_
index 9cd87f7..3380af9 100644 (file)
@@ -93,21 +93,6 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
 }
 
 std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
-    const Computation& computation, Client* client) {
-  auto program_shape =
-      client->GetComputationShape(computation).ConsumeValueOrDie();
-
-  // For every (unbound) parameter that the computation wants, we manufacture
-  // some arbitrary data so that we can invoke the computation.
-  std::vector<std::unique_ptr<GlobalData>> fake_arguments;
-  for (const Shape& parameter : program_shape->parameters()) {
-    fake_arguments.push_back(MakeFakeDataOrDie(parameter, client));
-  }
-
-  return fake_arguments;
-}
-
-std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
     const XlaComputation& computation, Client* client) {
   CHECK(computation.proto().has_program_shape())
       << "Computation should have progran shape.";
index 9e06141..dc61309 100644 (file)
@@ -34,12 +34,6 @@ std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
 
 // Returns vector of GlobalData handles of fake data (created using
 // MakeFakeDataOrDie) that are correctly shaped arguments for the given
-// computation.
-std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
-    const Computation& computation, Client* client);
-
-// Returns vector of GlobalData handles of fake data (created using
-// MakeFakeDataOrDie) that are correctly shaped arguments for the given
 // xla computation.
 std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
     const XlaComputation& computation, Client* client);
index 9d44d3a..a7c55c6 100644 (file)
@@ -262,25 +262,6 @@ Backend* LocalClient::mutable_backend() {
 }
 
 StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
-    const Computation& computation,
-    const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
-    const ExecutableBuildOptions& options) {
-  ExecutableBuildOptions updated_options = options;
-  if (options.device_ordinal() == -1) {
-    updated_options.set_device_ordinal(default_device_ordinal());
-    VLOG(3) << "Set device ordinal to default value of: "
-            << updated_options.device_ordinal();
-  }
-  TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<Executable> executable,
-      local_service_->CompileExecutable(computation.handle(), argument_layouts,
-                                        updated_options));
-  return WrapUnique(new LocalExecutable(std::move(executable),
-                                        local_service_->mutable_backend(),
-                                        updated_options));
-}
-
-StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
     const XlaComputation& computation,
     const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
     const ExecutableBuildOptions& options) {
index 3195037..d63d4ec 100644 (file)
@@ -19,7 +19,6 @@ limitations under the License.
 #include <memory>
 
 #include "tensorflow/compiler/xla/client/client.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/executable_build_options.h"
 #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/executable_run_options.h"
@@ -109,16 +108,7 @@ class LocalClient : public Client {
   void operator=(const LocalClient&) = delete;
 
   // Build and return a LocalExecutable object. The executable is compiled using
-  // the given argument layouts and options.
-  StatusOr<std::unique_ptr<LocalExecutable>> Compile(
-      const Computation& computation,
-      const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
-      const ExecutableBuildOptions& options);
-
-  // Build and return a LocalExecutable object. The executable is compiled using
   // the given XlaComputation, argument layouts and options.
-  //
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   StatusOr<std::unique_ptr<LocalExecutable>> Compile(
       const XlaComputation& computation,
       const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
index d802e43..2b3013a 100644 (file)
@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-// TODO(b/74197823): Replace computation_builder.h with this file.
-//
-// This is NOT YET ready to use.
-
 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_
 
@@ -48,8 +44,6 @@ class XlaBuilder;
 // This represents an instruction that has been enqueued using the XlaBuilder.
 // This is used to pass to subsequent computations that depends upon the
 // instruction as an operand.
-//
-// TODO(b/74197823): Replace xla::ComputationDataHandle with this one.
 class XlaOp {
  public:
   XlaOp() : handle_(0), builder_(nullptr) {}
@@ -85,8 +79,6 @@ class XlaOp {
 // A convenient interface for building up computations.
 //
 // Thread-compatible.
-//
-// TODO(b/74197823): Replace xla::ComputationBuilder with this one.
 class XlaBuilder {
  public:
   // computation_name: name to use for the built computation.
@@ -989,8 +981,6 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
 
 // RAII-style object: sets the current sharding assignment in builder on
 // construction, and sets back to the previous assignment on destruction.
-//
-// TODO(b/74197823): This is a part of a NOT YET ready refactor.
 class XlaScopedShardingAssignment {
  public:
   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
index b70b57e..0ffba20 100644 (file)
@@ -25,8 +25,6 @@ limitations under the License.
 namespace xla {
 
 // The computation graph that the user builds up with the XlaBuilder.
-//
-// TODO(b/74197823): Replace xla::Computation with this one.
 class XlaComputation {
  public:
   XlaComputation() : unique_id_(-1) {}
index 1874004..415cf9c 100644 (file)
@@ -36,7 +36,6 @@ cc_library(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/service",
@@ -63,7 +62,6 @@ tf_cc_binary(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xla/service:interpreter_plugin",
@@ -84,7 +82,6 @@ cc_library(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:global_data",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client/lib:testing",
@@ -164,7 +161,6 @@ tf_cc_binary(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service",
         "//tensorflow/compiler/xla/service:computation_tracker",
@@ -183,7 +179,6 @@ tf_cc_binary(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service",
         "//tensorflow/compiler/xla/service:hlo",
@@ -201,7 +196,6 @@ tf_cc_binary(
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/client",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/service",