From e6eb02f9babe07ab47852f9defe7f6d512164473 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 24 May 2018 15:27:00 -0700 Subject: [PATCH] Add local_init_run_options to SessionManager and Supervisor so that collective_graph_key can be passed in when collective ops are used in variable initialization. PiperOrigin-RevId: 197964316 --- tensorflow/python/training/session_manager.py | 8 ++++++-- tensorflow/python/training/supervisor.py | 9 +++++++-- .../tools/api/golden/tensorflow.train.-session-manager.pbtxt | 2 +- tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt | 2 +- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py index 3cb3877..974f757 100644 --- a/tensorflow/python/training/session_manager.py +++ b/tensorflow/python/training/session_manager.py @@ -95,7 +95,8 @@ class SessionManager(object): ready_op=None, ready_for_local_init_op=None, graph=None, - recovery_wait_secs=30): + recovery_wait_secs=30, + local_init_run_options=None): """Creates a SessionManager. The `local_init_op` is an `Operation` that is run always after a new session @@ -127,6 +128,8 @@ class SessionManager(object): to run local_init_op. graph: The `Graph` that the model will use. recovery_wait_secs: Seconds between checks for the model to be ready. + local_init_run_options: RunOptions to be passed to session.run when + executing the local_init_op. Raises: ValueError: If ready_for_local_init_op is not None but local_init_op is @@ -141,6 +144,7 @@ class SessionManager(object): self._graph = graph self._recovery_wait_secs = recovery_wait_secs self._target = None + self._local_init_run_options = local_init_run_options if ready_for_local_init_op is not None and local_init_op is None: raise ValueError("If you pass a ready_for_local_init_op " "you must also pass a local_init_op " @@ -485,7 +489,7 @@ class SessionManager(object): is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) if is_ready_for_local_init: logging.info("Running local_init_op.") - sess.run(self._local_init_op) + sess.run(self._local_init_op, options=self._local_init_run_options) logging.info("Done running local_init_op.") return True, None else: diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py index 7389e34..372ea41 100644 --- a/tensorflow/python/training/supervisor.py +++ b/tensorflow/python/training/supervisor.py @@ -225,7 +225,8 @@ class Supervisor(object): checkpoint_basename="model.ckpt", session_manager=None, summary_writer=USE_DEFAULT, - init_fn=None): + init_fn=None, + local_init_run_options=None): """Create a `Supervisor`. Args: @@ -294,6 +295,8 @@ class Supervisor(object): init_fn: Optional callable used to initialize the model. Called after the optional `init_op` is called. The callable must accept one argument, the session being initialized. + local_init_run_options: RunOptions to be passed as the SessionManager + local_init_run_options parameter. Returns: A `Supervisor`. @@ -327,6 +330,7 @@ class Supervisor(object): self._recovery_wait_secs = recovery_wait_secs self._stop_grace_secs = stop_grace_secs self._init_fn = init_fn + self._local_init_run_options = local_init_run_options # Set all attributes related to checkpointing and writing events to None. # Afterwards, set them appropriately for chief supervisors, as these are @@ -362,7 +366,8 @@ class Supervisor(object): ready_op=self._ready_op, ready_for_local_init_op=self._ready_for_local_init_op, graph=self._graph, - recovery_wait_secs=self._recovery_wait_secs) + recovery_wait_secs=self._recovery_wait_secs, + local_init_run_options=self._local_init_run_options) else: self._session_manager = session_manager diff --git a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt index cc31bb4..448764f 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-session-manager.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\'], " + argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], " } member_method { name: "prepare_session" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt index 1f0e59a..9677e5a 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-supervisor.pbtxt @@ -104,7 +104,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\'], " + argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "loop" -- 2.7.4