Add a checkpointable list data structure
authorAllen Lavoie <allenl@google.com>
Wed, 23 May 2018 17:43:28 +0000 (10:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 17:45:45 +0000 (10:45 -0700)
Allows tracking of Layers and other checkpointable objects by number.

Fixes #19250.

PiperOrigin-RevId: 197749961

tensorflow/contrib/checkpoint/__init__.py
tensorflow/contrib/checkpoint/python/BUILD
tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
tensorflow/python/keras/BUILD
tensorflow/python/keras/engine/network.py
tensorflow/python/training/checkpointable/BUILD
tensorflow/python/training/checkpointable/base.py
tensorflow/python/training/checkpointable/data_structures.py [new file with mode: 0644]
tensorflow/python/training/checkpointable/data_structures_base.py [new file with mode: 0644]
tensorflow/python/training/checkpointable/data_structures_test.py [new file with mode: 0644]

index af8df72..bd0bc9e 100644 (file)
@@ -18,11 +18,14 @@ Visualization and inspection:
 @@dot_graph_from_checkpoint
 @@object_metadata
 
-Creating and managing dependencies:
+Managing dependencies:
 @@Checkpointable
 @@CheckpointableObjectGraph
 @@NoDependency
 @@split_dependency
+
+Checkpointable data structures:
+@@List
 @@UniqueNameTracker
 """
 
@@ -36,6 +39,7 @@ from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkp
 from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
 from tensorflow.python.training.checkpointable.base import Checkpointable
 from tensorflow.python.training.checkpointable.base import NoDependency
+from tensorflow.python.training.checkpointable.data_structures import List
 from tensorflow.python.training.checkpointable.util import object_metadata
 
 from tensorflow.python.util.all_util import remove_undocumented
index 53f4e97..0b67619 100644 (file)
@@ -11,6 +11,7 @@ py_library(
         ":containers",
         ":split_dependency",
         ":visualize",
+        "//tensorflow/python/training/checkpointable:data_structures",
     ],
 )
 
@@ -30,8 +31,8 @@ py_test(
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:training",
         "//tensorflow/python/training/checkpointable:base",
+        "//tensorflow/python/training/checkpointable:util",
         "@six_archive//:six",
     ],
 )
@@ -44,6 +45,7 @@ py_library(
     deps = [
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:training",
+        "//tensorflow/python/training/checkpointable:base",
     ],
 )
 
@@ -55,8 +57,9 @@ py_test(
         "//tensorflow/python:array_ops",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:resource_variable_ops",
-        "//tensorflow/python:training",
         "//tensorflow/python/eager:test",
+        "//tensorflow/python/training/checkpointable:base",
+        "//tensorflow/python/training/checkpointable:util",
     ],
 )
 
@@ -67,6 +70,8 @@ py_library(
     visibility = ["//tensorflow:internal"],
     deps = [
         "//tensorflow/python:pywrap_tensorflow",
+        "//tensorflow/python/training/checkpointable:base",
+        "//tensorflow/python/training/checkpointable:util",
     ],
 )
 
@@ -75,10 +80,13 @@ py_test(
     srcs = ["visualize_test.py"],
     deps = [
         ":visualize",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:constant_op",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:training",
+        "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:test",
+        "//tensorflow/python/keras:engine",
+        "//tensorflow/python/keras:layers",
+        "//tensorflow/python/training/checkpointable:util",
     ],
 )
index 492adbe..5ee2176 100644 (file)
@@ -152,7 +152,7 @@ class RNNColorbot(tf.keras.Model):
     self.label_dimension = label_dimension
     self.keep_prob = keep_prob
 
-    self.cells = self._add_cells(
+    self.cells = tf.contrib.checkpoint.List(
         [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes])
     self.relu = layers.Dense(
         label_dimension, activation=tf.nn.relu, name="relu")
@@ -204,14 +204,6 @@ class RNNColorbot(tf.keras.Model):
     hidden_states = tf.gather_nd(chars, indices)
     return self.relu(hidden_states)
 
-  def _add_cells(self, cells):
-    # "Magic" required for keras.Model classes to track all the variables in
-    # a list of layers.Layer objects.
-    # TODO(ashankar): Figure out API so user code doesn't have to do this.
-    for i, c in enumerate(cells):
-      setattr(self, "cell-%d" % i, c)
-    return cells
-
 
 def loss(labels, predictions):
   """Computes mean squared loss."""
index 74701b2..c2340a2 100644 (file)
@@ -50,7 +50,7 @@ class RNN(tf.keras.Model):
   def __init__(self, hidden_dim, num_layers, keep_ratio):
     super(RNN, self).__init__()
     self.keep_ratio = keep_ratio
-    self.cells = self._add_cells([
+    self.cells = tf.contrib.checkpoint.List([
         tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
         for _ in range(num_layers)
     ])
@@ -74,14 +74,6 @@ class RNN(tf.keras.Model):
     # tuple (output, output_states).
     return [input_seq]
 
-  def _add_cells(self, cells):
-    # "Magic" required for keras.Model classes to track all the variables in
-    # a list of Layer objects.
-    # TODO(ashankar): Figure out API so user code doesn't have to do this.
-    for i, c in enumerate(cells):
-      setattr(self, "cell-%d" % i, c)
-    return cells
-
 
 class Embedding(layers.Layer):
   """An Embedding layer."""
index 5d73069..fe40c9f 100755 (executable)
@@ -135,6 +135,7 @@ py_library(
     deps = [
         ":backend",
         "//tensorflow/python/data",
+        "//tensorflow/python/training/checkpointable:data_structures_base",
         "@six_archive//:six",
     ],
 )
index 4a0e16f..6e818ec 100644 (file)
@@ -41,6 +41,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures_base
 from tensorflow.python.training.checkpointable import util as checkpointable_utils
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
@@ -321,7 +322,10 @@ class Network(base_layer.Layer):
     no_dependency = isinstance(value, checkpointable.NoDependency)
     if no_dependency:
       value = value.value
-    if isinstance(value, (base_layer.Layer, Network)):
+    if isinstance(value, (
+        base_layer.Layer,
+        Network,
+        data_structures_base.CheckpointableDataStructureBase)):
       try:
         is_graph_network = self._is_graph_network
       except AttributeError:
index a7ae6e5..87ba4dc 100644 (file)
@@ -22,8 +22,9 @@ py_library(
         "//tensorflow/python:constant_op",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
         "//tensorflow/python:io_ops_gen",
-        "//tensorflow/python:ops",
+        "//tensorflow/python:platform",
         "//tensorflow/python:saveable_object",
         "//tensorflow/python:util",
         "//tensorflow/python/eager:context",
@@ -41,6 +42,42 @@ py_test(
 )
 
 py_library(
+    name = "data_structures_base",
+    srcs = ["data_structures_base.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":base",
+    ],
+)
+
+py_library(
+    name = "data_structures",
+    srcs = ["data_structures.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":base",
+        ":data_structures_base",
+    ],
+)
+
+py_test(
+    name = "data_structures_test",
+    srcs = ["data_structures_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":data_structures",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python/eager:context",
+        "//tensorflow/python/eager:test",
+        "//tensorflow/python/keras:engine",
+        "//tensorflow/python/keras:layers",
+    ],
+)
+
+py_library(
     name = "util",
     srcs = ["util.py"],
     srcs_version = "PY2AND3",
index e378f0e..cfe7259 100644 (file)
@@ -591,11 +591,11 @@ class CheckpointableBase(object):
           self._unconditional_checkpoint_dependencies):
         if name == old_name:
           self._unconditional_checkpoint_dependencies[index] = new_reference
-    else:
+    elif current_object is None:
       self._unconditional_checkpoint_dependencies.append(new_reference)
-
-    self._unconditional_dependency_names[name] = checkpointable
-    self._handle_deferred_dependencies(name=name, checkpointable=checkpointable)
+      self._unconditional_dependency_names[name] = checkpointable
+      self._handle_deferred_dependencies(
+          name=name, checkpointable=checkpointable)
     return checkpointable
 
   def _handle_deferred_dependencies(self, name, checkpointable):
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
new file mode 100644 (file)
index 0000000..b514f7b
--- /dev/null
@@ -0,0 +1,218 @@
+"""Checkpointable data structures."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import data_structures_base
+
+
+# TODO(allenl): We could track regular Python data structures which get assigned
+# to Checkpointable objects. Making this work with restore-on-create would be
+# tricky; we'd need to re-create nested structures with our own wrapped objects
+# on assignment to an attribute, and track the user's original structure to make
+# sure they don't modify it except through the wrappers (since we could save the
+# user's updated structure, but would have no way to support restore-on-create
+# for those modifications).
+# TODO(allenl): A dictionary data structure would be good too.
+class CheckpointableDataStructure(
+    data_structures_base.CheckpointableDataStructureBase):
+  """Base class for data structures which contain checkpointable objects."""
+
+  def __init__(self):
+    self._layers = []
+    self.trainable = True
+
+  def _track_value(self, value, name):
+    """Add a dependency on `value`."""
+    if isinstance(value, checkpointable_lib.CheckpointableBase):
+      self._track_checkpointable(value, name=name)
+    else:
+      raise ValueError(
+          ("Only checkpointable objects (such as Layers or Optimizers) may be "
+           "stored in a List object. Got %s, which does not inherit from "
+           "CheckpointableBase.") % (value,))
+    if isinstance(value, (
+        base_layer.Layer,
+        data_structures_base.CheckpointableDataStructureBase)):
+      if value not in self._layers:
+        self._layers.append(value)
+        if hasattr(value, "_use_resource_variables"):
+          # In subclassed models, legacy layers (tf.layers) must always use
+          # resource variables.
+          value._use_resource_variables = True  # pylint: disable=protected-access
+
+  @property
+  def layers(self):
+    return self._layers
+
+  @property
+  def trainable_weights(self):
+    if not self.trainable:
+      return []
+    weights = []
+    for layer in self.layers:
+      weights += layer.trainable_weights
+    return weights
+
+  @property
+  def non_trainable_weights(self):
+    weights = []
+    for layer in self.layers:
+      weights += layer.non_trainable_weights
+    if not self.trainable:
+      trainable_weights = []
+      for layer in self.layers:
+        trainable_weights += layer.trainable_weights
+      return trainable_weights + weights
+    return weights
+
+  @property
+  def weights(self):
+    return self.trainable_weights + self.non_trainable_weights
+
+  @property
+  def variables(self):
+    return self.weights
+
+  @property
+  def updates(self):
+    """Aggregate updates from any `Layer` instances."""
+    # Updates and conditional losses are forwarded as-is rather than being
+    # filtered based on inputs, since this is just a container and won't ever
+    # have any inputs.
+    aggregated = []
+    for layer in self.layers:
+      aggregated += layer.updates
+    return aggregated
+
+  @property
+  def losses(self):
+    """Aggregate losses from any `Layer` instances."""
+    aggregated = []
+    for layer in self.layers:
+      aggregated += layer.losses
+    return aggregated
+
+  def __hash__(self):
+    # Support object-identity hashing, so these structures can be used as keys
+    # in sets/dicts.
+    return id(self)
+
+  def __eq__(self, other):
+    # Similar to Tensors, checkpointable data structures use object-identity
+    # equality to support set/dict membership.
+    return self is other
+
+
+class List(CheckpointableDataStructure, collections.Sequence):
+  """An append-only sequence type which is checkpointable.
+
+  Maintains checkpoint dependencies on its contents (which must also be
+  checkpointable), and forwards any `Layer` metadata such as updates and losses.
+
+  Note that `List` is purely a container. It lets a `tf.keras.Model` or
+  other checkpointable object know about its contents, but does not call any
+  `Layer` instances which are added to it. To indicate a sequence of `Layer`
+  instances which should be called sequentially, use `tf.keras.Sequential`.
+
+  Example usage:
+  ```python
+  class HasList(tf.keras.Model):
+
+    def __init__(self):
+      super(HasList, self).__init__()
+      self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
+      self.layer_list.append(layers.Dense(4))
+
+    def call(self, x):
+      aggregation = 0.
+      for l in self.layer_list:
+        x = l(x)
+        aggregation += tf.reduce_sum(x)
+      return aggregation
+  ```
+
+  This kind of wrapping is necessary because `Checkpointable` objects do not
+  (yet) deeply inspect regular Python data structures, so for example assigning
+  a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
+  checkpoint dependency and does not add the `Layer` instance's weights to its
+  parent `Model`.
+  """
+
+  def __init__(self, *args, **kwargs):
+    """Construct a new sequence. Arguments are passed to `list()`."""
+    super(List, self).__init__()
+    self._storage = list(*args, **kwargs)
+    for index, element in enumerate(self._storage):
+      self._track_value(element, name=self._name_element(index))
+
+  def _name_element(self, index):
+    return "%d" % (index,)
+
+  def append(self, value):
+    """Add a new checkpointable value."""
+    self._track_value(value, self._name_element(len(self._storage)))
+    self._storage.append(value)
+
+  def extend(self, values):
+    """Add a sequence of checkpointable values."""
+    for index_offset, value in enumerate(values):
+      self._track_value(
+          value, name=self._name_element(len(self._storage) + index_offset))
+    self._storage.extend(values)
+
+  def __iadd__(self, values):
+    self.extend(values)
+    return self
+
+  def __add__(self, other):
+    if isinstance(other, List):
+      return List(self._storage + other._storage)  # pylint: disable=protected-access
+    else:
+      return List(self._storage + other)
+
+  def __getitem__(self, key):
+    return self._storage[key]
+
+  def __len__(self):
+    return len(self._storage)
+
+  def __repr__(self):
+    return "List(%s)" % (repr(self._storage),)
+
+  @property
+  def updates(self):
+    """Aggregate updates from any `Layer` instances."""
+    # Updates and conditional losses are forwarded as-is rather than being
+    # filtered based on inputs, since this is just a container and won't ever
+    # have any inputs.
+    aggregated = []
+    for layer in self.layers:
+      aggregated += layer.updates
+    return aggregated
+
+  @property
+  def losses(self):
+    """Aggregate losses from any `Layer` instances."""
+    aggregated = []
+    for layer in self.layers:
+      aggregated += layer.losses
+    return aggregated
diff --git a/tensorflow/python/training/checkpointable/data_structures_base.py b/tensorflow/python/training/checkpointable/data_structures_base.py
new file mode 100644 (file)
index 0000000..f1b2cf1
--- /dev/null
@@ -0,0 +1,27 @@
+"""A trivial base class to avoid circular imports for isinstance checks."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.training.checkpointable import base as checkpointable_lib
+
+
+class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase):
+  """Base class for data structures which contain checkpointable objects."""
+
+  pass
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
new file mode 100644 (file)
index 0000000..6cabbea
--- /dev/null
@@ -0,0 +1,142 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.layers import normalization
+from tensorflow.python.layers import core as non_keras_core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training.checkpointable import data_structures
+
+
+class HasList(training.Model):
+
+  def __init__(self):
+    super(HasList, self).__init__()
+    self.layer_list = data_structures.List([core.Dense(3)])
+    self.layer_list.append(core.Dense(4))
+    self.layer_list.extend(
+        [core.Dense(5),
+         core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
+    self.layer_list += [
+        core.Dense(7, bias_regularizer=math_ops.reduce_sum),
+        core.Dense(8)
+    ]
+    self.layer_list += (
+        data_structures.List([core.Dense(9)]) + data_structures.List(
+            [core.Dense(10)]))
+    self.layer_list.extend(
+        data_structures.List(
+            list(sequence=[core.Dense(11)]) + [core.Dense(12)]))
+    self.layers_with_updates = data_structures.List(
+        sequence=(normalization.BatchNormalization(),))
+
+  def call(self, x):
+    aggregation = 0.
+    for l in self.layer_list:
+      x = l(x)
+      aggregation += math_ops.reduce_sum(x)
+    bn, = self.layers_with_updates
+    return bn(x) / aggregation
+
+
+class ListTests(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testTracking(self):
+    model = HasList()
+    output = model(array_ops.ones([32, 2]))
+    self.assertAllEqual([32, 12], output.shape)
+    self.assertEqual(2, len(model.layers))
+    self.assertIs(model.layer_list, model.layers[0])
+    self.assertEqual(10, len(model.layers[0].layers))
+    for index in range(10):
+      self.assertEqual(3 + index, model.layers[0].layers[index].units)
+    self.assertEqual(2, len(model._checkpoint_dependencies))
+    self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
+    self.assertIs(model.layers_with_updates,
+                  model._checkpoint_dependencies[1].ref)
+    self.assertEqual(
+        10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies))
+    self.evaluate([v.initializer for v in model.variables])
+    self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
+    save_path = os.path.join(self.get_temp_dir(), "ckpt")
+    model.save_weights(save_path)
+    self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
+    model.load_weights(save_path)
+    self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
+                        self.evaluate(model.variables[0]))
+
+  def testUpdatesForwarded(self):
+    with context.graph_mode():
+      model = HasList()
+      model_input = array_ops.ones([32, 2])
+      model(model_input)
+      self.assertGreater(len(model.layers_with_updates[0].updates), 0)
+      self.assertEqual(set(model.layers_with_updates[0].updates),
+                       set(model.updates))
+
+    with context.eager_mode():
+      model = HasList()
+      model_input = array_ops.ones([32, 2])
+      model(model_input)
+      self.assertEqual(0, len(model.updates))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testLossesForwarded(self):
+    model = HasList()
+    model_input = array_ops.ones([32, 2])
+    model(model_input)
+    self.assertEqual(2, len(model.losses))
+
+  def testNotCheckpointable(self):
+    class NotCheckpointable(object):
+      pass
+
+    with self.assertRaises(ValueError):
+      data_structures.List([NotCheckpointable()])
+
+  def testCallNotImplemented(self):
+    with self.assertRaisesRegexp(TypeError, "not callable"):
+      data_structures.List()(1.)
+
+  def testNoPop(self):
+    with self.assertRaises(AttributeError):
+      data_structures.List().pop()
+
+  def testNesting(self):
+    with context.graph_mode():
+      inner = data_structures.List()
+      outer = data_structures.List([inner])
+      inner.append(non_keras_core.Dense(1))
+      inner[0](array_ops.ones([2, 3]))
+      self.assertEqual(2, len(outer.variables))
+      self.assertIsInstance(
+          outer.variables[0],
+          resource_variable_ops.ResourceVariable)
+
+
+if __name__ == "__main__":
+  test.main()