From dd5ef1b9fc22b37e5eec87d659a3af064ca54b8b Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 4 May 2018 18:25:18 -0700 Subject: [PATCH] Checkpointable: A small utility for exempting objects from __setattr__ tracking Exposes it as tf.contrib.checkpoint.NoDependency. Objects wrapped in a NoDependency object get unwrapped in __setattr__ and not tracked. Removes the _save_counter dependency from tf.train.Checkpoint (the save counter is still tracked as "save_counter" and always has been, so this is a backwards-compatible dependency removal). PiperOrigin-RevId: 195502562 --- tensorflow/contrib/checkpoint/__init__.py | 2 ++ .../python/keras/_impl/keras/engine/network.py | 6 ++++- .../keras/_impl/keras/model_subclassing_test.py | 18 +++++++++++++ tensorflow/python/training/checkpointable.py | 30 +++++++++++++++++++++- tensorflow/python/training/checkpointable_test.py | 10 ++++++++ tensorflow/python/training/checkpointable_utils.py | 7 +++-- 6 files changed, 69 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index d2c30f1..e529b25 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -19,6 +19,7 @@ For creating and managing dependencies: @@CheckpointableObjectGraph @@dot_graph_from_checkpoint @@object_metadata +@@NoDependency @@split_dependency """ @@ -29,6 +30,7 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable import NoDependency from tensorflow.python.training.checkpointable_utils import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py index a0229be..9e75096 100644 --- a/tensorflow/python/keras/_impl/keras/engine/network.py +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -318,6 +318,9 @@ class Network(base_layer.Layer): layer, name='layer-%d' % layer_index, overwrite=True) def __setattr__(self, name, value): + no_dependency = isinstance(value, checkpointable.NoDependency) + if no_dependency: + value = value.value if isinstance(value, (base_layer.Layer, Network)): try: is_graph_network = self._is_graph_network @@ -332,7 +335,8 @@ class Network(base_layer.Layer): # In subclassed models, legacy layers (tf.layers) must always use # resource variables. value._use_resource_variables = True - if isinstance(value, checkpointable.CheckpointableBase): + if (not no_dependency + and isinstance(value, checkpointable.CheckpointableBase)): # Layer (and therefore Network/Model) inherit from CheckpointableBase # rather than Checkpointable, which means there is no Checkpointable # __setattr__ override (it would be a performance issue for functional diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py index 295ad47..3f850e5 100644 --- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -28,7 +28,9 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.training import checkpointable from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -583,6 +585,22 @@ class ModelSubclassingTest(test.TestCase): loss = model.train_on_batch(x, y) self.assertGreater(loss, 0.1) + def test_no_dependency(self): + class Foo(keras.Model): + + def __init__(self): + super(Foo, self).__init__() + self.isdep = keras.layers.Dense(1) + self.notdep = checkpointable.NoDependency(keras.layers.Dense(2)) + self.notdep_var = checkpointable.NoDependency( + resource_variable_ops.ResourceVariable(1., name='notdep_var')) + + m = Foo() + self.assertEqual([m.isdep, m.notdep], m.layers) + self.assertEqual(1, len(m._checkpoint_dependencies)) + self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref) + self.assertEqual('notdep_var:0', m.notdep_var.name) + class CustomCallModel(keras.Model): diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index 05afd37..d00312a 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -659,6 +659,31 @@ class CheckpointableBase(object): return {} +class NoDependency(object): + """Allows attribute assignment to `Checkpointable` objects with no dependency. + + Example usage: + ```python + obj = Checkpointable() + obj.has_dependency = tf.Variable(0., name="dep") + obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) + assert obj.no_dependency.name == "nodep:0" + ``` + + `obj` in this example has a dependency on the variable "dep", and both + attributes contain un-wrapped `Variable` objects. + + `NoDependency` also works with `tf.keras.Model`, but only for checkpoint + dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) + `Layer` to the attribute without a checkpoint dependency, but the `Model` will + still track the `Layer` (so it will appear in `Model.layers`, and its + variables will appear in `Model.variables`). + """ + + def __init__(self, value): + self.value = value + + class Checkpointable(CheckpointableBase): """Manages dependencies on other objects. @@ -691,8 +716,11 @@ class Checkpointable(CheckpointableBase): """Support self.foo = checkpointable syntax.""" # Perform the attribute assignment, and potentially call other __setattr__ # overrides such as that for tf.keras.Model. + no_dependency = isinstance(value, NoDependency) + if no_dependency: + value = value.value super(Checkpointable, self).__setattr__(name, value) - if isinstance(value, CheckpointableBase): + if not no_dependency and isinstance(value, CheckpointableBase): self._track_checkpointable( value, name=name, # Allow the user to switch the Checkpointable which is tracked by this diff --git a/tensorflow/python/training/checkpointable_test.py b/tensorflow/python/training/checkpointable_test.py index e79acb4..85802cb 100644 --- a/tensorflow/python/training/checkpointable_test.py +++ b/tensorflow/python/training/checkpointable_test.py @@ -34,6 +34,16 @@ class InterfaceTests(test.TestCase): root.leaf = duplicate_name_dep root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + def testNoDependency(self): + root = checkpointable.Checkpointable() + hasdep = checkpointable.Checkpointable() + root.hasdep = hasdep + nodep = checkpointable.Checkpointable() + root.nodep = checkpointable.NoDependency(nodep) + self.assertEqual(1, len(root._checkpoint_dependencies)) + self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep) + self.assertIs(root.hasdep, hasdep) + self.assertIs(root.nodep, nodep) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py index cf4112f..f2a2b41 100644 --- a/tensorflow/python/training/checkpointable_utils.py +++ b/tensorflow/python/training/checkpointable_utils.py @@ -1044,8 +1044,11 @@ class Checkpoint(checkpointable_lib.Checkpointable): if self._save_counter is None: # Initialized to 0 and incremented before saving. with ops.device("/cpu:0"): - self._save_counter = add_variable( - self, name="save_counter", initializer=0, dtype=dtypes.int64) + # add_variable creates a dependency named "save_counter"; NoDependency + # prevents creating a second dependency named "_save_counter". + self._save_counter = checkpointable_lib.NoDependency( + add_variable(self, name="save_counter", initializer=0, + dtype=dtypes.int64)) @property def save_counter(self): -- 2.7.4