From 7b78417a00e6805557d530c1f1fcc8b2a44d6e2e Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 23 May 2018 10:43:28 -0700 Subject: [PATCH] Add a checkpointable list data structure Allows tracking of Layers and other checkpointable objects by number. Fixes #19250. PiperOrigin-RevId: 197749961 --- tensorflow/contrib/checkpoint/__init__.py | 6 +- tensorflow/contrib/checkpoint/python/BUILD | 16 +- .../python/examples/rnn_colorbot/rnn_colorbot.py | 10 +- .../eager/python/examples/rnn_ptb/rnn_ptb.py | 10 +- tensorflow/python/keras/BUILD | 1 + tensorflow/python/keras/engine/network.py | 6 +- tensorflow/python/training/checkpointable/BUILD | 39 +++- tensorflow/python/training/checkpointable/base.py | 8 +- .../training/checkpointable/data_structures.py | 218 +++++++++++++++++++++ .../checkpointable/data_structures_base.py | 27 +++ .../checkpointable/data_structures_test.py | 142 ++++++++++++++ 11 files changed, 454 insertions(+), 29 deletions(-) create mode 100644 tensorflow/python/training/checkpointable/data_structures.py create mode 100644 tensorflow/python/training/checkpointable/data_structures_base.py create mode 100644 tensorflow/python/training/checkpointable/data_structures_test.py diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index af8df72..bd0bc9e 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -18,11 +18,14 @@ Visualization and inspection: @@dot_graph_from_checkpoint @@object_metadata -Creating and managing dependencies: +Managing dependencies: @@Checkpointable @@CheckpointableObjectGraph @@NoDependency @@split_dependency + +Checkpointable data structures: +@@List @@UniqueNameTracker """ @@ -36,6 +39,7 @@ from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkp from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph from tensorflow.python.training.checkpointable.base import Checkpointable from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 53f4e97..0b67619 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -11,6 +11,7 @@ py_library( ":containers", ":split_dependency", ":visualize", + "//tensorflow/python/training/checkpointable:data_structures", ], ) @@ -30,8 +31,8 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", "@six_archive//:six", ], ) @@ -44,6 +45,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", ], ) @@ -55,8 +57,9 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -67,6 +70,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -75,10 +80,13 @@ py_test( srcs = ["visualize_test.py"], deps = [ ":visualize", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 492adbe..5ee2176 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -152,7 +152,7 @@ class RNNColorbot(tf.keras.Model): self.label_dimension = label_dimension self.keep_prob = keep_prob - self.cells = self._add_cells( + self.cells = tf.contrib.checkpoint.List( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") @@ -204,14 +204,6 @@ class RNNColorbot(tf.keras.Model): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of layers.Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - def loss(labels, predictions): """Computes mean squared loss.""" diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 74701b2..c2340a2 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -50,7 +50,7 @@ class RNN(tf.keras.Model): def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - self.cells = self._add_cells([ + self.cells = tf.contrib.checkpoint.List([ tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) for _ in range(num_layers) ]) @@ -74,14 +74,6 @@ class RNN(tf.keras.Model): # tuple (output, output_states). return [input_seq] - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - class Embedding(layers.Layer): """An Embedding layer.""" diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 5d73069..fe40c9f 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -135,6 +135,7 @@ py_library( deps = [ ":backend", "//tensorflow/python/data", + "//tensorflow/python/training/checkpointable:data_structures_base", "@six_archive//:six", ], ) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 4a0e16f..6e818ec 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -41,6 +41,7 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures_base from tensorflow.python.training.checkpointable import util as checkpointable_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -321,7 +322,10 @@ class Network(base_layer.Layer): no_dependency = isinstance(value, checkpointable.NoDependency) if no_dependency: value = value.value - if isinstance(value, (base_layer.Layer, Network)): + if isinstance(value, ( + base_layer.Layer, + Network, + data_structures_base.CheckpointableDataStructureBase)): try: is_graph_network = self._is_graph_network except AttributeError: diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index a7ae6e5..87ba4dc 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -22,8 +22,9 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops_gen", - "//tensorflow/python:ops", + "//tensorflow/python:platform", "//tensorflow/python:saveable_object", "//tensorflow/python:util", "//tensorflow/python/eager:context", @@ -41,6 +42,42 @@ py_test( ) py_library( + name = "data_structures_base", + srcs = ["data_structures_base.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ], +) + +py_library( + name = "data_structures", + srcs = ["data_structures.py"], + srcs_version = "PY2AND3", + deps = [ + ":base", + ":data_structures_base", + ], +) + +py_test( + name = "data_structures_test", + srcs = ["data_structures_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data_structures", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + ], +) + +py_library( name = "util", srcs = ["util.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index e378f0e..cfe7259 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -591,11 +591,11 @@ class CheckpointableBase(object): self._unconditional_checkpoint_dependencies): if name == old_name: self._unconditional_checkpoint_dependencies[index] = new_reference - else: + elif current_object is None: self._unconditional_checkpoint_dependencies.append(new_reference) - - self._unconditional_dependency_names[name] = checkpointable - self._handle_deferred_dependencies(name=name, checkpointable=checkpointable) + self._unconditional_dependency_names[name] = checkpointable + self._handle_deferred_dependencies( + name=name, checkpointable=checkpointable) return checkpointable def _handle_deferred_dependencies(self, name, checkpointable): diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py new file mode 100644 index 0000000..b514f7b --- /dev/null +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -0,0 +1,218 @@ +"""Checkpointable data structures.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures_base + + +# TODO(allenl): We could track regular Python data structures which get assigned +# to Checkpointable objects. Making this work with restore-on-create would be +# tricky; we'd need to re-create nested structures with our own wrapped objects +# on assignment to an attribute, and track the user's original structure to make +# sure they don't modify it except through the wrappers (since we could save the +# user's updated structure, but would have no way to support restore-on-create +# for those modifications). +# TODO(allenl): A dictionary data structure would be good too. +class CheckpointableDataStructure( + data_structures_base.CheckpointableDataStructureBase): + """Base class for data structures which contain checkpointable objects.""" + + def __init__(self): + self._layers = [] + self.trainable = True + + def _track_value(self, value, name): + """Add a dependency on `value`.""" + if isinstance(value, checkpointable_lib.CheckpointableBase): + self._track_checkpointable(value, name=name) + else: + raise ValueError( + ("Only checkpointable objects (such as Layers or Optimizers) may be " + "stored in a List object. Got %s, which does not inherit from " + "CheckpointableBase.") % (value,)) + if isinstance(value, ( + base_layer.Layer, + data_structures_base.CheckpointableDataStructureBase)): + if value not in self._layers: + self._layers.append(value) + if hasattr(value, "_use_resource_variables"): + # In subclassed models, legacy layers (tf.layers) must always use + # resource variables. + value._use_resource_variables = True # pylint: disable=protected-access + + @property + def layers(self): + return self._layers + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for layer in self.layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self.layers: + weights += layer.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for layer in self.layers: + trainable_weights += layer.trainable_weights + return trainable_weights + weights + return weights + + @property + def weights(self): + return self.trainable_weights + self.non_trainable_weights + + @property + def variables(self): + return self.weights + + @property + def updates(self): + """Aggregate updates from any `Layer` instances.""" + # Updates and conditional losses are forwarded as-is rather than being + # filtered based on inputs, since this is just a container and won't ever + # have any inputs. + aggregated = [] + for layer in self.layers: + aggregated += layer.updates + return aggregated + + @property + def losses(self): + """Aggregate losses from any `Layer` instances.""" + aggregated = [] + for layer in self.layers: + aggregated += layer.losses + return aggregated + + def __hash__(self): + # Support object-identity hashing, so these structures can be used as keys + # in sets/dicts. + return id(self) + + def __eq__(self, other): + # Similar to Tensors, checkpointable data structures use object-identity + # equality to support set/dict membership. + return self is other + + +class List(CheckpointableDataStructure, collections.Sequence): + """An append-only sequence type which is checkpointable. + + Maintains checkpoint dependencies on its contents (which must also be + checkpointable), and forwards any `Layer` metadata such as updates and losses. + + Note that `List` is purely a container. It lets a `tf.keras.Model` or + other checkpointable object know about its contents, but does not call any + `Layer` instances which are added to it. To indicate a sequence of `Layer` + instances which should be called sequentially, use `tf.keras.Sequential`. + + Example usage: + ```python + class HasList(tf.keras.Model): + + def __init__(self): + super(HasList, self).__init__() + self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)]) + self.layer_list.append(layers.Dense(4)) + + def call(self, x): + aggregation = 0. + for l in self.layer_list: + x = l(x) + aggregation += tf.reduce_sum(x) + return aggregation + ``` + + This kind of wrapping is necessary because `Checkpointable` objects do not + (yet) deeply inspect regular Python data structures, so for example assigning + a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a + checkpoint dependency and does not add the `Layer` instance's weights to its + parent `Model`. + """ + + def __init__(self, *args, **kwargs): + """Construct a new sequence. Arguments are passed to `list()`.""" + super(List, self).__init__() + self._storage = list(*args, **kwargs) + for index, element in enumerate(self._storage): + self._track_value(element, name=self._name_element(index)) + + def _name_element(self, index): + return "%d" % (index,) + + def append(self, value): + """Add a new checkpointable value.""" + self._track_value(value, self._name_element(len(self._storage))) + self._storage.append(value) + + def extend(self, values): + """Add a sequence of checkpointable values.""" + for index_offset, value in enumerate(values): + self._track_value( + value, name=self._name_element(len(self._storage) + index_offset)) + self._storage.extend(values) + + def __iadd__(self, values): + self.extend(values) + return self + + def __add__(self, other): + if isinstance(other, List): + return List(self._storage + other._storage) # pylint: disable=protected-access + else: + return List(self._storage + other) + + def __getitem__(self, key): + return self._storage[key] + + def __len__(self): + return len(self._storage) + + def __repr__(self): + return "List(%s)" % (repr(self._storage),) + + @property + def updates(self): + """Aggregate updates from any `Layer` instances.""" + # Updates and conditional losses are forwarded as-is rather than being + # filtered based on inputs, since this is just a container and won't ever + # have any inputs. + aggregated = [] + for layer in self.layers: + aggregated += layer.updates + return aggregated + + @property + def losses(self): + """Aggregate losses from any `Layer` instances.""" + aggregated = [] + for layer in self.layers: + aggregated += layer.losses + return aggregated diff --git a/tensorflow/python/training/checkpointable/data_structures_base.py b/tensorflow/python/training/checkpointable/data_structures_base.py new file mode 100644 index 0000000..f1b2cf1 --- /dev/null +++ b/tensorflow/python/training/checkpointable/data_structures_base.py @@ -0,0 +1,27 @@ +"""A trivial base class to avoid circular imports for isinstance checks.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.training.checkpointable import base as checkpointable_lib + + +class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase): + """Base class for data structures which contain checkpointable objects.""" + + pass diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py new file mode 100644 index 0000000..6cabbea --- /dev/null +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -0,0 +1,142 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import normalization +from tensorflow.python.layers import core as non_keras_core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training.checkpointable import data_structures + + +class HasList(training.Model): + + def __init__(self): + super(HasList, self).__init__() + self.layer_list = data_structures.List([core.Dense(3)]) + self.layer_list.append(core.Dense(4)) + self.layer_list.extend( + [core.Dense(5), + core.Dense(6, kernel_regularizer=math_ops.reduce_sum)]) + self.layer_list += [ + core.Dense(7, bias_regularizer=math_ops.reduce_sum), + core.Dense(8) + ] + self.layer_list += ( + data_structures.List([core.Dense(9)]) + data_structures.List( + [core.Dense(10)])) + self.layer_list.extend( + data_structures.List( + list(sequence=[core.Dense(11)]) + [core.Dense(12)])) + self.layers_with_updates = data_structures.List( + sequence=(normalization.BatchNormalization(),)) + + def call(self, x): + aggregation = 0. + for l in self.layer_list: + x = l(x) + aggregation += math_ops.reduce_sum(x) + bn, = self.layers_with_updates + return bn(x) / aggregation + + +class ListTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testTracking(self): + model = HasList() + output = model(array_ops.ones([32, 2])) + self.assertAllEqual([32, 12], output.shape) + self.assertEqual(2, len(model.layers)) + self.assertIs(model.layer_list, model.layers[0]) + self.assertEqual(10, len(model.layers[0].layers)) + for index in range(10): + self.assertEqual(3 + index, model.layers[0].layers[index].units) + self.assertEqual(2, len(model._checkpoint_dependencies)) + self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref) + self.assertIs(model.layers_with_updates, + model._checkpoint_dependencies[1].ref) + self.assertEqual( + 10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies)) + self.evaluate([v.initializer for v in model.variables]) + self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]])) + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3]))) + model.load_weights(save_path) + self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]], + self.evaluate(model.variables[0])) + + def testUpdatesForwarded(self): + with context.graph_mode(): + model = HasList() + model_input = array_ops.ones([32, 2]) + model(model_input) + self.assertGreater(len(model.layers_with_updates[0].updates), 0) + self.assertEqual(set(model.layers_with_updates[0].updates), + set(model.updates)) + + with context.eager_mode(): + model = HasList() + model_input = array_ops.ones([32, 2]) + model(model_input) + self.assertEqual(0, len(model.updates)) + + @test_util.run_in_graph_and_eager_modes() + def testLossesForwarded(self): + model = HasList() + model_input = array_ops.ones([32, 2]) + model(model_input) + self.assertEqual(2, len(model.losses)) + + def testNotCheckpointable(self): + class NotCheckpointable(object): + pass + + with self.assertRaises(ValueError): + data_structures.List([NotCheckpointable()]) + + def testCallNotImplemented(self): + with self.assertRaisesRegexp(TypeError, "not callable"): + data_structures.List()(1.) + + def testNoPop(self): + with self.assertRaises(AttributeError): + data_structures.List().pop() + + def testNesting(self): + with context.graph_mode(): + inner = data_structures.List() + outer = data_structures.List([inner]) + inner.append(non_keras_core.Dense(1)) + inner[0](array_ops.ones([2, 3])) + self.assertEqual(2, len(outer.variables)) + self.assertIsInstance( + outer.variables[0], + resource_variable_ops.ResourceVariable) + + +if __name__ == "__main__": + test.main() -- 2.7.4