Add remote session support for the MakeCallable API.
authorDerek Murray <mrry@google.com>
Sat, 7 Apr 2018 00:39:17 +0000 (17:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 7 Apr 2018 01:18:06 +0000 (18:18 -0700)
PiperOrigin-RevId: 191964391

18 files changed:
tensorflow/core/distributed_runtime/local_master.cc
tensorflow/core/distributed_runtime/local_master.h
tensorflow/core/distributed_runtime/master.cc
tensorflow/core/distributed_runtime/master.h
tensorflow/core/distributed_runtime/master_interface.h
tensorflow/core/distributed_runtime/master_session.cc
tensorflow/core/distributed_runtime/master_session.h
tensorflow/core/distributed_runtime/message_wrappers.cc
tensorflow/core/distributed_runtime/message_wrappers.h
tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
tensorflow/core/distributed_runtime/rpc/grpc_session.cc
tensorflow/core/distributed_runtime/rpc/grpc_session.h
tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
tensorflow/core/protobuf/master.proto
tensorflow/core/protobuf/master_service.proto

index aaa4cfa..7631546 100644 (file)
@@ -157,6 +157,47 @@ Status LocalMaster::Reset(CallOptions* call_options,
   return ret;
 }
 
+Status LocalMaster::MakeCallable(CallOptions* call_options,
+                                 const MakeCallableRequest* request,
+                                 MakeCallableResponse* response) {
+  Notification n;
+  Status ret;
+  master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) {
+    ret.Update(s);
+    n.Notify();
+  });
+  TF_RETURN_IF_ERROR(
+      WaitForNotification(call_options, default_timeout_in_ms_, &n));
+  return ret;
+}
+Status LocalMaster::RunCallable(CallOptions* call_options,
+                                const RunCallableRequest* request,
+                                RunCallableResponse* response) {
+  Notification n;
+  Status ret;
+  master_impl_->RunCallable(call_options, request, response,
+                            [&n, &ret](const Status& s) {
+                              ret.Update(s);
+                              n.Notify();
+                            });
+  TF_RETURN_IF_ERROR(
+      WaitForNotification(call_options, default_timeout_in_ms_, &n));
+  return ret;
+}
+Status LocalMaster::ReleaseCallable(CallOptions* call_options,
+                                    const ReleaseCallableRequest* request,
+                                    ReleaseCallableResponse* response) {
+  Notification n;
+  Status ret;
+  master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) {
+    ret.Update(s);
+    n.Notify();
+  });
+  TF_RETURN_IF_ERROR(
+      WaitForNotification(call_options, default_timeout_in_ms_, &n));
+  return ret;
+}
+
 namespace {
 mutex* get_local_master_registry_lock() {
   static mutex local_master_registry_lock(LINKER_INITIALIZED);
index c20b403..cad6bab 100644 (file)
@@ -71,6 +71,16 @@ class LocalMaster : public MasterInterface {
   Status Reset(CallOptions* call_options, const ResetRequest* request,
                ResetResponse* response) override;
 
+  Status MakeCallable(CallOptions* call_options,
+                      const MakeCallableRequest* request,
+                      MakeCallableResponse* response) override;
+  Status RunCallable(CallOptions* call_options,
+                     const RunCallableRequest* request,
+                     RunCallableResponse* response) override;
+  Status ReleaseCallable(CallOptions* call_options,
+                         const ReleaseCallableRequest* request,
+                         ReleaseCallableResponse* response);
+
   // Registers the mapping from the given `target` to the given `master`.
   //
   // WARNING: The `master` pointer remains owned by the caller. It is
index 1a48830..f47502e 100644 (file)
@@ -611,4 +611,55 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp,
   });
 }
 
+void Master::MakeCallable(const MakeCallableRequest* req,
+                          MakeCallableResponse* resp, MyClosure done) {
+  auto session = FindMasterSession(req->session_handle());
+  if (session == nullptr) {
+    done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+    return;
+  }
+
+  SchedClosure(std::bind(
+      [this, session, req, resp](MyClosure done) {
+        Status s = session->MakeCallable(*req, resp);
+        session->Unref();
+        done(s);
+      },
+      std::move(done)));
+}
+
+void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
+                         RunCallableResponse* resp, MyClosure done) {
+  auto session = FindMasterSession(req->session_handle());
+  if (session == nullptr) {
+    done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+    return;
+  }
+
+  SchedClosure(std::bind(
+      [this, session, opts, req, resp](MyClosure done) {
+        Status s = session->RunCallable(opts, *req, resp);
+        session->Unref();
+        done(s);
+      },
+      std::move(done)));
+}
+
+void Master::ReleaseCallable(const ReleaseCallableRequest* req,
+                             ReleaseCallableResponse* resp, MyClosure done) {
+  auto session = FindMasterSession(req->session_handle());
+  if (session == nullptr) {
+    done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+    return;
+  }
+
+  SchedClosure(std::bind(
+      [this, session, req, resp](MyClosure done) {
+        Status s = session->ReleaseCallable(*req, resp);
+        session->Unref();
+        done(s);
+      },
+      std::move(done)));
+}
+
 }  // end namespace tensorflow
index 678fc46..dbb337f 100644 (file)
@@ -61,6 +61,13 @@ class Master {
   // See tensorflow::Reset() and the comment on ResetRequest.
   void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done);
 
+  void MakeCallable(const MakeCallableRequest* req, MakeCallableResponse* resp,
+                    MyClosure done);
+  void RunCallable(CallOptions* opts, const RunCallableRequest* req,
+                   RunCallableResponse* resp, MyClosure done);
+  void ReleaseCallable(const ReleaseCallableRequest* req,
+                       ReleaseCallableResponse* resp, MyClosure done);
+
  private:
   typedef Master ME;
 
index bf6a2db..a8ae3cb 100644 (file)
@@ -89,6 +89,16 @@ class MasterInterface {
   virtual Status Reset(CallOptions* call_options, const ResetRequest* request,
                        ResetResponse* response) = 0;
 
+  virtual Status MakeCallable(CallOptions* call_options,
+                              const MakeCallableRequest* request,
+                              MakeCallableResponse* response) = 0;
+  virtual Status RunCallable(CallOptions* call_options,
+                             const RunCallableRequest* request,
+                             RunCallableResponse* response) = 0;
+  virtual Status ReleaseCallable(CallOptions* call_options,
+                                 const ReleaseCallableRequest* request,
+                                 ReleaseCallableResponse* response) = 0;
+
  protected:
   // NOTE: This should only be called by implementations of this
   // interface whose CreateRunStepResponse() method returns a
index 64adf35..e0a5bb4 100644 (file)
@@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
         client_graph_(std::move(cg)),
         session_opts_(session_opts),
         is_partial_(is_partial),
-        debug_opts_(bopts.callable_options.run_options().debug_options()),
+        callable_opts_(bopts.callable_options),
         worker_cache_(worker_cache),
         should_deregister_(should_deregister) {
     VLOG(1) << "Created ReffedClientGraph for node with "
@@ -94,12 +94,18 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
 
   const ClientGraph* client_graph() { return client_graph_.get(); }
 
+  const CallableOptions& callable_options() { return callable_opts_; }
+
   std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
                                                     int64 execution_count,
                                                     const RunOptions& ropts) {
     return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
   }
 
+  int64 get_and_increment_execution_count() {
+    return execution_count_.fetch_add(1);
+  }
+
   // Turn RPC logging on or off, both at the WorkerCache used by this
   // master process, and at each remote worker in use for the current
   // partitions.
@@ -178,6 +184,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
                        CallOptions* opts, const RunStepRequestWrapper& req,
                        MutableRunStepResponseWrapper* resp,
                        CancellationManager* cm, const bool is_last_partial_run);
+  Status RunPartitions(const MasterEnv* env, int64 step_id,
+                       int64 execution_count, PerStepState* pss,
+                       CallOptions* call_opts, const RunCallableRequest& req,
+                       RunCallableResponse* resp, CancellationManager* cm);
 
   // Calls workers to cleanup states for the step "step_id".  Calls
   // `done` when all cleanup RPCs have completed.
@@ -211,10 +221,11 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
   const std::unique_ptr<ClientGraph> client_graph_;
   const SessionOptions session_opts_;
   const bool is_partial_;
-  const DebugOptions& debug_opts_;
+  const CallableOptions callable_opts_;
   WorkerCacheInterface* const worker_cache_;  // Not owned.
   std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_;
   const bool should_deregister_;
+  std::atomic<int64> execution_count_ = {0};
 
   // Graph partitioned into per-location subgraphs.
   struct Part {
@@ -269,6 +280,17 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
       const PartitionOptions& popts,
       std::unordered_map<string, GraphDef> graph_partitions);
 
+  // Prepares a number of calls to workers. One call per partition.
+  // This is a generic method that handles Run, PartialRun, and RunCallable.
+  template <class FetchListType, class ClientRequestType,
+            class ClientResponseType>
+  Status RunPartitionsHelper(
+      const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
+      const FetchListType& fetches, const MasterEnv* env, int64 step_id,
+      int64 execution_count, PerStepState* pss, CallOptions* call_opts,
+      const ClientRequestType& req, ClientResponseType* resp,
+      CancellationManager* cm, bool is_last_partial_run);
+
   // Deregisters the partitions on the workers.  Called in the
   // destructor and does not wait for the rpc completion.
   void DeregisterPartitions();
@@ -411,7 +433,8 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
     c->req.set_session_handle(session_handle_);
     c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
     *c->req.mutable_graph_options() = session_opts_.config.graph_options();
-    *c->req.mutable_debug_options() = debug_opts_;
+    *c->req.mutable_debug_options() =
+        callable_opts_.run_options().debug_options();
     VLOG(2) << "Register " << c->req.graph_def().DebugString();
     auto cb = [c, &done](const Status& s) {
       c->status = s;
@@ -490,24 +513,46 @@ class RunManyGraphs {
   TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
 };
 
-Status MasterSession::ReffedClientGraph::RunPartitions(
-    const MasterEnv* env, int64 step_id, int64 execution_count,
-    PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
-    MutableRunStepResponseWrapper* resp, CancellationManager* cm,
-    const bool is_last_partial_run) {
-  VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
-          << execution_count;
-  // Maps the names of fed tensors to their index in `req`.
-  std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+namespace {
+Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
+                                MutableRunGraphRequestWrapper* worker_req,
+                                size_t index, const string& send_key) {
+  return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
+}
 
-  for (size_t i = 0; i < req.num_feeds(); ++i) {
-    if (!feeds.insert({req.feed_name(i), i}).second) {
-      return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
-    }
-  }
+Status AddSendFromClientRequest(const RunCallableRequest& client_req,
+                                MutableRunGraphRequestWrapper* worker_req,
+                                size_t index, const string& send_key) {
+  return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
+}
 
-  // Prepares a number of calls to workers. One call per partition.
+// TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
+// in-process messages.
+struct RunCallableResponseWrapper {
+  RunCallableResponse* resp;  // Not owned.
+  std::unordered_map<string, TensorProto> fetch_key_to_protos;
+
+  RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
 
+  Status AddTensorFromRunGraphResponse(
+      const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
+      size_t index) {
+    // TODO(b/74355905): Add a specialized implementation that avoids
+    // copying the tensor into the RunCallableResponse when at least
+    // two of the {client, master, worker} are in the same process.
+    return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
+  }
+};
+}  // namespace
+
+template <class FetchListType, class ClientRequestType,
+          class ClientResponseType>
+Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
+    const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
+    const FetchListType& fetches, const MasterEnv* env, int64 step_id,
+    int64 execution_count, PerStepState* pss, CallOptions* call_opts,
+    const ClientRequestType& req, ClientResponseType* resp,
+    CancellationManager* cm, bool is_last_partial_run) {
   // Collect execution cost stats on a smoothly decreasing frequency.
   ExecutorOpts exec_opts;
   if (pss->report_tensor_allocations_upon_oom) {
@@ -553,28 +598,19 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
     // We keep these as separate paths for now, to ensure we aren't
     // inadvertently slowing down the normal run path.
     if (is_partial_) {
-      for (size_t i = 0; i < req.num_feeds(); ++i) {
-        const string& name = req.feed_name(i);
-        const auto iter = part.feed_key.find(name);
+      for (const auto& name_index : feeds) {
+        const auto iter = part.feed_key.find(name_index.first.ToString());
         if (iter == part.feed_key.end()) {
           // The provided feed must be for a different partition.
           continue;
         }
         const string& key = iter->second;
-        auto feeds_iter = feeds.find(name);
-        if (feeds_iter == feeds.end()) {
-          return errors::InvalidArgument("No feed is provided for feed=", name,
-                                         ", key=", key);
-        } else if (feeds_iter->second != static_cast<size_t>(i)) {
-          return errors::Internal("Cannot find feed named \"", name,
-                                  " in request.");
-        }
-        TF_RETURN_IF_ERROR(c->req->AddSendFromRunStepRequest(req, i, key));
+        TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
+                                                    name_index.second, key));
       }
       // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
       // For now, we just iterate through partitions to find the matching key.
-      for (int i = 0; static_cast<size_t>(i) < req.num_fetches(); ++i) {
-        const string& req_fetch = req.fetch_name(i);
+      for (const string& req_fetch : fetches) {
         for (const auto& key_fetch : part.key_fetch) {
           if (key_fetch.second == req_fetch) {
             c->req->add_recv_key(key_fetch.first);
@@ -586,9 +622,13 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
       for (const auto& feed_key : part.feed_key) {
         const string& feed = feed_key.first;
         const string& key = feed_key.second;
-        const int64 feed_index = feeds[feed];
+        auto iter = feeds.find(feed);
+        if (iter == feeds.end()) {
+          return errors::Internal("No feed index found for feed: ", feed);
+        }
+        const int64 feed_index = iter->second;
         TF_RETURN_IF_ERROR(
-            c->req->AddSendFromRunStepRequest(req, feed_index, key));
+            AddSendFromClientRequest(req, c->req.get(), feed_index, key));
       }
       for (const auto& key_fetch : part.key_fetch) {
         const string& key = key_fetch.first;
@@ -622,50 +662,115 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
   } else {
     return errors::Cancelled("Step was cancelled");
   }
+  TF_RETURN_IF_ERROR(calls.status());
 
-  // Collects fetches.
-  Status status = calls.status();
-  if (status.ok()) {
-    for (int i = 0; i < num; ++i) {
-      const Part& part = partitions_[i];
-      MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
-      for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
-        auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
-        if (iter == part.key_fetch.end()) {
-          status.Update(errors::Internal("Unexpected fetch key: ",
-                                         run_graph_resp->recv_key(j)));
-          break;
-        }
-        const string& fetch = iter->second;
-        status.Update(
-            resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
-        if (!status.ok()) {
-          break;
-        }
+  // Collects fetches and metadata.
+  Status status;
+  for (int i = 0; i < num; ++i) {
+    const Part& part = partitions_[i];
+    MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
+    for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
+      auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
+      if (iter == part.key_fetch.end()) {
+        status.Update(errors::Internal("Unexpected fetch key: ",
+                                       run_graph_resp->recv_key(j)));
+        break;
       }
-      if (pss->collect_timeline) {
-        pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
+      const string& fetch = iter->second;
+      status.Update(
+          resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
+      if (!status.ok()) {
+        break;
       }
-      if (pss->collect_costs) {
-        CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
-        for (int j = 0; j < cost_graph->node_size(); ++j) {
-          resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
-              cost_graph->mutable_node(j));
-        }
+    }
+    if (pss->collect_timeline) {
+      pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
+    }
+    if (pss->collect_costs) {
+      CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
+      for (int j = 0; j < cost_graph->node_size(); ++j) {
+        resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
+            cost_graph->mutable_node(j));
       }
-      if (pss->collect_partition_graphs) {
-        protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
-            resp->mutable_metadata()->mutable_partition_graphs();
-        for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
-          partition_graph_defs->Add()->Swap(
-              run_graph_resp->mutable_partition_graph(i));
-        }
+    }
+    if (pss->collect_partition_graphs) {
+      protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+          resp->mutable_metadata()->mutable_partition_graphs();
+      for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
+        partition_graph_defs->Add()->Swap(
+            run_graph_resp->mutable_partition_graph(i));
       }
     }
   }
   return status;
 }
 
+Status MasterSession::ReffedClientGraph::RunPartitions(
+    const MasterEnv* env, int64 step_id, int64 execution_count,
+    PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
+    MutableRunStepResponseWrapper* resp, CancellationManager* cm,
+    const bool is_last_partial_run) {
+  VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+          << execution_count;
+  // Maps the names of fed tensors to their index in `req`.
+  std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+  for (size_t i = 0; i < req.num_feeds(); ++i) {
+    if (!feeds.insert({req.feed_name(i), i}).second) {
+      return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
+    }
+  }
+
+  std::vector<string> fetches;
+  fetches.reserve(req.num_fetches());
+  for (size_t i = 0; i < req.num_fetches(); ++i) {
+    fetches.push_back(req.fetch_name(i));
+  }
+
+  return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
+                             call_opts, req, resp, cm, is_last_partial_run);
+}
+
+Status MasterSession::ReffedClientGraph::RunPartitions(
+    const MasterEnv* env, int64 step_id, int64 execution_count,
+    PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
+    RunCallableResponse* resp, CancellationManager* cm) {
+  VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+          << execution_count;
+  // Maps the names of fed tensors to their index in `req`.
+  std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+  for (size_t i = 0; i < callable_opts_.feed_size(); ++i) {
+    if (!feeds.insert({callable_opts_.feed(i), i}).second) {
+      // MakeCallable will fail if there are two feeds with the same name.
+      return errors::Internal("Duplicated feeds in callable: ",
+                              callable_opts_.feed(i));
+    }
+  }
+
+  // Create a wrapped response object to collect the fetched values and
+  // rearrange them for the RunCallableResponse.
+  RunCallableResponseWrapper wrapped_resp;
+  wrapped_resp.resp = resp;
+
+  TF_RETURN_IF_ERROR(RunPartitionsHelper(
+      feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
+      call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
+
+  // Collects fetches.
+  // TODO(b/74355905): Add a specialized implementation that avoids
+  // copying the tensor into the RunCallableResponse when at least
+  // two of the {client, master, worker} are in the same process.
+  for (const string& fetch : callable_opts_.fetch()) {
+    TensorProto* fetch_proto = resp->mutable_fetch()->Add();
+    auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
+    if (iter == wrapped_resp.fetch_key_to_protos.end()) {
+      return errors::Internal("Worker did not return a value for fetch: ",
+                              fetch);
+    }
+    fetch_proto->Swap(&iter->second);
+  }
+  return Status::OK();
+}
+
 namespace {
 
 class CleanupBroadcastHelper {
@@ -1266,15 +1371,11 @@ WorkerCacheInterface* MasterSession::get_worker_cache() const {
   return env_->worker_cache;
 }
 
-Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
-                                ReffedClientGraph** rcg, bool is_partial) {
+Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
+                                ReffedClientGraph** out_rcg, int64* out_count) {
   const uint64 hash = HashBuildGraphOptions(opts);
   {
     mutex_lock l(mu_);
-    // Keep track of how many times this subgraph has been executed in
-    // this session.
-    int64* c = &subgraph_execution_counts_[hash];
-    *count = (*c)++;
     // TODO(suharshs): We cache partial run graphs and run graphs separately
     // because there is preprocessing that needs to only be run for partial
     // run calls.
@@ -1296,8 +1397,9 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
       iter = m->insert({hash, entry}).first;
       VLOG(1) << "Preparing to execute new graph";
     }
-    *rcg = iter->second;
-    (*rcg)->Ref();
+    *out_rcg = iter->second;
+    (*out_rcg)->Ref();
+    *out_count = (*out_rcg)->get_and_increment_execution_count();
   }
   return Status::OK();
 }
@@ -1316,6 +1418,12 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
   rcg_map->clear();
 }
 
+namespace {
+uint64 MakeStepId() {
+  return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+}
+}  // namespace
+
 Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
                                       PartialRunSetupResponse* resp) {
   std::vector<string> inputs, outputs, targets;
@@ -1332,15 +1440,15 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
   string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
 
   ReffedClientGraph* rcg = nullptr;
-  int64 count = 0;
 
   // Prepare.
   BuildGraphOptions opts;
   BuildBuildGraphOptions(*req, &opts);
-  TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
+  int64 count;
+  TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
   // Keeps the highest 8 bits 0x01: we reserve some bits of the
   // step_id for future use.
-  uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+  const uint64 step_id = MakeStepId();
   TRACEPRINTF("stepid %llu", step_id);
 
   rcg->Ref();
@@ -1585,6 +1693,73 @@ Status MasterSession::CreateDebuggerState(
   return Status::OK();
 }
 
+void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
+                                     const RunOptions& run_options,
+                                     uint64 step_id, int64 count,
+                                     PerStepState* out_pss,
+                                     std::unique_ptr<ProfileHandler>* out_ph) {
+  out_pss->collect_timeline =
+      run_options.trace_level() == RunOptions::FULL_TRACE;
+  out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
+  out_pss->report_tensor_allocations_upon_oom =
+      run_options.report_tensor_allocations_upon_oom();
+  // Build the cost model every 'build_cost_model_every' steps after skipping an
+  // initial 'build_cost_model_after' steps.
+  const int64 build_cost_model_after =
+      session_opts_.config.graph_options().build_cost_model_after();
+  const int64 build_cost_model_every =
+      session_opts_.config.graph_options().build_cost_model();
+  out_pss->collect_costs =
+      build_cost_model_every > 0 &&
+      ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
+  out_pss->collect_partition_graphs = run_options.output_partition_graphs();
+
+  *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
+  if (*out_ph) {
+    out_pss->collect_timeline = true;
+    out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
+  }
+}
+
+Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
+                                     uint64 step_id,
+                                     const RunOptions& run_options,
+                                     PerStepState* pss,
+                                     const std::unique_ptr<ProfileHandler>& ph,
+                                     const Status& run_status,
+                                     RunMetadata* out_run_metadata) {
+  Status s = run_status;
+  if (s.ok()) {
+    pss->end_micros = Env::Default()->NowMicros();
+
+    // Schedule post-processing and cleanup to be done asynchronously.
+    rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
+  } else if (errors::IsCancelled(s)) {
+    mutex_lock l(mu_);
+    if (closed_) {
+      if (garbage_collected_) {
+        s = errors::Cancelled(
+            "Step was cancelled because the session was garbage collected due "
+            "to inactivity.");
+      } else {
+        s = errors::Cancelled(
+            "Step was cancelled by an explicit call to `Session::Close()`.");
+      }
+    }
+  }
+  Ref();
+  rcg->Ref();
+  rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
+    if (!s.ok()) {
+      LOG(ERROR) << "Cleanup partition error: " << s;
+    }
+    rcg->Unref();
+    MarkRunCompletion();
+    Unref();
+  });
+  return s;
+}
+
 Status MasterSession::DoRunWithLocalExecution(
     CallOptions* opts, const RunStepRequestWrapper& req,
     MutableRunStepResponseWrapper* resp) {
@@ -1597,8 +1772,8 @@ Status MasterSession::DoRunWithLocalExecution(
   BuildGraphOptions bgopts;
   BuildBuildGraphOptions(req, &bgopts);
   ReffedClientGraph* rcg = nullptr;
-  int64 count = 0;
-  TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
+  int64 count;
+  TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
 
   // Unref "rcg" when out of scope.
   core::ScopedUnref unref(rcg);
@@ -1614,64 +1789,133 @@ Status MasterSession::DoRunWithLocalExecution(
 
   // Keeps the highest 8 bits 0x01: we reserve some bits of the
   // step_id for future use.
-  const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+  const uint64 step_id = MakeStepId();
   TRACEPRINTF("stepid %llu", step_id);
 
-  pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE;
-  pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
-  pss.report_tensor_allocations_upon_oom =
-      req.options().report_tensor_allocations_upon_oom();
-  // Build the cost model every 'build_cost_model_every' steps after skipping an
-  // initial 'build_cost_model_after' steps.
-  const int64 build_cost_model_after =
-      session_opts_.config.graph_options().build_cost_model_after();
-  const int64 build_cost_model_every =
-      session_opts_.config.graph_options().build_cost_model();
-  pss.collect_costs =
-      build_cost_model_every > 0 &&
-      ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
-  pss.collect_partition_graphs = req.options().output_partition_graphs();
+  std::unique_ptr<ProfileHandler> ph;
+  FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
 
-  std::unique_ptr<ProfileHandler> ph =
-      rcg->GetProfileHandler(step_id, count, req.options());
-  if (ph) {
-    pss.collect_timeline = true;
-    pss.collect_rpcs = ph->should_collect_rpcs();
+  Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
+                                &cancellation_manager_, false);
+  cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
+  return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
+                        resp->mutable_metadata());
+}
+
+Status MasterSession::MakeCallable(const MakeCallableRequest& req,
+                                   MakeCallableResponse* resp) {
+  UpdateLastAccessTime();
+
+  BuildGraphOptions opts;
+  opts.callable_options = req.options();
+  opts.use_function_convention = false;
+
+  ReffedClientGraph* callable;
+
+  {
+    mutex_lock l(mu_);
+    if (closed_) {
+      return errors::FailedPrecondition("Session is closed.");
+    }
+    std::unique_ptr<ClientGraph> client_graph;
+    TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
+    callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
+                                     session_opts_, stats_publisher_factory_,
+                                     false /* is_partial */, get_worker_cache(),
+                                     !should_delete_worker_sessions_);
+  }
+
+  Status s = BuildAndRegisterPartitions(callable);
+  if (!s.ok()) {
+    callable->Unref();
+    return s;
   }
 
+  uint64 handle;
+  {
+    mutex_lock l(mu_);
+    handle = next_callable_handle_++;
+    callables_[handle] = callable;
+  }
+
+  resp->set_handle(handle);
+  return Status::OK();
+}
+
+Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
+                                    const RunCallableRequest& req,
+                                    RunCallableResponse* resp) {
+  VLOG(2) << "DoRunCallable req: " << req.DebugString();
+  PerStepState pss;
+  pss.start_micros = Env::Default()->NowMicros();
+  auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
+
+  // Prepare.
+  int64 count = rcg->get_and_increment_execution_count();
+
+  // Keeps the highest 8 bits 0x01: we reserve some bits of the
+  // step_id for future use.
+  const uint64 step_id = MakeStepId();
+  TRACEPRINTF("stepid %llu", step_id);
+
+  const RunOptions& run_options = rcg->callable_options().run_options();
+
+  if (run_options.timeout_in_ms() != 0) {
+    opts->SetTimeout(run_options.timeout_in_ms());
+  }
+
+  std::unique_ptr<ProfileHandler> ph;
+  FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
   Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
-                                &cancellation_manager_, false);
-  if (s.ok()) {
-    pss.end_micros = Env::Default()->NowMicros();
+                                &cancellation_manager_);
+  cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
+  return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
+                        resp->mutable_metadata());
+}
 
-    // Schedule post-processing and cleanup to be done asynchronously.
-    rcg->ProcessStats(step_id, &pss, ph.get(), req.options(),
-                      resp->mutable_metadata());
-  } else if (errors::IsCancelled(s)) {
+Status MasterSession::RunCallable(CallOptions* opts,
+                                  const RunCallableRequest& req,
+                                  RunCallableResponse* resp) {
+  UpdateLastAccessTime();
+  ReffedClientGraph* callable;
+  {
     mutex_lock l(mu_);
     if (closed_) {
-      if (garbage_collected_) {
-        s = errors::Cancelled(
-            "Step was cancelled because the session was garbage collected due "
-            "to inactivity.");
-      } else {
-        s = errors::Cancelled(
-            "Step was cancelled by an explicit call to `Session::Close()`.");
-      }
+      return errors::FailedPrecondition("Session is closed.");
+    }
+    int64 handle = req.handle();
+    if (handle >= next_callable_handle_) {
+      return errors::InvalidArgument("No such callable handle: ", handle);
+    }
+    auto iter = callables_.find(req.handle());
+    if (iter == callables_.end()) {
+      return errors::InvalidArgument(
+          "Attempted to run callable after handle was released: ", handle);
     }
+    callable = iter->second;
+    callable->Ref();
+    ++num_running_;
   }
-  Ref();
-  rcg->Ref();
-  cleanup.release();  // MarkRunCompletion called in done closure.
-  rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
-    if (!s.ok()) {
-      LOG(ERROR) << "Cleanup partition error: " << s;
+  core::ScopedUnref unref_callable(callable);
+  return DoRunCallable(opts, callable, req, resp);
+}
+
+Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
+                                      ReleaseCallableResponse* resp) {
+  UpdateLastAccessTime();
+  ReffedClientGraph* to_unref = nullptr;
+  {
+    mutex_lock l(mu_);
+    auto iter = callables_.find(req.handle());
+    if (iter != callables_.end()) {
+      to_unref = iter->second;
+      callables_.erase(iter);
     }
-    rcg->Unref();
-    MarkRunCompletion();
-    Unref();
-  });
-  return s;
+  }
+  if (to_unref != nullptr) {
+    to_unref->Unref();
+  }
+  return Status::OK();
 }
 
 Status MasterSession::Close() {
@@ -1688,6 +1932,7 @@ Status MasterSession::Close() {
     }
     ClearRunsTable(&to_unref, &run_graphs_);
     ClearRunsTable(&to_unref, &partial_run_graphs_);
+    ClearRunsTable(&to_unref, &callables_);
   }
   for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
   if (should_delete_worker_sessions_) {
index 4bd4e13..a054199 100644 (file)
@@ -89,6 +89,15 @@ class MasterSession : public core::RefCounted {
 
   Status ListDevices(ListDevicesResponse* resp) const;
 
+  Status MakeCallable(const MakeCallableRequest& req,
+                      MakeCallableResponse* resp);
+
+  Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
+                     RunCallableResponse* resp);
+
+  Status ReleaseCallable(const ReleaseCallableRequest& req,
+                         ReleaseCallableResponse* resp);
+
   // Close this session and delete "*this". Returns OK if all known
   // states are cleanup successfully.
   //
@@ -140,6 +149,8 @@ class MasterSession : public core::RefCounted {
   typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
   RCGMap run_graphs_ GUARDED_BY(mu_);
   RCGMap partial_run_graphs_ GUARDED_BY(mu_);
+  int64 next_callable_handle_ GUARDED_BY(mu_) = 0;
+  RCGMap callables_ GUARDED_BY(mu_);
 
   struct PerStepState {
     bool collect_costs = false;
@@ -205,15 +216,28 @@ class MasterSession : public core::RefCounted {
   bool should_delete_worker_sessions_ = false;
   Status DeleteWorkerSessions();
 
-  Status StartStep(const BuildGraphOptions& opts, int64* count,
-                   ReffedClientGraph** graph, bool is_partial);
+  Status StartStep(const BuildGraphOptions& opts, bool is_partial,
+                   ReffedClientGraph** out_rcg, int64* out_count);
   void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
                       RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+  void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
+                        const RunOptions& run_options, uint64 step_id,
+                        int64 count, PerStepState* out_pss,
+                        std::unique_ptr<ProfileHandler>* out_ph);
   Status DoRunWithLocalExecution(CallOptions* opts,
                                  const RunStepRequestWrapper& req,
                                  MutableRunStepResponseWrapper* resp);
   Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
                       MutableRunStepResponseWrapper* resp);
+  Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
+                       const RunCallableRequest& req,
+                       RunCallableResponse* resp);
+  Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
+                        const RunOptions& run_options, PerStepState* pss,
+                        const std::unique_ptr<ProfileHandler>& ph,
+                        const Status& run_status,
+                        RunMetadata* out_run_metadata);
+
   void MarkRunCompletion();
   void UpdateLastAccessTime();
 
index 66ebb30..18668b4 100644 (file)
@@ -326,6 +326,20 @@ Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
   return Status::OK();
 }
 
+// TODO(b/74355905): Add a specialized implementation that avoids
+// copying the tensor when at least two of the {client, master,
+// worker} are in the same process.
+Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
+    const RunCallableRequest& run_callable_request, size_t i,
+    const string& send_key) {
+  Tensor tensor;
+  if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
+    return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
+  }
+  sends_.emplace_back(send_key, std::move(tensor));
+  return Status::OK();
+}
+
 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
 
 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
@@ -439,6 +453,18 @@ Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
   return Status::OK();
 }
 
+// TODO(b/74355905): Add a specialized implementation that avoids
+// copying the tensor when at least two of the {client, master,
+// worker} are in the same process.
+Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
+    const RunCallableRequest& run_callable_request, size_t i,
+    const string& send_key) {
+  NamedTensorProto* send = request_.add_send();
+  send->set_name(send_key);
+  *send->mutable_tensor() = run_callable_request.feed(i);
+  return Status::OK();
+}
+
 size_t MutableProtoRunGraphRequest::num_recvs() const {
   return request_.recv_key_size();
 }
index 79fa6f9..1f7cdb9 100644 (file)
@@ -302,6 +302,9 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
   virtual Status AddSendFromRunStepRequest(
       const RunStepRequestWrapper& run_step_request, size_t i,
       const string& send_key) = 0;
+  virtual Status AddSendFromRunCallableRequest(
+      const RunCallableRequest& run_callable_request, size_t i,
+      const string& send_key) = 0;
 
   virtual void add_recv_key(const string& recv_key) = 0;
   virtual void set_is_partial(bool is_partial) = 0;
@@ -334,6 +337,9 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
   Status AddSendFromRunStepRequest(
       const RunStepRequestWrapper& run_step_request, size_t i,
       const string& send_key) override;
+  Status AddSendFromRunCallableRequest(
+      const RunCallableRequest& run_callable_request, size_t i,
+      const string& send_key) override;
   void add_recv_key(const string& recv_key) override;
   void set_is_partial(bool is_partial) override;
   void set_is_last_partial_run(bool is_last_partial_run) override;
@@ -385,6 +391,9 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
   Status AddSendFromRunStepRequest(
       const RunStepRequestWrapper& run_step_request, size_t i,
       const string& send_key) override;
+  Status AddSendFromRunCallableRequest(
+      const RunCallableRequest& run_callable_request, size_t i,
+      const string& send_key) override;
   void add_recv_key(const string& recv_key) override;
   void set_is_partial(bool is_partial) override;
   void set_is_last_partial_run(bool is_last_partial_run) override;
index 63745e8..23968e2 100644 (file)
@@ -111,6 +111,11 @@ class GrpcMasterService : public AsyncServiceInterface {
     ENQUEUE_REQUEST(CloseSession, false);
     ENQUEUE_REQUEST(ListDevices, false);
     ENQUEUE_REQUEST(Reset, false);
+    ENQUEUE_REQUEST(MakeCallable, false);
+    for (int i = 0; i < 100; ++i) {
+      ENQUEUE_REQUEST(RunCallable, true);
+    }
+    ENQUEUE_REQUEST(ReleaseCallable, false);
 
     void* tag;
     bool ok;
@@ -236,6 +241,47 @@ class GrpcMasterService : public AsyncServiceInterface {
                         });
     ENQUEUE_REQUEST(Reset, false);
   }
+
+  // RPC handler for making a callable.
+  void MakeCallableHandler(
+      MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
+    master_impl_->MakeCallable(&call->request, &call->response,
+                               [call](const Status& status) {
+                                 call->SendResponse(ToGrpcStatus(status));
+                               });
+    ENQUEUE_REQUEST(MakeCallable, false);
+  }
+
+  // RPC handler for running a callable.
+  void RunCallableHandler(
+      MasterCall<RunCallableRequest, RunCallableResponse>* call) {
+    auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
+    CallOptions* call_opts = new CallOptions;
+    // The timeout may be overridden by a non-zero timeout in the
+    // callable's `RunOptions`; this overriding will happen inside the
+    // `MasterSession` implementation.
+    call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
+    call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+    master_impl_->RunCallable(call_opts, &call->request, &call->response,
+                              [call, call_opts, trace](const Status& status) {
+                                call->ClearCancelCallback();
+                                delete call_opts;
+                                delete trace;
+                                call->SendResponse(ToGrpcStatus(status));
+                              });
+    ENQUEUE_REQUEST(RunCallable, false);
+  }
+
+  // RPC handler for making a callable.
+  void ReleaseCallableHandler(
+      MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
+    master_impl_->ReleaseCallable(&call->request, &call->response,
+                                  [call](const Status& status) {
+                                    call->SendResponse(ToGrpcStatus(status));
+                                  });
+    ENQUEUE_REQUEST(ReleaseCallable, false);
+  }
+
 #undef ENQUEUE_REQUEST
 
   // Start tracing, including the ID attached to the RPC.
index e2016e8..c832adb 100644 (file)
@@ -36,6 +36,9 @@ static const char* grpcMasterService_method_names[] = {
     "/tensorflow.MasterService/CloseSession",
     "/tensorflow.MasterService/ListDevices",
     "/tensorflow.MasterService/Reset",
+    "/tensorflow.MasterService/MakeCallable",
+    "/tensorflow.MasterService/RunCallable",
+    "/tensorflow.MasterService/ReleaseCallable",
 };
 
 std::unique_ptr<MasterService::Stub> MasterService::NewStub(
@@ -64,7 +67,14 @@ MasterService::Stub::Stub(
       rpcmethod_ListDevices_(grpcMasterService_method_names[5],
                              ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
       rpcmethod_Reset_(grpcMasterService_method_names[6],
-                       ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {}
+                       ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+      rpcmethod_MakeCallable_(grpcMasterService_method_names[7],
+                              ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+      rpcmethod_RunCallable_(grpcMasterService_method_names[8],
+                             ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+      rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9],
+                                 ::grpc::internal::RpcMethod::NORMAL_RPC,
+                                 channel) {}
 
 ::grpc::Status MasterService::Stub::CreateSession(
     ::grpc::ClientContext* context, const CreateSessionRequest& request,
@@ -115,8 +125,29 @@ MasterService::Stub::Stub(
                                              context, request, response);
 }
 
+::grpc::Status MasterService::Stub::MakeCallable(
+    ::grpc::ClientContext* context, const MakeCallableRequest& request,
+    MakeCallableResponse* response) {
+  return ::grpc::internal::BlockingUnaryCall(
+      channel_.get(), rpcmethod_MakeCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::RunCallable(
+    ::grpc::ClientContext* context, const RunCallableRequest& request,
+    RunCallableResponse* response) {
+  return ::grpc::internal::BlockingUnaryCall(
+      channel_.get(), rpcmethod_RunCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::ReleaseCallable(
+    ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+    ReleaseCallableResponse* response) {
+  return ::grpc::internal::BlockingUnaryCall(
+      channel_.get(), rpcmethod_ReleaseCallable_, context, request, response);
+}
+
 MasterService::AsyncService::AsyncService() {
-  for (int i = 0; i < 7; ++i) {
+  for (int i = 0; i < 10; ++i) {
     AddMethod(new ::grpc::internal::RpcServiceMethod(
         grpcMasterService_method_names[i],
         ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
index 6ae94b7..3c38273 100644 (file)
@@ -79,6 +79,15 @@ class MasterService final {
     virtual ::grpc::Status Reset(::grpc::ClientContext* context,
                                  const ResetRequest& request,
                                  ResetResponse* response) = 0;
+    virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+                                        const MakeCallableRequest& request,
+                                        MakeCallableResponse* response) = 0;
+    virtual ::grpc::Status RunCallable(::grpc::ClientContext* context,
+                                       const RunCallableRequest& request,
+                                       RunCallableResponse* response) = 0;
+    virtual ::grpc::Status ReleaseCallable(
+        ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+        ReleaseCallableResponse* response) = 0;
   };
   class Stub final : public StubInterface {
    public:
@@ -104,6 +113,15 @@ class MasterService final {
     ::grpc::Status Reset(::grpc::ClientContext* context,
                          const ResetRequest& request,
                          ResetResponse* response) override;
+    ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+                                const MakeCallableRequest& request,
+                                MakeCallableResponse* response) override;
+    ::grpc::Status RunCallable(::grpc::ClientContext* context,
+                               const RunCallableRequest& request,
+                               RunCallableResponse* response) override;
+    ::grpc::Status ReleaseCallable(::grpc::ClientContext* context,
+                                   const ReleaseCallableRequest& request,
+                                   ReleaseCallableResponse* response) override;
 
    private:
     std::shared_ptr< ::grpc::ChannelInterface> channel_;
@@ -114,6 +132,9 @@ class MasterService final {
     const ::grpc::internal::RpcMethod rpcmethod_CloseSession_;
     const ::grpc::internal::RpcMethod rpcmethod_ListDevices_;
     const ::grpc::internal::RpcMethod rpcmethod_Reset_;
+    const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_;
+    const ::grpc::internal::RpcMethod rpcmethod_RunCallable_;
+    const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_;
   };
   static std::unique_ptr<Stub> NewStub(
       const std::shared_ptr< ::grpc::ChannelInterface>& channel,
@@ -179,6 +200,30 @@ class MasterService final {
       ::grpc::Service::RequestAsyncUnary(6, context, request, response,
                                          new_call_cq, notification_cq, tag);
     }
+    void RequestMakeCallable(
+        ::grpc::ServerContext* context, MakeCallableRequest* request,
+        ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response,
+        ::grpc::CompletionQueue* new_call_cq,
+        ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+      ::grpc::Service::RequestAsyncUnary(7, context, request, response,
+                                         new_call_cq, notification_cq, tag);
+    }
+    void RequestRunCallable(
+        ::grpc::ServerContext* context, RunCallableRequest* request,
+        ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response,
+        ::grpc::CompletionQueue* new_call_cq,
+        ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+      ::grpc::Service::RequestAsyncUnary(8, context, request, response,
+                                         new_call_cq, notification_cq, tag);
+    }
+    void RequestReleaseCallable(
+        ::grpc::ServerContext* context, ReleaseCallableRequest* request,
+        ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response,
+        ::grpc::CompletionQueue* new_call_cq,
+        ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+      ::grpc::Service::RequestAsyncUnary(9, context, request, response,
+                                         new_call_cq, notification_cq, tag);
+    }
   };
 };
 
index 1088e9b..1b92a79 100644 (file)
@@ -95,6 +95,28 @@ class GrpcRemoteMaster : public MasterInterface {
                 &MasterServiceStub::Reset);
   }
 
+  Status MakeCallable(CallOptions* call_options,
+                      const MakeCallableRequest* request,
+                      MakeCallableResponse* response) override {
+    ::grpc::ClientContext ctx;
+    return Call(&ctx, call_options, request, response,
+                &MasterServiceStub::MakeCallable);
+  }
+  Status RunCallable(CallOptions* call_options,
+                     const RunCallableRequest* request,
+                     RunCallableResponse* response) override {
+    ::grpc::ClientContext ctx;
+    return Call(&ctx, call_options, request, response,
+                &MasterServiceStub::RunCallable);
+  }
+  Status ReleaseCallable(CallOptions* call_options,
+                         const ReleaseCallableRequest* request,
+                         ReleaseCallableResponse* response) override {
+    ::grpc::ClientContext ctx;
+    return Call(&ctx, call_options, request, response,
+                &MasterServiceStub::ReleaseCallable);
+  }
+
  private:
   // Start tracing, attaching a unique ID to both the trace and the RPC.
   port::Tracing::TraceMe TraceRpc(StringPiece name,
index 3e79a40..fd1c150 100644 (file)
@@ -91,6 +91,15 @@ void ReEncodeConsts(GraphDef* gdef) {
 }
 }  // namespace
 
+Status GrpcSession::Handle(string* out_handle) {
+  mutex_lock l(mu_);
+  if (handle_.empty()) {
+    return errors::InvalidArgument("A session is not created yet....");
+  }
+  *out_handle = handle_;
+  return Status::OK();
+}
+
 Status GrpcSession::CreateImpl(CallOptions* call_options,
                                const GraphDef& graph) {
   {
@@ -274,14 +283,9 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
 Status GrpcSession::RunProto(CallOptions* call_options,
                              MutableRunStepRequestWrapper* req,
                              MutableRunStepResponseWrapper* resp) {
-  {
-    mutex_lock l(mu_);
-    if (handle_.empty()) {
-      return errors::InvalidArgument("A session is not created yet....");
-    }
-
-    req->set_session_handle(handle_);
-  }
+  string handle;
+  TF_RETURN_IF_ERROR(Handle(&handle));
+  req->set_session_handle(handle);
   return master_->RunStep(call_options, req, resp);
 }
 
@@ -293,14 +297,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
   PartialRunSetupRequest req;
   PartialRunSetupResponse resp;
   CallOptions call_options;
-  {
-    mutex_lock l(mu_);
-    if (handle_.empty()) {
-      return errors::InvalidArgument("A session is not created yet....");
-    }
-
-    req.set_session_handle(handle_);
-  }
+  TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
   for (const string& feed : input_names) {
     req.add_feed(feed);
   }
@@ -400,6 +397,55 @@ Status GrpcSession::Reset(const SessionOptions& options,
   return ret;
 }
 
+Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
+                                 CallableHandle* out_handle) {
+  MakeCallableRequest req;
+  TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+  *req.mutable_options() = callable_options;
+  MakeCallableResponse resp;
+  CallOptions call_options;
+  call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+  TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
+  *out_handle = resp.handle();
+  return Status::OK();
+}
+
+Status GrpcSession::RunCallable(CallableHandle handle,
+                                const std::vector<Tensor>& feed_tensors,
+                                std::vector<Tensor>* fetch_tensors,
+                                RunMetadata* run_metadata) {
+  RunCallableRequest req;
+  TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+  req.set_handle(handle);
+  for (const Tensor& feed : feed_tensors) {
+    feed.AsProtoTensorContent(req.mutable_feed()->Add());
+  }
+
+  RunCallableResponse resp;
+  CallOptions call_options;
+  call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+  TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
+  for (const TensorProto& fetch : resp.fetch()) {
+    Tensor fetch_tensor;
+    if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
+      return errors::Internal(
+          "Could not parse fetched tensor data in response from master.");
+    }
+    fetch_tensors->push_back(std::move(fetch_tensor));
+  }
+  return Status::OK();
+}
+
+Status GrpcSession::ReleaseCallable(CallableHandle handle) {
+  ReleaseCallableRequest req;
+  TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+  req.set_handle(handle);
+  ReleaseCallableResponse resp;
+  CallOptions call_options;
+  call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+  return master_->ReleaseCallable(&call_options, &req, &resp);
+}
+
 class GrpcSessionFactory : public SessionFactory {
  public:
   bool AcceptsOptions(const SessionOptions& options) override {
index d87956a..6379511 100644 (file)
@@ -82,20 +82,27 @@ class GrpcSession : public Session {
   Status Close() override;
 
   // NOTE: This API is still experimental and may change.
-  ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
-                                 const std::vector<string>& output_names,
-                                 const std::vector<string>& target_nodes,
-                                 string* handle) override;
+  Status PRunSetup(const std::vector<string>& input_names,
+                   const std::vector<string>& output_names,
+                   const std::vector<string>& target_nodes,
+                   string* handle) override;
 
   // NOTE: This API is still experimental and may change.
-  ::tensorflow::Status PRun(
-      const string& handle,
-      const std::vector<std::pair<string, Tensor> >& inputs,
-      const std::vector<string>& output_names,
-      std::vector<Tensor>* outputs) override;
+  Status PRun(const string& handle,
+              const std::vector<std::pair<string, Tensor> >& inputs,
+              const std::vector<string>& output_names,
+              std::vector<Tensor>* outputs) override;
 
   Status ListDevices(std::vector<DeviceAttributes>* response) override;
 
+  Status MakeCallable(const CallableOptions& callable_options,
+                      CallableHandle* out_handle) override;
+  Status RunCallable(CallableHandle handle,
+                     const std::vector<Tensor>& feed_tensors,
+                     std::vector<Tensor>* fetch_tensors,
+                     RunMetadata* run_metadata) override;
+  Status ReleaseCallable(CallableHandle handle) override;
+
  protected:
   // Takes ownership of `*master`.
   void SetRemoteMaster(std::unique_ptr<MasterInterface> master);
@@ -111,6 +118,8 @@ class GrpcSession : public Session {
   // The current version of the graph.
   int64 current_graph_version_ GUARDED_BY(mu_);
 
+  Status Handle(string* out_handle) LOCKS_EXCLUDED(mu_);
+
   Status RunHelper(const RunOptions& run_options,
                    const std::vector<std::pair<string, Tensor> >& inputs,
                    const std::vector<string>& output_tensor_names,
index 335c3fe..45b15a5 100644 (file)
@@ -120,6 +120,49 @@ TEST(GrpcSessionTest, BasicNonProtoAPI) {
   }
 }
 
+TEST(GrpcSessionTest, BasicCallable) {
+  GraphDef graph;
+  string node_names[3];
+  // c = a * b
+  CreateGraphDef(&graph, node_names);
+
+  std::unique_ptr<test::TestCluster> cluster;
+  TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+  std::unique_ptr<Session> session(
+      NewRemote(Options(cluster->targets()[0], 1)));
+  ASSERT_TRUE(session != nullptr);
+
+  for (int iters = 0; iters < 25; ++iters) {
+    TF_CHECK_OK(session->Create(graph));
+    {
+      // Just run to target node
+      CallableOptions opts;
+      opts.add_target(node_names[2]);
+      Session::CallableHandle handle;
+      TF_CHECK_OK(session->MakeCallable(opts, &handle));
+      TF_CHECK_OK(session->RunCallable(handle, {}, nullptr, nullptr));
+      TF_CHECK_OK(session->ReleaseCallable(handle));
+    }
+    {
+      // Run to a target node and a real tensor
+      CallableOptions opts;
+      opts.add_target(node_names[1]);
+      opts.add_fetch(node_names[2] + ":0");
+      Session::CallableHandle handle;
+      TF_CHECK_OK(session->MakeCallable(opts, &handle));
+      std::vector<Tensor> outputs;
+      TF_CHECK_OK(session->RunCallable(handle, {}, &outputs, nullptr));
+      ASSERT_EQ(1, outputs.size());
+      ASSERT_TRUE(outputs[0].IsInitialized());
+      ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
+      TF_CHECK_OK(session->ReleaseCallable(handle));
+    }
+
+    TF_CHECK_OK(session->Close());
+  }
+}
+
 TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
   GraphDef graph;
   string node_names[3];
index 0437cb1..96c9153 100644 (file)
@@ -23,6 +23,7 @@ option java_package = "org.tensorflow.distruntime";
 
 import "tensorflow/core/framework/device_attributes.proto";
 import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/tensor.proto";
 import "tensorflow/core/lib/core/error_codes.proto";
 import "tensorflow/core/protobuf/config.proto";
 import "tensorflow/core/protobuf/named_tensor.proto";
@@ -264,3 +265,70 @@ message ListDevicesResponse {
   repeated DeviceAttributes local_device = 1;
   repeated DeviceAttributes remote_device = 2;
 }
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// MakeCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message MakeCallableRequest {
+  // REQUIRED: session_handle must be returned by a CreateSession call
+  // to the same master service.
+  string session_handle = 1;
+
+  // Options that define the behavior of the created callable.
+  CallableOptions options = 2;
+}
+
+message MakeCallableResponse {
+  // A handle to the created callable.
+  int64 handle = 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RunCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RunCallableRequest {
+  // REQUIRED: session_handle must be returned by a CreateSession call
+  // to the same master service.
+  string session_handle = 1;
+  // REQUIRED: handle must be returned by a MakeCallable call to the same
+  // master service.
+  int64 handle = 2;
+
+  // Values of the tensors passed as arguments to the callable, in the order
+  // defined in the CallableOptions.feed field passed to MakeCallable.
+  repeated TensorProto feed = 3;
+}
+
+message RunCallableResponse {
+  // Values of the tensors returned by the callable, in the order defined in the
+  // CallableOptions.fetch field passed to MakeCallable.
+  repeated TensorProto fetch = 1;
+
+  // Returned metadata if requested in the options.
+  RunMetadata metadata = 2;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// ReleaseCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message ReleaseCallableRequest {
+  // REQUIRED: session_handle must be returned by a CreateSession call
+  // to the same master service.
+  string session_handle = 1;
+
+  // REQUIRED: handle must be returned by a MakeCallable call to the same
+  // master service.
+  int64 handle = 2;
+}
+
+message ReleaseCallableResponse {
+}
index 771c805..1170611 100644 (file)
@@ -107,4 +107,13 @@ service MasterService {
   // will no longer affect fresh ones via the resources in containers listed in
   // the ResetRequest.  See ResetRequest for more details.
   rpc Reset(ResetRequest) returns (ResetResponse);
+
+  // Registers a callable for execution with RunCallable.
+  rpc MakeCallable(MakeCallableRequest) returns (MakeCallableResponse);
+
+  // Executes a callable registered with MakeCallable.
+  rpc RunCallable(RunCallableRequest) returns (RunCallableResponse);
+
+  // Frees resources associated with a callable registered with MakeCallable.
+  rpc ReleaseCallable(ReleaseCallableRequest) returns (ReleaseCallableResponse);
 }