eval_spec: A `EvalSpec` instance to specify the evaluation and export
specification.
+ Returns:
+ A tuple of the result of the `evaluate` call to the `Estimator` and the
+ export results using the specified `ExportStrategy`.
+ Currently, the return value is undefined for distributed training mode.
+
Raises:
ValueError: if environment variable `TF_CONFIG` is incorrectly set.
"""
procedure is `run_foo'. This `run` method invoke the procedure base on the
`RunConfig.task_type`.
+ Returns:
+ A tuple of the result of the `evaluate` call to the `Estimator` and the
+ export results using the specified `ExportStrategy`.
+ Currently undefined for distributed training mode.
+
Raises:
ValueError: if the estimator.config is mis-configured.
"""
if (not config.cluster_spec and
config.task_type != run_config_lib.TaskType.EVALUATOR):
logging.info('Running training and evaluation locally (non-distributed).')
- self.run_local()
- return
+ return self.run_local()
# Distributed case.
if not config.task_type:
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
self._train_spec.max_steps)
+ eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
+ export_results = []
+
while True:
self._estimator.train(
input_fn=self._train_spec.input_fn,
# _should_stop_local_train will then end the while True as the stopping
# condition is satisfied (both checks use the same global_step value,
# i.e., no race condition)
- eval_result = evaluator.evaluate_and_export()
+ eval_result, export_results = evaluator.evaluate_and_export()
if eval_result.status != _EvalStatus.EVALUATED:
# This is unexpected; should never happen.
if _should_stop_local_train(
eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
break
+ return eval_result.metrics, export_results
def _start_std_server(self, config):
"""Creates, starts, and returns a server_lib.Server."""
# iteration of while loop will end the continuous eval as the stopping
# condition is satisfied (both checks use the same global_step value,
# i.e., no race condition)
- eval_result = evaluator.evaluate_and_export()
+ eval_result, _ = evaluator.evaluate_and_export()
if not self._continuous_eval_listener.after_eval(eval_result):
logging.info('Exiting evaluation, as requested by '
"""Evaluate and (maybe) export the current model.
Returns:
- An `EvalResult` instance.
+ A tuple of `EvalResult` instance and the export results.
Raises:
RuntimeError: for any unexpected internal error.
if not latest_ckpt_path:
self._log_err_msg('Estimator is not trained yet. Will start an '
'evaluation when a checkpoint is ready.')
- return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
+ return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), []
if latest_ckpt_path == self._previous_ckpt_path:
self._log_err_msg(
'No new checkpoint ready for evaluation. Skip the current '
'evaluation pass as evaluation results are expected to be same '
'for the same checkpoint.')
- return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT)
+ return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), []
metrics = self._estimator.evaluate(
input_fn=self._eval_spec.input_fn,
is_the_final_export = (
eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >=
self._max_training_steps if self._max_training_steps else False)
- self._export_eval_result(eval_result, is_the_final_export)
+ export_results = self._export_eval_result(eval_result,
+ is_the_final_export)
if is_the_final_export:
logging.debug('Calling exporter with the `is_the_final_export=True`.')
self._last_warning_time = 0
self._previous_ckpt_path = latest_ckpt_path
- return eval_result
+ return eval_result, export_results
def _log_err_msg(self, message):
"""Prints warning `message` every 10 mins."""
compat.as_str_any(self._estimator.model_dir),
compat.as_str_any('export'))
+ export_results = []
for exporter in self._eval_spec.exporters:
- exporter.export(
- estimator=self._estimator,
- export_path=os.path.join(
- compat.as_str_any(export_dir_base),
- compat.as_str_any(exporter.name)),
- checkpoint_path=eval_result.checkpoint_path,
- eval_result=eval_result.metrics,
- is_the_final_export=is_the_final_export)
+ export_results.append(
+ exporter.export(
+ estimator=self._estimator,
+ export_path=os.path.join(
+ compat.as_str_any(export_dir_base),
+ compat.as_str_any(exporter.name)),
+ checkpoint_path=eval_result.checkpoint_path,
+ eval_result=eval_result.metrics,
+ is_the_final_export=is_the_final_export))
+ return export_results
class _EvalStatus(object):
def export(estimator, *args, **kwargs):
del args, kwargs
estimator.export_was_called = True
+ return 'path_to_export'
exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_whether_export_is_called'
exporters=exporter)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
- executor.run_local()
+ # pylint: disable=assignment-from-no-return
+ _, export_results = executor.run_local()
+ # pylint: enable=assignment-from-no-return
self.assertTrue(mock_est.export_was_called)
+ self.assertEqual(export_results, ['path_to_export'])
def test_errors_out_if_evaluate_returns_empty_dict(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
mock_est.evaluate.return_value = 123
-
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
executor.run_local()
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
executor.run_local()
+ def test_train_and_evaluate_return_metrics(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+ mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
+ mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+ # pylint: disable=assignment-from-no-return
+ metrics, _ = executor.run_local()
+ # pylint: enable=assignment-from-no-return
+ self.assertEqual(metrics['global_step'], 300)
+
class TrainAndEvaluateRunTest(test.TestCase):