Add evaluation metrics and export results in the new train_and_evaluate API (for...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 09:47:24 +0000 (02:47 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 09:49:54 +0000 (02:49 -0700)
PiperOrigin-RevId: 196962253

tensorflow/python/estimator/training.py
tensorflow/python/estimator/training_test.py

index 2f14a6f..dc5edc7 100644 (file)
@@ -424,6 +424,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
     eval_spec: A `EvalSpec` instance to specify the evaluation and export
       specification.
 
+  Returns:
+    A tuple of the result of the `evaluate` call to the `Estimator` and the
+    export results using the specified `ExportStrategy`.
+    Currently, the return value is undefined for distributed training mode.
+
   Raises:
     ValueError: if environment variable `TF_CONFIG` is incorrectly set.
   """
@@ -510,6 +515,11 @@ class _TrainingExecutor(object):
     procedure is `run_foo'. This `run` method invoke the procedure base on the
     `RunConfig.task_type`.
 
+    Returns:
+      A tuple of the result of the `evaluate` call to the `Estimator` and the
+      export results using the specified `ExportStrategy`.
+      Currently undefined for distributed training mode.
+
     Raises:
       ValueError: if the estimator.config is mis-configured.
     """
@@ -518,8 +528,7 @@ class _TrainingExecutor(object):
     if (not config.cluster_spec and
         config.task_type != run_config_lib.TaskType.EVALUATOR):
       logging.info('Running training and evaluation locally (non-distributed).')
-      self.run_local()
-      return
+      return self.run_local()
 
     # Distributed case.
     if not config.task_type:
@@ -650,6 +659,9 @@ class _TrainingExecutor(object):
     evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
                                              self._train_spec.max_steps)
 
+    eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
+    export_results = []
+
     while True:
       self._estimator.train(
           input_fn=self._train_spec.input_fn,
@@ -666,7 +678,7 @@ class _TrainingExecutor(object):
       # _should_stop_local_train will then end the while True as the stopping
       # condition is satisfied (both checks use the same global_step value,
       # i.e., no race condition)
-      eval_result = evaluator.evaluate_and_export()
+      eval_result, export_results = evaluator.evaluate_and_export()
 
       if eval_result.status != _EvalStatus.EVALUATED:
         #  This is unexpected; should never happen.
@@ -682,6 +694,7 @@ class _TrainingExecutor(object):
       if _should_stop_local_train(
           eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
         break
+    return eval_result.metrics, export_results
 
   def _start_std_server(self, config):
     """Creates, starts, and returns a server_lib.Server."""
@@ -807,7 +820,7 @@ class _TrainingExecutor(object):
     # iteration of while loop will end the continuous eval as the stopping
     # condition is satisfied (both checks use the same global_step value,
     # i.e., no race condition)
-    eval_result = evaluator.evaluate_and_export()
+    eval_result, _ = evaluator.evaluate_and_export()
 
     if not self._continuous_eval_listener.after_eval(eval_result):
       logging.info('Exiting evaluation, as requested by '
@@ -846,7 +859,7 @@ class _TrainingExecutor(object):
       """Evaluate and (maybe) export the current model.
 
       Returns:
-        An `EvalResult` instance.
+        A tuple of `EvalResult` instance and the export results.
 
       Raises:
         RuntimeError: for any unexpected internal error.
@@ -856,14 +869,14 @@ class _TrainingExecutor(object):
       if not latest_ckpt_path:
         self._log_err_msg('Estimator is not trained yet. Will start an '
                           'evaluation when a checkpoint is ready.')
-        return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
+        return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), []
 
       if latest_ckpt_path == self._previous_ckpt_path:
         self._log_err_msg(
             'No new checkpoint ready for evaluation. Skip the current '
             'evaluation pass as evaluation results are expected to be same '
             'for the same checkpoint.')
-        return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT)
+        return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), []
 
       metrics = self._estimator.evaluate(
           input_fn=self._eval_spec.input_fn,
@@ -881,7 +894,8 @@ class _TrainingExecutor(object):
       is_the_final_export = (
           eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
           self._max_training_steps if self._max_training_steps else False)
-      self._export_eval_result(eval_result, is_the_final_export)
+      export_results = self._export_eval_result(eval_result,
+                                                is_the_final_export)
 
       if is_the_final_export:
         logging.debug('Calling exporter with the `is_the_final_export=True`.')
@@ -889,7 +903,7 @@ class _TrainingExecutor(object):
 
       self._last_warning_time = 0
       self._previous_ckpt_path = latest_ckpt_path
-      return eval_result
+      return eval_result, export_results
 
     def _log_err_msg(self, message):
       """Prints warning `message` every 10 mins."""
@@ -904,15 +918,18 @@ class _TrainingExecutor(object):
           compat.as_str_any(self._estimator.model_dir),
           compat.as_str_any('export'))
 
+      export_results = []
       for exporter in self._eval_spec.exporters:
-        exporter.export(
-            estimator=self._estimator,
-            export_path=os.path.join(
-                compat.as_str_any(export_dir_base),
-                compat.as_str_any(exporter.name)),
-            checkpoint_path=eval_result.checkpoint_path,
-            eval_result=eval_result.metrics,
-            is_the_final_export=is_the_final_export)
+        export_results.append(
+            exporter.export(
+                estimator=self._estimator,
+                export_path=os.path.join(
+                    compat.as_str_any(export_dir_base),
+                    compat.as_str_any(exporter.name)),
+                checkpoint_path=eval_result.checkpoint_path,
+                eval_result=eval_result.metrics,
+                is_the_final_export=is_the_final_export))
+      return export_results
 
 
 class _EvalStatus(object):
index 3b6f5e1..2c838db 100644 (file)
@@ -1835,6 +1835,7 @@ class TrainingExecutorRunLocalTest(test.TestCase):
     def export(estimator, *args, **kwargs):
       del args, kwargs
       estimator.export_was_called = True
+      return 'path_to_export'
 
     exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
     exporter.name = 'see_whether_export_is_called'
@@ -1848,9 +1849,12 @@ class TrainingExecutorRunLocalTest(test.TestCase):
         exporters=exporter)
 
     executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
-    executor.run_local()
+    # pylint: disable=assignment-from-no-return
+    _, export_results = executor.run_local()
+    # pylint: enable=assignment-from-no-return
 
     self.assertTrue(mock_est.export_was_called)
+    self.assertEqual(export_results, ['path_to_export'])
 
   def test_errors_out_if_evaluate_returns_empty_dict(self):
     mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
@@ -1867,7 +1871,6 @@ class TrainingExecutorRunLocalTest(test.TestCase):
     train_spec = training.TrainSpec(input_fn=lambda: 1)
     eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
     mock_est.evaluate.return_value = 123
-
     executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
     with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
       executor.run_local()
@@ -1883,6 +1886,21 @@ class TrainingExecutorRunLocalTest(test.TestCase):
                                  _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
       executor.run_local()
 
+  def test_train_and_evaluate_return_metrics(self):
+    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
+    train_spec = training.TrainSpec(
+        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+    eval_spec = training.EvalSpec(
+        input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
+    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+
+    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+    # pylint: disable=assignment-from-no-return
+    metrics, _ = executor.run_local()
+    # pylint: enable=assignment-from-no-return
+    self.assertEqual(metrics['global_step'], 300)
+
 
 class TrainAndEvaluateRunTest(test.TestCase):