Ensure that saving/restoring iterator in CheckpointInputPipelineHook is performed...
authorSaurabh Saxena <srbs@google.com>
Mon, 21 May 2018 23:43:53 +0000 (16:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 21 May 2018 23:45:55 +0000 (16:45 -0700)
In the TPUEstimator the _DatasetInitializerHook is present in the
EstimatorSpec.training_hooks. Since these are executed after the `hooks`
passed to Estimator.train the input pipeline checkpointing hook fails
since it finds an uninitialized iterator.

PiperOrigin-RevId: 197482609

tensorflow/contrib/data/python/ops/iterator_ops.py

index f1d0e5c..0d71be6 100644 (file)
@@ -170,6 +170,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
     # `checkpoint_dir` is the same as the model checkpoint directory, there are
     # no conflicts during restore.
     self._latest_filename = "checkpoint_" + checkpoint_prefix
+    self._first_run = True
 
   def begin(self):
     # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
@@ -184,7 +185,25 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
     # pylint: enable=protected-access
     self._checkpoint_saver_hook.begin()
 
-  def after_create_session(self, session, coord):
+  def _restore_or_save_initial_ckpt(self, session):
+    # Ideally this should be run in after_create_session but is not for the
+    # following reason:
+    # Currently there is no way of enforcing an order of running the
+    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
+    # is run *after* this hook. That is troublesome because
+    # 1. If a checkpoint exists and this hook restores it, the initializer hook
+    #    will override it.
+    # 2. If no checkpoint exists, this hook will try to save an initialized
+    #    iterator which will result in an exception.
+    #
+    # As a temporary fix we enter the following implicit contract between this
+    # hook and the _DatasetInitializerHook.
+    # 1. The _DatasetInitializerHook initializes the iterator in the call to
+    #    after_create_session.
+    # 2. This hook saves the iterator on the first call to `before_run()`, which
+    #    is guaranteed to happen after `after_create_session()` of all hooks
+    #    have been run.
+
     # Check if there is an existing checkpoint. If so, restore from it.
     # pylint: disable=protected-access
     latest_checkpoint_path = saver_lib.latest_checkpoint(
@@ -202,6 +221,9 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
     # pylint: enable=protected-access
 
   def before_run(self, run_context):
+    if self._first_run:
+      self._restore_or_save_initial_ckpt(run_context.session)
+      self._first_run = False
     return self._checkpoint_saver_hook.before_run(run_context)
 
   def after_run(self, run_context, run_values):