From a4de23973ddddfa8dc26d846dc0a902942347b11 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 22 Feb 2018 12:01:57 -0800 Subject: [PATCH] Checkpointable: Re-use the Saver's SaveableObject infrastructure to create restore ops, cache them. The basic infrastructure is usable after this CL in graph and eager, but I still need to make a bunch of objects Checkpointable and make some other usability fixes. Also sets some of the groundwork for feeding and fetching Python values during save/restore (as in, save has a feed dict now; gathering feeds and placeholders from Checkpointable objects is still to do but should be relatively straightforward). PiperOrigin-RevId: 186652696 --- .../contrib/eager/python/checkpointable_utils.py | 395 ++++++++++++++------- .../eager/python/checkpointable_utils_test.py | 143 ++++---- tensorflow/python/BUILD | 1 - tensorflow/python/ops/variables.py | 12 +- tensorflow/python/training/checkpointable.py | 178 +++++----- tensorflow/python/training/checkpointable_utils.py | 78 ++++ 6 files changed, 506 insertions(+), 301 deletions(-) create mode 100644 tensorflow/python/training/checkpointable_utils.py diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py index 0506af3..d9648ff 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -18,8 +18,11 @@ from __future__ import division from __future__ import print_function import collections +import weakref from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import session as session_lib from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,6 +34,7 @@ from tensorflow.python.ops import io_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpointable as core_checkpointable +from tensorflow.python.training import checkpointable_utils as core_checkpointable_utils from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import saver as saver_lib @@ -214,7 +218,7 @@ def _serialize_checkpointables( object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) object_name = object_names[checkpointable] for name, saveable in ( - checkpointable._gather_tensors_for_checkpoint().items()): # pylint: disable=protected-access + checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( @@ -284,58 +288,39 @@ class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return control_flow_ops.no_op() -def save(file_prefix, root_checkpointable, checkpoint_number=None, - session=None): - """Save a training checkpoint. - - Args: - file_prefix: A prefix to use for the checkpoint filenames - (/path/to/directory/and_a_prefix). Names are generated based on this - prefix and the global step, if provided. - root_checkpointable: A Checkpointable object to save. The checkpoint - includes variables created by this object and any Checkpointable objects - it depends on. - checkpoint_number: An integer variable or Tensor, used to number - checkpoints. Typically this value is saved along with other variables in - training checkpoints, which will happen automatically if it was created by - `root_checkpointable` or one of its dependencies (via - `Checkpointable._add_variable`). - session: The session to evaluate variables in. Ignored when executing - eagerly. If not provided when graph building, the default session is used. +class CheckpointLoadStatus(object): + """Checks the status of checkpoint loading and manages restore ops. - Returns: - The full path to the checkpoint. - """ - named_variables, serialized_graph = _serialize_object_graph( - root_checkpointable) - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - assert _OBJECT_GRAPH_PROTO_KEY not in named_variables - # TODO(allenl): Feed rather than embedding a constant. - named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( - tensor=constant_op.constant( - serialized_graph.SerializeToString(), dtype=dtypes.string), - name=_OBJECT_GRAPH_PROTO_KEY) - with ops.device("/device:CPU:0"): - save_path = saver_lib.Saver(var_list=named_variables).save( - sess=session, - save_path=file_prefix, - write_meta_graph=False, - global_step=checkpoint_number) - return save_path + Returned from `Saver.restore`. Since `restore` may defer the loading of values + in the checkpoint which don't yet have corresponding Python objects, + `CheckpointLoadStatus` provides a callback to verify that checkpoint loading + is complete (`assert_consumed`). + When graph building, `restore` does not run restore ops itself since their + creation may be deferred. The `run_restore_ops` method must be called once all + Python objects with values to restore have been created and added to the + dependency graph (this does not necessarily have to be the whole checkpoint; + calling `run_restore_ops` while `assert_consumed` fails is supported and will + partially restore the checkpoint). -class CheckpointLoadStatus(object): - """Checks the status of checkpoint loading.""" + See `Saver.restore` for usage examples. + """ - def __init__(self, checkpoint): + def __init__(self, checkpoint, feed_dict): self._checkpoint = checkpoint + self._feed_dict = feed_dict def assert_consumed(self): - """Asserts that all objects in the checkpoint have been created/matched.""" + """Asserts that all objects in the checkpoint have been created/matched. + + Returns: + `self` for chaining. + Raises: + AssertionError: If there are any Python objects in the dependency graph + which have not been restored from this checkpoint or a later `restore`, + or if there are any checkpointed values which have not been matched to + Python objects. + """ for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None) if checkpointable is None: @@ -348,88 +333,256 @@ class CheckpointLoadStatus(object): # restored. raise AssertionError("Unresolved slot restorations: %s" % ( self._checkpoint.slot_restorations,)) + if self._checkpoint.unused_attributes: + raise AssertionError( + ("Unused attributes in these objects (the attributes exist in the " + "checkpoint but not in the objects): %s") % ( + self._checkpoint.unused_attributes.items(),)) return self - @property - def restore_ops(self): - """Operations to restore objects in the dependency graph.""" - return self._checkpoint.restore_ops - - -def restore(save_path, root_checkpointable, session=None): - """Restore a training checkpoint. - - Restores the values of variables created with `Checkpointable._add_variable` - in `root_checkpointable` and any objects that it tracks (transitive). Either - assigns values immediately if variables to restore have been created already, - or defers restoration until the variables are created. Dependencies added to - `root_checkpointable` after this call will be matched if they have a - corresponding object in the checkpoint. + def run_restore_ops(self, session=None): + """Run operations to restore objects in the dependency graph.""" + if context.in_eager_mode(): + return # Run eagerly + if session is None: + session = ops.get_default_session() + session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) - When building a graph, restorations are added to the graph but not run. A - session is required to retrieve checkpoint metadata. - To disallow deferred loading, assert immediately that all checkpointed - variables have been matched to variable objects: +class _SessionWithFeedDictAdditions(session_lib.SessionInterface): + """Pretends to be a session, inserts extra feeds on run().""" - ```python - restore(path, root).assert_consumed() - ``` + def __init__(self, session, feed_additions): + self._wrapped_session = session + self._feed_additions = feed_additions - An exception will be raised unless every object was matched and its variables - already exist. + def run(self, fetches, feed_dict=None, **kwargs): + if feed_dict is None: + feed_dict = {} + else: + feed_dict = feed_dict.copy() + feed_dict.update(self._feed_additions) + return self._wrapped_session.run( + fetches=fetches, feed_dict=feed_dict, **kwargs) + + +class Saver(object): + """Saves and restores a `Checkpointable` object and its dependencies. + + See `Checkpointable` for details of dependency management. `Saver` wraps + `tf.train.Saver` for saving, including extra information about the graph of + dependencies between Python objects. When restoring, it uses this information + about the save-time dependency graph to more robustly match objects with their + checkpointed values. When executing eagerly, it supports restoring variables + on object creation (see `Saver.restore`). + + Values in a checkpoint are mapped to `Checkpointable` Python objects + (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the + checkpoint was written. To avoid breaking existing checkpoints when modifying + a class, dependency names (the names of attributes to which `Checkpointable` + objects are assigned) may not change. These names are local to objects, in + contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and + so allow additional program transformations. + """ - When graph building, `assert_consumed()` indicates that all of the restore ops - which will be created for this checkpoint have been created. They are - available in the `restore_ops` property of the status object: + def __init__(self, root_checkpointable): + """Configure saving. + + Args: + root_checkpointable: The root of the object graph to save/restore. This + object and all of its dependencies are saved in the checkpoint. When + restoring, objects are matched and restored starting from this root. + """ + # Allow passing in a weak reference to avoid reference cycles when + # `Checkpointable` objects save themselves. + self._root_checkpointable_ref = root_checkpointable + if context.in_graph_mode(): + self._file_prefix_placeholder = constant_op.constant("model") + else: + self._file_prefix_placeholder = None - ```python - session.run(restore(path, root).assert_consumed().restore_ops) - ``` + # Op caching for save + self._object_graph_feed_tensor = None + self._last_save_object_graph = None + self._last_save_saver = None - If the checkpoint has not been consumed completely, then the list of - `restore_ops` will grow as more objects are added to the dependency graph. + # Op caching for restore + self._object_graph_restore_tensor = None + self._last_restore_object_graph = None + self._last_restore_checkpoint = None - Args: - save_path: The path to the checkpoint, as returned by `save` or - `tf.train.latest_checkpoint`. If None (as when there is no latest - checkpoint for `tf.train.latest_checkpoint` to return), does nothing. - root_checkpointable: The root of the object graph to restore. Variables to - restore need not have been created yet, but all dependencies on other - `Checkpointable` objects should already be declared. Objects in the - dependency graph are matched to objects in the checkpointed graph, and - matching objects have their variables restored (or the checkpointed values - saved for eventual restoration when the variable is created). - session: The session to retrieve metadata with. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - Returns: - A `CheckpointLoadStatus` object, which can be used to make assertions about - the status of checkpoint restoration and fetch restore ops. - """ - if save_path is None: - return - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - object_graph_string, = io_ops.restore_v2( - prefix=save_path, - tensor_names=[_OBJECT_GRAPH_PROTO_KEY], - shape_and_slices=[""], - dtypes=[dtypes.string], - name="object_graph_proto_read") - if session is not None: - object_graph_string = session.run(object_graph_string) - else: - object_graph_string = object_graph_string.numpy() - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - object_graph_proto.ParseFromString(object_graph_string) - checkpoint = core_checkpointable._Checkpoint( # pylint: disable=protected-access - object_graph_proto=object_graph_proto, - save_path=save_path) - core_checkpointable._CheckpointPosition( # pylint: disable=protected-access - checkpoint=checkpoint, proto_id=0).restore(root_checkpointable) - load_status = CheckpointLoadStatus(checkpoint) - return load_status + @property + def _root_checkpointable(self): + if isinstance(self._root_checkpointable_ref, weakref.ref): + derefed = self._root_checkpointable_ref() + assert derefed is not None + return derefed + else: + return self._root_checkpointable_ref + + def save(self, file_prefix, checkpoint_number=None, session=None): + """Save a training checkpoint. + + The saved checkpoint includes variables created by this object and any + Checkpointable objects it depends on at the time `Saver.save()` is called. + + Args: + file_prefix: A prefix to use for the checkpoint filenames + (/path/to/directory/and_a_prefix). Names are generated based on this + prefix and the global step, if provided. + checkpoint_number: An integer variable or Tensor, used to number + checkpoints. Typically this value is saved along with other variables in + training checkpoints, which will happen automatically if it was created + by `root_checkpointable` or one of its dependencies (via + `Checkpointable._add_variable`). + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + + Returns: + The full path to the checkpoint. + """ + named_variables, graph_proto = _serialize_object_graph( + self._root_checkpointable) + in_graph_mode = context.in_graph_mode() + if in_graph_mode: + if session is None: + session = ops.get_default_session() + if self._object_graph_feed_tensor is None: + self._object_graph_feed_tensor = constant_op.constant( + "", dtype=dtypes.string) + object_graph_tensor = self._object_graph_feed_tensor + feed_additions = {object_graph_tensor: graph_proto.SerializeToString()} + else: + session = None + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) + feed_additions = None + assert _OBJECT_GRAPH_PROTO_KEY not in named_variables + named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( + tensor=object_graph_tensor, + name=_OBJECT_GRAPH_PROTO_KEY) + if not in_graph_mode or self._last_save_object_graph != graph_proto: + if self._last_save_object_graph is not None and in_graph_mode: + raise NotImplementedError( + "Using a single Saver to save a mutated object graph is not " + "currently supported when graph building. Use a different Saver " + "when the object graph changes (save ops will be duplicated), or " + "file a feature request if this limitation bothers you.") + saver = saver_lib.Saver(var_list=named_variables) + if in_graph_mode: + self._last_save_saver = saver + self._last_save_object_graph = graph_proto + else: + saver = self._last_save_saver + save_path = saver.save( + sess=_SessionWithFeedDictAdditions( + session=session, feed_additions=feed_additions), + save_path=file_prefix, + write_meta_graph=False, + global_step=checkpoint_number) + return save_path + + def restore(self, save_path, session=None): + """Restore a training checkpoint. + + Restores `root_checkpointable` and any objects that it tracks + (transitive). Either assigns values immediately if variables to restore have + been created already, or defers restoration until the variables are + created. Dependencies added to the `root_checkpointable` passed to the + constructor after this call will be matched if they have a corresponding + object in the checkpoint. + + When building a graph, restorations are added to the graph but not run. A + session is required to retrieve checkpoint metadata. + + To disallow deferred loading, assert immediately that all checkpointed + variables have been matched to variable objects: + + ```python + saver = Saver(root) + saver.restore(path).assert_consumed() + ``` + + An exception will be raised unless every object was matched and its + variables already exist. + + When graph building, `assert_consumed()` indicates that all of the restore + ops which will be created for this checkpoint have been created. They can be + run via the `run_restore_ops()` function of the status object: + + ```python + saver.restore(path).assert_consumed().run_restore_ops() + ``` + + If the checkpoint has not been consumed completely, then the list of restore + ops will grow as more objects are added to the dependency graph. + + Args: + save_path: The path to the checkpoint, as returned by `save` or + `tf.train.latest_checkpoint`. If None (as when there is no latest + checkpoint for `tf.train.latest_checkpoint` to return), does nothing. + session: The session to retrieve metadata with. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + + Returns: + A `CheckpointLoadStatus` object, which can be used to make assertions + about the status of checkpoint restoration and run restore ops. + """ + if save_path is None: + return + in_graph_mode = context.in_graph_mode() + if in_graph_mode: + if session is None: + session = ops.get_default_session() + file_prefix_tensor = self._file_prefix_placeholder + file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} + else: + session = None + file_prefix_tensor = constant_op.constant(save_path) + file_prefix_feed_dict = None + if not in_graph_mode or self._object_graph_restore_tensor is None: + object_graph_string, = io_ops.restore_v2( + prefix=file_prefix_tensor, + tensor_names=[_OBJECT_GRAPH_PROTO_KEY], + shape_and_slices=[""], + dtypes=[dtypes.string], + name="object_graph_proto_read") + if in_graph_mode: + self._object_graph_restore_tensor = object_graph_string + if in_graph_mode: + object_graph_string = session.run( + self._object_graph_restore_tensor, + feed_dict=file_prefix_feed_dict) + else: + object_graph_string = object_graph_string.numpy() + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + object_graph_proto.ParseFromString(object_graph_string) + if in_graph_mode and object_graph_proto == self._last_restore_object_graph: + checkpoint = self._last_restore_checkpoint + else: + if in_graph_mode: + dtype_map = None + else: + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + dtype_map = reader.get_variable_to_dtype_map() + checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access + object_graph_proto=object_graph_proto, + save_path=file_prefix_tensor, + dtype_map=dtype_map) + if in_graph_mode: + if self._last_restore_object_graph is not None: + raise NotImplementedError( + "Using a single Saver to restore different object graphs is not " + "currently supported when graph building. Use a different Saver " + "for each object graph (restore ops will be duplicated), or " + "file a feature request if this limitation bothers you.") + self._last_restore_checkpoint = checkpoint + self._last_restore_object_graph = object_graph_proto + core_checkpointable._CheckpointPosition( # pylint: disable=protected-access + checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) + load_status = CheckpointLoadStatus( + checkpoint, feed_dict=file_prefix_feed_dict) + return load_status diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index 21ba6ad..b7554de 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -18,7 +18,7 @@ from __future__ import print_function import functools import os -import unittest +import weakref import six @@ -141,6 +141,7 @@ class Checkpoint(checkpointable.Checkpointable): for k, v in sorted(kwargs.items(), key=lambda item: item[0]): setattr(self, k, v) self._save_counter = None + self._saver = checkpointable_utils.Saver(weakref.ref(self)) @property def save_counter(self): @@ -163,16 +164,14 @@ class Checkpoint(checkpointable.Checkpointable): if session is None: session = ops.get_default_session() session.run(assign_op) - return checkpointable_utils.save( + return self._saver.save( file_prefix=file_prefix, - root_checkpointable=self, checkpoint_number=self.save_counter, session=session) def restore(self, save_path): - return checkpointable_utils.restore( - save_path=save_path, - root_checkpointable=self) + return self._saver.restore( + save_path=save_path) class InterfaceTests(test.TestCase): @@ -399,7 +398,7 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(m_bias_slot, [-2.])) # Immediate restoration status = root_checkpointable.restore(save_path=save_path).assert_consumed() - self.evaluate(status.restore_ops) + status.run_restore_ops() self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) @@ -480,11 +479,8 @@ class CheckpointingTests(test.TestCase): # if no checkpoint is being loaded. This would make deferred # loading a bit more useful with graph execution. else: - status = checkpointable_utils.restore( - save_path=checkpoint_path, - root_checkpointable=root, - session=session).assert_consumed() - session.run(status.restore_ops) + status = root.restore(save_path=checkpoint_path).assert_consumed() + status.run_restore_ops() for _ in range(num_training_steps): session.run(train_op) root.save(file_prefix=checkpoint_prefix, @@ -555,14 +551,14 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(original.dep.var, 123.)) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_path = checkpointable_utils.save(checkpoint_prefix, original) + save_path = checkpointable_utils.Saver(original).save(checkpoint_prefix) load_into = LateDependencies() - status = checkpointable_utils.restore(save_path, load_into) + status = checkpointable_utils.Saver(load_into).restore(save_path) with self.assertRaises(AssertionError): status.assert_consumed() load_into.add_dep() status.assert_consumed() - self.evaluate(status.restore_ops) + status.run_restore_ops() self.assertEqual(123., self.evaluate(load_into.dep.var)) @test_util.run_in_graph_and_eager_modes() @@ -586,15 +582,14 @@ class CheckpointingTests(test.TestCase): self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.)) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - save_path = checkpointable_utils.save( - checkpoint_prefix, dep_after_var) + save_path = checkpointable_utils.Saver(dep_after_var).save( + checkpoint_prefix) loaded_dep_after_var = DepAfterVar() - status = checkpointable_utils.restore( - save_path, loaded_dep_after_var) + status = checkpointable_utils.Saver(loaded_dep_after_var).restore(save_path) loaded_dep_after_var.add_dep() status.assert_consumed() - self.evaluate(status.restore_ops) + status.run_restore_ops() self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -612,27 +607,25 @@ class CheckpointingTests(test.TestCase): else: optimizer.minimize(root.var.read_value) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "no_slots"), root) + no_slots_path = checkpointable_utils.Saver(root).save( + os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "with_slots"), root) + slots_path = checkpointable_utils.Saver(root).save( + os.path.join(checkpoint_directory, "with_slots")) new_root = checkpointable.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = checkpointable_utils.restore( - slots_path, new_root) - no_slot_status = checkpointable_utils.restore( - no_slots_path, new_root) + slot_status = checkpointable_utils.Saver(new_root).restore(slots_path) + no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() new_root.var = checkpointable_utils.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() - self.evaluate(no_slot_status.restore_ops) + no_slot_status.run_restore_ops() self.assertEqual(12., self.evaluate(new_root.var)) new_root.optimizer = CheckpointableAdam(0.1) with self.assertRaisesRegexp(AssertionError, "beta1_power"): @@ -650,7 +643,7 @@ class CheckpointingTests(test.TestCase): train_op = new_root.optimizer.minimize(new_root.var) # The slot variable now exists; restore() didn't create it, but we should # now have a restore op for it. - self.evaluate(slot_status.restore_ops) + slot_status.run_restore_ops() self.assertEqual(14., self.evaluate( new_root.optimizer.get_slot(name="m", var=new_root.var))) self.evaluate(train_op) @@ -666,50 +659,43 @@ class CheckpointingTests(test.TestCase): save_root.dep.var = checkpointable_utils.add_variable( save_root.dep, name="var", initializer=0.) self.evaluate(state_ops.assign(save_root.dep.var, 12.)) - first_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "first"), save_root) + saver = checkpointable_utils.Saver(save_root) + first_path = saver.save(os.path.join(checkpoint_directory, "first")) self.evaluate(state_ops.assign(save_root.dep.var, 13.)) - second_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "second"), save_root) + second_path = saver.save(os.path.join(checkpoint_directory, "second")) first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() - first_status = checkpointable_utils.restore( - first_path, first_root) - second_status = checkpointable_utils.restore( - second_path, second_root) + first_status = checkpointable_utils.Saver(first_root).restore(first_path) + second_status = checkpointable_utils.Saver(second_root).restore(second_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() - self.evaluate(first_status.restore_ops) - self.assertEqual([], second_status.restore_ops) + first_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() - self.evaluate(second_status.restore_ops) + second_status.run_restore_ops() self.assertEqual(13., self.evaluate(load_dep.var)) # Try again with the order of the restore() reversed. The last restore # determines the final value. first_root = checkpointable.Checkpointable() second_root = checkpointable.Checkpointable() - second_status = checkpointable_utils.restore( - second_path, second_root) - first_status = checkpointable_utils.restore( - first_path, first_root) + second_status = checkpointable_utils.Saver(second_root).restore(second_path) + first_status = checkpointable_utils.Saver(first_root).restore(first_path) load_dep = checkpointable.Checkpointable() load_dep.var = checkpointable_utils.add_variable( load_dep, name="var", shape=[]) first_root.dep = load_dep first_status.assert_consumed() - self.assertEqual([], second_status.restore_ops) - self.evaluate(first_status.restore_ops) + first_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) second_root.dep = load_dep second_status.assert_consumed() - self.evaluate(second_status.restore_ops) + second_status.run_restore_ops() self.assertEqual(12., self.evaluate(load_dep.var)) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -724,10 +710,10 @@ class CheckpointingTests(test.TestCase): save_root.dep_two.dep_three = dep_three checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "ckpt"), save_root) + save_path = checkpointable_utils.Saver(save_root).save( + os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() - checkpointable_utils.restore(save_path, load_root) + checkpointable_utils.Saver(load_root).restore(save_path) load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = checkpointable.Checkpointable() load_root.dep_one.dep_three = checkpointable.Checkpointable() @@ -747,8 +733,8 @@ class CheckpointingTests(test.TestCase): checkpointable_utils.add_variable( save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "ckpt"), save_root) + save_path = checkpointable_utils.Saver(save_root).save( + os.path.join(checkpoint_directory, "ckpt")) load_root = checkpointable.Checkpointable() load_root.dep_one = checkpointable.Checkpointable() load_root.dep_two = load_root.dep_one @@ -756,9 +742,9 @@ class CheckpointingTests(test.TestCase): load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) v2 = checkpointable_utils.add_variable( load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) - status = checkpointable_utils.restore( - save_path, load_root).assert_consumed() - self.evaluate(status.restore_ops) + status = checkpointable_utils.Saver(load_root).restore( + save_path).assert_consumed() + status.run_restore_ops() self.assertEqual(32., self.evaluate(v1)) self.assertEqual(64., self.evaluate(v2)) @@ -776,12 +762,12 @@ class CheckpointingTests(test.TestCase): second, "v2", initializer=[1., 1., 2., 3.]) self.evaluate(variables.global_variables_initializer()) checkpoint_directory = self.get_temp_dir() - save_path = checkpointable_utils.save( - os.path.join(checkpoint_directory, "ckpt"), first) + save_path = checkpointable_utils.Saver(first).save( + os.path.join(checkpoint_directory, "ckpt")) # Test deferred loading first_load = checkpointable.Checkpointable() - status = checkpointable_utils.restore(save_path, first_load) + status = checkpointable_utils.Saver(first_load).restore(save_path) second_load = checkpointable.Checkpointable() first_load.second = second_load second_load.first = first_load @@ -792,7 +778,7 @@ class CheckpointingTests(test.TestCase): second_load.v = checkpointable_utils.add_variable( second_load, "v2", shape=[4]) status.assert_consumed() - self.evaluate(status.restore_ops) + status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) @@ -801,9 +787,9 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) self.evaluate(second_load.v.assign([2., 7., 1., 8.])) self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) - status = checkpointable_utils.restore( - save_path, first_load).assert_consumed() - self.evaluate(status.restore_ops) + status = checkpointable_utils.Saver(first_load).restore( + save_path).assert_consumed() + status.run_restore_ops() self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) @@ -820,27 +806,24 @@ class CheckpointingTests(test.TestCase): name="blah", initializer=0.) self.evaluate(first.var1.assign(4.)) self.evaluate(first.var2.assign(8.)) - save_path = checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=first) + save_path = checkpointable_utils.Saver(first).save( + checkpoint_prefix) restore_graph = ops.Graph() with restore_graph.as_default(), self.test_session(restore_graph): second = checkpointable.Checkpointable() second.var2 = variable_scope.get_variable( name="blah", initializer=0.) - status = checkpointable_utils.restore( - save_path, root_checkpointable=second) + status = checkpointable_utils.Saver(second).restore(save_path) recreated_var1 = variable_scope.get_variable( name="outside_var", initializer=0.) - self.evaluate(status.restore_ops) + status.run_restore_ops() self.assertEqual(8., self.evaluate(second.var2)) self.evaluate(recreated_var1.assign(-2.)) self.assertEqual(-2., self.evaluate(recreated_var1)) second.var1 = recreated_var1 - self.evaluate(status.restore_ops) + status.run_restore_ops() self.assertEqual(4., self.evaluate(recreated_var1)) - # TODO(allenl): Saver class that doesn't pollute the graph with constants. - @unittest.skip("todo") def testManySavesGraph(self): """Saves after the first should not modify the graph.""" with context.graph_mode(): @@ -853,14 +836,12 @@ class CheckpointingTests(test.TestCase): obj.opt = CheckpointableAdam(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(variables.global_variables_initializer()) - checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=obj) + saver = checkpointable_utils.Saver(obj) + saver.save(checkpoint_prefix) before_ops = graph.get_operations() - checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=obj) + saver.save(checkpoint_prefix) self.assertEqual(before_ops, graph.get_operations()) - @unittest.skip("todo") def testManyRestoresGraph(self): """Restores after the first should not modify the graph.""" with context.graph_mode(): @@ -873,13 +854,11 @@ class CheckpointingTests(test.TestCase): obj.opt = CheckpointableAdam(0.1) obj.opt.minimize(obj.var.read_value()) self.evaluate(variables.global_variables_initializer()) - save_path = checkpointable_utils.save( - checkpoint_prefix, root_checkpointable=obj) - checkpointable_utils.restore( - save_path, root_checkpointable=obj) + saver = checkpointable_utils.Saver(obj) + save_path = saver.save(checkpoint_prefix) + saver.restore(save_path) before_ops = graph.get_operations() - checkpointable_utils.restore( - save_path, root_checkpointable=obj) + saver.restore(save_path) self.assertEqual(before_ops, graph.get_operations()) if __name__ == "__main__": diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 6a7ece4..4c8c735 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2860,7 +2860,6 @@ py_library( ":dtypes", ":io_ops_gen", ":ops", - ":pywrap_tensorflow", ":util", "//tensorflow/python/eager:context", ], diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index b785d0e..d382683 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -792,17 +792,7 @@ class Variable(checkpointable.CheckpointableBase): setattr(Variable, operator, _run_op) - def _scatter_tensors_from_checkpoint(self, attributes): - """For implementing `Checkpointable`. Return an assignment op to run.""" - if (len(attributes) != 1 - or checkpointable.VARIABLE_VALUE_KEY not in attributes): - raise ValueError( - ("The variable %s was restored with unexpected values (expected one " - "with key %s, got %s)") % ( - self, checkpointable.VARIABLE_VALUE_KEY, attributes)) - return self.assign(attributes[checkpointable.VARIABLE_VALUE_KEY]) - - def _gather_tensors_for_checkpoint(self): + def _gather_saveables_for_checkpoint(self): """For implementing `Checkpointable`. This object is saveable on its own.""" return {checkpointable.VARIABLE_VALUE_KEY: self} diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index 9d62c5f..11caa76 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function import collections -import weakref -from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +26,7 @@ from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.util import nest # A key indicating a variable's value in an object's checkpointed Tensors -# (Checkpointable._gather_tensors_for_checkpoint). If this is the only key and +# (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and # the object has no dependencies, then its value may be restored on object # creation (avoiding double assignment when executing eagerly). VARIABLE_VALUE_KEY = "VARIABLE_VALUE" @@ -57,7 +55,7 @@ class CheckpointInitialValue(ops.Tensor): """ def __init__(self, checkpoint_position, shape=None): - self.wrapped_value = checkpoint_position.restore_ops()[ + self.wrapped_value = checkpoint_position.value_tensors()[ VARIABLE_VALUE_KEY] if shape: # We need to set the static shape information on the initializer if @@ -168,22 +166,86 @@ class _CheckpointPosition(object): and attributes[0].name == VARIABLE_VALUE_KEY and not self.object_proto.children) - def restore_ops(self): - """Create restore ops for this object's attributes.""" - restore_tensors = {} + def value_tensors(self): + """Create value `Tensor`s for this object's attributes. + + Does not require that the Python object has been created. Used for + restore-on-create when executing eagerly. + + Returns: + A dictionary mapping from object attribute names to `Tensor`s. + """ + value_tensors = {} for serialized_tensor in self.object_proto.attributes: checkpoint_key = serialized_tensor.checkpoint_key dtype = self._checkpoint.dtype_map[checkpoint_key] base_type = dtype.base_dtype with ops.init_scope(): - restore, = io_ops.restore_v2( + value, = io_ops.restore_v2( prefix=self._checkpoint.save_path, tensor_names=[checkpoint_key], shape_and_slices=[""], dtypes=[base_type], name="%s_checkpoint_read" % (serialized_tensor.name,)) - restore_tensors[serialized_tensor.name] = restore - return restore_tensors + value_tensors[serialized_tensor.name] = value + return value_tensors + + def restore_ops(self): + """Create or fetch restore ops for this object's attributes. + + Requires that the `Checkpointable` Python object has been bound to an object + ID in the checkpoint. + + Returns: + A list of operations when graph building, or an empty list when executing + eagerly. + """ + saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access + # Name saveables based on the name this object had when it was checkpointed. + named_saveables = {} + restore_ops = [] + in_graph_mode = context.in_graph_mode() + for serialized_tensor in self.object_proto.attributes: + saveable_object = saveables.get(serialized_tensor.name, None) + if saveable_object is None: + # Purposefully does not throw an exception if attributes have been added + # or deleted. Stores unused attributes so an exception can be raised if + # the user decides to check that everything in the checkpoint was + # loaded. + self._checkpoint.unused_attributes.setdefault( + self.checkpointable, []).append(serialized_tensor.name) + continue + if in_graph_mode: + existing_ops = self._checkpoint.restore_ops_by_name.get( + serialized_tensor.name, None) + else: + existing_ops = None + if existing_ops is None: + named_saveables[serialized_tensor.checkpoint_key] = saveable_object + if named_saveables: + validated_saveables = ( + self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access + validated_names = set(saveable.name for saveable in validated_saveables) + if set(named_saveables.keys()) != validated_names: + raise AssertionError( + ("Saveable keys changed when validating. Got back %s, was " + "expecting %s") % (named_saveables.keys(), validated_names)) + all_tensors = self._checkpoint.builder.bulk_restore( + filename_tensor=self._checkpoint.save_path, + saveables=validated_saveables, preferred_shard=-1, + restore_sequentially=False) + saveable_index = 0 + for saveable in validated_saveables: + num_specs = len(saveable.specs) + saveable_tensors = all_tensors[ + saveable_index:saveable_index + num_specs] + saveable_index += num_specs + restore_op = saveable.restore(saveable_tensors, restored_shapes=None) + if in_graph_mode: + assert saveable.name not in self._checkpoint.restore_ops_by_name + self._checkpoint.restore_ops_by_name[saveable.name] = restore_op + restore_ops.append(restore_op) + return restore_ops @property def checkpoint(self): @@ -225,54 +287,6 @@ _SlotVariableRestoration = collections.namedtuple( ]) -class _Checkpoint(object): - """Holds the status of an object-based checkpoint load.""" - - def __init__(self, object_graph_proto, save_path): - """Specify the checkpoint being loaded. - - Args: - object_graph_proto: The CheckpointableObjectGraph protocol buffer - associated with this checkpoint. - save_path: The path to the checkpoint, as returned by - `tf.train.latest_checkpoint`. - """ - self.object_graph_proto = object_graph_proto - self.restore_uid = ops.uid() - # Dictionary mapping from an id in the protocol buffer flat array to - # Checkpointable Python objects. This mapping may be deferred if a - # checkpoint is restored before all dependencies have been tracked. Uses - # weak references so that partial restorations don't create reference cycles - # (as objects with deferred dependencies will generally have references to - # this object). - self.object_by_proto_id = weakref.WeakValueDictionary() - self.save_path = save_path - reader = pywrap_tensorflow.NewCheckpointReader(save_path) - self.dtype_map = reader.get_variable_to_dtype_map() - # When graph building, contains a list of ops to run to restore objects from - # this checkpoint. - self.restore_ops = [] - # A mapping from optimizer proto ids to lists of slot variables to be - # restored when the optimizer is tracked. Only includes slot variables whose - # regular variables have already been created, and only for optimizer - # objects which have not yet been created/tracked. - self.deferred_slot_restorations = {} - # A mapping from variable proto ids to lists of slot variables to be - # restored when the variable is created/tracked. These get shifted over to - # deferred_slot_restorations if the optimizer hasn't been created when that - # happens. - self.slot_restorations = {} - for node_index, node in enumerate(self.object_graph_proto.nodes): - for slot_reference in node.slot_variables: - # `node` refers to an `Optimizer`, since only these have slot variables. - self.slot_restorations.setdefault( - slot_reference.original_variable_node_id, []).append( - _SlotVariableRestoration( - optimizer_id=node_index, - slot_variable_id=slot_reference.slot_variable_node_id, - slot_name=slot_reference.slot_name)) - - class CheckpointableBase(object): """Base class for `Checkpointable` objects without automatic dependencies. @@ -415,13 +429,10 @@ class CheckpointableBase(object): Indicates that checkpoints for this object should include variables from `checkpointable`. - Variables in a checkpoint are mapped to `Checkpointable`s based on names if - provided when the checkpoint was written, but otherwise use the order those - `Checkpointable`s were declared as dependencies. - - To avoid breaking existing checkpoints when modifying a class, neither - variable names nor dependency names (the names passed to - `track_checkpointable`) may change. + Variables in a checkpoint are mapped to `Checkpointable`s based on the names + provided when the checkpoint was written. To avoid breaking existing + checkpoints when modifying a class, neither variable names nor dependency + names (the names passed to `_track_checkpointable`) may change. Args: checkpointable: A `Checkpointable` which this object depends on. @@ -493,11 +504,11 @@ class CheckpointableBase(object): # need to actually restore the object. However, we should pass the # restoration on to our dependencies. if checkpoint.restore_uid > self._update_uid: - restore_op = self._scatter_tensors_from_checkpoint( - checkpoint_position.restore_ops()) + restore_ops = checkpoint_position.restore_ops() + # TODO(allenl): Get a list of feeds for saving Python state self._update_uid = checkpoint.restore_uid else: - restore_op = () + restore_ops = () for child in checkpoint_position.object_proto.children: child_position = _CheckpointPosition( checkpoint=checkpoint, @@ -515,25 +526,21 @@ class CheckpointableBase(object): # resolution order (shallowest paths first). The caller is responsible # for emptying visit_queue. visit_queue.append(child_position) - return restore_op + return restore_ops - def _scatter_tensors_from_checkpoint(self, attributes): - """Restores this object from a checkpoint. + def _gather_saveables_for_checkpoint(self): + """Returns a dictionary of values to checkpoint with this object. - Args: - attributes: A dictionary of Tensors, with key corresponding to those - returned from _gather_tensors_for_checkpoint. - Returns: - A restore op to run (if graph building). - """ - if attributes: - raise AssertionError( - ("A Checkpointable object which was not expecting any data received " - "some from a checkpoint. (Got %s)") % (attributes,)) - return () # No restore ops + Keys in the returned dictionary are local to this object and in a separate + namespace from dependencies. Values may either be `SaveableObject`s or + variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s + `var_list` constructor argument). - def _gather_tensors_for_checkpoint(self): - """Returns a dictionary of Tensors to save with this object.""" + Returned values must be saved only by this object; if any value may be + shared, it should instead be a dependency. For example, variable objects + save their own values with the key `VARIABLE_VALUE_KEY`, but objects which + reference variables simply add a dependency. + """ return {} @@ -562,8 +569,7 @@ class Checkpointable(CheckpointableBase): `Checkpointable` objects may specify `Tensor`s to be saved and restored directly (e.g. a `Variable` indicating how to save itself) rather than through dependencies on other objects. See - `Checkpointable._scatter_tensors_from_checkpoint` and - `Checkpointable._gather_tensors_for_checkpoint` for details. + `Checkpointable._gather_saveables_for_checkpoint` for details. """ def __setattr__(self, name, value): diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py new file mode 100644 index 0000000..32123f8 --- /dev/null +++ b/tensorflow/python/training/checkpointable_utils.py @@ -0,0 +1,78 @@ +"""Utilities for saving/loading Checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import weakref + +from tensorflow.python.framework import ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import saver as saver_lib + + +class _Checkpoint(object): + """Holds the status of an object-based checkpoint load.""" + + def __init__(self, object_graph_proto, save_path, dtype_map=None): + """Specify the checkpoint being loaded. + + Args: + object_graph_proto: The CheckpointableObjectGraph protocol buffer + associated with this checkpoint. + save_path: A string `Tensor`. The path to the checkpoint, as returned by + `tf.train.latest_checkpoint`. + dtype_map: When executing eagerly, specifies dtypes for creating slot + variables. None when graph building. + """ + self.builder = saver_lib.BulkSaverBuilder() + self.object_graph_proto = object_graph_proto + self.restore_uid = ops.uid() + # Maps from objects to lists of attributes which were in the checkpoint but + # not loaded into any object, for error checking. + self.unused_attributes = weakref.WeakKeyDictionary() + # Dictionary mapping from an id in the protocol buffer flat array to + # Checkpointable Python objects. This mapping may be deferred if a + # checkpoint is restored before all dependencies have been tracked. Uses + # weak references so that partial restorations don't create reference cycles + # (as objects with deferred dependencies will generally have references to + # this object). + self.object_by_proto_id = weakref.WeakValueDictionary() + self.save_path = save_path + self.dtype_map = dtype_map + # When graph building, contains a list of ops to run to restore objects from + # this checkpoint. + self.restore_ops = [] + self.restore_ops_by_name = {} + # A mapping from optimizer proto ids to lists of slot variables to be + # restored when the optimizer is tracked. Only includes slot variables whose + # regular variables have already been created, and only for optimizer + # objects which have not yet been created/tracked. + self.deferred_slot_restorations = {} + # A mapping from variable proto ids to lists of slot variables to be + # restored when the variable is created/tracked. These get shifted over to + # deferred_slot_restorations if the optimizer hasn't been created when that + # happens. + self.slot_restorations = {} + for node_index, node in enumerate(self.object_graph_proto.nodes): + for slot_reference in node.slot_variables: + # `node` refers to an `Optimizer`, since only these have slot variables. + self.slot_restorations.setdefault( + slot_reference.original_variable_node_id, []).append( + checkpointable._SlotVariableRestoration( # pylint: disable=protected-access + optimizer_id=node_index, + slot_variable_id=slot_reference.slot_variable_node_id, + slot_name=slot_reference.slot_name)) -- 2.7.4