From: A. Unique TensorFlower Date: Fri, 19 Jan 2018 02:54:45 +0000 (-0800) Subject: Support writing np.ndarray as tensors in Summary proto. X-Git-Tag: v1.6.0-rc0~182^2^2~2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e7cfe02588dc8517c23c8b4bf20ee3e6daac15ff;p=platform%2Fupstream%2Ftensorflow.git Support writing np.ndarray as tensors in Summary proto. PiperOrigin-RevId: 182474037 --- diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 2395c7e717..50c74add86 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -360,10 +360,23 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step): logging.warn('Skipping summary for %s, cannot parse string to Summary.', key) continue + elif isinstance(dictionary[key], np.ndarray): + value = summary_proto.value.add() + value.tag = key + value.node_name = key + tensor_proto = tensor_util.make_tensor_proto(dictionary[key]) + value.tensor.CopyFrom(tensor_proto) + logging.info( + 'Summary for np.ndarray is not visible in Tensorboard by default. ' + 'Consider using a Tensorboard plugin for visualization (see ' + 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md ' # pylint:disable=line-too-long + 'for more information).' + ) else: logging.warn( 'Skipping summary for %s, must be a float, np.float32, np.int64, ' - 'np.int32 or int or a serialized string of Summary.', key) + 'np.int32 or int or np.ndarray or a serialized string of Summary.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 2a13a84627..5f682838b7 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -884,6 +884,35 @@ class EstimatorTest(test.TestCase): self.assertTrue('MSE' in output_values) self.assertTrue(output_values['MSE'].HasField('histo')) + def testSummaryWritingWithTensor(self): + + def _streaming_precition_mean_tensor(predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + return metric_ops.streaming_mean_tensor( + predictions, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) + + est = estimator.Estimator(model_fn=linear_model_fn) + est.fit(input_fn=boston_input_fn, steps=200) + est.evaluate( + input_fn=boston_input_fn, + steps=200, + metrics={'PMT': _streaming_precition_mean_tensor}) + events = util_test.latest_events(est.model_dir + '/eval') + output_values = {} + for e in events: + if e.HasField('summary'): + for v in e.summary.value: + output_values[v.tag] = v + self.assertTrue('PMT' in output_values) + self.assertTrue(output_values['PMT'].HasField('tensor')) + def testLossInGraphCollection(self): class _LossCheckerHook(session_run_hook.SessionRunHook):