Add a utility to visualize object-based checkpoints
authorAllen Lavoie <allenl@google.com>
Fri, 20 Apr 2018 19:40:57 +0000 (12:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 20 Apr 2018 19:43:51 +0000 (12:43 -0700)
Useful for generating a warm fuzzy feeling that everything you think should be saved was saved, and for explaining what object-based checkpointing is. (Also useful on the former front will be a planned "assert that all of this Graph's trainable variables are accessible from object X" function.)

Somewhat hacky since it generates strings rather than using the pydot bindings (and so works without a pydot dependency).

PiperOrigin-RevId: 193708003

tensorflow/contrib/BUILD
tensorflow/contrib/checkpoint/__init__.py
tensorflow/contrib/checkpoint/python/BUILD
tensorflow/contrib/checkpoint/python/visualize.py [new file with mode: 0644]
tensorflow/contrib/checkpoint/python/visualize_test.py [new file with mode: 0644]

index 7e47516..d28392a 100644 (file)
@@ -25,6 +25,7 @@ py_library(
         "//tensorflow/contrib/batching:batch_py",
         "//tensorflow/contrib/bayesflow:bayesflow_py",
         "//tensorflow/contrib/boosted_trees:init_py",
+        "//tensorflow/contrib/checkpoint/python:checkpoint",
         "//tensorflow/contrib/cloud:cloud_py",
         "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
         "//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
index 70d7d2d..1192cc4 100644 (file)
@@ -16,6 +16,7 @@
 
 
 For creating and managing dependencies:
+@@dot_graph_from_checkpoint
 @@split_dependency
 """
 
@@ -24,6 +25,8 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
+from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
+
 from tensorflow.python.util.all_util import remove_undocumented
 
 remove_undocumented(module_name=__name__)
index d57b01a..a5681ff 100644 (file)
@@ -5,6 +5,15 @@ package(default_visibility = ["//tensorflow:internal"])
 load("//tensorflow:tensorflow.bzl", "py_test")
 
 py_library(
+    name = "checkpoint",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":split_dependency",
+        ":visualize",
+    ],
+)
+
+py_library(
     name = "split_dependency",
     srcs = ["split_dependency.py"],
     srcs_version = "PY2AND3",
@@ -27,3 +36,26 @@ py_test(
         "//tensorflow/python/eager:test",
     ],
 )
+
+py_library(
+    name = "visualize",
+    srcs = ["visualize.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/python:pywrap_tensorflow",
+    ],
+)
+
+py_test(
+    name = "visualize_test",
+    srcs = ["visualize_test.py"],
+    deps = [
+        ":visualize",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:training",
+        "//tensorflow/python/eager:test",
+    ],
+)
diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py
new file mode 100644 (file)
index 0000000..86fbdb4
--- /dev/null
@@ -0,0 +1,111 @@
+"""Utilities for visualizing dependency graphs."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import checkpointable_object_graph_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.training import checkpointable
+
+
+def dot_graph_from_checkpoint(save_path):
+  r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`).
+
+  Useful for inspecting checkpoints and debugging loading issues.
+
+  Example usage from Python (requires pydot):
+  ```python
+  import tensorflow as tf
+  import pydot
+
+  dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt')
+  parsed, = pydot.graph_from_dot_data(dot_string)
+  parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg')
+  ```
+
+  Example command line usage:
+  ```sh
+  python -c "import tensorflow as tf;\
+    print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\
+    | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg
+  ```
+
+  Args:
+    save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save`
+      or `tf.train.latest_checkpoint`.
+  Returns:
+    A graph in DOT format as a string.
+  """
+  reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+  try:
+    object_graph_string = reader.get_tensor(
+        checkpointable.OBJECT_GRAPH_PROTO_KEY)
+  except errors_impl.NotFoundError:
+    raise ValueError(
+        ('The specified checkpoint "%s" does not appear to be object-based (it '
+         'is missing the key "%s"). Likely it was created with a name-based '
+         'saver and does not contain an object dependency graph.') % (
+             save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY))
+  shape_map = reader.get_variable_to_shape_map()
+  dtype_map = reader.get_variable_to_dtype_map()
+  object_graph = (
+      checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+  object_graph.ParseFromString(object_graph_string)
+  graph = 'digraph {\n'
+  def _escape(name):
+    return name.replace('"', '\\"')
+  slot_ids = set()
+  for node in object_graph.nodes:
+    for slot_reference in node.slot_variables:
+      slot_ids.add(slot_reference.slot_variable_node_id)
+  for node_id, node in enumerate(object_graph.nodes):
+    if (len(node.attributes) == 1
+        and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY):
+      if node_id in slot_ids:
+        color = 'orange'
+        tooltip_prefix = 'Slot variable'
+      else:
+        color = 'blue'
+        tooltip_prefix = 'Variable'
+      attribute = node.attributes[0]
+      graph += ('N_%d [shape=point label="" color=%s width=.25'
+                ' tooltip="%s %s shape=%s %s"]\n') % (
+                    node_id,
+                    color,
+                    tooltip_prefix,
+                    _escape(attribute.full_name),
+                    shape_map[attribute.checkpoint_key],
+                    dtype_map[attribute.checkpoint_key].name)
+    elif node.slot_variables:
+      graph += ('N_%d [shape=point label="" width=.25 color=red,'
+                'tooltip="Optimizer"]\n') % node_id
+    else:
+      graph += 'N_%d [shape=point label="" width=.25]\n' % node_id
+    for reference in node.children:
+      graph += 'N_%d -> N_%d [label="%s"]\n' % (
+          node_id, reference.node_id, _escape(reference.local_name))
+    for slot_reference in node.slot_variables:
+      graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % (
+          node_id,
+          slot_reference.slot_variable_node_id,
+          _escape(slot_reference.slot_name))
+      graph += 'N_%d -> N_%d [style=dotted]\n' % (
+          slot_reference.original_variable_node_id,
+          slot_reference.slot_variable_node_id)
+  graph += '}\n'
+  return graph
diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py
new file mode 100644 (file)
index 0000000..1d9ab78
--- /dev/null
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+from tensorflow.contrib.checkpoint.python import visualize
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import adam
+from tensorflow.python.training import checkpointable_utils
+
+try:
+  import pydot  # pylint: disable=g-import-not-at-top
+except ImportError:
+  pydot = None
+
+
+class MyModel(training.Model):
+  """A concrete Model for testing."""
+
+  def __init__(self):
+    super(MyModel, self).__init__()
+    self._named_dense = core.Dense(1, use_bias=True)
+    self._second = core.Dense(1, use_bias=False)
+
+  def call(self, values):
+    ret = self._second(self._named_dense(values))
+    return ret
+
+
+class DotGraphTests(test.TestCase):
+
+  def testMakeDotGraph(self):
+    with context.eager_mode():
+      input_value = constant_op.constant([[3.]])
+      model = MyModel()
+      optimizer = adam.AdamOptimizer(0.001)
+      optimizer_step = resource_variable_ops.ResourceVariable(12)
+      save_checkpoint = checkpointable_utils.Checkpoint(
+          optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+      optimizer.minimize(functools.partial(model, input_value))
+      checkpoint_directory = self.get_temp_dir()
+      checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+      save_path = save_checkpoint.save(checkpoint_prefix)
+      prefix = save_checkpoint.save(save_path)
+
+    dot_graph_string = visualize.dot_graph_from_checkpoint(prefix)
+
+    # The remainder of this test is more-or-less optional since it's so
+    # dependent on pydot/platform/Python versions.
+    if pydot is None:
+      self.skipTest('pydot is required for the remainder of this test.')
+    try:
+      parsed, = pydot.graph_from_dot_data(dot_graph_string)
+    except NameError as e:
+      if "name 'dot_parser' is not defined" in str(e):
+        self.skipTest("pydot isn't working")
+      else:
+        raise
+    # Check that the graph isn't completely trivial
+    self.assertEqual(
+        '"model"',
+        parsed.obj_dict['edges'][('N_0', 'N_1')][0]['attributes']['label'])
+    image_path = os.path.join(self.get_temp_dir(), 'saved.svg')
+    try:
+      parsed.write_svg(image_path)
+    except Exception as e:  # pylint: disable=broad-except
+      # For some reason PyDot's "dot not available" error is an Exception, not
+      # something more specific.
+      if '"dot" not found in path' in str(e):
+        self.skipTest("pydot won't save SVGs (dot not available)")
+      else:
+        raise
+
+if __name__ == '__main__':
+  test.main()