Only calls compare function if values were read from event file
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 5 Jun 2018 22:59:21 +0000 (15:59 -0700)
committerGuillaume Klein <guillaume.klein@systrangroup.com>
Sat, 9 Jun 2018 14:39:09 +0000 (16:39 +0200)
PiperOrigin-RevId: 199373169

tensorflow/python/estimator/exporter.py
tensorflow/python/estimator/exporter_test.py

index a7212bb..766ea23 100644 (file)
@@ -360,9 +360,10 @@ class BestExporter(Exporter):
           for value in event.summary.value:
             if value.HasField('simple_value'):
               event_eval_result[value.tag] = value.simple_value
-          if best_eval_result is None or self._compare_fn(
-              best_eval_result, event_eval_result):
-            best_eval_result = event_eval_result
+          if event_eval_result:
+            if best_eval_result is None or self._compare_fn(
+                best_eval_result, event_eval_result):
+              best_eval_result = event_eval_result
     return best_eval_result
 
 
index 4cb4bff..c4b0069 100644 (file)
@@ -148,6 +148,40 @@ class BestExporterTest(test.TestCase):
                                     "checkpoint_path", {"loss": 20}, False)
     self.assertEqual(None, export_result)
 
+  def test_best_exporter_with_empty_event(self):
+
+    def _serving_input_receiver_fn():
+      pass
+
+    export_dir_base = tempfile.mkdtemp()
+    gfile.MkDir(export_dir_base)
+    gfile.MkDir(export_dir_base + "/export")
+    gfile.MkDir(export_dir_base + "/eval")
+
+    eval_dir_base = os.path.join(export_dir_base, "eval_continuous")
+    estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1)
+    estimator_lib._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2)
+
+    exporter = exporter_lib.BestExporter(
+        name="best_exporter",
+        serving_input_receiver_fn=_serving_input_receiver_fn,
+        event_file_pattern="eval_continuous/*.tfevents.*",
+        assets_extra={"from/path": "to/path"},
+        as_text=False,
+        exports_to_keep=1)
+
+    estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+    estimator.model_dir = export_dir_base
+    estimator.export_savedmodel.return_value = "export_result_path"
+
+    export_result = exporter.export(estimator, export_dir_base,
+                                    "checkpoint_path", {"loss": 100}, False)
+    self.assertEqual(None, export_result)
+
+    export_result = exporter.export(estimator, export_dir_base,
+                                    "checkpoint_path", {"loss": 10}, False)
+    self.assertEqual("export_result_path", export_result)
+
   def test_garbage_collect_exports(self):
     export_dir_base = tempfile.mkdtemp()
     gfile.MkDir(export_dir_base)