def after_save(self, session, global_step_value):
print('Done writing checkpoint.')
+ if decided_to_stop_training():
+ return True
def end(self, session, global_step_value):
print('Done with the session.')
implementors should implement the `end()` method to handle actions related to
the last checkpoint save. But the listener should not act twice if
`after_save()` already handled this last checkpoint save.
+
+ A `CheckpointSaverListener` can request training to be stopped, by returning
+ True in `after_save`. Please note that, in replicated distributed training
+ setting, only `chief` should use this behavior. Otherwise each worker will do
+ their own evaluation, which may be wasteful of resources.
"""
def begin(self):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
- self._save(run_context.session, global_step)
+ if self._save(run_context.session, global_step):
+ run_context.request_stop()
def end(self, session):
last_step = session.run(self._global_step_tensor)
l.end(session, last_step)
def _save(self, session, step):
- """Saves the latest checkpoint."""
+ """Saves the latest checkpoint, returns should_stop."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
step)
+ should_stop = False
for l in self._listeners:
- l.after_save(session, step)
+ if l.after_save(session, step):
+ logging.info(
+ "A CheckpointSaverListener requested that training be stopped. "
+ "listener: {}".format(l))
+ should_stop = True
+ return should_stop
def _get_saver(self):
if self._saver is not None:
self.before_save_count = 0
self.after_save_count = 0
self.end_count = 0
+ self.ask_for_stop = False
def begin(self):
self.begin_count += 1
def after_save(self, session, global_step):
self.after_save_count += 1
+ if self.ask_for_stop:
+ return True
def end(self, session, global_step):
self.end_count += 1
'end': 1
}, listener_counts)
+ def test_listener_stops_training_in_after_save(self):
+ with ops.Graph().as_default():
+ scaffold = monitored_session.Scaffold()
+ variables.get_or_create_global_step()
+ train_op = training_util._increment_global_step(1)
+ listener = MockCheckpointSaverListener()
+ hook = basic_session_run_hooks.CheckpointSaverHook(
+ self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])
+ with monitored_session.SingularMonitoredSession(
+ hooks=[hook], scaffold=scaffold,
+ checkpoint_dir=self.model_dir) as sess:
+ sess.run(train_op)
+ self.assertFalse(sess.should_stop())
+ sess.run(train_op)
+ self.assertFalse(sess.should_stop())
+ listener.ask_for_stop = True
+ sess.run(train_op)
+ self.assertTrue(sess.should_stop())
+
def test_listener_with_default_saver(self):
with ops.Graph().as_default():
global_step = variables.get_or_create_global_step()