[XLA] Redesign: delete Client::LoadSnapeshot(SessionModule). This is a precondition...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 19:31:17 +0000 (12:31 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 19:37:29 +0000 (12:37 -0700)
PiperOrigin-RevId: 197033641

tensorflow/compiler/xla/client/BUILD
tensorflow/compiler/xla/client/client.cc
tensorflow/compiler/xla/client/client.h
tensorflow/compiler/xla/tools/BUILD
tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
tensorflow/compiler/xla/tools/replay_computation.cc
tensorflow/compiler/xla/tools/show_signature.cc

index 989cd61..9d86827 100644 (file)
@@ -76,7 +76,7 @@ cc_library(
         "//tensorflow/compiler/xla:xla_proto",
         "//tensorflow/compiler/xla/client/xla_client:xla_computation",
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
-        "//tensorflow/compiler/xla/service:session_proto",
+        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/core:lib",
     ],
 )
index 0a79b3c..10a2d97 100644 (file)
@@ -221,20 +221,6 @@ StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
   return Literal::CreateFromProto(response.literal());
 }
 
-StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
-  LoadComputationSnapshotRequest request;
-  *request.mutable_module() = module;
-  LoadComputationSnapshotResponse response;
-
-  Status s = stub_->LoadComputationSnapshot(&request, &response);
-  if (!s.ok()) {
-    return s;
-  }
-
-  VLOG(1) << "load snapshot response: " << response.ShortDebugString();
-  return Computation(stub_, response.computation());
-}
-
 StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
   TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
   return XlaComputation(module.hlo().hlo_module());
index a63ff4c..d359e87 100644 (file)
@@ -23,7 +23,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
 #include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service_interface.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -253,9 +253,6 @@ class Client {
   // two computations via a pair of Send and Recv instructions.
   StatusOr<ChannelHandle> CreateChannelHandle();
 
-  StatusOr<Computation> LoadSnapshot(const SessionModule& module);
-
-  // TODO(b/74197823): This is a part of a NOT YET ready refactor.
   StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module);
 
   ServiceInterface* stub() { return stub_; }
index 78ab2dc..1874004 100644 (file)
@@ -40,7 +40,7 @@ cc_library(
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/service",
-        "//tensorflow/compiler/xla/service:session_proto",
+        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/core:lib",
     ],
 )
@@ -65,8 +65,8 @@ tf_cc_binary(
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:local_client",
+        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xla/service:interpreter_plugin",
-        "//tensorflow/compiler/xla/service:session_proto",
         "//tensorflow/core:lib",
     ],
 )
@@ -89,7 +89,6 @@ cc_library(
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client/lib:testing",
         "//tensorflow/compiler/xla/service:hlo_proto",
-        "//tensorflow/compiler/xla/service:session_proto",
         "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
@@ -169,8 +168,8 @@ tf_cc_binary(
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service",
         "//tensorflow/compiler/xla/service:computation_tracker",
+        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xla/service:interpreter_plugin",
-        "//tensorflow/compiler/xla/service:session_proto",
         "//tensorflow/core:lib",
     ],
 )
@@ -188,8 +187,8 @@ tf_cc_binary(
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xla/service:interpreter_plugin",
-        "//tensorflow/compiler/xla/service:session_proto",
         "//tensorflow/core:lib",
     ],
 )
@@ -207,8 +206,8 @@ tf_cc_binary(
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/service",
         "//tensorflow/compiler/xla/service:hlo_graph_dumper",
+        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xla/service:interpreter_plugin",
-        "//tensorflow/compiler/xla/service:session_proto",
         "//tensorflow/core:lib",
     ],
 )
index 21ae858..befb554 100644 (file)
@@ -17,7 +17,7 @@ limitations under the License.
 //
 // Dumps a graphviz URL for a snapshot computation to the command line.
 //
-// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
 // ServiceInterface::SnapshotComputation to disk.
 //
 // The GraphViz URL is placed into the log stderr, whereas computation
@@ -30,11 +30,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,10 +48,11 @@ namespace tools {
 void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
   Client* client = ClientLibrary::LocalClientOrDie();
   for (char* arg : args) {
-    SessionModule module;
+    HloSnapshot module;
     TF_CHECK_OK(
         tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
-    Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+    XlaComputation computation =
+        client->LoadSnapshot(module).ConsumeValueOrDie();
     DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
     debug_options.set_xla_generate_hlo_graph(".*");
     ComputationStats stats =
index b82f1c8..cfb8f37 100644 (file)
@@ -21,11 +21,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
   LocalService* local_service =
       ClientLibrary::GetXlaService(client->platform());
   for (char* arg : args) {
-    SessionModule session_module;
+    HloSnapshot snapshot;
     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
-                                            &session_module));
-    auto computation_status = client->LoadSnapshot(session_module);
+                                            &snapshot));
+    auto computation_status = client->LoadSnapshot(snapshot);
     if (!computation_status.ok()) {
       fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
               computation_status.status().ToString().c_str());
       continue;
     }
-    Computation computation = computation_status.ConsumeValueOrDie();
+    XlaComputation computation = computation_status.ConsumeValueOrDie();
 
     std::unique_ptr<ProgramShape> program_shape =
         client->GetComputationShape(computation).ConsumeValueOrDie();
@@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
     build_options.set_device_ordinal(0);
     build_options.set_result_layout(program_shape->result());
     StatusOr<std::unique_ptr<Executable>> executable =
-        local_service->CompileExecutable(computation.handle(), layouts,
-                                         build_options);
+        local_service->CompileExecutable(computation, layouts, build_options);
 
     const HloModule& module = executable.ValueOrDie()->module();
 
index 05c0fdf..b815bbf 100644 (file)
@@ -19,11 +19,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -40,16 +39,16 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
   LocalService* local_service =
       ClientLibrary::GetXlaService(client->platform());
   for (char* arg : args) {
-    SessionModule session_module;
+    HloSnapshot snapshot;
     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
-                                            &session_module));
-    auto computation_status = client->LoadSnapshot(session_module);
+                                            &snapshot));
+    auto computation_status = client->LoadSnapshot(snapshot);
     if (!computation_status.ok()) {
       fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
               computation_status.status().ToString().c_str());
       continue;
     }
-    Computation computation = computation_status.ConsumeValueOrDie();
+    XlaComputation computation = computation_status.ConsumeValueOrDie();
 
     if (compile) {
       std::unique_ptr<ProgramShape> program_shape =
@@ -65,8 +64,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
       build_options.set_device_ordinal(0);
       build_options.set_result_layout(program_shape->result());
       StatusOr<std::unique_ptr<Executable>> executable =
-          local_service->CompileExecutable(computation.handle(), layouts,
-                                           build_options);
+          local_service->CompileExecutable(computation, layouts, build_options);
 
       const HloModule& module = executable.ValueOrDie()->module();
 
@@ -74,13 +72,11 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool compile) {
               local_service->backend().platform()->Name().c_str(),
               module.ToString(HloPrintOptions::ShortParsable()).c_str());
     } else {
-      const ComputationTracker& tracker = local_service->computation_tracker();
-      UserComputation* user_computation =
-          tracker.Resolve(computation.handle()).ConsumeValueOrDie();
-      VersionedComputationHandle versioned_handle =
-          user_computation->GetVersionedHandle();
+      auto config = HloModule::CreateModuleConfigFromProto(computation.proto(),
+                                                           DebugOptions())
+                        .ConsumeValueOrDie();
       std::unique_ptr<HloModule> module =
-          tracker.BuildHloModule(versioned_handle, HloModuleConfig())
+          HloModule::CreateFromProto(computation.proto(), config)
               .ConsumeValueOrDie();
 
       fprintf(stdout, "%s\n",
index 51f90b0..a5dce20 100644 (file)
@@ -28,11 +28,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/service.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
@@ -48,10 +47,11 @@ namespace tools {
 void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
   Client* client = ClientLibrary::LocalClientOrDie();
   for (char* arg : args) {
-    SessionModule module;
+    HloSnapshot module;
     TF_CHECK_OK(
         tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
-    Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+    XlaComputation computation =
+        client->LoadSnapshot(module).ConsumeValueOrDie();
     DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
     debug_options.set_xla_generate_hlo_graph(".*");
     debug_options.set_xla_hlo_dump_as_graphdef(true);
index d8cedad..df05013 100644 (file)
@@ -17,7 +17,7 @@ limitations under the License.
 //
 // Replays computations and shows the results on the command line.
 //
-// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
 // ServiceInterface::SnapshotComputation to disk.
 //
 // Computations that require arguments can be replayed using fake data by
@@ -36,14 +36,12 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/lib/testing.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/execution_options_util.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -76,13 +74,9 @@ struct Options {
 //
 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
 // otherwise, no infeed is performed.
-template <typename ModuleT>
-StatusOr<std::unique_ptr<Literal>> ReplayComputation(const ModuleT& module,
+StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
                                                      Client* client,
                                                      const Options& opts) {
-  static_assert(std::is_same<ModuleT, HloSnapshot>::value ||
-                    std::is_same<ModuleT, SessionModule>::value,
-                "Proto must be in HloSnapshot or SessionModule format");
   TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module));
 
   std::vector<std::unique_ptr<GlobalData>> arguments;
@@ -161,40 +155,13 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
   for (char* arg : args) {
     HloSnapshot snapshot;
     auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
-    if (status.ok()) {
-      StatusOr<std::unique_ptr<Literal>> result_status =
-          ReplayComputation(snapshot, client, opts);
-      if (!result_status.ok()) {
-        fprintf(stderr, "%s: error: %s\n", arg,
-                result_status.status().ToString().c_str());
-        exit_status = EXIT_FAILURE;
-        continue;
-      }
-
-      std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
-      if (result != nullptr) {
-        fprintf(stdout, "%s: %s :: %s:%s\n", arg,
-                snapshot.hlo().hlo_module().name().c_str(),
-                ShapeUtil::HumanString(result->shape()).c_str(),
-                result->ToString().c_str());
-        if (snapshot.has_result()) {
-          std::unique_ptr<Literal> literal =
-              Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
-          fprintf(stdout, "was %s:%s\n",
-                  ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
-                  literal->ToString().c_str());
-        }
-      }
-
+    if (!status.ok()) {
+      fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg,
+              status.ToString().c_str());
       continue;
     }
-    fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n",
-            arg, status.ToString().c_str());
-
-    SessionModule module;
-    TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
     StatusOr<std::unique_ptr<Literal>> result_status =
-        ReplayComputation(module, client, opts);
+        ReplayComputation(snapshot, client, opts);
     if (!result_status.ok()) {
       fprintf(stderr, "%s: error: %s\n", arg,
               result_status.status().ToString().c_str());
@@ -204,14 +171,15 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
 
     std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
     if (result != nullptr) {
-      fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
+      fprintf(stdout, "%s: %s :: %s:%s\n", arg,
+              snapshot.hlo().hlo_module().name().c_str(),
               ShapeUtil::HumanString(result->shape()).c_str(),
               result->ToString().c_str());
-      if (module.has_result()) {
+      if (snapshot.has_result()) {
         std::unique_ptr<Literal> literal =
-            Literal::CreateFromProto(module.result()).ConsumeValueOrDie();
+            Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
         fprintf(stdout, "was %s:%s\n",
-                ShapeUtil::HumanString(module.result().shape()).c_str(),
+                ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
                 literal->ToString().c_str());
       }
     }
index 1f3340c..4e53faf 100644 (file)
@@ -18,7 +18,7 @@ limitations under the License.
 // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command
 // line.
 //
-// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
 // ServiceInterface::SnapshotComputation to disk.
 //
 // The output format is:
@@ -31,9 +31,8 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/client/client.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -49,13 +48,14 @@ namespace tools {
 void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
   Client* client = ClientLibrary::LocalClientOrDie();
   for (char* arg : args) {
-    SessionModule module;
+    HloSnapshot module;
     TF_CHECK_OK(
         tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
-    Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+    auto computation = client->LoadSnapshot(module).ConsumeValueOrDie();
     std::unique_ptr<ProgramShape> shape =
         client->GetComputationShape(computation).ConsumeValueOrDie();
-    fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(),
+    fprintf(stdout, "%s: %s :: %s\n", arg,
+            module.hlo().hlo_module().name().c_str(),
             ShapeUtil::HumanString(*shape).c_str());
   }
 }