text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
self.assertAllEqual(
- ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+ ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import replicate_model_fn
-from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.canned import prediction_keys
loss=loss,
eval_metric_ops=metrics,
predictions={'probabilities': predictions},
- train_op=optimizer.minimize(
- loss, global_step=training.get_global_step()))
+ train_op=optimizer.minimize(loss))
@property
def params(self):
estimator = estimator_lib.Estimator(
model_fn=self.model_fn,
model_dir=tempfile.mkdtemp(),
- params=self.params,
- config=run_config.RunConfig(save_checkpoints_steps=1))
- estimator.train(train_input_fn, steps=2)
+ params=self.params)
+ estimator.train(train_input_fn, steps=1)
self.assertEqual(7.0, estimator.get_variable_value('c'))
for l in self._listeners:
l.begin()
- def after_create_session(self, session, coord):
- global_step = session.run(self._global_step_tensor)
- self._save(session, global_step)
- self._timer.update_last_triggered_step(global_step)
-
def before_run(self, run_context): # pylint: disable=unused-argument
if self._timer.last_triggered_step() is None:
# We do write graph and saver_def at the first call of before_run.
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 3,
- 'after_save': 3,
+ 'before_save': 2,
+ 'after_save': 2,
'end': 1
}, listener_counts)
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 3,
- 'after_save': 3,
+ 'before_save': 2,
+ 'after_save': 2,
'end': 1
}, listener_counts)
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 3,
- 'after_save': 3,
+ 'before_save': 2,
+ 'after_save': 2,
'end': 1
}, listener1_counts)
self.assertEqual(listener1_counts, listener2_counts)
fake_summary_writer.FakeSummaryWriter.uninstall()
- def test_save_checkpoint_before_first_train_step(self):
- with self.graph.as_default():
- hook = basic_session_run_hooks.CheckpointSaverHook(
- self.model_dir, save_steps=2, scaffold=self.scaffold)
- hook.begin()
- self.scaffold.finalize()
- with session_lib.Session() as sess:
- mon_sess = monitored_session._HookedSession(sess, [hook])
- sess.run(self.scaffold.init_op)
- hook.after_create_session(sess, None)
- # Verifies that checkpoint is saved at step 0.
- self.assertEqual(0,
- checkpoint_utils.load_variable(self.model_dir,
- self.global_step.name))
- # Verifies that no checkpoint is saved after one training step.
- mon_sess.run(self.train_op)
- self.assertEqual(0,
- checkpoint_utils.load_variable(self.model_dir,
- self.global_step.name))
- # Verifies that checkpoint is saved after save_steps.
- mon_sess.run(self.train_op)
- self.assertEqual(2,
- checkpoint_utils.load_variable(self.model_dir,
- self.global_step.name))
-
class ResourceCheckpointSaverHookTest(test.TestCase):