Enable checkpointless eval and predict for tf.estimator.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 1 May 2018 21:04:59 +0000 (14:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 21:07:19 +0000 (14:07 -0700)
PiperOrigin-RevId: 194993191

tensorflow/python/estimator/estimator.py
tensorflow/python/estimator/estimator_test.py

index 2363845..63099b4 100644 (file)
@@ -400,7 +400,9 @@ class Estimator(object):
       hooks: List of `SessionRunHook` subclass instances. Used for callbacks
         inside the evaluation call.
       checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
-        latest checkpoint in `model_dir` is used.
+        latest checkpoint in `model_dir` is used.  If there are no checkpoints
+        in `model_dir`, evaluation is run with newly initialized `Variables`
+        instead of restored from checkpoint.
       name: Name of the evaluation if user needs to run multiple evaluations on
         different data sets, such as on training data vs test data. Metrics for
         different evaluations are saved in separate folders, and appear
@@ -464,7 +466,9 @@ class Estimator(object):
       hooks: List of `SessionRunHook` subclass instances. Used for callbacks
         inside the prediction call.
       checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
-        latest checkpoint in `model_dir` is used.
+        latest checkpoint in `model_dir` is used.  If there are no checkpoints
+        in `model_dir`, prediction is run with newly initialized `Variables`
+        instead of restored from checkpoint.
       yield_single_examples: If False, yield the whole batch as returned by the
         `model_fn` instead of decomposing the batch into individual elements.
         This is useful if `model_fn` returns some tensors whose first dimension
@@ -487,9 +491,8 @@ class Estimator(object):
       if not checkpoint_path:
         checkpoint_path = saver.latest_checkpoint(self._model_dir)
       if not checkpoint_path:
-        raise ValueError(
-            'Could not find trained model in model_dir: {}.'.format(
-                self._model_dir))
+        logging.info('Could not find trained model in model_dir: {}, running '
+                     'initialization to predict.'.format(self._model_dir))
 
       with ops.Graph().as_default() as g:
         random_seed.set_random_seed(self._config.tf_random_seed)
@@ -1068,8 +1071,8 @@ class Estimator(object):
     if not checkpoint_path:
       latest_path = saver.latest_checkpoint(self._model_dir)
       if not latest_path:
-        raise ValueError('Could not find trained model in model_dir: {}.'.
-                         format(self._model_dir))
+        logging.info('Could not find trained model in model_dir: {}, running '
+                     'initialization to evaluate.'.format(self._model_dir))
       checkpoint_path = latest_path
 
     # Setup output directory.
index 0fea861..74114fa 100644 (file)
@@ -1067,11 +1067,19 @@ class EstimatorEvaluateTest(test.TestCase):
         ValueError, 'model_fn should return an EstimatorSpec'):
       est.evaluate(dummy_input_fn, steps=1)
 
-  def test_no_trained_model(self):
-    est = estimator.Estimator(model_fn=_model_fn_with_eval_metric_ops)
-    with self.assertRaisesRegexp(
-        ValueError, 'Could not find trained model in model_dir'):
-      est.evaluate(dummy_input_fn, steps=1)
+  def test_no_checkpoint_uses_init(self):
+    def _model_fn(features, labels, mode, params):
+      del features, labels, params
+      return model_fn_lib.EstimatorSpec(
+          mode,
+          loss=constant_op.constant(1.),
+          eval_metric_ops={'metric': metrics_lib.mean(
+              variables.Variable(2.) + 1)})
+    est = estimator.Estimator(model_fn=_model_fn)
+    metrics = est.evaluate(dummy_input_fn, steps=1)
+    # Metric value here is set to 1 + the value of the Variable that is newly
+    # initialized (since there is no checkpoint).
+    self.assertEqual(3., metrics['metric'])
 
   def test_scores(self):
     est = estimator.Estimator(
@@ -1331,11 +1339,15 @@ class EstimatorPredictTest(test.TestCase):
     next(est.predict(_input_fn))
     self.assertEqual(1, input_fn_call_count[0])
 
-  def test_no_trained_model_in_model_dir(self):
-    est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
-    with self.assertRaisesRegexp(ValueError,
-                                 'Could not find trained model in model_dir'):
-      next(est.predict(dummy_input_fn))
+  def test_no_checkpoint_uses_init(self):
+    def _model_fn(features, labels, mode, params, config):
+      del features, labels, params, config
+      x = variables.Variable([[3.]], name='x')
+      return model_fn_lib.EstimatorSpec(mode, predictions=math_ops.add(x, 1.))
+    est = estimator.Estimator(model_fn=_model_fn)
+    # Expected prediction value is 1 + the value of the Variable that is newly
+    # initialized (since there is no checkpoint).
+    self.assertEqual(4., next(est.predict(dummy_input_fn)))
 
   def test_no_trained_model_invalid_checkpoint_path(self):
     est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)