Add RunMetadata logging to tf.train.ProfilerHook for Tensorboard Memeory/CPU usage...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 01:50:35 +0000 (18:50 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 01:52:44 +0000 (18:52 -0700)
PiperOrigin-RevId: 191832832

tensorflow/contrib/testing/python/framework/fake_summary_writer.py
tensorflow/python/training/basic_session_run_hooks.py
tensorflow/python/training/basic_session_run_hooks_test.py

index 15a415d..eac34af 100644 (file)
@@ -52,6 +52,7 @@ class FakeSummaryWriter(object):
     self._added_graphs = []
     self._added_meta_graphs = []
     self._added_session_logs = []
+    self._added_run_metadata = {}
 
   @property
   def summaries(self):
@@ -127,6 +128,11 @@ class FakeSummaryWriter(object):
     # pylint: disable=unused-argument
     self._added_session_logs.append(session_log)
 
+  def add_run_metadata(self, run_metadata, tag, global_step=None):
+    if (global_step is not None) and (global_step < 0):
+      raise ValueError('Invalid global_step %s.' % global_step)
+    self._added_run_metadata[tag] = run_metadata
+
   def flush(self):
     pass
 
index aae757b..094a9e8 100644 (file)
@@ -859,6 +859,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
           showing the sizes and lifetimes of tensors.
     """
     self._output_file = os.path.join(output_dir, "timeline-{}.json")
+    self._file_writer = SummaryWriterCache.get(output_dir)
     self._show_dataflow = show_dataflow
     self._show_memory = show_memory
     self._timer = SecondOrStepTimer(
@@ -889,6 +890,8 @@ class ProfilerHook(session_run_hook.SessionRunHook):
       self._save(global_step,
                  self._output_file.format(global_step),
                  run_values.run_metadata.step_stats)
+      self._file_writer.add_run_metadata(run_values.run_metadata,
+                                         "step_%d" % global_step)
 
     self._next_step = global_step + 1
 
index 2547661..f39a526 100644 (file)
@@ -1274,6 +1274,19 @@ class ProfilerHookTest(test.TestCase):
         sess.run(self.train_op)  # Saved.
         self.assertEqual(3, self._count_timeline_files())
 
+  def test_run_metadata_saves_in_first_step(self):
+    writer_cache.FileWriterCache.clear()
+    fake_summary_writer.FakeSummaryWriter.install()
+    fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
+    with self.graph.as_default():
+      hook = basic_session_run_hooks.ProfilerHook(
+          save_secs=2, output_dir=self.output_dir)
+      with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+        sess.run(self.train_op)  # Saved.
+        self.assertEqual(
+            list(fake_writer._added_run_metadata.keys()), ['step_1'])
+    fake_summary_writer.FakeSummaryWriter.uninstall()
+
 
 if __name__ == '__main__':
   test.main()