From 53cd5c01407451cf918c1d1c1f5ca640b7d5dbc8 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 24 May 2018 10:30:41 -0700 Subject: [PATCH] Add a checkpointable map data structure PiperOrigin-RevId: 197913890 --- tensorflow/contrib/checkpoint/__init__.py | 3 + tensorflow/contrib/checkpoint/python/BUILD | 5 +- .../training/checkpointable/data_structures.py | 67 ++++++++++++++----- .../checkpointable/data_structures_test.py | 77 ++++++++++++++++++++++ 4 files changed, 134 insertions(+), 18 deletions(-) diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index bd0bc9e..8ae493b 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -26,6 +26,7 @@ Managing dependencies: Checkpointable data structures: @@List +@@Mapping @@UniqueNameTracker """ @@ -40,8 +41,10 @@ from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import Checkpointa from tensorflow.python.training.checkpointable.base import Checkpointable from tensorflow.python.training.checkpointable.base import NoDependency from tensorflow.python.training.checkpointable.data_structures import List +from tensorflow.python.training.checkpointable.data_structures import Mapping from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) + diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 0b67619..7b200a2 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -20,7 +20,10 @@ py_library( srcs = ["containers.py"], srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], - deps = ["//tensorflow/python/training/checkpointable:base"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:data_structures", + ], ) py_test( diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index b514f7b..62cefa4 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -19,6 +19,8 @@ from __future__ import print_function import collections +import six + from tensorflow.python.keras.engine import base_layer from tensorflow.python.training.checkpointable import base as checkpointable_lib from tensorflow.python.training.checkpointable import data_structures_base @@ -198,21 +200,52 @@ class List(CheckpointableDataStructure, collections.Sequence): def __repr__(self): return "List(%s)" % (repr(self._storage),) - @property - def updates(self): - """Aggregate updates from any `Layer` instances.""" - # Updates and conditional losses are forwarded as-is rather than being - # filtered based on inputs, since this is just a container and won't ever - # have any inputs. - aggregated = [] - for layer in self.layers: - aggregated += layer.updates - return aggregated - @property - def losses(self): - """Aggregate losses from any `Layer` instances.""" - aggregated = [] - for layer in self.layers: - aggregated += layer.losses - return aggregated +class Mapping(CheckpointableDataStructure, collections.Mapping): + """An append-only checkpointable mapping data structure with string keys. + + Maintains checkpoint dependencies on its contents (which must also be + checkpointable), named based on its keys. + + Note that once a key has been added, it may not be deleted or replaced. If + names may not be unique, see `tf.contrib.checkpoint.UniqueNameTracker`. + """ + + def __init__(self, *args, **kwargs): + """Construct a new sequence. Arguments are passed to `dict()`.""" + super(Mapping, self).__init__() + self._storage = dict(*args, **kwargs) + for key, value in self._storage.items(): + self._track_value(value, name=self._name_element(key)) + + def _name_element(self, key): + if not isinstance(key, six.string_types): + raise TypeError( + "Mapping accepts only string keys, but got a key %s." + % repr(key)) + return str(key) + + def __setitem__(self, key, value): + current_value = self._storage.setdefault(key, value) + if current_value is not value: + raise ValueError( + ("Mappings are an append-only data structure. Tried to overwrite the " + "key '%s' with value %s, but it already contains %s") + % (key, value, current_value)) + self._track_value(value, name=self._name_element(key)) + + def update(self, *args, **kwargs): + for key, value in dict(*args, **kwargs).items(): + self[key] = value + + def __getitem__(self, key): + return self._storage[key] + + def __len__(self): + return len(self._storage) + + def __repr__(self): + return "Mapping(%s)" % (repr(self._storage),) + + def __iter__(self): + return iter(self._storage) diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index 6cabbea..31a0e8b 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -18,6 +18,8 @@ from __future__ import print_function import os +import numpy + from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -137,6 +139,81 @@ class ListTests(test.TestCase): outer.variables[0], resource_variable_ops.ResourceVariable) + def testHashing(self): + has_sequences = set([data_structures.List(), + data_structures.List()]) + self.assertEqual(2, len(has_sequences)) + self.assertNotIn(data_structures.List(), has_sequences) + + +class HasMapping(training.Model): + + def __init__(self): + super(HasMapping, self).__init__() + self.layer_dict = data_structures.Mapping(output=core.Dense(7)) + self.layer_dict["norm"] = data_structures.List() + self.layer_dict["dense"] = data_structures.List() + self.layer_dict["dense"].extend( + [core.Dense(5), + core.Dense(6, kernel_regularizer=math_ops.reduce_sum)]) + self.layer_dict["norm"].append( + normalization.BatchNormalization()) + self.layer_dict["norm"].append( + normalization.BatchNormalization()) + + def call(self, x): + aggregation = 0. + for norm, dense in zip(self.layer_dict["norm"], self.layer_dict["dense"]): + x = norm(dense(x)) + aggregation += math_ops.reduce_sum(x) + return self.layer_dict["output"](x) / aggregation + + +class MappingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testTracking(self): + model = HasMapping() + output = model(array_ops.ones([32, 2])) + self.assertAllEqual([32, 7], output.shape) + self.assertEqual(1, len(model.layers)) + self.assertIs(model.layer_dict, model.layers[0]) + self.assertEqual(3, len(model.layers[0].layers)) + self.assertEqual(1, len(model._checkpoint_dependencies)) + self.assertIs(model.layer_dict, model._checkpoint_dependencies[0].ref) + self.evaluate([v.initializer for v in model.variables]) + test_var = model.layer_dict["output"].kernel + self.evaluate(test_var.assign(array_ops.ones([6, 7]))) + save_path = os.path.join(self.get_temp_dir(), "ckpt") + model.save_weights(save_path) + self.evaluate(test_var.assign(array_ops.zeros([6, 7]))) + model.load_weights(save_path) + self.assertAllEqual(numpy.ones([6, 7]), + self.evaluate(test_var)) + + def testNoOverwrite(self): + mapping = data_structures.Mapping() + original = data_structures.List() + mapping["a"] = original + with self.assertRaises(ValueError): + mapping["a"] = data_structures.List() + self.assertIs(original, mapping["a"]) + with self.assertRaises(AttributeError): + del mapping["a"] + mapping.update(b=data_structures.Mapping()) + with self.assertRaises(ValueError): + mapping.update({"b": data_structures.Mapping()}) + + def testNonStringKeys(self): + mapping = data_structures.Mapping() + with self.assertRaises(TypeError): + mapping[1] = data_structures.List() + + def testHashing(self): + has_mappings = set([data_structures.Mapping(), + data_structures.Mapping()]) + self.assertEqual(2, len(has_mappings)) + self.assertNotIn(data_structures.Mapping(), has_mappings) if __name__ == "__main__": test.main() -- 2.7.4