Add local_init_run_options to SessionManager and Supervisor so that
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 22:27:00 +0000 (15:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 22:29:42 +0000 (15:29 -0700)
collective_graph_key can be passed in when collective ops are used
in variable initialization.

PiperOrigin-RevId: 197964316

tensorflow/python/training/session_manager.py
tensorflow/python/training/supervisor.py
tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt
tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt

index 3cb3877..974f757 100644 (file)
@@ -95,7 +95,8 @@ class SessionManager(object):
                ready_op=None,
                ready_for_local_init_op=None,
                graph=None,
-               recovery_wait_secs=30):
+               recovery_wait_secs=30,
+               local_init_run_options=None):
     """Creates a SessionManager.
 
     The `local_init_op` is an `Operation` that is run always after a new session
@@ -127,6 +128,8 @@ class SessionManager(object):
          to run local_init_op.
       graph: The `Graph` that the model will use.
       recovery_wait_secs: Seconds between checks for the model to be ready.
+      local_init_run_options: RunOptions to be passed to session.run when
+        executing the local_init_op.
 
     Raises:
       ValueError: If ready_for_local_init_op is not None but local_init_op is
@@ -141,6 +144,7 @@ class SessionManager(object):
     self._graph = graph
     self._recovery_wait_secs = recovery_wait_secs
     self._target = None
+    self._local_init_run_options = local_init_run_options
     if ready_for_local_init_op is not None and local_init_op is None:
       raise ValueError("If you pass a ready_for_local_init_op "
                        "you must also pass a local_init_op "
@@ -485,7 +489,7 @@ class SessionManager(object):
       is_ready_for_local_init, msg = self._model_ready_for_local_init(sess)
       if is_ready_for_local_init:
         logging.info("Running local_init_op.")
-        sess.run(self._local_init_op)
+        sess.run(self._local_init_op, options=self._local_init_run_options)
         logging.info("Done running local_init_op.")
         return True, None
       else:
index 7389e34..372ea41 100644 (file)
@@ -225,7 +225,8 @@ class Supervisor(object):
                checkpoint_basename="model.ckpt",
                session_manager=None,
                summary_writer=USE_DEFAULT,
-               init_fn=None):
+               init_fn=None,
+               local_init_run_options=None):
     """Create a `Supervisor`.
 
     Args:
@@ -294,6 +295,8 @@ class Supervisor(object):
       init_fn: Optional callable used to initialize the model. Called
         after the optional `init_op` is called.  The callable must accept one
         argument, the session being initialized.
+      local_init_run_options: RunOptions to be passed as the SessionManager
+        local_init_run_options parameter.
 
     Returns:
       A `Supervisor`.
@@ -327,6 +330,7 @@ class Supervisor(object):
     self._recovery_wait_secs = recovery_wait_secs
     self._stop_grace_secs = stop_grace_secs
     self._init_fn = init_fn
+    self._local_init_run_options = local_init_run_options
 
     # Set all attributes related to checkpointing and writing events to None.
     # Afterwards, set them appropriately for chief supervisors, as these are
@@ -362,7 +366,8 @@ class Supervisor(object):
           ready_op=self._ready_op,
           ready_for_local_init_op=self._ready_for_local_init_op,
           graph=self._graph,
-          recovery_wait_secs=self._recovery_wait_secs)
+          recovery_wait_secs=self._recovery_wait_secs,
+          local_init_run_options=self._local_init_run_options)
     else:
       self._session_manager = session_manager
 
index cc31bb4..448764f 100644 (file)
@@ -4,7 +4,7 @@ tf_class {
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\'], "
+    argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], "
   }
   member_method {
     name: "prepare_session"
index 1f0e59a..9677e5a 100644 (file)
@@ -104,7 +104,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\', \'None\'], "
   }
   member_method {
     name: "loop"