* A tuple (features, labels): Where features is a `Tensor` or a
dictionary of string feature name to `Tensor` and labels is a
`Tensor` or a dictionary of string label name to `Tensor`.
-
+
steps: Int. Positive number of steps for which to evaluate model. If
`None`, evaluates until `input_fn` raises an end-of-input exception.
See `Estimator.evaluate` for details.
Raises:
ValueError: if environment variable `TF_CONFIG` is incorrectly set.
"""
+ _assert_eval_spec(eval_spec) # fail fast if eval_spec is invalid.
+
executor = _TrainingExecutor(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
'Got: {}'.format(type(train_spec)))
self._train_spec = train_spec
- if not isinstance(eval_spec, EvalSpec):
- raise TypeError(
- '`eval_spec` must have type `tf.estimator.EvalSpec`. '
- 'Got: {}'.format(type(eval_spec)))
+ if eval_spec and not isinstance(eval_spec, EvalSpec):
+ raise TypeError('`eval_spec` must be either `None` or have type '
+ '`tf.estimator.EvalSpec`. Got: {}'.format(
+ type(eval_spec)))
self._eval_spec = eval_spec
self._train_hooks = _validate_hooks(train_hooks)
logging.info('Skip the current checkpoint eval due to throttle secs '
'({} secs).'.format(self._eval_throttle_secs))
+ _assert_eval_spec(self._eval_spec)
+
# Final export signal: For any eval result with global_step >= train
# max_steps, the evaluator will send the final export signal. There is a
# small chance that the Estimator.train stopping logic sees a different
return True
return False
+ _assert_eval_spec(self._eval_spec)
+
if self._eval_spec.throttle_secs <= 0:
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
'It is used do determine how long each training '
def _start_continuous_evaluation(self):
"""Repeatedly calls `Estimator` evaluate and export until training ends."""
+
+ _assert_eval_spec(self._eval_spec)
+
start_delay_secs = self._eval_spec.start_delay_secs
if start_delay_secs:
logging.info('Waiting %f secs before starting eval.', start_delay_secs)
def _execute_evaluator_once(self, evaluator, continuous_eval_listener,
throttle_secs):
"""Executes the `evaluator`."""
+
+ _assert_eval_spec(self._eval_spec)
+
start = time.time()
eval_result = None
def __init__(self, estimator, eval_spec, max_training_steps):
self._estimator = estimator
+
+ _assert_eval_spec(eval_spec)
self._eval_spec = eval_spec
+
self._is_final_export_triggered = False
self._previous_ckpt_path = None
self._last_warning_time = 0
"""
del eval_result
return True
+
+
+def _assert_eval_spec(eval_spec):
+ """Raise error if `eval_spec` is not of the right type."""
+ if not isinstance(eval_spec, EvalSpec):
+ raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. '
+ 'Got: {}'.format(type(eval_spec)))
'An Exporter cannot have a name that is `None` or empty.')
_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
+_EVAL_SPEC_OR_NONE_MSG = (
+ '`eval_spec` must be either `None` or have type `tf.estimator.EvalSpec`')
_INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
training.train_and_evaluate(invalid_estimator, mock_train_spec,
mock_eval_spec)
+ def test_fail_fast_if_invalid_eval_spec(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ invalid_eval_spec = object()
+
+ with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
+ with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+ training.train_and_evaluate(mock_est, mock_train_spec,
+ invalid_eval_spec)
+
+ mock_executor.assert_not_called()
+
class TrainingExecutorConstructorTest(test.TestCase):
"""Tests constructor of _TrainingExecutor."""
- def testRequiredArgumentsSet(self):
+ def test_required_arguments_set(self):
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=lambda: 1)
train_spec = training.TrainSpec(input_fn=lambda: 1)
invalid_eval_spec = object()
- with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+ with self.assertRaisesRegexp(TypeError, _EVAL_SPEC_OR_NONE_MSG):
training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
+ def test_eval_spec_none(self):
+ estimator = estimator_lib.Estimator(model_fn=lambda features: features)
+ train_spec = training.TrainSpec(input_fn=lambda: 1)
+ eval_spec = None
+
+ # Tests that no error is raised.
+ training._TrainingExecutor(estimator, train_spec, eval_spec)
+
def test_invalid_train_hooks(self):
estimator = estimator_lib.Estimator(model_fn=lambda features: features)
train_spec = training.TrainSpec(input_fn=lambda: 1)
@test.mock.patch.object(time, 'sleep')
@test.mock.patch.object(server_lib, 'Server')
+ def test_train_with_no_eval_spec(self, mock_server, unused_mock_sleep):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.config = self._run_config
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+ eval_spec = None
+ mock_server_instance = mock_server.return_value
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+ self._run_task(executor)
+
+ mock_server.assert_called_with(
+ mock_est.config.cluster_spec,
+ job_name=mock_est.config.task_type,
+ task_index=mock_est.config.task_id,
+ config=test.mock.ANY,
+ start=False)
+
+ self.assertTrue(mock_server_instance.start.called)
+
+ mock_est.train.assert_called_with(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks),
+ saving_listeners=test.mock.ANY)
+ mock_est.evaluate.assert_not_called()
+ mock_est.export_savedmodel.assert_not_called()
+
+ @test.mock.patch.object(time, 'sleep')
+ @test.mock.patch.object(server_lib, 'Server')
def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
@test.mock.patch.object(time, 'sleep')
@test.mock.patch.object(server_lib, 'Server')
+ def test_train_with_no_eval_spec_fails(self, mock_server, unused_mock_sleep):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
+ mock_est.config = self._run_config
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+ eval_spec = None
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+ with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+ executor.run_master()
+
+ @test.mock.patch.object(time, 'sleep')
+ @test.mock.patch.object(server_lib, 'Server')
def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
hooks=eval_spec.hooks)
self.assertFalse(mock_est.train.called)
+ def test_evaluate_with_no_eval_spec_fails(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.latest_checkpoint.return_value = 'latest_it_is'
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+ eval_spec = None
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+
+ with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+ executor.run_evaluator()
+
def test_evaluate_with_train_hooks(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.latest_checkpoint.return_value = 'latest_it_is'
self.assertEqual(train_spec.input_fn, train_args['input_fn'])
self.assertEqual(train_spec.max_steps, train_args['max_steps'])
+ def test_train_with_no_eval_spec_fails(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ eval_spec = None
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+
+ with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+ executor.run_local()
+
def test_train_hooks(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'