Make tf.contrib.estimator.add_metrics work with warm-starting.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 22:11:02 +0000 (15:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 22:13:56 +0000 (15:13 -0700)
PiperOrigin-RevId: 191805682

tensorflow/contrib/estimator/python/estimator/extenders.py

index 266ae93..201699e 100644 (file)
@@ -97,7 +97,10 @@ def add_metrics(estimator, metric_fn):
   return estimator_lib.Estimator(
       model_fn=new_model_fn,
       model_dir=estimator.model_dir,
-      config=estimator.config)
+      config=estimator.config,
+      # pylint: disable=protected-access
+      warm_start_from=estimator._warm_start_settings)
+      # pylint: enable=protected-access
 
 
 def clip_gradients_by_norm(optimizer, clip_norm):