Always use the local worker name in CreateWorkerSession when not doing ClusterSpec...
authorDerek Murray <mrry@google.com>
Fri, 20 Apr 2018 22:38:06 +0000 (15:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 20 Apr 2018 22:40:46 +0000 (15:40 -0700)
Previously, the master would send a job name and task index in an
otherwise-empty ServerDef, and the worker would unquestioningly use
those to build its worker name. However, this would lead to errors if
the worker had a local name like "/job:worker/replica:1/task:0",
because the ServerDef doesn't support non-zero replica IDs, and so the
local worker would end up an inconsistent view of what its worker name
should be. In particular `WorkerSession::worker_name` would disagree
with the device names added during graph partitioning by the master,
which would lead to runtime failures ("InvalidArgumentError: Invalid
rendezvous key").

PiperOrigin-RevId: 193733855

tensorflow/core/distributed_runtime/BUILD
tensorflow/core/distributed_runtime/master_session.cc
tensorflow/core/distributed_runtime/session_mgr.cc
tensorflow/core/distributed_runtime/session_mgr_test.cc

index d564727..343dd5d 100644 (file)
@@ -145,6 +145,7 @@ tf_cc_test(
     deps = [
         ":session_mgr",
         ":worker_env",
+        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
index ebe350d..e3022f3 100644 (file)
@@ -1219,17 +1219,6 @@ Status MasterSession::CreateWorkerSessions(
     workers[i].name = &worker_names[i];
     workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
     workers[i].request.set_session_handle(handle_);
-    if (options.cluster_def) {
-      *workers[i].request.mutable_server_def()->mutable_cluster() =
-          *options.cluster_def;
-      workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
-      // Session state is always isolated when ClusterSpec propagation
-      // is in use.
-      workers[i].request.set_isolate_session_state(true);
-    } else {
-      workers[i].request.set_isolate_session_state(
-          session_opts_.config.isolate_session_state());
-    }
 
     DeviceNameUtils::ParsedName name;
     if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
@@ -1243,8 +1232,21 @@ Status MasterSession::CreateWorkerSessions(
       return status;
     }
 
-    workers[i].request.mutable_server_def()->set_job_name(name.job);
-    workers[i].request.mutable_server_def()->set_task_index(name.task);
+    if (options.cluster_def) {
+      *workers[i].request.mutable_server_def()->mutable_cluster() =
+          *options.cluster_def;
+      workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
+      workers[i].request.mutable_server_def()->set_job_name(name.job);
+      workers[i].request.mutable_server_def()->set_task_index(name.task);
+      // Session state is always isolated when ClusterSpec propagation
+      // is in use.
+      workers[i].request.set_isolate_session_state(true);
+    } else {
+      // NOTE(mrry): Do not set any component of the ServerDef,
+      // because the worker will use its local configuration.
+      workers[i].request.set_isolate_session_state(
+          session_opts_.config.isolate_session_state());
+    }
   }
 
   for (size_t i = 0; i < worker_names.size(); ++i) {
index 357e9f8..7ef4206 100644 (file)
@@ -43,6 +43,7 @@ SessionMgr::SessionMgr(
               new GraphMgr(worker_env, worker_env->device_mgr)))),
       worker_cache_factory_(std::move(worker_cache_factory)) {}
 
+/* static */
 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
   return strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:",
                          server_def.task_index());
@@ -56,13 +57,14 @@ Status SessionMgr::CreateSession(const string& session,
     return errors::InvalidArgument("Session must be non-empty.");
   }
 
-  const string worker_name = WorkerNameFromServerDef(server_def);
-
   WorkerCacheInterface* worker_cache = nullptr;
+  string worker_name;
   if (server_def.cluster().job().empty()) {
     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
+    worker_name = legacy_session_->worker_name;
   } else {
     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
+    worker_name = WorkerNameFromServerDef(server_def);
   }
 
   if (worker_cache != nullptr & default_worker_cache_.get() != nullptr) {
index 0da3338..9919211 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/core/distributed_runtime/worker_env.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/cluster.pb.h"
 
 namespace tensorflow {
 
@@ -77,6 +78,34 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
   TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
 }
 
+TEST_F(SessionMgrTest, CreateSessionClusterDefWorkerName) {
+  ServerDef server_def;
+  server_def.set_job_name("worker");
+  server_def.set_task_index(3);
+  auto job = server_def.mutable_cluster()->add_job();
+  job->set_name("worker");
+  job->mutable_tasks()->insert({3, "localhost:3333"});
+
+  string session_handle = "test_session_handle";
+  TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
+  std::shared_ptr<WorkerSession> session;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
+  EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
+  EXPECT_EQ("/job:worker/replica:0/task:3", session->worker_name);
+  TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
+}
+
+TEST_F(SessionMgrTest, CreateSessionDefaultWorkerName) {
+  ServerDef server_def;
+  string session_handle = "test_session_handle";
+  TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
+  std::shared_ptr<WorkerSession> session;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
+  EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
+  EXPECT_EQ("/job:mnist/replica:0/task:0", session->worker_name);
+  TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
+}
+
 TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
   ServerDef server_def;
   server_def.set_job_name("worker");