From bd946a5bd7b59be8bb276fdd93e0a97653dedbfd Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 23 Feb 2018 15:51:23 -0800 Subject: [PATCH] Checkpointable: Utility to gather initialization ops A bit safer, since only variables which will be saved get initialized. Graph building then raises an error when you've used one which won't be saved. Reduces the need for the global collection. Makes it a bit easier to deal with initialization when writing graph/eager agnostic programs. PiperOrigin-RevId: 186835744 --- .../contrib/eager/python/checkpointable_utils.py | 128 ++++++++++++++++++++- .../eager/python/checkpointable_utils_test.py | 86 ++++++++++---- tensorflow/python/framework/test_util.py | 1 + 3 files changed, 186 insertions(+), 29 deletions(-) diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py index d9648ff..e26ecc7 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import weakref @@ -278,6 +279,37 @@ def _serialize_object_graph(root_checkpointable): slot_variables=slot_variables) +def gather_initializers(root_checkpointable): + """Traverse the object graph and find initialization ops. + + Looks for `Checkpointable` objects which are dependencies of + `root_checkpointable` and which have an `initializer` property. Includes + initializers for slot variables only if the variable they are slotting for and + the optimizer are dependencies of `root_checkpointable` (i.e. if they would be + saved with a checkpoint). + + Args: + root_checkpointable: A `Checkpointable` object to gather initializers for. + Returns: + A list of initialization ops. + """ + # TODO(allenl): Extract out gathering logic so the naming logic doesn't have + # to run. + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_names = { + obj: _object_prefix_from_path(path) + for obj, path in path_to_root.items()} + node_ids = {node: node_id for node_id, node + in enumerate(checkpointable_objects)} + _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return [c.initializer for c in checkpointable_objects + if hasattr(c, "initializer") and c.initializer is not None] + + class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): def __init__(self, tensor, name): @@ -288,7 +320,26 @@ class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return control_flow_ops.no_op() -class CheckpointLoadStatus(object): +class _LoadStatus(object): + """Abstract base for load status callbacks.""" + + @abc.abstractmethod + def assert_consumed(self): + """Raises an exception unless a non-trivial restoration has completed.""" + pass + + @abc.abstractmethod + def run_restore_ops(self, session=None): + """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" + pass + + @abc.abstractmethod + def initialize_or_restore(self, session=None): + """Runs restore ops from the checkpoint, or initializes variables.""" + pass + + +class CheckpointLoadStatus(_LoadStatus): """Checks the status of checkpoint loading and manages restore ops. Returned from `Saver.restore`. Since `restore` may defer the loading of values @@ -348,6 +399,70 @@ class CheckpointLoadStatus(object): session = ops.get_default_session() session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) + def initialize_or_restore(self, session=None): + """Alias for `run_restore_ops`. + + This method has a sibling in `InitializationOnlyStatus` which instead + initializes variables. That type is returned if no checkpoint is specified + in `Saver.restore`. + + Args: + session: The session to run restore ops in. If `None`, uses the default + session. + """ + self.run_restore_ops(session=session) + + +class InitializationOnlyStatus(_LoadStatus): + """Returned from `Saver.restore` when no checkpoint has been specified. + + Objects of this type have the same `assert_consumed` method as + `CheckpointLoadStatus`, but it always fails. However, + `initialize_or_restore` works on objects of both types, and will + initialize variables in `InitializationOnlyStatus` objects or restore them + otherwise. + """ + + def __init__(self, root_checkpointable): + self._root_checkpointable = root_checkpointable + + def assert_consumed(self): + """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" + raise AssertionError( + "No checkpoint specified (save_path=None); nothing is being restored.") + + def run_restore_ops(self, session=None): + """For consistency with `CheckpointLoadStatus`. + + Use `initialize_or_restore` for initializing if no checkpoint was passed + to `Saver.restore` and restoring otherwise. + + Args: + session: Not used. + """ + raise AssertionError( + "No checkpoint specified, so no restore ops are available " + "(save_path=None to Saver.restore).") + + def initialize_or_restore(self, session=None): + """Runs initialization ops for variables. + + Only objects which would be saved by `Saver.save` will be initialized. See + `gather_initializers` for details. + + This method does nothing when executing eagerly (initializers get run + eagerly). + + Args: + session: The session to run initialization ops in. If `None`, uses the + default session. + """ + if context.in_eager_mode(): + return # run eagerly + if session is None: + session = ops.get_default_session() + session.run(gather_initializers(self._root_checkpointable)) + class _SessionWithFeedDictAdditions(session_lib.SessionInterface): """Pretends to be a session, inserts extra feeds on run().""" @@ -521,17 +636,20 @@ class Saver(object): Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest - checkpoint for `tf.train.latest_checkpoint` to return), does nothing. + checkpoint for `tf.train.latest_checkpoint` to return), returns an + object which may run initializers for objects in the dependency graph. session: The session to retrieve metadata with. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: - A `CheckpointLoadStatus` object, which can be used to make assertions - about the status of checkpoint restoration and run restore ops. + A load status object, which can be used to make assertions about the + status of checkpoint restoration and run initialization/restore ops + (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if + `save_path` is `None`). """ if save_path is None: - return + return InitializationOnlyStatus(self._root_checkpointable) in_graph_mode = context.in_graph_mode() if in_graph_mode: if session is None: diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index b7554de..6b86d41 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables from tensorflow.python.training import adam from tensorflow.python.training import checkpointable from tensorflow.python.training import saver as core_saver @@ -140,7 +139,7 @@ class Checkpoint(checkpointable.Checkpointable): super(Checkpoint, self).__init__() for k, v in sorted(kwargs.items(), key=lambda item: item[0]): setattr(self, k, v) - self._save_counter = None + self._save_counter = None # Created lazily for restore-on-create. self._saver = checkpointable_utils.Saver(weakref.ref(self)) @property @@ -170,8 +169,12 @@ class Checkpoint(checkpointable.Checkpointable): session=session) def restore(self, save_path): - return self._saver.restore( - save_path=save_path) + status = self._saver.restore(save_path=save_path) + # Create the save counter now so it gets initialized with other variables + # when graph building. Creating it earlier would lead to double + # initialization when executing eagerly. + self.save_counter # pylint: disable=pointless-statement + return status class InterfaceTests(test.TestCase): @@ -206,8 +209,7 @@ class InterfaceTests(test.TestCase): with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) - if context.in_graph_mode(): - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) self.assertEqual("constant_initializer:0", constant_initializer.name) self.assertEqual(1, self.evaluate(constant_initializer)) self.assertEqual("some_variable_scope/ones_initializer:0", @@ -287,7 +289,8 @@ class CheckpointingTests(test.TestCase): optimizer.minimize( other_network(input_value), global_step=optimizer_step) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph = ( checkpointable_utils._serialize_object_graph(root_checkpointable)) @@ -385,7 +388,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize(network(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) @@ -429,6 +433,7 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + # TODO(allenl): Debug garbage created by this test in python3. def testDeferredRestorationUsageEager(self): """An idiomatic eager execution example.""" num_training_steps = 10 @@ -468,28 +473,57 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( network(input_value), global_step=root.global_step) - root.save_counter # pylint: disable=pointless-statement - init_op = variables.global_variables_initializer() checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: + status = root.restore(save_path=checkpoint_path) + status.initialize_or_restore(session=session) if checkpoint_path is None: self.assertEqual(0, training_continuation) - session.run(init_op) - # Another alternative would be to run initializers automatically - # if no checkpoint is being loaded. This would make deferred - # loading a bit more useful with graph execution. + with self.assertRaises(AssertionError): + status.assert_consumed() else: - status = root.restore(save_path=checkpoint_path).assert_consumed() - status.run_restore_ops() + status.assert_consumed() for _ in range(num_training_steps): session.run(train_op) - root.save(file_prefix=checkpoint_prefix, - session=session) + root.save(file_prefix=checkpoint_prefix, session=session) self.assertEqual((training_continuation + 1) * num_training_steps, session.run(root.global_step)) self.assertEqual(training_continuation + 1, session.run(root.save_counter)) + @test_util.run_in_graph_and_eager_modes() + def testAgnosticUsage(self): + """Graph/eager agnostic usage.""" + # Does create garbage when executing eagerly due to ops.Graph() creation. + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(), self.test_session( + graph=ops.get_default_graph()): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Checkpoint( + optimizer=optimizer, network=network, + global_step=training_util.get_or_create_global_step()) + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + status = root.restore(save_path=checkpoint_path) + input_value = constant_op.constant([[3.]]) + train_fn = functools.partial( + optimizer.minimize, + functools.partial(network, input_value), + global_step=root.global_step) + if context.in_graph_mode(): + train_fn = functools.partial(self.evaluate, train_fn()) + status.initialize_or_restore() + for _ in range(num_training_steps): + train_fn() + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + self.evaluate(root.global_step)) + self.assertEqual(training_continuation + 1, + self.evaluate(root.save_counter)) + def _get_checkpoint_name(self, name): root = checkpointable.Checkpointable() checkpointable_utils.add_variable( @@ -602,7 +636,11 @@ class CheckpointingTests(test.TestCase): optimizer = CheckpointableAdam(0.1) if context.in_graph_mode(): train_op = optimizer.minimize(root.var) - self.evaluate(variables.global_variables_initializer()) + # Note that `optimizer` has not been added as a dependency of + # `root`. Create a one-off grouping so that slot variables for `root.var` + # get initialized too. + self.evaluate(checkpointable_utils.gather_initializers( + Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) else: optimizer.minimize(root.var.read_value) @@ -709,7 +747,7 @@ class CheckpointingTests(test.TestCase): save_root.dep_one.dep_three = dep_three save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.Saver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() @@ -732,7 +770,7 @@ class CheckpointingTests(test.TestCase): save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) checkpointable_utils.add_variable( save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(save_root)) save_path = checkpointable_utils.Saver(save_root).save( os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() @@ -760,7 +798,7 @@ class CheckpointingTests(test.TestCase): first, "v1", initializer=[3., 1., 4.]) second.v = checkpointable_utils.add_variable( second, "v2", initializer=[1., 1., 2., 3.]) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(first)) checkpoint_directory = self.get_temp_dir() save_path = checkpointable_utils.Saver(first).save( os.path.join(checkpoint_directory, "ckpt")) @@ -835,7 +873,7 @@ class CheckpointingTests(test.TestCase): obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = CheckpointableAdam(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.Saver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() @@ -853,7 +891,7 @@ class CheckpointingTests(test.TestCase): obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = CheckpointableAdam(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(variables.global_variables_initializer()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) saver = checkpointable_utils.Saver(obj) save_path = saver.save(checkpoint_prefix) saver.restore(save_path) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index e1c37a5..aabf89a 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -588,6 +588,7 @@ def run_in_graph_and_eager_modes(__unused__=None, # This decorator runs the wrapped test twice. # Reset the test environment between runs. self.tearDown() + self._tempdir = None self.setUp() def run_eager_mode(self, **kwargs): -- 2.7.4