From 0ad7f20ed6876809a2b804365293a5c21dbcd374 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Mon, 21 May 2018 16:43:53 -0700 Subject: [PATCH] Ensure that saving/restoring iterator in CheckpointInputPipelineHook is performed *after* the _DatasetInitializerHook has been run. In the TPUEstimator the _DatasetInitializerHook is present in the EstimatorSpec.training_hooks. Since these are executed after the `hooks` passed to Estimator.train the input pipeline checkpointing hook fails since it finds an uninitialized iterator. PiperOrigin-RevId: 197482609 --- tensorflow/contrib/data/python/ops/iterator_ops.py | 24 +++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index f1d0e5c..0d71be6 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -170,6 +170,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # `checkpoint_dir` is the same as the model checkpoint directory, there are # no conflicts during restore. self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True def begin(self): # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` @@ -184,7 +185,25 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # pylint: enable=protected-access self._checkpoint_saver_hook.begin() - def after_create_session(self, session, coord): + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access latest_checkpoint_path = saver_lib.latest_checkpoint( @@ -202,6 +221,9 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # pylint: enable=protected-access def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False return self._checkpoint_saver_hook.before_run(run_context) def after_run(self, run_context, run_values): -- 2.7.4