From 538cab2870fd02b4d89e0534ad52573e67f16606 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 5 Apr 2018 18:50:35 -0700 Subject: [PATCH] Add RunMetadata logging to tf.train.ProfilerHook for Tensorboard Memeory/CPU usage visualization PiperOrigin-RevId: 191832832 --- .../contrib/testing/python/framework/fake_summary_writer.py | 6 ++++++ tensorflow/python/training/basic_session_run_hooks.py | 3 +++ tensorflow/python/training/basic_session_run_hooks_test.py | 13 +++++++++++++ 3 files changed, 22 insertions(+) diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py index 15a415d..eac34af 100644 --- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py +++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py @@ -52,6 +52,7 @@ class FakeSummaryWriter(object): self._added_graphs = [] self._added_meta_graphs = [] self._added_session_logs = [] + self._added_run_metadata = {} @property def summaries(self): @@ -127,6 +128,11 @@ class FakeSummaryWriter(object): # pylint: disable=unused-argument self._added_session_logs.append(session_log) + def add_run_metadata(self, run_metadata, tag, global_step=None): + if (global_step is not None) and (global_step < 0): + raise ValueError('Invalid global_step %s.' % global_step) + self._added_run_metadata[tag] = run_metadata + def flush(self): pass diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index aae757b..094a9e8 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -859,6 +859,7 @@ class ProfilerHook(session_run_hook.SessionRunHook): showing the sizes and lifetimes of tensors. """ self._output_file = os.path.join(output_dir, "timeline-{}.json") + self._file_writer = SummaryWriterCache.get(output_dir) self._show_dataflow = show_dataflow self._show_memory = show_memory self._timer = SecondOrStepTimer( @@ -889,6 +890,8 @@ class ProfilerHook(session_run_hook.SessionRunHook): self._save(global_step, self._output_file.format(global_step), run_values.run_metadata.step_stats) + self._file_writer.add_run_metadata(run_values.run_metadata, + "step_%d" % global_step) self._next_step = global_step + 1 diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 2547661..f39a526 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -1274,6 +1274,19 @@ class ProfilerHookTest(test.TestCase): sess.run(self.train_op) # Saved. self.assertEqual(3, self._count_timeline_files()) + def test_run_metadata_saves_in_first_step(self): + writer_cache.FileWriterCache.clear() + fake_summary_writer.FakeSummaryWriter.install() + fake_writer = writer_cache.FileWriterCache.get(self.output_dir) + with self.graph.as_default(): + hook = basic_session_run_hooks.ProfilerHook( + save_secs=2, output_dir=self.output_dir) + with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: + sess.run(self.train_op) # Saved. + self.assertEqual( + list(fake_writer._added_run_metadata.keys()), ['step_1']) + fake_summary_writer.FakeSummaryWriter.uninstall() + if __name__ == '__main__': test.main() -- 2.7.4