From 3fb89650a1e7f5cc4c04f091170fac504ba10021 Mon Sep 17 00:00:00 2001 From: Sherry Moore Date: Thu, 5 Apr 2018 09:33:20 -0700 Subject: [PATCH] Added a call in CheckpointSaverHook.after_create_session to always save checkpoint before the first training step. PiperOrigin-RevId: 191753026 --- tensorflow/python/estimator/estimator_test.py | 2 +- .../python/estimator/replicate_model_fn_test.py | 9 ++++-- .../python/training/basic_session_run_hooks.py | 5 +++ .../training/basic_session_run_hooks_test.py | 37 ++++++++++++++++++---- 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index f425509..498f529 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -680,7 +680,7 @@ class EstimatorTrainTest(test.TestCase): text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') self.assertAllEqual( - ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) + ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): tmpdir = tempfile.mkdtemp() diff --git a/tensorflow/python/estimator/replicate_model_fn_test.py b/tensorflow/python/estimator/replicate_model_fn_test.py index ad1f9c0..00035ef 100644 --- a/tensorflow/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/python/estimator/replicate_model_fn_test.py @@ -27,6 +27,7 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import replicate_model_fn +from tensorflow.python.estimator import run_config from tensorflow.python.estimator.canned import dnn from tensorflow.python.estimator.canned import optimizers from tensorflow.python.estimator.canned import prediction_keys @@ -593,7 +594,8 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): loss=loss, eval_metric_ops=metrics, predictions={'probabilities': predictions}, - train_op=optimizer.minimize(loss)) + train_op=optimizer.minimize( + loss, global_step=training.get_global_step())) @property def params(self): @@ -612,8 +614,9 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): estimator = estimator_lib.Estimator( model_fn=self.model_fn, model_dir=tempfile.mkdtemp(), - params=self.params) - estimator.train(train_input_fn, steps=1) + params=self.params, + config=run_config.RunConfig(save_checkpoints_steps=1)) + estimator.train(train_input_fn, steps=2) self.assertEqual(7.0, estimator.get_variable_value('c')) diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index aae757b..77d4f15 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -429,6 +429,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): for l in self._listeners: l.begin() + def after_create_session(self, session, coord): + global_step = session.run(self._global_step_tensor) + self._save(session, global_step) + self._timer.update_last_triggered_step(global_step) + def before_run(self, run_context): # pylint: disable=unused-argument if self._timer.last_triggered_step() is None: # We do write graph and saver_def at the first call of before_run. diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 2547661..4bf4a59 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase): self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, - 'before_save': 2, - 'after_save': 2, + 'before_save': 3, + 'after_save': 3, 'end': 1 }, listener_counts) @@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase): self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, - 'before_save': 2, - 'after_save': 2, + 'before_save': 3, + 'after_save': 3, 'end': 1 }, listener_counts) @@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase): self.assertEqual(2, global_step_val) self.assertEqual({ 'begin': 1, - 'before_save': 2, - 'after_save': 2, + 'before_save': 3, + 'after_save': 3, 'end': 1 }, listener1_counts) self.assertEqual(listener1_counts, listener2_counts) @@ -718,6 +718,31 @@ class CheckpointSaverHookTest(test.TestCase): fake_summary_writer.FakeSummaryWriter.uninstall() + def test_save_checkpoint_before_first_train_step(self): + with self.graph.as_default(): + hook = basic_session_run_hooks.CheckpointSaverHook( + self.model_dir, save_steps=2, scaffold=self.scaffold) + hook.begin() + self.scaffold.finalize() + with session_lib.Session() as sess: + mon_sess = monitored_session._HookedSession(sess, [hook]) + sess.run(self.scaffold.init_op) + hook.after_create_session(sess, None) + # Verifies that checkpoint is saved at step 0. + self.assertEqual(0, + checkpoint_utils.load_variable(self.model_dir, + self.global_step.name)) + # Verifies that no checkpoint is saved after one training step. + mon_sess.run(self.train_op) + self.assertEqual(0, + checkpoint_utils.load_variable(self.model_dir, + self.global_step.name)) + # Verifies that checkpoint is saved after save_steps. + mon_sess.run(self.train_op) + self.assertEqual(2, + checkpoint_utils.load_variable(self.model_dir, + self.global_step.name)) + class ResourceCheckpointSaverHookTest(test.TestCase): -- 2.7.4