Automated g4 rollback of changelist 191753026
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 20:00:06 +0000 (13:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 20:02:56 +0000 (13:02 -0700)
PiperOrigin-RevId: 191784709

tensorflow/python/estimator/estimator_test.py
tensorflow/python/estimator/replicate_model_fn_test.py
tensorflow/python/training/basic_session_run_hooks.py
tensorflow/python/training/basic_session_run_hooks_test.py

index 498f529..f425509 100644 (file)
@@ -680,7 +680,7 @@ class EstimatorTrainTest(test.TestCase):
     text_format.Merge(checkpoint_file_content, ckpt)
     self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
     self.assertAllEqual(
-        ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+        ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
 
   def test_train_save_copy_reload(self):
     tmpdir = tempfile.mkdtemp()
index 00035ef..ad1f9c0 100644 (file)
@@ -27,7 +27,6 @@ import six
 from tensorflow.python.estimator import estimator as estimator_lib
 from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import replicate_model_fn
-from tensorflow.python.estimator import run_config
 from tensorflow.python.estimator.canned import dnn
 from tensorflow.python.estimator.canned import optimizers
 from tensorflow.python.estimator.canned import prediction_keys
@@ -594,8 +593,7 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
         loss=loss,
         eval_metric_ops=metrics,
         predictions={'probabilities': predictions},
-        train_op=optimizer.minimize(
-            loss, global_step=training.get_global_step()))
+        train_op=optimizer.minimize(loss))
 
   @property
   def params(self):
@@ -614,9 +612,8 @@ class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
       estimator = estimator_lib.Estimator(
           model_fn=self.model_fn,
           model_dir=tempfile.mkdtemp(),
-          params=self.params,
-          config=run_config.RunConfig(save_checkpoints_steps=1))
-      estimator.train(train_input_fn, steps=2)
+          params=self.params)
+      estimator.train(train_input_fn, steps=1)
 
       self.assertEqual(7.0, estimator.get_variable_value('c'))
 
index 77d4f15..aae757b 100644 (file)
@@ -429,11 +429,6 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
     for l in self._listeners:
       l.begin()
 
-  def after_create_session(self, session, coord):
-    global_step = session.run(self._global_step_tensor)
-    self._save(session, global_step)
-    self._timer.update_last_triggered_step(global_step)
-
   def before_run(self, run_context):  # pylint: disable=unused-argument
     if self._timer.last_triggered_step() is None:
       # We do write graph and saver_def at the first call of before_run.
index 4bf4a59..2547661 100644 (file)
@@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase):
     self.assertEqual(2, global_step_val)
     self.assertEqual({
         'begin': 1,
-        'before_save': 3,
-        'after_save': 3,
+        'before_save': 2,
+        'after_save': 2,
         'end': 1
     }, listener_counts)
 
@@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase):
     self.assertEqual(2, global_step_val)
     self.assertEqual({
         'begin': 1,
-        'before_save': 3,
-        'after_save': 3,
+        'before_save': 2,
+        'after_save': 2,
         'end': 1
     }, listener_counts)
 
@@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase):
     self.assertEqual(2, global_step_val)
     self.assertEqual({
         'begin': 1,
-        'before_save': 3,
-        'after_save': 3,
+        'before_save': 2,
+        'after_save': 2,
         'end': 1
     }, listener1_counts)
     self.assertEqual(listener1_counts, listener2_counts)
@@ -718,31 +718,6 @@ class CheckpointSaverHookTest(test.TestCase):
 
     fake_summary_writer.FakeSummaryWriter.uninstall()
 
-  def test_save_checkpoint_before_first_train_step(self):
-    with self.graph.as_default():
-      hook = basic_session_run_hooks.CheckpointSaverHook(
-          self.model_dir, save_steps=2, scaffold=self.scaffold)
-      hook.begin()
-      self.scaffold.finalize()
-      with session_lib.Session() as sess:
-        mon_sess = monitored_session._HookedSession(sess, [hook])
-        sess.run(self.scaffold.init_op)
-        hook.after_create_session(sess, None)
-        # Verifies that checkpoint is saved at step 0.
-        self.assertEqual(0,
-                         checkpoint_utils.load_variable(self.model_dir,
-                                                        self.global_step.name))
-        # Verifies that no checkpoint is saved after one training step.
-        mon_sess.run(self.train_op)
-        self.assertEqual(0,
-                         checkpoint_utils.load_variable(self.model_dir,
-                                                        self.global_step.name))
-        # Verifies that checkpoint is saved after save_steps.
-        mon_sess.run(self.train_op)
-        self.assertEqual(2,
-                         checkpoint_utils.load_variable(self.model_dir,
-                                                        self.global_step.name))
-
 
 class ResourceCheckpointSaverHookTest(test.TestCase):