Never use the LegacySession when a Master explicitly calls CreateWorkerSession.
authorDerek Murray <mrry@google.com>
Wed, 18 Apr 2018 22:27:05 +0000 (15:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 18 Apr 2018 22:30:46 +0000 (15:30 -0700)
Previously, if the session handle was unrecognized by the worker, it
would default to using the LegacySession. This prevents us from
noticing that a server has been restarted.

To address the problem in a backwards-compatible way, we add a bit to
each session-handle-carrying worker request, indicating whether the
master believes that CreateWorkerSession has been called. If this bit
is set and the handle is unrecognized, the worker will raise an
AbortedError, which can be caught by high-level frameworks such as
`tf.estimator`.

Note that CreateWorkerSession is not yet used by default, and a
follow-up change will add that.

PiperOrigin-RevId: 193427057

12 files changed:
tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
tensorflow/core/distributed_runtime/master_session.cc
tensorflow/core/distributed_runtime/message_wrappers.cc
tensorflow/core/distributed_runtime/message_wrappers.h
tensorflow/core/distributed_runtime/session_mgr.cc
tensorflow/core/distributed_runtime/session_mgr.h
tensorflow/core/distributed_runtime/session_mgr_test.cc
tensorflow/core/distributed_runtime/worker.cc
tensorflow/core/distributed_runtime/worker_session.cc
tensorflow/core/protobuf/worker.proto

index 000a03d..6edc2ec 100644 (file)
@@ -145,6 +145,7 @@ Status ClusterFunctionLibraryRuntime::Instantiate(
 
   RegisterGraphRequest req;
   req.set_session_handle(worker_session_->session_name);
+  req.set_create_worker_session_called(create_worker_session_called_);
   *req.mutable_graph_def() = gdef;
   req.mutable_graph_options()
       ->mutable_optimizer_options()
@@ -182,6 +183,7 @@ void ClusterFunctionLibraryRuntime::Run(
 
   RunGraphRequest* req = new RunGraphRequest;
   req->set_session_handle(worker_session_->session_name);
+  req->set_create_worker_session_called(create_worker_session_called_);
   req->set_graph_handle(function_data->graph_handle);
   // Borrowed from master_session.cc
   const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
index d3ca350..1ea0a3a 100644 (file)
@@ -27,8 +27,10 @@ struct WorkerSession;
 // functions across processes by making RPCs.
 class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
  public:
-  ClusterFunctionLibraryRuntime(WorkerSession* worker_session)
-      : worker_session_(worker_session) {}
+  ClusterFunctionLibraryRuntime(WorkerSession* worker_session,
+                                bool create_worker_session_called)
+      : worker_session_(worker_session),
+        create_worker_session_called_(create_worker_session_called) {}
 
   ~ClusterFunctionLibraryRuntime() override;
 
@@ -51,6 +53,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime {
 
   mutable mutex mu_;
   WorkerSession* const worker_session_ = nullptr;  // not owned.
+  const bool create_worker_session_called_;
 
   struct FunctionData {
     const string graph_handle;
index 1810996..6f96d7c 100644 (file)
@@ -44,7 +44,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
         std::unique_ptr<GraphMgr>()));
 
     cluster_flr_.reset(
-        new ClusterFunctionLibraryRuntime(worker_session_.get()));
+        new ClusterFunctionLibraryRuntime(worker_session_.get(), true));
   }
 
   Status ConstructFunctionGraphHelper(
index e0a5bb4..08020f0 100644 (file)
@@ -431,6 +431,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
     const Part& part = partitions_[i];
     Call* c = &calls[i];
     c->req.set_session_handle(session_handle_);
+    c->req.set_create_worker_session_called(!should_deregister_);
     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
     *c->req.mutable_debug_options() =
@@ -587,6 +588,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
       c->req->set_is_last_partial_run(is_last_partial_run);
     }
     c->req->set_session_handle(session_handle_);
+    c->req->set_create_worker_session_called(!should_deregister_);
     c->req->set_graph_handle(part.graph_handle);
     c->req->set_step_id(step_id);
     *c->req->mutable_exec_opts() = exec_opts;
@@ -1003,6 +1005,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
     if (!part.graph_handle.empty()) {
       Call* c = new Call;
       c->req.set_session_handle(session_handle_);
+      c->req.set_create_worker_session_called(!should_deregister_);
       c->req.set_graph_handle(part.graph_handle);
       // NOTE(mrry): We must capture `worker_cache_` since `this`
       // could be deleted before the callback is called.
index 18668b4..40bf564 100644 (file)
@@ -282,10 +282,18 @@ const string& InMemoryRunGraphRequest::session_handle() const {
   return session_handle_;
 }
 
+bool InMemoryRunGraphRequest::create_worker_session_called() const {
+  return create_worker_session_called_;
+}
+
 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
   session_handle_ = handle;
 }
 
+void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
+  create_worker_session_called_ = called;
+}
+
 const string& InMemoryRunGraphRequest::graph_handle() const {
   return graph_handle_;
 }
@@ -378,6 +386,8 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
   if (!proto_version_) {
     proto_version_.reset(new RunGraphRequest);
     proto_version_->set_session_handle(session_handle());
+    proto_version_->set_create_worker_session_called(
+        create_worker_session_called());
     proto_version_->set_graph_handle(graph_handle());
     proto_version_->set_step_id(step_id());
     *proto_version_->mutable_exec_opts() = exec_opts();
@@ -403,6 +413,15 @@ void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
   request_.set_session_handle(handle);
 }
 
+bool MutableProtoRunGraphRequest::create_worker_session_called() const {
+  return request_.create_worker_session_called();
+}
+
+void MutableProtoRunGraphRequest::set_create_worker_session_called(
+    bool called) {
+  request_.set_create_worker_session_called(called);
+}
+
 const string& MutableProtoRunGraphRequest::graph_handle() const {
   return request_.graph_handle();
 }
@@ -514,6 +533,10 @@ const string& ProtoRunGraphRequest::session_handle() const {
   return request_->session_handle();
 }
 
+bool ProtoRunGraphRequest::create_worker_session_called() const {
+  return request_->create_worker_session_called();
+}
+
 const string& ProtoRunGraphRequest::graph_handle() const {
   return request_->graph_handle();
 }
index 1f7cdb9..92c5668 100644 (file)
@@ -246,6 +246,9 @@ class RunGraphRequestWrapper {
   // namespace is used.
   virtual const string& session_handle() const = 0;
 
+  // Set to true if `CreateWorkerSession` was called for `session_handle`.
+  virtual bool create_worker_session_called() const = 0;
+
   // REQUIRED: graph_handle must be returned by a RegisterGraph call
   // to the same WorkerService.
   virtual const string& graph_handle() const = 0;
@@ -293,6 +296,7 @@ class RunGraphRequestWrapper {
 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
  public:
   virtual void set_session_handle(const string& handle) = 0;
+  virtual void set_create_worker_session_called(bool called) = 0;
   virtual void set_graph_handle(const string& handle) = 0;
   virtual void set_step_id(int64 step_id) = 0;
   virtual ExecutorOpts* mutable_exec_opts() = 0;
@@ -317,6 +321,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
   // RunGraphRequestWrapper methods.
   const string& session_handle() const override;
   const string& graph_handle() const override;
+  bool create_worker_session_called() const override;
   int64 step_id() const override;
   const ExecutorOpts& exec_opts() const override;
   size_t num_sends() const override;
@@ -331,6 +336,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
 
   // MutableRunGraphRequestWrapper methods.
   void set_session_handle(const string& handle) override;
+  void set_create_worker_session_called(bool called) override;
   void set_graph_handle(const string& handle) override;
   void set_step_id(int64 step_id) override;
   ExecutorOpts* mutable_exec_opts() override;
@@ -347,6 +353,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
 
  private:
   string session_handle_;
+  bool create_worker_session_called_;
   string graph_handle_;
   int64 step_id_;
   ExecutorOpts exec_opts_;
@@ -370,6 +377,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
  public:
   // RunGraphRequestWrapper methods.
   const string& session_handle() const override;
+  bool create_worker_session_called() const override;
   const string& graph_handle() const override;
   int64 step_id() const override;
   const ExecutorOpts& exec_opts() const override;
@@ -385,6 +393,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
 
   // MutableRunGraphRequestWrapper methods.
   void set_session_handle(const string& handle) override;
+  void set_create_worker_session_called(bool called) override;
   void set_graph_handle(const string& handle) override;
   void set_step_id(int64 step_id) override;
   ExecutorOpts* mutable_exec_opts() override;
@@ -409,6 +418,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
 
   // RunGraphRequestWrapper methods.
   const string& session_handle() const override;
+  bool create_worker_session_called() const override;
   const string& graph_handle() const override;
   int64 step_id() const override;
   const ExecutorOpts& exec_opts() const override;
index 51b9547..e51d63c 100644 (file)
@@ -98,20 +98,26 @@ Status SessionMgr::DeleteSession(const string& session) {
   return Status::OK();
 }
 
-std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSessionUnlocked(
-    const string& session) {
-  auto it = sessions_.find(session);
-  if (it == sessions_.end()) {
-    return legacy_session_;
+Status SessionMgr::WorkerSessionForSessionLocked(
+    const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
+  if (session_handle.empty()) {
+    *out_session = legacy_session_;
   } else {
-    return it->second;
+    auto it = sessions_.find(session_handle);
+    if (it == sessions_.end()) {
+      return errors::Aborted("Session handle is not found: ", session_handle,
+                             ". Possibly this worker just restarted.");
+    } else {
+      *out_session = it->second;
+    }
   }
+  return Status::OK();
 }
 
-std::shared_ptr<WorkerSession> SessionMgr::WorkerSessionForSession(
-    const string& session) {
+Status SessionMgr::WorkerSessionForSession(
+    const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
   mutex_lock l(mu_);
-  return WorkerSessionForSessionUnlocked(session);
+  return WorkerSessionForSessionLocked(session_handle, out_session);
 }
 
 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
index 4c9702d..0a10fe2 100644 (file)
@@ -50,7 +50,8 @@ class SessionMgr {
                        bool isolate_session_state);
 
   // Locates the worker session for a given session handle
-  std::shared_ptr<WorkerSession> WorkerSessionForSession(const string& session);
+  Status WorkerSessionForSession(const string& session_handle,
+                                 std::shared_ptr<WorkerSession>* out_session);
   std::shared_ptr<WorkerSession> LegacySession();
 
   Status DeleteSession(const string& session);
@@ -86,8 +87,9 @@ class SessionMgr {
 
   const WorkerCacheFactory worker_cache_factory_;
 
-  std::shared_ptr<WorkerSession> WorkerSessionForSessionUnlocked(
-      const string& session) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+  Status WorkerSessionForSessionLocked(
+      const string& session_handle, std::shared_ptr<WorkerSession>* out_session)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
   mutex mu_;
   // A map from session identifier to internal session structure.
index 4d028f7..858e636 100644 (file)
@@ -46,8 +46,8 @@ class SessionMgrTest : public ::testing::Test {
       : device_(FakeDevice::MakeCPU(
             "/job:mnist/replica:0/task:0/device:fakecpu:0")),
         mgr_(&env_, "/job:mnist/replica:0/task:0",
-             std::unique_ptr<WorkerCacheInterface>(), factory_),
-        legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {
+             std::unique_ptr<WorkerCacheInterface>(), factory_) {
+    TF_CHECK_OK(mgr_.WorkerSessionForSession("", &legacy_session_));
     env_.local_devices = {device_.get()};
   }
 
@@ -69,7 +69,8 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
 
   string session_handle = "test_session_handle";
   TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
-  auto session = mgr_.WorkerSessionForSession(session_handle);
+  std::shared_ptr<WorkerSession> session;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
   EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
   EXPECT_NE(mgr_.LegacySession(), session);
   TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
@@ -81,22 +82,26 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
   server_def.set_task_index(3);
 
   TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false));
-  auto session_1 = mgr_.WorkerSessionForSession("handle_1");
+  std::shared_ptr<WorkerSession> session_1;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_1", &session_1));
   std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices();
   EXPECT_EQ(1, devices_1.size());
 
   TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false));
-  auto session_2 = mgr_.WorkerSessionForSession("handle_2");
+  std::shared_ptr<WorkerSession> session_2;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_2", &session_2));
   std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices();
   EXPECT_EQ(1, devices_2.size());
 
   TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true));
-  auto session_3 = mgr_.WorkerSessionForSession("handle_3");
+  std::shared_ptr<WorkerSession> session_3;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_3", &session_3));
   std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices();
   EXPECT_EQ(1, devices_3.size());
 
   TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true));
-  auto session_4 = mgr_.WorkerSessionForSession("handle_4");
+  std::shared_ptr<WorkerSession> session_4;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession("handle_4", &session_4));
   std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices();
   EXPECT_EQ(1, devices_4.size());
 
@@ -109,12 +114,23 @@ TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
 TEST_F(SessionMgrTest, LegacySession) {
   ServerDef server_def;
   string session_handle = "";
-  auto session = mgr_.WorkerSessionForSession(session_handle);
+  std::shared_ptr<WorkerSession> session;
+  TF_EXPECT_OK(mgr_.WorkerSessionForSession(session_handle, &session));
   EXPECT_EQ(mgr_.LegacySession(), session);
 
   TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
 }
 
+TEST_F(SessionMgrTest, UnknownSessionHandle) {
+  ServerDef server_def;
+  string session_handle = "unknown_session_handle";
+  std::shared_ptr<WorkerSession> session;
+  Status s = mgr_.WorkerSessionForSession(session_handle, &session);
+  EXPECT_TRUE(errors::IsAborted(s));
+  EXPECT_TRUE(
+      str_util::StrContains(s.error_message(), "Session handle is not found"));
+}
+
 TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
   ServerDef server_def;
   server_def.set_job_name("worker");
@@ -124,7 +140,7 @@ TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
 }
 
 TEST_F(SessionMgrTest, DeleteLegacySession) {
-  TF_EXPECT_OK(mgr_.DeleteSession("legacy_session"));
+  TF_EXPECT_OK(mgr_.DeleteSession(""));
 }
 
 }  // namespace tensorflow
index 598652f..6b2536c 100644 (file)
@@ -59,21 +59,37 @@ void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request,
 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
                                 RegisterGraphResponse* response,
                                 StatusCallback done) {
-  auto session =
-      env_->session_mgr->WorkerSessionForSession(request->session_handle());
-  Status s = session->graph_mgr->Register(
-      request->session_handle(), request->graph_def(), request->graph_options(),
-      request->debug_options(), session->cluster_flr.get(),
-      response->mutable_graph_handle());
+  std::shared_ptr<WorkerSession> session;
+  Status s;
+  if (request->create_worker_session_called()) {
+    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+                                                   &session);
+  } else {
+    session = env_->session_mgr->LegacySession();
+  }
+  if (s.ok()) {
+    s = session->graph_mgr->Register(
+        request->session_handle(), request->graph_def(),
+        request->graph_options(), request->debug_options(),
+        session->cluster_flr.get(), response->mutable_graph_handle());
+  }
   done(s);
 }
 
 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
                                   DeregisterGraphResponse* response,
                                   StatusCallback done) {
-  auto session =
-      env_->session_mgr->WorkerSessionForSession(request->session_handle());
-  Status s = session->graph_mgr->Deregister(request->graph_handle());
+  std::shared_ptr<WorkerSession> session;
+  Status s;
+  if (request->create_worker_session_called()) {
+    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+                                                   &session);
+  } else {
+    session = env_->session_mgr->LegacySession();
+  }
+  if (s.ok()) {
+    s = session->graph_mgr->Deregister(request->graph_handle());
+  }
 
   done(s);
 }
@@ -135,11 +151,21 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
                         StatusCallback done) {
   const int64 step_id = request->step_id();
   TRACEPRINTF("RunGraph: %lld", step_id);
-  auto session =
-      env_->session_mgr->WorkerSessionForSession(request->session_handle());
+  std::shared_ptr<WorkerSession> session;
+  Status s;
+  if (request->create_worker_session_called()) {
+    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+                                                   &session);
+  } else {
+    session = env_->session_mgr->LegacySession();
+  }
+  if (!s.ok()) {
+    done(s);
+    return;
+  }
   GraphMgr::NamedTensors in;
   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
-  Status s = PrepareRunGraph(request, &in, out);
+  s = PrepareRunGraph(request, &in, out);
   if (!s.ok()) {
     delete out;
     done(s);
@@ -209,12 +235,23 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
   const int64 step_id = request->step_id();
   const string& graph_handle = request->graph_handle();
   TRACEPRINTF("PartialRunGraph: %lld", step_id);
-  auto session =
-      env_->session_mgr->WorkerSessionForSession(request->session_handle());
+  std::shared_ptr<WorkerSession> session;
+
+  Status s;
+  if (request->create_worker_session_called()) {
+    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
+                                                   &session);
+  } else {
+    session = env_->session_mgr->LegacySession();
+  }
+  if (!s.ok()) {
+    done(s);
+    return;
+  }
 
   GraphMgr::NamedTensors in;
   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
-  Status s = PrepareRunGraph(request, &in, out);
+  s = PrepareRunGraph(request, &in, out);
   auto finish = [done, out, opts](const Status& s) {
     opts->ClearCancelCallback();
     delete out;
index cb7059b..18886ba 100644 (file)
@@ -97,6 +97,7 @@ WorkerSession::WorkerSession(const string& session_name,
       worker_cache(new WorkerFreeListCache(std::move(worker_cache))),
       device_mgr(std::move(device_mgr)),
       graph_mgr(std::move(graph_mgr)),
-      cluster_flr(new ClusterFunctionLibraryRuntime(this)) {}
+      cluster_flr(
+          new ClusterFunctionLibraryRuntime(this, !session_name.empty())) {}
 
 }  // namespace tensorflow
index 3e7289b..1819a35 100644 (file)
@@ -103,6 +103,9 @@ message RegisterGraphRequest {
   // Subgraphs are scoped within one session.
   string session_handle = 1;
 
+  // Set to true if `CreateWorkerSession` was called for `session_handle`.
+  bool create_worker_session_called = 6;
+
   // "graph_def" has the subgraph of nodes for this worker, with each node
   // having its device_name filled in.
   GraphDef graph_def = 2;
@@ -144,6 +147,9 @@ message DeregisterGraphRequest {
   // empty, a single global namespace is used.
   string session_handle = 2;
 
+  // Set to true if `CreateWorkerSession` was called for `session_handle`.
+  bool create_worker_session_called = 3;
+
   // REQUIRED: graph_handle must be returned by a RegisterGraph call
   // to the same WorkerService.
   string graph_handle = 1;
@@ -200,6 +206,9 @@ message RunGraphRequest {
   // search for the graph_handle.
   string session_handle = 8;
 
+  // Set to true if `CreateWorkerSession` was called for `session_handle`.
+  bool create_worker_session_called = 10;
+
   // REQUIRED: graph_handle must be returned by a RegisterGraph call
   // to the same WorkerService.
   string graph_handle = 1;
@@ -234,7 +243,7 @@ message RunGraphRequest {
   // truncate long metadata messages.
   bool store_errors_in_response_body = 9;
 
-  // Next: 10
+  // Next: 11
 }
 
 message RunGraphResponse {