Respect any device filters in {Create,Delete}WorkerSessions().
authorDerek Murray <mrry@google.com>
Fri, 20 Apr 2018 01:12:57 +0000 (18:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 20 Apr 2018 01:15:41 +0000 (18:15 -0700)
This is another step towards enabling us to turn on explicit worker
sessions for all master sessions.

PiperOrigin-RevId: 193605565

tensorflow/core/distributed_runtime/master.cc
tensorflow/core/distributed_runtime/master_env.h
tensorflow/core/distributed_runtime/master_session.cc
tensorflow/core/distributed_runtime/master_session.h
tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc

index f47502e..288656e 100644 (file)
@@ -417,9 +417,13 @@ void Master::CreateSession(const CreateSessionRequest* req,
     SessionOptions options;
     options.config = req->config();
 
+    std::vector<string> filtered_worker_list;
+    DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
+                                   worker_cache, &filtered_worker_list);
+
     MasterSession* session = env_->master_session_factory(
         options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
-        std::move(device_set));
+        std::move(device_set), std::move(filtered_worker_list));
 
     GraphDef* gdef =
         const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
index 178c5b4..16f4d93 100644 (file)
@@ -83,7 +83,8 @@ struct MasterEnv {
       SessionOptions, MasterEnv*,
       std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
       std::unique_ptr<WorkerCacheInterface>,
-      std::unique_ptr<DeviceSet> device_set)>
+      std::unique_ptr<DeviceSet> device_set,
+      std::vector<string> filtered_worker_list)>
       master_session_factory;
 
   std::function<Status(const WorkerCacheFactoryOptions&,
index 7868200..ebe350d 100644 (file)
@@ -416,6 +416,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
   if (!s.ok()) {
     for (Part& part : partitions_) {
       worker_cache_->ReleaseWorker(part.name, part.worker);
+      part.worker = nullptr;
     }
     return s;
   }
@@ -1119,6 +1120,7 @@ MasterSession::MasterSession(
     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
     std::unique_ptr<WorkerCacheInterface> worker_cache,
     std::unique_ptr<DeviceSet> device_set,
+    std::vector<string> filtered_worker_list,
     StatsPublisherFactory stats_publisher_factory)
     : session_opts_(opt),
       env_(env),
@@ -1126,6 +1128,7 @@ MasterSession::MasterSession(
       remote_devs_(std::move(remote_devs)),
       worker_cache_(std::move(worker_cache)),
       devices_(std::move(device_set)),
+      filtered_worker_list_(std::move(filtered_worker_list)),
       stats_publisher_factory_(std::move(stats_publisher_factory)),
       graph_version_(0),
       run_graphs_(5),
@@ -1183,9 +1186,8 @@ Status MasterSession::Create(GraphDef* graph_def,
 
 Status MasterSession::CreateWorkerSessions(
     const WorkerCacheFactoryOptions& options) {
-  std::vector<string> worker_names;
+  const std::vector<string> worker_names = filtered_worker_list_;
   WorkerCacheInterface* worker_cache = get_worker_cache();
-  worker_cache->ListWorkers(&worker_names);
 
   struct WorkerGroup {
     // The worker name. (Not owned.)
@@ -1263,8 +1265,7 @@ Status MasterSession::CreateWorkerSessions(
 
 Status MasterSession::DeleteWorkerSessions() {
   WorkerCacheInterface* worker_cache = get_worker_cache();
-  std::vector<string> worker_names;
-  worker_cache->ListWorkers(&worker_names);
+  const std::vector<string>& worker_names = filtered_worker_list_;
 
   struct WorkerGroup {
     // The worker name. (Not owned.)
index a054199..ec34e20 100644 (file)
@@ -52,6 +52,7 @@ class MasterSession : public core::RefCounted {
       std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
       std::unique_ptr<WorkerCacheInterface> worker_cache,
       std::unique_ptr<DeviceSet> device_set,
+      std::vector<string> filtered_worker_list,
       StatsPublisherFactory stats_publisher_factory);
 
   // Initialize the MasterSession for "def".  Must be called before Extend(),
@@ -130,6 +131,10 @@ class MasterSession : public core::RefCounted {
   // The device set used by this session.
   std::unique_ptr<DeviceSet> devices_;
 
+  // The (partial device) names of remote worker tasks that this
+  // session will contact.
+  const std::vector<string> filtered_worker_list_;
+
   StatsPublisherFactory stats_publisher_factory_;
 
   std::atomic_ulong last_access_time_usec_;
@@ -212,7 +217,6 @@ class MasterSession : public core::RefCounted {
   // workers.
   Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
 
-  // TODO(b/36574172): Always use Create/DeleteWorkerSession.
   bool should_delete_worker_sessions_ = false;
   Status DeleteWorkerSessions();
 
index be19103..488dcde 100644 (file)
@@ -222,10 +222,12 @@ Status GrpcServer::Init(
           SessionOptions options, const MasterEnv* env,
           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
           std::unique_ptr<WorkerCacheInterface> worker_cache,
-          std::unique_ptr<DeviceSet> device_set) {
+          std::unique_ptr<DeviceSet> device_set,
+          std::vector<string> filtered_worker_list) {
         options.config.MergeFrom(config);
         return new MasterSession(options, env, std::move(remote_devs),
                                  std::move(worker_cache), std::move(device_set),
+                                 std::move(filtered_worker_list),
                                  stats_factory);
       };
   master_env_.worker_cache_factory =