# `checkpoint_dir` is the same as the model checkpoint directory, there are
# no conflicts during restore.
self._latest_filename = "checkpoint_" + checkpoint_prefix
+ self._first_run = True
def begin(self):
# Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
# pylint: enable=protected-access
self._checkpoint_saver_hook.begin()
- def after_create_session(self, session, coord):
+ def _restore_or_save_initial_ckpt(self, session):
+ # Ideally this should be run in after_create_session but is not for the
+ # following reason:
+ # Currently there is no way of enforcing an order of running the
+ # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
+ # is run *after* this hook. That is troublesome because
+ # 1. If a checkpoint exists and this hook restores it, the initializer hook
+ # will override it.
+ # 2. If no checkpoint exists, this hook will try to save an initialized
+ # iterator which will result in an exception.
+ #
+ # As a temporary fix we enter the following implicit contract between this
+ # hook and the _DatasetInitializerHook.
+ # 1. The _DatasetInitializerHook initializes the iterator in the call to
+ # after_create_session.
+ # 2. This hook saves the iterator on the first call to `before_run()`, which
+ # is guaranteed to happen after `after_create_session()` of all hooks
+ # have been run.
+
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
latest_checkpoint_path = saver_lib.latest_checkpoint(
# pylint: enable=protected-access
def before_run(self, run_context):
+ if self._first_run:
+ self._restore_or_save_initial_ckpt(run_context.session)
+ self._first_run = False
return self._checkpoint_saver_hook.before_run(run_context)
def after_run(self, run_context, run_values):