Checkpointable: Add UniqueNameTracker for managing dependencies on arbitrarily named...
authorAllen Lavoie <allenl@google.com>
Fri, 11 May 2018 21:45:36 +0000 (14:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 21:48:25 +0000 (14:48 -0700)
Makes generating object-unique dependency names easier, which will hopefully discourage people from using Graph-global names with Checkpointable.

PiperOrigin-RevId: 196311633

tensorflow/contrib/checkpoint/__init__.py
tensorflow/contrib/checkpoint/python/BUILD
tensorflow/contrib/checkpoint/python/containers.py [new file with mode: 0644]
tensorflow/contrib/checkpoint/python/containers_test.py [new file with mode: 0644]

index e529b25..c5f7072 100644 (file)
 # ==============================================================================
 """Tools for working with object-based checkpoints.
 
-
-For creating and managing dependencies:
-@@CheckpointableObjectGraph
+Visualization and inspection:
 @@dot_graph_from_checkpoint
 @@object_metadata
+
+Creating and managing dependencies:
+@@Checkpointable
+@@CheckpointableObjectGraph
 @@NoDependency
 @@split_dependency
+@@UniqueNameTracker
 """
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
 from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
 from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
 from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
+from tensorflow.python.training.checkpointable import Checkpointable
 from tensorflow.python.training.checkpointable import NoDependency
 from tensorflow.python.training.checkpointable_utils import object_metadata
 
index a5681ff..cbb9852 100644 (file)
@@ -8,12 +8,35 @@ py_library(
     name = "checkpoint",
     srcs_version = "PY2AND3",
     deps = [
+        ":containers",
         ":split_dependency",
         ":visualize",
     ],
 )
 
 py_library(
+    name = "containers",
+    srcs = ["containers.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:internal"],
+    deps = ["//tensorflow/python:checkpointable"],
+)
+
+py_test(
+    name = "containers_test",
+    srcs = ["containers_test.py"],
+    deps = [
+        ":containers",
+        "//tensorflow/python:checkpointable",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:training",
+        "@six_archive//:six",
+    ],
+)
+
+py_library(
     name = "split_dependency",
     srcs = ["split_dependency.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
new file mode 100644 (file)
index 0000000..82aa04e
--- /dev/null
@@ -0,0 +1,77 @@
+"""Checkpointable data structures."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.training import checkpointable as checkpointable_lib
+
+
+class UniqueNameTracker(checkpointable_lib.CheckpointableBase):
+  """Adds dependencies on checkpointable objects with name hints.
+
+  Useful for creating dependencies with locally unique names.
+
+  Example usage:
+  ```python
+  class SlotManager(tf.contrib.checkpoint.Checkpointable):
+
+    def __init__(self):
+      # Create a dependency named "slotdeps" on the container.
+      self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker()
+      slotdeps = self.slotdeps
+      slots = []
+      slots.append(slotdeps.track(tfe.Variable(3.), "x"))  # Named "x"
+      slots.append(slotdeps.track(tfe.Variable(4.), "y"))
+      slots.append(slotdeps.track(tfe.Variable(5.), "x"))  # Named "x_1"
+  ```
+  """
+
+  def __init__(self):
+    self._maybe_initialize_checkpointable()
+    self._name_counts = {}
+
+  def track(self, checkpointable, base_name):
+    """Add a dependency on `checkpointable`.
+
+    Args:
+      checkpointable: An object to add a checkpoint dependency on.
+      base_name: A name hint, which is uniquified to determine the dependency
+        name.
+    Returns:
+      `checkpointable`, for chaining.
+    Raises:
+      ValueError: If `checkpointable` is not a checkpointable object.
+    """
+
+    if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase):
+      raise ValueError(
+          ("Expected a checkpointable value, got %s which does not inherit "
+           "from CheckpointableBase.") % (checkpointable,))
+
+    def _format_name(prefix, number):
+      if number > 0:
+        return "%s_%d" % (prefix, number)
+      else:
+        return prefix
+
+    count = self._name_counts.get(base_name, 0)
+    candidate = _format_name(base_name, count)
+    while self._lookup_dependency(candidate) is not None:
+      count += 1
+      candidate = _format_name(base_name, count)
+    self._name_counts[base_name] = count + 1
+    return self._track_checkpointable(checkpointable, name=candidate)
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
new file mode 100644 (file)
index 0000000..15775f4
--- /dev/null
@@ -0,0 +1,100 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import six
+
+from tensorflow.contrib.checkpoint.python import containers
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
+from tensorflow.python.training.checkpointable_utils import object_metadata
+
+
+class UniqueNameTrackerTests(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testNames(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+    x1 = resource_variable_ops.ResourceVariable(2.)
+    x2 = resource_variable_ops.ResourceVariable(3.)
+    x3 = resource_variable_ops.ResourceVariable(4.)
+    y = resource_variable_ops.ResourceVariable(5.)
+    slots = containers.UniqueNameTracker()
+    slots.track(x1, "x")
+    slots.track(x2, "x")
+    slots.track(x3, "x_1")
+    slots.track(y, "y")
+    self.evaluate((x1.initializer, x2.initializer, x3.initializer,
+                   y.initializer))
+    save_root = checkpointable_utils.Checkpoint(slots=slots)
+    save_path = save_root.save(checkpoint_prefix)
+
+    restore_slots = checkpointable.Checkpointable()
+    restore_root = checkpointable_utils.Checkpoint(
+        slots=restore_slots)
+    status = restore_root.restore(save_path)
+    restore_slots.x = resource_variable_ops.ResourceVariable(0.)
+    restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.)
+    restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.)
+    restore_slots.y = resource_variable_ops.ResourceVariable(0.)
+    status.assert_consumed().run_restore_ops()
+    self.assertEqual(2., self.evaluate(restore_slots.x))
+    self.assertEqual(3., self.evaluate(restore_slots.x_1))
+    self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
+    self.assertEqual(5., self.evaluate(restore_slots.y))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testExample(self):
+    class SlotManager(checkpointable.Checkpointable):
+
+      def __init__(self):
+        self.slotdeps = containers.UniqueNameTracker()
+        slotdeps = self.slotdeps
+        slots = []
+        slots.append(slotdeps.track(
+            resource_variable_ops.ResourceVariable(3.), "x"))
+        slots.append(slotdeps.track(
+            resource_variable_ops.ResourceVariable(4.), "y"))
+        slots.append(slotdeps.track(
+            resource_variable_ops.ResourceVariable(5.), "x"))
+        self.slots = slots
+
+    manager = SlotManager()
+    self.evaluate([v.initializer for v in manager.slots])
+    checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager)
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    save_path = checkpoint.save(checkpoint_prefix)
+    metadata = object_metadata(save_path)
+    dependency_names = []
+    for node in metadata.nodes:
+      for child in node.children:
+        dependency_names.append(child.local_name)
+    six.assertCountEqual(
+        self,
+        dependency_names,
+        ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
+
+if __name__ == "__main__":
+  test.main()