logging.warn('Skipping summary for %s, cannot parse string to Summary.',
key)
continue
+ elif isinstance(dictionary[key], np.ndarray):
+ value = summary_proto.value.add()
+ value.tag = key
+ value.node_name = key
+ tensor_proto = tensor_util.make_tensor_proto(dictionary[key])
+ value.tensor.CopyFrom(tensor_proto)
+ logging.info(
+ 'Summary for np.ndarray is not visible in Tensorboard by default. '
+ 'Consider using a Tensorboard plugin for visualization (see '
+ 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md ' # pylint:disable=line-too-long
+ 'for more information).'
+ )
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, '
- 'np.int32 or int or a serialized string of Summary.', key)
+ 'np.int32 or int or np.ndarray or a serialized string of Summary.',
+ key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()
self.assertTrue('MSE' in output_values)
self.assertTrue(output_values['MSE'].HasField('histo'))
+ def testSummaryWritingWithTensor(self):
+
+ def _streaming_precition_mean_tensor(predictions,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
+ return metric_ops.streaming_mean_tensor(
+ predictions,
+ weights=weights,
+ metrics_collections=metrics_collections,
+ updates_collections=updates_collections,
+ name=name)
+
+ est = estimator.Estimator(model_fn=linear_model_fn)
+ est.fit(input_fn=boston_input_fn, steps=200)
+ est.evaluate(
+ input_fn=boston_input_fn,
+ steps=200,
+ metrics={'PMT': _streaming_precition_mean_tensor})
+ events = util_test.latest_events(est.model_dir + '/eval')
+ output_values = {}
+ for e in events:
+ if e.HasField('summary'):
+ for v in e.summary.value:
+ output_values[v.tag] = v
+ self.assertTrue('PMT' in output_values)
+ self.assertTrue(output_values['PMT'].HasField('tensor'))
+
def testLossInGraphCollection(self):
class _LossCheckerHook(session_run_hook.SessionRunHook):