Added a call in CheckpointSaverHook.after_create_session to always save
authorSherry Moore <sherrym@google.com>
Thu, 5 Apr 2018 16:33:20 +0000 (09:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 16:35:47 +0000 (09:35 -0700)
checkpoint before the first training step.

PiperOrigin-RevId: 191753026

tensorflow/python/estimator/estimator_test.py
tensorflow/python/estimator/replicate_model_fn_test.py
tensorflow/python/training/basic_session_run_hooks.py
tensorflow/python/training/basic_session_run_hooks_test.py

index f425509..498f529 100644 (file)
@@ -680,7 +680,7 @@ class EstimatorTrainTest(test.TestCase):
     text_format.Merge(checkpoint_file_content, ckpt)
     self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
     self.assertAllEqual(
-        ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+        ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
 
   def test_train_save_copy_reload(self):
     tmpdir = tempfile.mkdtemp()
index ad1f9c0..00035ef 100644 (file)
@@ -27,6 +27,7 @@ import six
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import replicate_model_fn
+from tensorflow.python.estimator import run_config
 from tensorflow.python.estimator.canned import dnn
 from tensorflow.python.estimator.canned import optimizers
 from tensorflow.python.estimator.canned import prediction_keys
@@ -593,7 +594,8 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
         loss=loss,
         eval_metric_ops=metrics,
         predictions={'probabilities': predictions},
-        train_op=optimizer.minimize(loss))
+        train_op=optimizer.minimize(
+            loss, global_step=training.get_global_step()))
 
   @property
   def params(self):
@@ -612,8 +614,9 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
       estimator = estimator_lib.Estimator(
           model_fn=self.model_fn,
           model_dir=tempfile.mkdtemp(),
-          params=self.params)
-      estimator.train(train_input_fn, steps=1)
+          params=self.params,
+          config=run_config.RunConfig(save_checkpoints_steps=1))
+      estimator.train(train_input_fn, steps=2)
 
       self.assertEqual(7.0, estimator.get_variable_value('c'))
 
index aae757b..77d4f15 100644 (file)
@@ -429,6 +429,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
     for l in self._listeners:
       l.begin()
 
+  def after_create_session(self, session, coord):
+    global_step = session.run(self._global_step_tensor)
+    self._save(session, global_step)
+    self._timer.update_last_triggered_step(global_step)
+
   def before_run(self, run_context):  # pylint: disable=unused-argument
     if self._timer.last_triggered_step() is None:
       # We do write graph and saver_def at the first call of before_run.
index 2547661..4bf4a59 100644 (file)
@@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase):
     self.assertEqual(2, global_step_val)
     self.assertEqual({
         'begin': 1,
-        'before_save': 2,
-        'after_save': 2,
+        'before_save': 3,
+        'after_save': 3,
         'end': 1
     }, listener_counts)
 
@@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase):
     self.assertEqual(2, global_step_val)
     self.assertEqual({
         'begin': 1,
-        'before_save': 2,
-        'after_save': 2,
+        'before_save': 3,
+        'after_save': 3,
         'end': 1
     }, listener_counts)
 
@@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase):
     self.assertEqual(2, global_step_val)
     self.assertEqual({
         'begin': 1,
-        'before_save': 2,
-        'after_save': 2,
+        'before_save': 3,
+        'after_save': 3,
         'end': 1
     }, listener1_counts)
     self.assertEqual(listener1_counts, listener2_counts)
@@ -718,6 +718,31 @@ class CheckpointSaverHookTest(test.TestCase):
 
     fake_summary_writer.FakeSummaryWriter.uninstall()
 
+  def test_save_checkpoint_before_first_train_step(self):
+    with self.graph.as_default():
+      hook = basic_session_run_hooks.CheckpointSaverHook(
+          self.model_dir, save_steps=2, scaffold=self.scaffold)
+      hook.begin()
+      self.scaffold.finalize()
+      with session_lib.Session() as sess:
+        mon_sess = monitored_session._HookedSession(sess, [hook])
+        sess.run(self.scaffold.init_op)
+        hook.after_create_session(sess, None)
+        # Verifies that checkpoint is saved at step 0.
+        self.assertEqual(0,
+                         checkpoint_utils.load_variable(self.model_dir,
+                                                        self.global_step.name))
+        # Verifies that no checkpoint is saved after one training step.
+        mon_sess.run(self.train_op)
+        self.assertEqual(0,
+                         checkpoint_utils.load_variable(self.model_dir,
+                                                        self.global_step.name))
+        # Verifies that checkpoint is saved after save_steps.
+        mon_sess.run(self.train_op)
+        self.assertEqual(2,
+                         checkpoint_utils.load_variable(self.model_dir,
+                                                        self.global_step.name))
+
 
 class ResourceCheckpointSaverHookTest(test.TestCase):