"//tensorflow/contrib/batching:batch_py",
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/boosted_trees:init_py",
+ "//tensorflow/contrib/checkpoint/python:checkpoint",
"//tensorflow/contrib/cloud:cloud_py",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
For creating and managing dependencies:
+@@dot_graph_from_checkpoint
@@split_dependency
"""
from __future__ import print_function
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
+from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
+
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
+ name = "checkpoint",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":split_dependency",
+ ":visualize",
+ ],
+)
+
+py_library(
name = "split_dependency",
srcs = ["split_dependency.py"],
srcs_version = "PY2AND3",
"//tensorflow/python/eager:test",
],
)
+
+py_library(
+ name = "visualize",
+ srcs = ["visualize.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:pywrap_tensorflow",
+ ],
+)
+
+py_test(
+ name = "visualize_test",
+ srcs = ["visualize_test.py"],
+ deps = [
+ ":visualize",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/eager:test",
+ ],
+)
--- /dev/null
+"""Utilities for visualizing dependency graphs."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import checkpointable_object_graph_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.training import checkpointable
+
+
+def dot_graph_from_checkpoint(save_path):
+ r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`).
+
+ Useful for inspecting checkpoints and debugging loading issues.
+
+ Example usage from Python (requires pydot):
+ ```python
+ import tensorflow as tf
+ import pydot
+
+ dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt')
+ parsed, = pydot.graph_from_dot_data(dot_string)
+ parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg')
+ ```
+
+ Example command line usage:
+ ```sh
+ python -c "import tensorflow as tf;\
+ print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\
+ | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg
+ ```
+
+ Args:
+ save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save`
+ or `tf.train.latest_checkpoint`.
+ Returns:
+ A graph in DOT format as a string.
+ """
+ reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+ try:
+ object_graph_string = reader.get_tensor(
+ checkpointable.OBJECT_GRAPH_PROTO_KEY)
+ except errors_impl.NotFoundError:
+ raise ValueError(
+ ('The specified checkpoint "%s" does not appear to be object-based (it '
+ 'is missing the key "%s"). Likely it was created with a name-based '
+ 'saver and does not contain an object dependency graph.') % (
+ save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY))
+ shape_map = reader.get_variable_to_shape_map()
+ dtype_map = reader.get_variable_to_dtype_map()
+ object_graph = (
+ checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+ object_graph.ParseFromString(object_graph_string)
+ graph = 'digraph {\n'
+ def _escape(name):
+ return name.replace('"', '\\"')
+ slot_ids = set()
+ for node in object_graph.nodes:
+ for slot_reference in node.slot_variables:
+ slot_ids.add(slot_reference.slot_variable_node_id)
+ for node_id, node in enumerate(object_graph.nodes):
+ if (len(node.attributes) == 1
+ and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY):
+ if node_id in slot_ids:
+ color = 'orange'
+ tooltip_prefix = 'Slot variable'
+ else:
+ color = 'blue'
+ tooltip_prefix = 'Variable'
+ attribute = node.attributes[0]
+ graph += ('N_%d [shape=point label="" color=%s width=.25'
+ ' tooltip="%s %s shape=%s %s"]\n') % (
+ node_id,
+ color,
+ tooltip_prefix,
+ _escape(attribute.full_name),
+ shape_map[attribute.checkpoint_key],
+ dtype_map[attribute.checkpoint_key].name)
+ elif node.slot_variables:
+ graph += ('N_%d [shape=point label="" width=.25 color=red,'
+ 'tooltip="Optimizer"]\n') % node_id
+ else:
+ graph += 'N_%d [shape=point label="" width=.25]\n' % node_id
+ for reference in node.children:
+ graph += 'N_%d -> N_%d [label="%s"]\n' % (
+ node_id, reference.node_id, _escape(reference.local_name))
+ for slot_reference in node.slot_variables:
+ graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % (
+ node_id,
+ slot_reference.slot_variable_node_id,
+ _escape(slot_reference.slot_name))
+ graph += 'N_%d -> N_%d [style=dotted]\n' % (
+ slot_reference.original_variable_node_id,
+ slot_reference.slot_variable_node_id)
+ graph += '}\n'
+ return graph
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+from tensorflow.contrib.checkpoint.python import visualize
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import adam
+from tensorflow.python.training import checkpointable_utils
+
+try:
+ import pydot # pylint: disable=g-import-not-at-top
+except ImportError:
+ pydot = None
+
+
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
+class DotGraphTests(test.TestCase):
+
+ def testMakeDotGraph(self):
+ with context.eager_mode():
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001)
+ optimizer_step = resource_variable_ops.ResourceVariable(12)
+ save_checkpoint = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ optimizer.minimize(functools.partial(model, input_value))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ save_path = save_checkpoint.save(checkpoint_prefix)
+ prefix = save_checkpoint.save(save_path)
+
+ dot_graph_string = visualize.dot_graph_from_checkpoint(prefix)
+
+ # The remainder of this test is more-or-less optional since it's so
+ # dependent on pydot/platform/Python versions.
+ if pydot is None:
+ self.skipTest('pydot is required for the remainder of this test.')
+ try:
+ parsed, = pydot.graph_from_dot_data(dot_graph_string)
+ except NameError as e:
+ if "name 'dot_parser' is not defined" in str(e):
+ self.skipTest("pydot isn't working")
+ else:
+ raise
+ # Check that the graph isn't completely trivial
+ self.assertEqual(
+ '"model"',
+ parsed.obj_dict['edges'][('N_0', 'N_1')][0]['attributes']['label'])
+ image_path = os.path.join(self.get_temp_dir(), 'saved.svg')
+ try:
+ parsed.write_svg(image_path)
+ except Exception as e: # pylint: disable=broad-except
+ # For some reason PyDot's "dot not available" error is an Exception, not
+ # something more specific.
+ if '"dot" not found in path' in str(e):
+ self.skipTest("pydot won't save SVGs (dot not available)")
+ else:
+ raise
+
+if __name__ == '__main__':
+ test.main()