Adding stop request capability to CheckpointSaverListener. An example usage of it...
authorMustafa Ispir <ispir@google.com>
Tue, 22 May 2018 17:42:31 +0000 (10:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 17:45:14 +0000 (10:45 -0700)
my-estimator = tf.estimator.DNNClassifier(...)
stopper = StopTrainingBasedOnEvaluateMetrics(my-estimator)
my-estimator.train(..., saving_listeners=[stopper])

where:

class StopTrainingBasedOnEvaluateMetrics(tf.train.CheckpointSaverListener):
  """A saver listener to run evaluate with every checkpoint."""
  def __init__(self, estimator):
    self._estimator = estimator

  def after_save(self, session, global_step_value):
    eval_results = self._estimator.evaluate(...)
    if stop-if-started-overfitting(eval_results):
      return True

PiperOrigin-RevId: 197586515

tensorflow/python/estimator/estimator_test.py
tensorflow/python/training/basic_session_run_hooks.py
tensorflow/python/training/basic_session_run_hooks_test.py

index 1b70189..a9f20f7 100644 (file)
@@ -814,6 +814,7 @@ class EstimatorTrainTest(test.TestCase):
 
   def test_saving_listeners_are_used(self):
     listener = test.mock.Mock(spec=training.CheckpointSaverListener)
+    listener.after_save.return_value = None
     est = estimator.Estimator(
         model_fn=model_fn_global_step_incrementer,
         config=run_config.RunConfig(save_checkpoints_steps=10))
index df528d5..9b40817 100644 (file)
@@ -336,6 +336,8 @@ class CheckpointSaverListener(object):
 
     def after_save(self, session, global_step_value):
       print('Done writing checkpoint.')
+      if decided_to_stop_training():
+        return True
 
     def end(self, session, global_step_value):
       print('Done with the session.')
@@ -354,6 +356,11 @@ class CheckpointSaverListener(object):
   implementors should implement the `end()` method to handle actions related to
   the last checkpoint save. But the listener should not act twice if
   `after_save()` already handled this last checkpoint save.
+
+  A `CheckpointSaverListener` can request training to be stopped, by returning
+  True in `after_save`. Please note that, in replicated distributed training
+  setting, only `chief` should use this behavior. Otherwise each worker will do
+  their own evaluation, which may be wasteful of resources.
   """
 
   def begin(self):
@@ -453,7 +460,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
       global_step = run_context.session.run(self._global_step_tensor)
       if self._timer.should_trigger_for_step(global_step):
         self._timer.update_last_triggered_step(global_step)
-        self._save(run_context.session, global_step)
+        if self._save(run_context.session, global_step):
+          run_context.request_stop()
 
   def end(self, session):
     last_step = session.run(self._global_step_tensor)
@@ -463,7 +471,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
       l.end(session, last_step)
 
   def _save(self, session, step):
-    """Saves the latest checkpoint."""
+    """Saves the latest checkpoint, returns should_stop."""
     logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
 
     for l in self._listeners:
@@ -475,8 +483,14 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
             status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
         step)
 
+    should_stop = False
     for l in self._listeners:
-      l.after_save(session, step)
+      if l.after_save(session, step):
+        logging.info(
+            "A CheckpointSaverListener requested that training be stopped. "
+            "listener: {}".format(l))
+        should_stop = True
+    return should_stop
 
   def _get_saver(self):
     if self._saver is not None:
index 7344ce2..21c584f 100644 (file)
@@ -58,6 +58,7 @@ class MockCheckpointSaverListener(
     self.before_save_count = 0
     self.after_save_count = 0
     self.end_count = 0
+    self.ask_for_stop = False
 
   def begin(self):
     self.begin_count += 1
@@ -67,6 +68,8 @@ class MockCheckpointSaverListener(
 
   def after_save(self, session, global_step):
     self.after_save_count += 1
+    if self.ask_for_stop:
+      return True
 
   def end(self, session, global_step):
     self.end_count += 1
@@ -471,6 +474,25 @@ class CheckpointSaverHookTest(test.TestCase):
         'end': 1
     }, listener_counts)
 
+  def test_listener_stops_training_in_after_save(self):
+    with ops.Graph().as_default():
+      scaffold = monitored_session.Scaffold()
+      variables.get_or_create_global_step()
+      train_op = training_util._increment_global_step(1)
+      listener = MockCheckpointSaverListener()
+      hook = basic_session_run_hooks.CheckpointSaverHook(
+          self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])
+      with monitored_session.SingularMonitoredSession(
+          hooks=[hook], scaffold=scaffold,
+          checkpoint_dir=self.model_dir) as sess:
+        sess.run(train_op)
+        self.assertFalse(sess.should_stop())
+        sess.run(train_op)
+        self.assertFalse(sess.should_stop())
+        listener.ask_for_stop = True
+        sess.run(train_op)
+        self.assertTrue(sess.should_stop())
+
   def test_listener_with_default_saver(self):
     with ops.Graph().as_default():
       global_step = variables.get_or_create_global_step()