Allow not specifying eval_spec when evaluation is not necessarily run.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 28 Apr 2018 17:40:49 +0000 (10:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 28 Apr 2018 17:43:18 +0000 (10:43 -0700)
PiperOrigin-RevId: 194661814

tensorflow/python/estimator/training.py
tensorflow/python/estimator/training_test.py

index 9d27175..534c357 100644 (file)
@@ -201,7 +201,7 @@ class EvalSpec(
           * A tuple (features, labels): Where features is a `Tensor` or a
             dictionary of string feature name to `Tensor` and labels is a
             `Tensor` or a dictionary of string label name to `Tensor`.
-            
+
       steps: Int. Positive number of steps for which to evaluate model. If
         `None`, evaluates until `input_fn` raises an end-of-input exception.
         See `Estimator.evaluate` for details.
@@ -427,6 +427,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
   Raises:
     ValueError: if environment variable `TF_CONFIG` is incorrectly set.
   """
+  _assert_eval_spec(eval_spec)  # fail fast if eval_spec is invalid.
+
   executor = _TrainingExecutor(
       estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
 
@@ -481,10 +483,10 @@ class _TrainingExecutor(object):
           'Got: {}'.format(type(train_spec)))
     self._train_spec = train_spec
 
-    if not isinstance(eval_spec, EvalSpec):
-      raise TypeError(
-          '`eval_spec` must have type `tf.estimator.EvalSpec`. '
-          'Got: {}'.format(type(eval_spec)))
+    if eval_spec and not isinstance(eval_spec, EvalSpec):
+      raise TypeError('`eval_spec` must be either `None` or have type '
+                      '`tf.estimator.EvalSpec`. Got: {}'.format(
+                          type(eval_spec)))
     self._eval_spec = eval_spec
 
     self._train_hooks = _validate_hooks(train_hooks)
@@ -580,6 +582,8 @@ class _TrainingExecutor(object):
           logging.info('Skip the current checkpoint eval due to throttle secs '
                        '({} secs).'.format(self._eval_throttle_secs))
 
+    _assert_eval_spec(self._eval_spec)
+
     # Final export signal: For any eval result with global_step >= train
     # max_steps, the evaluator will send the final export signal. There is a
     # small chance that the Estimator.train stopping logic sees a different
@@ -628,6 +632,8 @@ class _TrainingExecutor(object):
         return True
       return False
 
+    _assert_eval_spec(self._eval_spec)
+
     if self._eval_spec.throttle_secs <= 0:
       raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
                        'It is used do determine how long each training '
@@ -741,6 +747,9 @@ class _TrainingExecutor(object):
 
   def _start_continuous_evaluation(self):
     """Repeatedly calls `Estimator` evaluate and export until training ends."""
+
+    _assert_eval_spec(self._eval_spec)
+
     start_delay_secs = self._eval_spec.start_delay_secs
     if start_delay_secs:
       logging.info('Waiting %f secs before starting eval.', start_delay_secs)
@@ -769,6 +778,9 @@ class _TrainingExecutor(object):
   def _execute_evaluator_once(self, evaluator, continuous_eval_listener,
                               throttle_secs):
     """Executes the `evaluator`."""
+
+    _assert_eval_spec(self._eval_spec)
+
     start = time.time()
 
     eval_result = None
@@ -807,7 +819,10 @@ class _TrainingExecutor(object):
 
     def __init__(self, estimator, eval_spec, max_training_steps):
       self._estimator = estimator
+
+      _assert_eval_spec(eval_spec)
       self._eval_spec = eval_spec
+
       self._is_final_export_triggered = False
       self._previous_ckpt_path = None
       self._last_warning_time = 0
@@ -996,3 +1011,10 @@ class _ContinuousEvalListener(object):
     """
     del eval_result
     return True
+
+
+def _assert_eval_spec(eval_spec):
+  """Raise error if `eval_spec` is not of the right type."""
+  if not isinstance(eval_spec, EvalSpec):
+    raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. '
+                    'Got: {}'.format(type(eval_spec)))
index 4f7da84..c04905a 100644 (file)
@@ -72,6 +72,8 @@ _NONE_EXPORTER_NAME_MSG = (
     'An Exporter cannot have a name that is `None` or empty.')
 _INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
 _INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
+_EVAL_SPEC_OR_NONE_MSG = (
+    '`eval_spec` must be either `None` or have type `tf.estimator.EvalSpec`')
 _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
 _INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
 _INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
@@ -356,11 +358,23 @@ class TrainAndEvaluateTest(test.TestCase):
       training.train_and_evaluate(invalid_estimator, mock_train_spec,
                                   mock_eval_spec)
 
+  def test_fail_fast_if_invalid_eval_spec(self):
+    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+    invalid_eval_spec = object()
+
+    with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
+      with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+        training.train_and_evaluate(mock_est, mock_train_spec,
+                                    invalid_eval_spec)
+
+      mock_executor.assert_not_called()
+
 
 class TrainingExecutorConstructorTest(test.TestCase):
   """Tests constructor of _TrainingExecutor."""
 
-  def testRequiredArgumentsSet(self):
+  def test_required_arguments_set(self):
     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
     train_spec = training.TrainSpec(input_fn=lambda: 1)
     eval_spec = training.EvalSpec(input_fn=lambda: 1)
@@ -389,9 +403,17 @@ class TrainingExecutorConstructorTest(test.TestCase):
     train_spec = training.TrainSpec(input_fn=lambda: 1)
     invalid_eval_spec = object()
 
-    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+    with self.assertRaisesRegexp(TypeError, _EVAL_SPEC_OR_NONE_MSG):
       training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
 
+  def test_eval_spec_none(self):
+    estimator = estimator_lib.Estimator(model_fn=lambda features: features)
+    train_spec = training.TrainSpec(input_fn=lambda: 1)
+    eval_spec = None
+
+    # Tests that no error is raised.
+    training._TrainingExecutor(estimator, train_spec, eval_spec)
+
   def test_invalid_train_hooks(self):
     estimator = estimator_lib.Estimator(model_fn=lambda features: features)
     train_spec = training.TrainSpec(input_fn=lambda: 1)
@@ -459,6 +481,36 @@ class _TrainingExecutorTrainingTest(object):
 
   @test.mock.patch.object(time, 'sleep')
   @test.mock.patch.object(server_lib, 'Server')
+  def test_train_with_no_eval_spec(self, mock_server, unused_mock_sleep):
+    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+    mock_est.config = self._run_config
+    train_spec = training.TrainSpec(
+        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+    eval_spec = None
+    mock_server_instance = mock_server.return_value
+
+    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+    self._run_task(executor)
+
+    mock_server.assert_called_with(
+        mock_est.config.cluster_spec,
+        job_name=mock_est.config.task_type,
+        task_index=mock_est.config.task_id,
+        config=test.mock.ANY,
+        start=False)
+
+    self.assertTrue(mock_server_instance.start.called)
+
+    mock_est.train.assert_called_with(
+        input_fn=train_spec.input_fn,
+        max_steps=train_spec.max_steps,
+        hooks=list(train_spec.hooks),
+        saving_listeners=test.mock.ANY)
+    mock_est.evaluate.assert_not_called()
+    mock_est.export_savedmodel.assert_not_called()
+
+  @test.mock.patch.object(time, 'sleep')
+  @test.mock.patch.object(server_lib, 'Server')
   def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
     mock_est.config = self._run_config
@@ -685,6 +737,20 @@ class TrainingExecutorRunMasterTest(test.TestCase):
 
   @test.mock.patch.object(time, 'sleep')
   @test.mock.patch.object(server_lib, 'Server')
+  def test_train_with_no_eval_spec_fails(self, mock_server, unused_mock_sleep):
+    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+    mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
+    mock_est.config = self._run_config
+    train_spec = training.TrainSpec(
+        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
+    eval_spec = None
+
+    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+      executor.run_master()
+
+  @test.mock.patch.object(time, 'sleep')
+  @test.mock.patch.object(server_lib, 'Server')
   def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
     mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
@@ -980,6 +1046,19 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
         hooks=eval_spec.hooks)
     self.assertFalse(mock_est.train.called)
 
+  def test_evaluate_with_no_eval_spec_fails(self):
+    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+    mock_est.latest_checkpoint.return_value = 'latest_it_is'
+    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+    eval_spec = None
+
+    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+
+    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+      executor.run_evaluator()
+
   def test_evaluate_with_train_hooks(self):
     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
     mock_est.latest_checkpoint.return_value = 'latest_it_is'
@@ -1635,6 +1714,17 @@ class TrainingExecutorRunLocalTest(test.TestCase):
     self.assertEqual(train_spec.input_fn, train_args['input_fn'])
     self.assertEqual(train_spec.max_steps, train_args['max_steps'])
 
+  def test_train_with_no_eval_spec_fails(self):
+    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+    train_spec = training.TrainSpec(
+        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+    eval_spec = None
+
+    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+
+    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
+      executor.run_local()
+
   def test_train_hooks(self):
     mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
     mock_est.latest_checkpoint.return_value = 'checkpoint_path/'