Support writing np.ndarray as tensors in Summary proto.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 19 Jan 2018 02:54:45 +0000 (18:54 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 19 Jan 2018 02:58:17 +0000 (18:58 -0800)
PiperOrigin-RevId: 182474037

tensorflow/contrib/learn/python/learn/estimators/estimator.py
tensorflow/contrib/learn/python/learn/estimators/estimator_test.py

index 2395c7e7172c5b63db4900524a64c11d78b079b5..50c74add86fcf62c738e81426bfaf842fbac2b4e 100644 (file)
@@ -360,10 +360,23 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step):
         logging.warn('Skipping summary for %s, cannot parse string to Summary.',
                      key)
         continue
+    elif isinstance(dictionary[key], np.ndarray):
+      value = summary_proto.value.add()
+      value.tag = key
+      value.node_name = key
+      tensor_proto = tensor_util.make_tensor_proto(dictionary[key])
+      value.tensor.CopyFrom(tensor_proto)
+      logging.info(
+          'Summary for np.ndarray is not visible in Tensorboard by default. '
+          'Consider using a Tensorboard plugin for visualization (see '
+          'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md '  # pylint:disable=line-too-long
+          'for more information).'
+      )
     else:
       logging.warn(
           'Skipping summary for %s, must be a float, np.float32, np.int64, '
-          'np.int32 or int or a serialized string of Summary.', key)
+          'np.int32 or int or np.ndarray or a serialized string of Summary.',
+          key)
   summary_writer.add_summary(summary_proto, current_global_step)
   summary_writer.flush()
 
index 2a13a84627df35a68a4f04b25ab26ceecad0db0d..5f682838b7afadec7a54df782cb5b89ac6746659 100644 (file)
@@ -884,6 +884,35 @@ class EstimatorTest(test.TestCase):
     self.assertTrue('MSE' in output_values)
     self.assertTrue(output_values['MSE'].HasField('histo'))
 
+  def testSummaryWritingWithTensor(self):
+
+    def _streaming_precition_mean_tensor(predictions,
+                                         weights=None,
+                                         metrics_collections=None,
+                                         updates_collections=None,
+                                         name=None):
+      return metric_ops.streaming_mean_tensor(
+          predictions,
+          weights=weights,
+          metrics_collections=metrics_collections,
+          updates_collections=updates_collections,
+          name=name)
+
+    est = estimator.Estimator(model_fn=linear_model_fn)
+    est.fit(input_fn=boston_input_fn, steps=200)
+    est.evaluate(
+        input_fn=boston_input_fn,
+        steps=200,
+        metrics={'PMT': _streaming_precition_mean_tensor})
+    events = util_test.latest_events(est.model_dir + '/eval')
+    output_values = {}
+    for e in events:
+      if e.HasField('summary'):
+        for v in e.summary.value:
+          output_values[v.tag] = v
+    self.assertTrue('PMT' in output_values)
+    self.assertTrue(output_values['PMT'].HasField('tensor'))
+
   def testLossInGraphCollection(self):
 
     class _LossCheckerHook(session_run_hook.SessionRunHook):