Support configurable stats publishers in the grpc server.
authorBrennan Saeta <saeta@google.com>
Tue, 27 Feb 2018 01:56:15 +0000 (17:56 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187110497

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h

index c4ac92d..a6f4be3 100644 (file)
@@ -106,7 +106,8 @@ GrpcServer::~GrpcServer() {
 Status GrpcServer::Init(
     ServiceInitFunction service_func,
     const RendezvousMgrCreationFunction& rendezvous_mgr_func,
-    const WorkerCreationFunction& worker_func) {
+    const WorkerCreationFunction& worker_func,
+    const StatsPublisherFactory& stats_factory) {
   mutex_lock l(mu_);
   CHECK_EQ(state_, NEW);
   master_env_.env = env_;
@@ -218,7 +219,7 @@ Status GrpcServer::Init(
   master_env_.ops = OpRegistry::Global();
   master_env_.worker_cache = worker_cache;
   master_env_.master_session_factory =
-      [config](
+      [config, stats_factory](
           SessionOptions options, const MasterEnv* env,
           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
           std::unique_ptr<WorkerCacheInterface> worker_cache,
@@ -226,7 +227,7 @@ Status GrpcServer::Init(
         options.config.MergeFrom(config);
         return new MasterSession(options, env, std::move(remote_devs),
                                  std::move(worker_cache), std::move(device_set),
-                                 CreateNoOpStatsPublisher);
+                                 stats_factory);
       };
   master_env_.worker_cache_factory =
       [this](const WorkerCacheFactoryOptions& options,
@@ -243,6 +244,14 @@ Status GrpcServer::Init(
 
 Status GrpcServer::Init(
     ServiceInitFunction service_func,
+    const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+    const WorkerCreationFunction& worker_func) {
+  return Init(std::move(service_func), rendezvous_mgr_func, worker_func,
+              CreateNoOpStatsPublisher);
+}
+
+Status GrpcServer::Init(
+    ServiceInitFunction service_func,
     const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
   return Init(service_func, rendezvous_mgr_func, nullptr);
 }
index 8b12ac1..7c2f06f 100644 (file)
@@ -22,6 +22,7 @@ limitations under the License.
 #include "grpc++/security/credentials.h"
 
 #include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
 #include "tensorflow/core/distributed_runtime/master_env.h"
 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
@@ -70,6 +71,11 @@ class GrpcServer : public ServerInterface {
  protected:
   Status Init(ServiceInitFunction service_func,
               const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+              const WorkerCreationFunction& worker_func,
+              const StatsPublisherFactory& stats_factory);
+
+  Status Init(ServiceInitFunction service_func,
+              const RendezvousMgrCreationFunction& rendezvous_mgr_func,
               const WorkerCreationFunction& worker_func);
 
   Status Init(ServiceInitFunction service_func,