@@dot_graph_from_checkpoint
@@object_metadata
-Creating and managing dependencies:
+Managing dependencies:
@@Checkpointable
@@CheckpointableObjectGraph
@@NoDependency
@@split_dependency
+
+Checkpointable data structures:
+@@List
@@UniqueNameTracker
"""
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.python.training.checkpointable.base import Checkpointable
from tensorflow.python.training.checkpointable.base import NoDependency
+from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
":containers",
":split_dependency",
":visualize",
+ "//tensorflow/python/training/checkpointable:data_structures",
],
)
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
"//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
"@six_archive//:six",
],
)
deps = [
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/training/checkpointable:base",
],
)
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
],
)
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
],
)
srcs = ["visualize_test.py"],
deps = [
":visualize",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ "//tensorflow/python/training/checkpointable:util",
],
)
self.label_dimension = label_dimension
self.keep_prob = keep_prob
- self.cells = self._add_cells(
+ self.cells = tf.contrib.checkpoint.List(
[tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes])
self.relu = layers.Dense(
label_dimension, activation=tf.nn.relu, name="relu")
hidden_states = tf.gather_nd(chars, indices)
return self.relu(hidden_states)
- def _add_cells(self, cells):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for i, c in enumerate(cells):
- setattr(self, "cell-%d" % i, c)
- return cells
-
def loss(labels, predictions):
"""Computes mean squared loss."""
def __init__(self, hidden_dim, num_layers, keep_ratio):
super(RNN, self).__init__()
self.keep_ratio = keep_ratio
- self.cells = self._add_cells([
+ self.cells = tf.contrib.checkpoint.List([
tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)
for _ in range(num_layers)
])
# tuple (output, output_states).
return [input_seq]
- def _add_cells(self, cells):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for i, c in enumerate(cells):
- setattr(self, "cell-%d" % i, c)
- return cells
-
class Embedding(layers.Layer):
"""An Embedding layer."""
deps = [
":backend",
"//tensorflow/python/data",
+ "//tensorflow/python/training/checkpointable:data_structures_base",
"@six_archive//:six",
],
)
from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures_base
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
no_dependency = isinstance(value, checkpointable.NoDependency)
if no_dependency:
value = value.value
- if isinstance(value, (base_layer.Layer, Network)):
+ if isinstance(value, (
+ base_layer.Layer,
+ Network,
+ data_structures_base.CheckpointableDataStructureBase)):
try:
is_graph_network = self._is_graph_network
except AttributeError:
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:io_ops_gen",
- "//tensorflow/python:ops",
+ "//tensorflow/python:platform",
"//tensorflow/python:saveable_object",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
)
py_library(
+ name = "data_structures_base",
+ srcs = ["data_structures_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base",
+ ],
+)
+
+py_library(
+ name = "data_structures",
+ srcs = ["data_structures.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base",
+ ":data_structures_base",
+ ],
+)
+
+py_test(
+ name = "data_structures_test",
+ srcs = ["data_structures_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":data_structures",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ ],
+)
+
+py_library(
name = "util",
srcs = ["util.py"],
srcs_version = "PY2AND3",
self._unconditional_checkpoint_dependencies):
if name == old_name:
self._unconditional_checkpoint_dependencies[index] = new_reference
- else:
+ elif current_object is None:
self._unconditional_checkpoint_dependencies.append(new_reference)
-
- self._unconditional_dependency_names[name] = checkpointable
- self._handle_deferred_dependencies(name=name, checkpointable=checkpointable)
+ self._unconditional_dependency_names[name] = checkpointable
+ self._handle_deferred_dependencies(
+ name=name, checkpointable=checkpointable)
return checkpointable
def _handle_deferred_dependencies(self, name, checkpointable):
--- /dev/null
+"""Checkpointable data structures."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import data_structures_base
+
+
+# TODO(allenl): We could track regular Python data structures which get assigned
+# to Checkpointable objects. Making this work with restore-on-create would be
+# tricky; we'd need to re-create nested structures with our own wrapped objects
+# on assignment to an attribute, and track the user's original structure to make
+# sure they don't modify it except through the wrappers (since we could save the
+# user's updated structure, but would have no way to support restore-on-create
+# for those modifications).
+# TODO(allenl): A dictionary data structure would be good too.
+class CheckpointableDataStructure(
+ data_structures_base.CheckpointableDataStructureBase):
+ """Base class for data structures which contain checkpointable objects."""
+
+ def __init__(self):
+ self._layers = []
+ self.trainable = True
+
+ def _track_value(self, value, name):
+ """Add a dependency on `value`."""
+ if isinstance(value, checkpointable_lib.CheckpointableBase):
+ self._track_checkpointable(value, name=name)
+ else:
+ raise ValueError(
+ ("Only checkpointable objects (such as Layers or Optimizers) may be "
+ "stored in a List object. Got %s, which does not inherit from "
+ "CheckpointableBase.") % (value,))
+ if isinstance(value, (
+ base_layer.Layer,
+ data_structures_base.CheckpointableDataStructureBase)):
+ if value not in self._layers:
+ self._layers.append(value)
+ if hasattr(value, "_use_resource_variables"):
+ # In subclassed models, legacy layers (tf.layers) must always use
+ # resource variables.
+ value._use_resource_variables = True # pylint: disable=protected-access
+
+ @property
+ def layers(self):
+ return self._layers
+
+ @property
+ def trainable_weights(self):
+ if not self.trainable:
+ return []
+ weights = []
+ for layer in self.layers:
+ weights += layer.trainable_weights
+ return weights
+
+ @property
+ def non_trainable_weights(self):
+ weights = []
+ for layer in self.layers:
+ weights += layer.non_trainable_weights
+ if not self.trainable:
+ trainable_weights = []
+ for layer in self.layers:
+ trainable_weights += layer.trainable_weights
+ return trainable_weights + weights
+ return weights
+
+ @property
+ def weights(self):
+ return self.trainable_weights + self.non_trainable_weights
+
+ @property
+ def variables(self):
+ return self.weights
+
+ @property
+ def updates(self):
+ """Aggregate updates from any `Layer` instances."""
+ # Updates and conditional losses are forwarded as-is rather than being
+ # filtered based on inputs, since this is just a container and won't ever
+ # have any inputs.
+ aggregated = []
+ for layer in self.layers:
+ aggregated += layer.updates
+ return aggregated
+
+ @property
+ def losses(self):
+ """Aggregate losses from any `Layer` instances."""
+ aggregated = []
+ for layer in self.layers:
+ aggregated += layer.losses
+ return aggregated
+
+ def __hash__(self):
+ # Support object-identity hashing, so these structures can be used as keys
+ # in sets/dicts.
+ return id(self)
+
+ def __eq__(self, other):
+ # Similar to Tensors, checkpointable data structures use object-identity
+ # equality to support set/dict membership.
+ return self is other
+
+
+class List(CheckpointableDataStructure, collections.Sequence):
+ """An append-only sequence type which is checkpointable.
+
+ Maintains checkpoint dependencies on its contents (which must also be
+ checkpointable), and forwards any `Layer` metadata such as updates and losses.
+
+ Note that `List` is purely a container. It lets a `tf.keras.Model` or
+ other checkpointable object know about its contents, but does not call any
+ `Layer` instances which are added to it. To indicate a sequence of `Layer`
+ instances which should be called sequentially, use `tf.keras.Sequential`.
+
+ Example usage:
+ ```python
+ class HasList(tf.keras.Model):
+
+ def __init__(self):
+ super(HasList, self).__init__()
+ self.layer_list = tf.contrib.checkpoint.List([layers.Dense(3)])
+ self.layer_list.append(layers.Dense(4))
+
+ def call(self, x):
+ aggregation = 0.
+ for l in self.layer_list:
+ x = l(x)
+ aggregation += tf.reduce_sum(x)
+ return aggregation
+ ```
+
+ This kind of wrapping is necessary because `Checkpointable` objects do not
+ (yet) deeply inspect regular Python data structures, so for example assigning
+ a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a
+ checkpoint dependency and does not add the `Layer` instance's weights to its
+ parent `Model`.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """Construct a new sequence. Arguments are passed to `list()`."""
+ super(List, self).__init__()
+ self._storage = list(*args, **kwargs)
+ for index, element in enumerate(self._storage):
+ self._track_value(element, name=self._name_element(index))
+
+ def _name_element(self, index):
+ return "%d" % (index,)
+
+ def append(self, value):
+ """Add a new checkpointable value."""
+ self._track_value(value, self._name_element(len(self._storage)))
+ self._storage.append(value)
+
+ def extend(self, values):
+ """Add a sequence of checkpointable values."""
+ for index_offset, value in enumerate(values):
+ self._track_value(
+ value, name=self._name_element(len(self._storage) + index_offset))
+ self._storage.extend(values)
+
+ def __iadd__(self, values):
+ self.extend(values)
+ return self
+
+ def __add__(self, other):
+ if isinstance(other, List):
+ return List(self._storage + other._storage) # pylint: disable=protected-access
+ else:
+ return List(self._storage + other)
+
+ def __getitem__(self, key):
+ return self._storage[key]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __repr__(self):
+ return "List(%s)" % (repr(self._storage),)
+
+ @property
+ def updates(self):
+ """Aggregate updates from any `Layer` instances."""
+ # Updates and conditional losses are forwarded as-is rather than being
+ # filtered based on inputs, since this is just a container and won't ever
+ # have any inputs.
+ aggregated = []
+ for layer in self.layers:
+ aggregated += layer.updates
+ return aggregated
+
+ @property
+ def losses(self):
+ """Aggregate losses from any `Layer` instances."""
+ aggregated = []
+ for layer in self.layers:
+ aggregated += layer.losses
+ return aggregated
--- /dev/null
+"""A trivial base class to avoid circular imports for isinstance checks."""
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.training.checkpointable import base as checkpointable_lib
+
+
+class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase):
+ """Base class for data structures which contain checkpointable objects."""
+
+ pass
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.layers import normalization
+from tensorflow.python.layers import core as non_keras_core
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training.checkpointable import data_structures
+
+
+class HasList(training.Model):
+
+ def __init__(self):
+ super(HasList, self).__init__()
+ self.layer_list = data_structures.List([core.Dense(3)])
+ self.layer_list.append(core.Dense(4))
+ self.layer_list.extend(
+ [core.Dense(5),
+ core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
+ self.layer_list += [
+ core.Dense(7, bias_regularizer=math_ops.reduce_sum),
+ core.Dense(8)
+ ]
+ self.layer_list += (
+ data_structures.List([core.Dense(9)]) + data_structures.List(
+ [core.Dense(10)]))
+ self.layer_list.extend(
+ data_structures.List(
+ list(sequence=[core.Dense(11)]) + [core.Dense(12)]))
+ self.layers_with_updates = data_structures.List(
+ sequence=(normalization.BatchNormalization(),))
+
+ def call(self, x):
+ aggregation = 0.
+ for l in self.layer_list:
+ x = l(x)
+ aggregation += math_ops.reduce_sum(x)
+ bn, = self.layers_with_updates
+ return bn(x) / aggregation
+
+
+class ListTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTracking(self):
+ model = HasList()
+ output = model(array_ops.ones([32, 2]))
+ self.assertAllEqual([32, 12], output.shape)
+ self.assertEqual(2, len(model.layers))
+ self.assertIs(model.layer_list, model.layers[0])
+ self.assertEqual(10, len(model.layers[0].layers))
+ for index in range(10):
+ self.assertEqual(3 + index, model.layers[0].layers[index].units)
+ self.assertEqual(2, len(model._checkpoint_dependencies))
+ self.assertIs(model.layer_list, model._checkpoint_dependencies[0].ref)
+ self.assertIs(model.layers_with_updates,
+ model._checkpoint_dependencies[1].ref)
+ self.assertEqual(
+ 10, len(model._checkpoint_dependencies[0].ref._checkpoint_dependencies))
+ self.evaluate([v.initializer for v in model.variables])
+ self.evaluate(model.variables[0].assign([[1., 2., 3.], [4., 5., 6.]]))
+ save_path = os.path.join(self.get_temp_dir(), "ckpt")
+ model.save_weights(save_path)
+ self.evaluate(model.variables[0].assign(array_ops.zeros([2, 3])))
+ model.load_weights(save_path)
+ self.assertAllEqual([[1., 2., 3.], [4., 5., 6.]],
+ self.evaluate(model.variables[0]))
+
+ def testUpdatesForwarded(self):
+ with context.graph_mode():
+ model = HasList()
+ model_input = array_ops.ones([32, 2])
+ model(model_input)
+ self.assertGreater(len(model.layers_with_updates[0].updates), 0)
+ self.assertEqual(set(model.layers_with_updates[0].updates),
+ set(model.updates))
+
+ with context.eager_mode():
+ model = HasList()
+ model_input = array_ops.ones([32, 2])
+ model(model_input)
+ self.assertEqual(0, len(model.updates))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testLossesForwarded(self):
+ model = HasList()
+ model_input = array_ops.ones([32, 2])
+ model(model_input)
+ self.assertEqual(2, len(model.losses))
+
+ def testNotCheckpointable(self):
+ class NotCheckpointable(object):
+ pass
+
+ with self.assertRaises(ValueError):
+ data_structures.List([NotCheckpointable()])
+
+ def testCallNotImplemented(self):
+ with self.assertRaisesRegexp(TypeError, "not callable"):
+ data_structures.List()(1.)
+
+ def testNoPop(self):
+ with self.assertRaises(AttributeError):
+ data_structures.List().pop()
+
+ def testNesting(self):
+ with context.graph_mode():
+ inner = data_structures.List()
+ outer = data_structures.List([inner])
+ inner.append(non_keras_core.Dense(1))
+ inner[0](array_ops.ones([2, 3]))
+ self.assertEqual(2, len(outer.variables))
+ self.assertIsInstance(
+ outer.variables[0],
+ resource_variable_ops.ResourceVariable)
+
+
+if __name__ == "__main__":
+ test.main()