Checkpointable: Re-use the Saver's SaveableObject infrastructure to create
authorAllen Lavoie <allenl@google.com>
Thu, 22 Feb 2018 20:01:57 +0000 (12:01 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 20:05:52 +0000 (12:05 -0800)
restore ops, cache them.

The basic infrastructure is usable after this CL in graph and eager, but I still need to make a bunch of objects Checkpointable and make some other usability fixes.

Also sets some of the groundwork for feeding and fetching Python values during save/restore (as in, save has a feed dict now; gathering feeds and placeholders from Checkpointable objects is still to do but should be relatively straightforward).

PiperOrigin-RevId: 186652696

tensorflow/contrib/eager/python/checkpointable_utils.py
tensorflow/contrib/eager/python/checkpointable_utils_test.py
tensorflow/python/BUILD
tensorflow/python/ops/variables.py
tensorflow/python/training/checkpointable.py
tensorflow/python/training/checkpointable_utils.py [new file with mode: 0644]

index 0506af3..d9648ff 100644 (file)
@@ -18,8 +18,11 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
+import weakref
 
 from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.client import session as session_lib
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -31,6 +34,7 @@ from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.training import checkpointable as core_checkpointable
+from tensorflow.python.training import checkpointable_utils as core_checkpointable_utils
 from tensorflow.python.training import optimizer as optimizer_lib
 from tensorflow.python.training import saver as saver_lib
 
@@ -214,7 +218,7 @@ def _serialize_checkpointables(
     object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
     object_name = object_names[checkpointable]
     for name, saveable in (
-        checkpointable._gather_tensors_for_checkpoint().items()):  # pylint: disable=protected-access
+        checkpointable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
       attribute = object_proto.attributes.add()
       attribute.name = name
       attribute.checkpoint_key = "%s/%s/%s" % (
@@ -284,58 +288,39 @@ class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
     return control_flow_ops.no_op()
 
 
-def save(file_prefix, root_checkpointable, checkpoint_number=None,
-         session=None):
-  """Save a training checkpoint.
-
-  Args:
-    file_prefix: A prefix to use for the checkpoint filenames
-      (/path/to/directory/and_a_prefix). Names are generated based on this
-      prefix and the global step, if provided.
-    root_checkpointable: A Checkpointable object to save. The checkpoint
-      includes variables created by this object and any Checkpointable objects
-      it depends on.
-    checkpoint_number: An integer variable or Tensor, used to number
-      checkpoints. Typically this value is saved along with other variables in
-      training checkpoints, which will happen automatically if it was created by
-      `root_checkpointable` or one of its dependencies (via
-      `Checkpointable._add_variable`).
-    session: The session to evaluate variables in. Ignored when executing
-      eagerly. If not provided when graph building, the default session is used.
+class CheckpointLoadStatus(object):
+  """Checks the status of checkpoint loading and manages restore ops.
 
-  Returns:
-    The full path to the checkpoint.
-  """
-  named_variables, serialized_graph = _serialize_object_graph(
-      root_checkpointable)
-  if context.in_graph_mode():
-    if session is None:
-      session = ops.get_default_session()
-  else:
-    session = None
-  assert _OBJECT_GRAPH_PROTO_KEY not in named_variables
-  # TODO(allenl): Feed rather than embedding a constant.
-  named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
-      tensor=constant_op.constant(
-          serialized_graph.SerializeToString(), dtype=dtypes.string),
-      name=_OBJECT_GRAPH_PROTO_KEY)
-  with ops.device("/device:CPU:0"):
-    save_path = saver_lib.Saver(var_list=named_variables).save(
-        sess=session,
-        save_path=file_prefix,
-        write_meta_graph=False,
-        global_step=checkpoint_number)
-  return save_path
+  Returned from `Saver.restore`. Since `restore` may defer the loading of values
+  in the checkpoint which don't yet have corresponding Python objects,
+  `CheckpointLoadStatus` provides a callback to verify that checkpoint loading
+  is complete (`assert_consumed`).
 
+  When graph building, `restore` does not run restore ops itself since their
+  creation may be deferred. The `run_restore_ops` method must be called once all
+  Python objects with values to restore have been created and added to the
+  dependency graph (this does not necessarily have to be the whole checkpoint;
+  calling `run_restore_ops` while `assert_consumed` fails is supported and will
+  partially restore the checkpoint).
 
-class CheckpointLoadStatus(object):
-  """Checks the status of checkpoint loading."""
+  See `Saver.restore` for usage examples.
+  """
 
-  def __init__(self, checkpoint):
+  def __init__(self, checkpoint, feed_dict):
     self._checkpoint = checkpoint
+    self._feed_dict = feed_dict
 
   def assert_consumed(self):
-    """Asserts that all objects in the checkpoint have been created/matched."""
+    """Asserts that all objects in the checkpoint have been created/matched.
+
+    Returns:
+      `self` for chaining.
+    Raises:
+      AssertionError: If there are any Python objects in the dependency graph
+        which have not been restored from this checkpoint or a later `restore`,
+        or if there are any checkpointed values which have not been matched to
+        Python objects.
+    """
     for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
       checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None)
       if checkpointable is None:
@@ -348,88 +333,256 @@ class CheckpointLoadStatus(object):
       # restored.
       raise AssertionError("Unresolved slot restorations: %s" % (
           self._checkpoint.slot_restorations,))
+    if self._checkpoint.unused_attributes:
+      raise AssertionError(
+          ("Unused attributes in these objects (the attributes exist in the "
+           "checkpoint but not in the objects): %s") % (
+               self._checkpoint.unused_attributes.items(),))
     return self
 
-  @property
-  def restore_ops(self):
-    """Operations to restore objects in the dependency graph."""
-    return self._checkpoint.restore_ops
-
-
-def restore(save_path, root_checkpointable, session=None):
-  """Restore a training checkpoint.
-
-  Restores the values of variables created with `Checkpointable._add_variable`
-  in `root_checkpointable` and any objects that it tracks (transitive). Either
-  assigns values immediately if variables to restore have been created already,
-  or defers restoration until the variables are created. Dependencies added to
-  `root_checkpointable` after this call will be matched if they have a
-  corresponding object in the checkpoint.
+  def run_restore_ops(self, session=None):
+    """Run operations to restore objects in the dependency graph."""
+    if context.in_eager_mode():
+      return  # Run eagerly
+    if session is None:
+      session = ops.get_default_session()
+    session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
 
-  When building a graph, restorations are added to the graph but not run. A
-  session is required to retrieve checkpoint metadata.
 
-  To disallow deferred loading, assert immediately that all checkpointed
-  variables have been matched to variable objects:
+class _SessionWithFeedDictAdditions(session_lib.SessionInterface):
+  """Pretends to be a session, inserts extra feeds on run()."""
 
-  ```python
-  restore(path, root).assert_consumed()
-  ```
+  def __init__(self, session, feed_additions):
+    self._wrapped_session = session
+    self._feed_additions = feed_additions
 
-  An exception will be raised unless every object was matched and its variables
-  already exist.
+  def run(self, fetches, feed_dict=None, **kwargs):
+    if feed_dict is None:
+      feed_dict = {}
+    else:
+      feed_dict = feed_dict.copy()
+    feed_dict.update(self._feed_additions)
+    return self._wrapped_session.run(
+        fetches=fetches, feed_dict=feed_dict, **kwargs)
+
+
+class Saver(object):
+  """Saves and restores a `Checkpointable` object and its dependencies.
+
+  See `Checkpointable` for details of dependency management. `Saver` wraps
+  `tf.train.Saver` for saving, including extra information about the graph of
+  dependencies between Python objects. When restoring, it uses this information
+  about the save-time dependency graph to more robustly match objects with their
+  checkpointed values. When executing eagerly, it supports restoring variables
+  on object creation (see `Saver.restore`).
+
+  Values in a checkpoint are mapped to `Checkpointable` Python objects
+  (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the
+  checkpoint was written. To avoid breaking existing checkpoints when modifying
+  a class, dependency names (the names of attributes to which `Checkpointable`
+  objects are assigned) may not change. These names are local to objects, in
+  contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and
+  so allow additional program transformations.
+  """
 
-  When graph building, `assert_consumed()` indicates that all of the restore ops
-  which will be created for this checkpoint have been created. They are
-  available in the `restore_ops` property of the status object:
+  def __init__(self, root_checkpointable):
+    """Configure saving.
+
+    Args:
+      root_checkpointable: The root of the object graph to save/restore. This
+        object and all of its dependencies are saved in the checkpoint. When
+        restoring, objects are matched and restored starting from this root.
+    """
+    # Allow passing in a weak reference to avoid reference cycles when
+    # `Checkpointable` objects save themselves.
+    self._root_checkpointable_ref = root_checkpointable
+    if context.in_graph_mode():
+      self._file_prefix_placeholder = constant_op.constant("model")
+    else:
+      self._file_prefix_placeholder = None
 
-  ```python
-  session.run(restore(path, root).assert_consumed().restore_ops)
-  ```
+    # Op caching for save
+    self._object_graph_feed_tensor = None
+    self._last_save_object_graph = None
+    self._last_save_saver = None
 
-  If the checkpoint has not been consumed completely, then the list of
-  `restore_ops` will grow as more objects are added to the dependency graph.
+    # Op caching for restore
+    self._object_graph_restore_tensor = None
+    self._last_restore_object_graph = None
+    self._last_restore_checkpoint = None
 
-  Args:
-    save_path: The path to the checkpoint, as returned by `save` or
-      `tf.train.latest_checkpoint`. If None (as when there is no latest
-      checkpoint for `tf.train.latest_checkpoint` to return), does nothing.
-    root_checkpointable: The root of the object graph to restore. Variables to
-      restore need not have been created yet, but all dependencies on other
-      `Checkpointable` objects should already be declared. Objects in the
-      dependency graph are matched to objects in the checkpointed graph, and
-      matching objects have their variables restored (or the checkpointed values
-      saved for eventual restoration when the variable is created).
-    session: The session to retrieve metadata with. Ignored when executing
-      eagerly. If not provided when graph building, the default session is used.
-  Returns:
-    A `CheckpointLoadStatus` object, which can be used to make assertions about
-    the status of checkpoint restoration and fetch restore ops.
-  """
-  if save_path is None:
-    return
-  if context.in_graph_mode():
-    if session is None:
-      session = ops.get_default_session()
-  else:
-    session = None
-  object_graph_string, = io_ops.restore_v2(
-      prefix=save_path,
-      tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
-      shape_and_slices=[""],
-      dtypes=[dtypes.string],
-      name="object_graph_proto_read")
-  if session is not None:
-    object_graph_string = session.run(object_graph_string)
-  else:
-    object_graph_string = object_graph_string.numpy()
-  object_graph_proto = (
-      checkpointable_object_graph_pb2.CheckpointableObjectGraph())
-  object_graph_proto.ParseFromString(object_graph_string)
-  checkpoint = core_checkpointable._Checkpoint(  # pylint: disable=protected-access
-      object_graph_proto=object_graph_proto,
-      save_path=save_path)
-  core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
-      checkpoint=checkpoint, proto_id=0).restore(root_checkpointable)
-  load_status = CheckpointLoadStatus(checkpoint)
-  return load_status
+  @property
+  def _root_checkpointable(self):
+    if isinstance(self._root_checkpointable_ref, weakref.ref):
+      derefed = self._root_checkpointable_ref()
+      assert derefed is not None
+      return derefed
+    else:
+      return self._root_checkpointable_ref
+
+  def save(self, file_prefix, checkpoint_number=None, session=None):
+    """Save a training checkpoint.
+
+    The saved checkpoint includes variables created by this object and any
+    Checkpointable objects it depends on at the time `Saver.save()` is called.
+
+    Args:
+      file_prefix: A prefix to use for the checkpoint filenames
+        (/path/to/directory/and_a_prefix). Names are generated based on this
+        prefix and the global step, if provided.
+      checkpoint_number: An integer variable or Tensor, used to number
+        checkpoints. Typically this value is saved along with other variables in
+        training checkpoints, which will happen automatically if it was created
+        by `root_checkpointable` or one of its dependencies (via
+        `Checkpointable._add_variable`).
+      session: The session to evaluate variables in. Ignored when executing
+        eagerly. If not provided when graph building, the default session is
+        used.
+
+    Returns:
+      The full path to the checkpoint.
+    """
+    named_variables, graph_proto = _serialize_object_graph(
+        self._root_checkpointable)
+    in_graph_mode = context.in_graph_mode()
+    if in_graph_mode:
+      if session is None:
+        session = ops.get_default_session()
+      if self._object_graph_feed_tensor is None:
+        self._object_graph_feed_tensor = constant_op.constant(
+            "", dtype=dtypes.string)
+      object_graph_tensor = self._object_graph_feed_tensor
+      feed_additions = {object_graph_tensor: graph_proto.SerializeToString()}
+    else:
+      session = None
+      object_graph_tensor = constant_op.constant(
+          graph_proto.SerializeToString(), dtype=dtypes.string)
+      feed_additions = None
+    assert _OBJECT_GRAPH_PROTO_KEY not in named_variables
+    named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
+        tensor=object_graph_tensor,
+        name=_OBJECT_GRAPH_PROTO_KEY)
+    if not in_graph_mode or self._last_save_object_graph != graph_proto:
+      if self._last_save_object_graph is not None and in_graph_mode:
+        raise NotImplementedError(
+            "Using a single Saver to save a mutated object graph is not "
+            "currently supported when graph building. Use a different Saver "
+            "when the object graph changes (save ops will be duplicated), or "
+            "file a feature request if this limitation bothers you.")
+      saver = saver_lib.Saver(var_list=named_variables)
+      if in_graph_mode:
+        self._last_save_saver = saver
+        self._last_save_object_graph = graph_proto
+    else:
+      saver = self._last_save_saver
+    save_path = saver.save(
+        sess=_SessionWithFeedDictAdditions(
+            session=session, feed_additions=feed_additions),
+        save_path=file_prefix,
+        write_meta_graph=False,
+        global_step=checkpoint_number)
+    return save_path
+
+  def restore(self, save_path, session=None):
+    """Restore a training checkpoint.
+
+    Restores `root_checkpointable` and any objects that it tracks
+    (transitive). Either assigns values immediately if variables to restore have
+    been created already, or defers restoration until the variables are
+    created. Dependencies added to the `root_checkpointable` passed to the
+    constructor after this call will be matched if they have a corresponding
+    object in the checkpoint.
+
+    When building a graph, restorations are added to the graph but not run. A
+    session is required to retrieve checkpoint metadata.
+
+    To disallow deferred loading, assert immediately that all checkpointed
+    variables have been matched to variable objects:
+
+    ```python
+    saver = Saver(root)
+    saver.restore(path).assert_consumed()
+    ```
+
+    An exception will be raised unless every object was matched and its
+    variables already exist.
+
+    When graph building, `assert_consumed()` indicates that all of the restore
+    ops which will be created for this checkpoint have been created. They can be
+    run via the `run_restore_ops()` function of the status object:
+
+    ```python
+    saver.restore(path).assert_consumed().run_restore_ops()
+    ```
+
+    If the checkpoint has not been consumed completely, then the list of restore
+    ops will grow as more objects are added to the dependency graph.
+
+    Args:
+      save_path: The path to the checkpoint, as returned by `save` or
+        `tf.train.latest_checkpoint`. If None (as when there is no latest
+        checkpoint for `tf.train.latest_checkpoint` to return), does nothing.
+      session: The session to retrieve metadata with. Ignored when executing
+        eagerly. If not provided when graph building, the default session is
+        used.
+
+    Returns:
+      A `CheckpointLoadStatus` object, which can be used to make assertions
+      about the status of checkpoint restoration and run restore ops.
+    """
+    if save_path is None:
+      return
+    in_graph_mode = context.in_graph_mode()
+    if in_graph_mode:
+      if session is None:
+        session = ops.get_default_session()
+      file_prefix_tensor = self._file_prefix_placeholder
+      file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
+    else:
+      session = None
+      file_prefix_tensor = constant_op.constant(save_path)
+      file_prefix_feed_dict = None
+    if not in_graph_mode or self._object_graph_restore_tensor is None:
+      object_graph_string, = io_ops.restore_v2(
+          prefix=file_prefix_tensor,
+          tensor_names=[_OBJECT_GRAPH_PROTO_KEY],
+          shape_and_slices=[""],
+          dtypes=[dtypes.string],
+          name="object_graph_proto_read")
+      if in_graph_mode:
+        self._object_graph_restore_tensor = object_graph_string
+    if in_graph_mode:
+      object_graph_string = session.run(
+          self._object_graph_restore_tensor,
+          feed_dict=file_prefix_feed_dict)
+    else:
+      object_graph_string = object_graph_string.numpy()
+    object_graph_proto = (
+        checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+    object_graph_proto.ParseFromString(object_graph_string)
+    if in_graph_mode and object_graph_proto == self._last_restore_object_graph:
+      checkpoint = self._last_restore_checkpoint
+    else:
+      if in_graph_mode:
+        dtype_map = None
+      else:
+        reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+        dtype_map = reader.get_variable_to_dtype_map()
+      checkpoint = core_checkpointable_utils._Checkpoint(  # pylint: disable=protected-access
+          object_graph_proto=object_graph_proto,
+          save_path=file_prefix_tensor,
+          dtype_map=dtype_map)
+      if in_graph_mode:
+        if self._last_restore_object_graph is not None:
+          raise NotImplementedError(
+              "Using a single Saver to restore different object graphs is not "
+              "currently supported when graph building. Use a different Saver "
+              "for each object graph (restore ops will be duplicated), or "
+              "file a feature request if this limitation bothers you.")
+        self._last_restore_checkpoint = checkpoint
+        self._last_restore_object_graph = object_graph_proto
+    core_checkpointable._CheckpointPosition(  # pylint: disable=protected-access
+        checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
+    load_status = CheckpointLoadStatus(
+        checkpoint, feed_dict=file_prefix_feed_dict)
+    return load_status
index 21ba6ad..b7554de 100644 (file)
@@ -18,7 +18,7 @@ from __future__ import print_function
 
 import functools
 import os
-import unittest
+import weakref
 
 import six
 
@@ -141,6 +141,7 @@ class Checkpoint(checkpointable.Checkpointable):
     for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
       setattr(self, k, v)
     self._save_counter = None
+    self._saver = checkpointable_utils.Saver(weakref.ref(self))
 
   @property
   def save_counter(self):
@@ -163,16 +164,14 @@ class Checkpoint(checkpointable.Checkpointable):
       if session is None:
         session = ops.get_default_session()
       session.run(assign_op)
-    return checkpointable_utils.save(
+    return self._saver.save(
         file_prefix=file_prefix,
-        root_checkpointable=self,
         checkpoint_number=self.save_counter,
         session=session)
 
   def restore(self, save_path):
-    return checkpointable_utils.restore(
-        save_path=save_path,
-        root_checkpointable=self)
+    return self._saver.restore(
+        save_path=save_path)
 
 
 class InterfaceTests(test.TestCase):
@@ -399,7 +398,7 @@ class CheckpointingTests(test.TestCase):
     self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
     # Immediate restoration
     status = root_checkpointable.restore(save_path=save_path).assert_consumed()
-    self.evaluate(status.restore_ops)
+    status.run_restore_ops()
     self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1]))
     self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
     self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
@@ -480,11 +479,8 @@ class CheckpointingTests(test.TestCase):
               # if no checkpoint is being loaded. This would make deferred
               # loading a bit more useful with graph execution.
             else:
-              status = checkpointable_utils.restore(
-                  save_path=checkpoint_path,
-                  root_checkpointable=root,
-                  session=session).assert_consumed()
-              session.run(status.restore_ops)
+              status = root.restore(save_path=checkpoint_path).assert_consumed()
+              status.run_restore_ops()
             for _ in range(num_training_steps):
               session.run(train_op)
             root.save(file_prefix=checkpoint_prefix,
@@ -555,14 +551,14 @@ class CheckpointingTests(test.TestCase):
     self.evaluate(state_ops.assign(original.dep.var, 123.))
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-    save_path = checkpointable_utils.save(checkpoint_prefix, original)
+    save_path = checkpointable_utils.Saver(original).save(checkpoint_prefix)
     load_into = LateDependencies()
-    status = checkpointable_utils.restore(save_path, load_into)
+    status = checkpointable_utils.Saver(load_into).restore(save_path)
     with self.assertRaises(AssertionError):
       status.assert_consumed()
     load_into.add_dep()
     status.assert_consumed()
-    self.evaluate(status.restore_ops)
+    status.run_restore_ops()
     self.assertEqual(123., self.evaluate(load_into.dep.var))
 
   @test_util.run_in_graph_and_eager_modes()
@@ -586,15 +582,14 @@ class CheckpointingTests(test.TestCase):
     self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.))
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
-    save_path = checkpointable_utils.save(
-        checkpoint_prefix, dep_after_var)
+    save_path = checkpointable_utils.Saver(dep_after_var).save(
+        checkpoint_prefix)
 
     loaded_dep_after_var = DepAfterVar()
-    status = checkpointable_utils.restore(
-        save_path, loaded_dep_after_var)
+    status = checkpointable_utils.Saver(loaded_dep_after_var).restore(save_path)
     loaded_dep_after_var.add_dep()
     status.assert_consumed()
-    self.evaluate(status.restore_ops)
+    status.run_restore_ops()
     self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var))
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
@@ -612,27 +607,25 @@ class CheckpointingTests(test.TestCase):
     else:
       optimizer.minimize(root.var.read_value)
     self.evaluate(state_ops.assign(root.var, 12.))
-    no_slots_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "no_slots"), root)
+    no_slots_path = checkpointable_utils.Saver(root).save(
+        os.path.join(checkpoint_directory, "no_slots"))
     root.optimizer = optimizer
     self.evaluate(state_ops.assign(root.var, 13.))
     self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
                                    14.))
-    slots_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "with_slots"), root)
+    slots_path = checkpointable_utils.Saver(root).save(
+        os.path.join(checkpoint_directory, "with_slots"))
     new_root = checkpointable.Checkpointable()
     # Load the slot-containing checkpoint (deferred), then immediately overwrite
     # the non-slot variable (also deferred).
-    slot_status = checkpointable_utils.restore(
-        slots_path, new_root)
-    no_slot_status = checkpointable_utils.restore(
-        no_slots_path, new_root)
+    slot_status = checkpointable_utils.Saver(new_root).restore(slots_path)
+    no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path)
     with self.assertRaises(AssertionError):
       no_slot_status.assert_consumed()
     new_root.var = checkpointable_utils.add_variable(
         new_root, name="var", shape=[])
     no_slot_status.assert_consumed()
-    self.evaluate(no_slot_status.restore_ops)
+    no_slot_status.run_restore_ops()
     self.assertEqual(12., self.evaluate(new_root.var))
     new_root.optimizer = CheckpointableAdam(0.1)
     with self.assertRaisesRegexp(AssertionError, "beta1_power"):
@@ -650,7 +643,7 @@ class CheckpointingTests(test.TestCase):
       train_op = new_root.optimizer.minimize(new_root.var)
       # The slot variable now exists; restore() didn't create it, but we should
       # now have a restore op for it.
-      self.evaluate(slot_status.restore_ops)
+      slot_status.run_restore_ops()
       self.assertEqual(14., self.evaluate(
           new_root.optimizer.get_slot(name="m", var=new_root.var)))
       self.evaluate(train_op)
@@ -666,50 +659,43 @@ class CheckpointingTests(test.TestCase):
     save_root.dep.var = checkpointable_utils.add_variable(
         save_root.dep, name="var", initializer=0.)
     self.evaluate(state_ops.assign(save_root.dep.var, 12.))
-    first_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "first"), save_root)
+    saver = checkpointable_utils.Saver(save_root)
+    first_path = saver.save(os.path.join(checkpoint_directory, "first"))
     self.evaluate(state_ops.assign(save_root.dep.var, 13.))
-    second_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "second"), save_root)
+    second_path = saver.save(os.path.join(checkpoint_directory, "second"))
 
     first_root = checkpointable.Checkpointable()
     second_root = checkpointable.Checkpointable()
-    first_status = checkpointable_utils.restore(
-        first_path, first_root)
-    second_status = checkpointable_utils.restore(
-        second_path, second_root)
+    first_status = checkpointable_utils.Saver(first_root).restore(first_path)
+    second_status = checkpointable_utils.Saver(second_root).restore(second_path)
     load_dep = checkpointable.Checkpointable()
     load_dep.var = checkpointable_utils.add_variable(
         load_dep, name="var", shape=[])
     first_root.dep = load_dep
     first_status.assert_consumed()
-    self.evaluate(first_status.restore_ops)
-    self.assertEqual([], second_status.restore_ops)
+    first_status.run_restore_ops()
     self.assertEqual(12., self.evaluate(load_dep.var))
     second_root.dep = load_dep
     second_status.assert_consumed()
-    self.evaluate(second_status.restore_ops)
+    second_status.run_restore_ops()
     self.assertEqual(13., self.evaluate(load_dep.var))
 
     # Try again with the order of the restore() reversed. The last restore
     # determines the final value.
     first_root = checkpointable.Checkpointable()
     second_root = checkpointable.Checkpointable()
-    second_status = checkpointable_utils.restore(
-        second_path, second_root)
-    first_status = checkpointable_utils.restore(
-        first_path, first_root)
+    second_status = checkpointable_utils.Saver(second_root).restore(second_path)
+    first_status = checkpointable_utils.Saver(first_root).restore(first_path)
     load_dep = checkpointable.Checkpointable()
     load_dep.var = checkpointable_utils.add_variable(
         load_dep, name="var", shape=[])
     first_root.dep = load_dep
     first_status.assert_consumed()
-    self.assertEqual([], second_status.restore_ops)
-    self.evaluate(first_status.restore_ops)
+    first_status.run_restore_ops()
     self.assertEqual(12., self.evaluate(load_dep.var))
     second_root.dep = load_dep
     second_status.assert_consumed()
-    self.evaluate(second_status.restore_ops)
+    second_status.run_restore_ops()
     self.assertEqual(12., self.evaluate(load_dep.var))
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
@@ -724,10 +710,10 @@ class CheckpointingTests(test.TestCase):
     save_root.dep_two.dep_three = dep_three
     checkpointable_utils.add_variable(dep_three, name="var", initializer=0.)
     self.evaluate(variables.global_variables_initializer())
-    save_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "ckpt"), save_root)
+    save_path = checkpointable_utils.Saver(save_root).save(
+        os.path.join(checkpoint_directory, "ckpt"))
     load_root = checkpointable.Checkpointable()
-    checkpointable_utils.restore(save_path, load_root)
+    checkpointable_utils.Saver(load_root).restore(save_path)
     load_root.dep_one = checkpointable.Checkpointable()
     load_root.dep_two = checkpointable.Checkpointable()
     load_root.dep_one.dep_three = checkpointable.Checkpointable()
@@ -747,8 +733,8 @@ class CheckpointingTests(test.TestCase):
     checkpointable_utils.add_variable(
         save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64)
     self.evaluate(variables.global_variables_initializer())
-    save_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "ckpt"), save_root)
+    save_path = checkpointable_utils.Saver(save_root).save(
+        os.path.join(checkpoint_directory, "ckpt"))
     load_root = checkpointable.Checkpointable()
     load_root.dep_one = checkpointable.Checkpointable()
     load_root.dep_two = load_root.dep_one
@@ -756,9 +742,9 @@ class CheckpointingTests(test.TestCase):
         load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64)
     v2 = checkpointable_utils.add_variable(
         load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
-    status = checkpointable_utils.restore(
-        save_path, load_root).assert_consumed()
-    self.evaluate(status.restore_ops)
+    status = checkpointable_utils.Saver(load_root).restore(
+        save_path).assert_consumed()
+    status.run_restore_ops()
     self.assertEqual(32., self.evaluate(v1))
     self.assertEqual(64., self.evaluate(v2))
 
@@ -776,12 +762,12 @@ class CheckpointingTests(test.TestCase):
         second, "v2", initializer=[1., 1., 2., 3.])
     self.evaluate(variables.global_variables_initializer())
     checkpoint_directory = self.get_temp_dir()
-    save_path = checkpointable_utils.save(
-        os.path.join(checkpoint_directory, "ckpt"), first)
+    save_path = checkpointable_utils.Saver(first).save(
+        os.path.join(checkpoint_directory, "ckpt"))
 
     # Test deferred loading
     first_load = checkpointable.Checkpointable()
-    status = checkpointable_utils.restore(save_path, first_load)
+    status = checkpointable_utils.Saver(first_load).restore(save_path)
     second_load = checkpointable.Checkpointable()
     first_load.second = second_load
     second_load.first = first_load
@@ -792,7 +778,7 @@ class CheckpointingTests(test.TestCase):
     second_load.v = checkpointable_utils.add_variable(
         second_load, "v2", shape=[4])
     status.assert_consumed()
-    self.evaluate(status.restore_ops)
+    status.run_restore_ops()
     self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
     self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
 
@@ -801,9 +787,9 @@ class CheckpointingTests(test.TestCase):
     self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v))
     self.evaluate(second_load.v.assign([2., 7., 1., 8.]))
     self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v))
-    status = checkpointable_utils.restore(
-        save_path, first_load).assert_consumed()
-    self.evaluate(status.restore_ops)
+    status = checkpointable_utils.Saver(first_load).restore(
+        save_path).assert_consumed()
+    status.run_restore_ops()
     self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
     self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
 
@@ -820,27 +806,24 @@ class CheckpointingTests(test.TestCase):
           name="blah", initializer=0.)
       self.evaluate(first.var1.assign(4.))
       self.evaluate(first.var2.assign(8.))
-      save_path = checkpointable_utils.save(
-          checkpoint_prefix, root_checkpointable=first)
+      save_path = checkpointable_utils.Saver(first).save(
+          checkpoint_prefix)
     restore_graph = ops.Graph()
     with restore_graph.as_default(), self.test_session(restore_graph):
       second = checkpointable.Checkpointable()
       second.var2 = variable_scope.get_variable(
           name="blah", initializer=0.)
-      status = checkpointable_utils.restore(
-          save_path, root_checkpointable=second)
+      status = checkpointable_utils.Saver(second).restore(save_path)
       recreated_var1 = variable_scope.get_variable(
           name="outside_var", initializer=0.)
-      self.evaluate(status.restore_ops)
+      status.run_restore_ops()
       self.assertEqual(8., self.evaluate(second.var2))
       self.evaluate(recreated_var1.assign(-2.))
       self.assertEqual(-2., self.evaluate(recreated_var1))
       second.var1 = recreated_var1
-      self.evaluate(status.restore_ops)
+      status.run_restore_ops()
       self.assertEqual(4., self.evaluate(recreated_var1))
 
-  # TODO(allenl): Saver class that doesn't pollute the graph with constants.
-  @unittest.skip("todo")
   def testManySavesGraph(self):
     """Saves after the first should not modify the graph."""
     with context.graph_mode():
@@ -853,14 +836,12 @@ class CheckpointingTests(test.TestCase):
         obj.opt = CheckpointableAdam(0.1)
         obj.opt.minimize(obj.var.read_value())
         self.evaluate(variables.global_variables_initializer())
-        checkpointable_utils.save(
-            checkpoint_prefix, root_checkpointable=obj)
+        saver = checkpointable_utils.Saver(obj)
+        saver.save(checkpoint_prefix)
         before_ops = graph.get_operations()
-        checkpointable_utils.save(
-            checkpoint_prefix, root_checkpointable=obj)
+        saver.save(checkpoint_prefix)
         self.assertEqual(before_ops, graph.get_operations())
 
-  @unittest.skip("todo")
   def testManyRestoresGraph(self):
     """Restores after the first should not modify the graph."""
     with context.graph_mode():
@@ -873,13 +854,11 @@ class CheckpointingTests(test.TestCase):
         obj.opt = CheckpointableAdam(0.1)
         obj.opt.minimize(obj.var.read_value())
         self.evaluate(variables.global_variables_initializer())
-        save_path = checkpointable_utils.save(
-            checkpoint_prefix, root_checkpointable=obj)
-        checkpointable_utils.restore(
-            save_path, root_checkpointable=obj)
+        saver = checkpointable_utils.Saver(obj)
+        save_path = saver.save(checkpoint_prefix)
+        saver.restore(save_path)
         before_ops = graph.get_operations()
-        checkpointable_utils.restore(
-            save_path, root_checkpointable=obj)
+        saver.restore(save_path)
         self.assertEqual(before_ops, graph.get_operations())
 
 if __name__ == "__main__":
index 6a7ece4..4c8c735 100644 (file)
@@ -2860,7 +2860,6 @@ py_library(
         ":dtypes",
         ":io_ops_gen",
         ":ops",
-        ":pywrap_tensorflow",
         ":util",
         "//tensorflow/python/eager:context",
     ],
index b785d0e..d382683 100644 (file)
@@ -792,17 +792,7 @@ class Variable(checkpointable.CheckpointableBase):
 
     setattr(Variable, operator, _run_op)
 
-  def _scatter_tensors_from_checkpoint(self, attributes):
-    """For implementing `Checkpointable`. Return an assignment op to run."""
-    if (len(attributes) != 1
-        or checkpointable.VARIABLE_VALUE_KEY not in attributes):
-      raise ValueError(
-          ("The variable %s was restored with unexpected values (expected one "
-           "with key %s, got %s)") % (
-               self, checkpointable.VARIABLE_VALUE_KEY, attributes))
-    return self.assign(attributes[checkpointable.VARIABLE_VALUE_KEY])
-
-  def _gather_tensors_for_checkpoint(self):
+  def _gather_saveables_for_checkpoint(self):
     """For implementing `Checkpointable`. This object is saveable on its own."""
     return {checkpointable.VARIABLE_VALUE_KEY: self}
 
index 9d62c5f..11caa76 100644 (file)
@@ -18,9 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import weakref
 
-from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -28,7 +26,7 @@ from tensorflow.python.ops import gen_io_ops as io_ops
 from tensorflow.python.util import nest
 
 # A key indicating a variable's value in an object's checkpointed Tensors
-# (Checkpointable._gather_tensors_for_checkpoint). If this is the only key and
+# (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and
 # the object has no dependencies, then its value may be restored on object
 # creation (avoiding double assignment when executing eagerly).
 VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
@@ -57,7 +55,7 @@ class CheckpointInitialValue(ops.Tensor):
   """
 
   def __init__(self, checkpoint_position, shape=None):
-    self.wrapped_value = checkpoint_position.restore_ops()[
+    self.wrapped_value = checkpoint_position.value_tensors()[
         VARIABLE_VALUE_KEY]
     if shape:
       # We need to set the static shape information on the initializer if
@@ -168,22 +166,86 @@ class _CheckpointPosition(object):
             and attributes[0].name == VARIABLE_VALUE_KEY
             and not self.object_proto.children)
 
-  def restore_ops(self):
-    """Create restore ops for this object's attributes."""
-    restore_tensors = {}
+  def value_tensors(self):
+    """Create value `Tensor`s for this object's attributes.
+
+    Does not require that the Python object has been created. Used for
+    restore-on-create when executing eagerly.
+
+    Returns:
+      A dictionary mapping from object attribute names to `Tensor`s.
+    """
+    value_tensors = {}
     for serialized_tensor in self.object_proto.attributes:
       checkpoint_key = serialized_tensor.checkpoint_key
       dtype = self._checkpoint.dtype_map[checkpoint_key]
       base_type = dtype.base_dtype
       with ops.init_scope():
-        restore, = io_ops.restore_v2(
+        value, = io_ops.restore_v2(
             prefix=self._checkpoint.save_path,
             tensor_names=[checkpoint_key],
             shape_and_slices=[""],
             dtypes=[base_type],
             name="%s_checkpoint_read" % (serialized_tensor.name,))
-        restore_tensors[serialized_tensor.name] = restore
-      return restore_tensors
+        value_tensors[serialized_tensor.name] = value
+      return value_tensors
+
+  def restore_ops(self):
+    """Create or fetch restore ops for this object's attributes.
+
+    Requires that the `Checkpointable` Python object has been bound to an object
+    ID in the checkpoint.
+
+    Returns:
+      A list of operations when graph building, or an empty list when executing
+      eagerly.
+    """
+    saveables = self.checkpointable._gather_saveables_for_checkpoint()  # pylint: disable=protected-access
+    # Name saveables based on the name this object had when it was checkpointed.
+    named_saveables = {}
+    restore_ops = []
+    in_graph_mode = context.in_graph_mode()
+    for serialized_tensor in self.object_proto.attributes:
+      saveable_object = saveables.get(serialized_tensor.name, None)
+      if saveable_object is None:
+        # Purposefully does not throw an exception if attributes have been added
+        # or deleted. Stores unused attributes so an exception can be raised if
+        # the user decides to check that everything in the checkpoint was
+        # loaded.
+        self._checkpoint.unused_attributes.setdefault(
+            self.checkpointable, []).append(serialized_tensor.name)
+        continue
+      if in_graph_mode:
+        existing_ops = self._checkpoint.restore_ops_by_name.get(
+            serialized_tensor.name, None)
+      else:
+        existing_ops = None
+      if existing_ops is None:
+        named_saveables[serialized_tensor.checkpoint_key] = saveable_object
+    if named_saveables:
+      validated_saveables = (
+          self._checkpoint.builder._ValidateAndSliceInputs(named_saveables))  # pylint: disable=protected-access
+      validated_names = set(saveable.name for saveable in validated_saveables)
+      if set(named_saveables.keys()) != validated_names:
+        raise AssertionError(
+            ("Saveable keys changed when validating. Got back %s, was "
+             "expecting %s") % (named_saveables.keys(), validated_names))
+      all_tensors = self._checkpoint.builder.bulk_restore(
+          filename_tensor=self._checkpoint.save_path,
+          saveables=validated_saveables, preferred_shard=-1,
+          restore_sequentially=False)
+      saveable_index = 0
+      for saveable in validated_saveables:
+        num_specs = len(saveable.specs)
+        saveable_tensors = all_tensors[
+            saveable_index:saveable_index + num_specs]
+        saveable_index += num_specs
+        restore_op = saveable.restore(saveable_tensors, restored_shapes=None)
+        if in_graph_mode:
+          assert saveable.name not in self._checkpoint.restore_ops_by_name
+          self._checkpoint.restore_ops_by_name[saveable.name] = restore_op
+          restore_ops.append(restore_op)
+    return restore_ops
 
   @property
   def checkpoint(self):
@@ -225,54 +287,6 @@ _SlotVariableRestoration = collections.namedtuple(
     ])
 
 
-class _Checkpoint(object):
-  """Holds the status of an object-based checkpoint load."""
-
-  def __init__(self, object_graph_proto, save_path):
-    """Specify the checkpoint being loaded.
-
-    Args:
-      object_graph_proto: The CheckpointableObjectGraph protocol buffer
-        associated with this checkpoint.
-      save_path: The path to the checkpoint, as returned by
-        `tf.train.latest_checkpoint`.
-    """
-    self.object_graph_proto = object_graph_proto
-    self.restore_uid = ops.uid()
-    # Dictionary mapping from an id in the protocol buffer flat array to
-    # Checkpointable Python objects. This mapping may be deferred if a
-    # checkpoint is restored before all dependencies have been tracked. Uses
-    # weak references so that partial restorations don't create reference cycles
-    # (as objects with deferred dependencies will generally have references to
-    # this object).
-    self.object_by_proto_id = weakref.WeakValueDictionary()
-    self.save_path = save_path
-    reader = pywrap_tensorflow.NewCheckpointReader(save_path)
-    self.dtype_map = reader.get_variable_to_dtype_map()
-    # When graph building, contains a list of ops to run to restore objects from
-    # this checkpoint.
-    self.restore_ops = []
-    # A mapping from optimizer proto ids to lists of slot variables to be
-    # restored when the optimizer is tracked. Only includes slot variables whose
-    # regular variables have already been created, and only for optimizer
-    # objects which have not yet been created/tracked.
-    self.deferred_slot_restorations = {}
-    # A mapping from variable proto ids to lists of slot variables to be
-    # restored when the variable is created/tracked. These get shifted over to
-    # deferred_slot_restorations if the optimizer hasn't been created when that
-    # happens.
-    self.slot_restorations = {}
-    for node_index, node in enumerate(self.object_graph_proto.nodes):
-      for slot_reference in node.slot_variables:
-        # `node` refers to an `Optimizer`, since only these have slot variables.
-        self.slot_restorations.setdefault(
-            slot_reference.original_variable_node_id, []).append(
-                _SlotVariableRestoration(
-                    optimizer_id=node_index,
-                    slot_variable_id=slot_reference.slot_variable_node_id,
-                    slot_name=slot_reference.slot_name))
-
-
 class CheckpointableBase(object):
   """Base class for `Checkpointable` objects without automatic dependencies.
 
@@ -415,13 +429,10 @@ class CheckpointableBase(object):
     Indicates that checkpoints for this object should include variables from
     `checkpointable`.
 
-    Variables in a checkpoint are mapped to `Checkpointable`s based on names if
-    provided when the checkpoint was written, but otherwise use the order those
-    `Checkpointable`s were declared as dependencies.
-
-    To avoid breaking existing checkpoints when modifying a class, neither
-    variable names nor dependency names (the names passed to
-    `track_checkpointable`) may change.
+    Variables in a checkpoint are mapped to `Checkpointable`s based on the names
+    provided when the checkpoint was written. To avoid breaking existing
+    checkpoints when modifying a class, neither variable names nor dependency
+    names (the names passed to `_track_checkpointable`) may change.
 
     Args:
       checkpointable: A `Checkpointable` which this object depends on.
@@ -493,11 +504,11 @@ class CheckpointableBase(object):
     # need to actually restore the object. However, we should pass the
     # restoration on to our dependencies.
     if checkpoint.restore_uid > self._update_uid:
-      restore_op = self._scatter_tensors_from_checkpoint(
-          checkpoint_position.restore_ops())
+      restore_ops = checkpoint_position.restore_ops()
+      # TODO(allenl): Get a list of feeds for saving Python state
       self._update_uid = checkpoint.restore_uid
     else:
-      restore_op = ()
+      restore_ops = ()
     for child in checkpoint_position.object_proto.children:
       child_position = _CheckpointPosition(
           checkpoint=checkpoint,
@@ -515,25 +526,21 @@ class CheckpointableBase(object):
           # resolution order (shallowest paths first). The caller is responsible
           # for emptying visit_queue.
           visit_queue.append(child_position)
-    return restore_op
+    return restore_ops
 
-  def _scatter_tensors_from_checkpoint(self, attributes):
-    """Restores this object from a checkpoint.
+  def _gather_saveables_for_checkpoint(self):
+    """Returns a dictionary of values to checkpoint with this object.
 
-    Args:
-      attributes: A dictionary of Tensors, with key corresponding to those
-        returned from _gather_tensors_for_checkpoint.
-    Returns:
-      A restore op to run (if graph building).
-    """
-    if attributes:
-      raise AssertionError(
-          ("A Checkpointable object which was not expecting any data received "
-           "some from a checkpoint. (Got %s)") % (attributes,))
-    return ()  # No restore ops
+    Keys in the returned dictionary are local to this object and in a separate
+    namespace from dependencies. Values may either be `SaveableObject`s or
+    variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
+    `var_list` constructor argument).
 
-  def _gather_tensors_for_checkpoint(self):
-    """Returns a dictionary of Tensors to save with this object."""
+    Returned values must be saved only by this object; if any value may be
+    shared, it should instead be a dependency. For example, variable objects
+    save their own values with the key `VARIABLE_VALUE_KEY`, but objects which
+    reference variables simply add a dependency.
+    """
     return {}
 
 
@@ -562,8 +569,7 @@ class Checkpointable(CheckpointableBase):
   `Checkpointable` objects may specify `Tensor`s to be saved and restored
   directly (e.g. a `Variable` indicating how to save itself) rather than through
   dependencies on other objects. See
-  `Checkpointable._scatter_tensors_from_checkpoint` and
-  `Checkpointable._gather_tensors_for_checkpoint` for details.
+  `Checkpointable._gather_saveables_for_checkpoint` for details.
   """
 
   def __setattr__(self, name, value):
diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py
new file mode 100644 (file)
index 0000000..32123f8
--- /dev/null
@@ -0,0 +1,78 @@
+"""Utilities for saving/loading Checkpointable objects."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import weakref
+
+from tensorflow.python.framework import ops
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training import saver as saver_lib
+
+
+class _Checkpoint(object):
+  """Holds the status of an object-based checkpoint load."""
+
+  def __init__(self, object_graph_proto, save_path, dtype_map=None):
+    """Specify the checkpoint being loaded.
+
+    Args:
+      object_graph_proto: The CheckpointableObjectGraph protocol buffer
+        associated with this checkpoint.
+      save_path: A string `Tensor`. The path to the checkpoint, as returned by
+        `tf.train.latest_checkpoint`.
+      dtype_map: When executing eagerly, specifies dtypes for creating slot
+        variables. None when graph building.
+    """
+    self.builder = saver_lib.BulkSaverBuilder()
+    self.object_graph_proto = object_graph_proto
+    self.restore_uid = ops.uid()
+    # Maps from objects to lists of attributes which were in the checkpoint but
+    # not loaded into any object, for error checking.
+    self.unused_attributes = weakref.WeakKeyDictionary()
+    # Dictionary mapping from an id in the protocol buffer flat array to
+    # Checkpointable Python objects. This mapping may be deferred if a
+    # checkpoint is restored before all dependencies have been tracked. Uses
+    # weak references so that partial restorations don't create reference cycles
+    # (as objects with deferred dependencies will generally have references to
+    # this object).
+    self.object_by_proto_id = weakref.WeakValueDictionary()
+    self.save_path = save_path
+    self.dtype_map = dtype_map
+    # When graph building, contains a list of ops to run to restore objects from
+    # this checkpoint.
+    self.restore_ops = []
+    self.restore_ops_by_name = {}
+    # A mapping from optimizer proto ids to lists of slot variables to be
+    # restored when the optimizer is tracked. Only includes slot variables whose
+    # regular variables have already been created, and only for optimizer
+    # objects which have not yet been created/tracked.
+    self.deferred_slot_restorations = {}
+    # A mapping from variable proto ids to lists of slot variables to be
+    # restored when the variable is created/tracked. These get shifted over to
+    # deferred_slot_restorations if the optimizer hasn't been created when that
+    # happens.
+    self.slot_restorations = {}
+    for node_index, node in enumerate(self.object_graph_proto.nodes):
+      for slot_reference in node.slot_variables:
+        # `node` refers to an `Optimizer`, since only these have slot variables.
+        self.slot_restorations.setdefault(
+            slot_reference.original_variable_node_id, []).append(
+                checkpointable._SlotVariableRestoration(  # pylint: disable=protected-access
+                    optimizer_id=node_index,
+                    slot_variable_id=slot_reference.slot_variable_node_id,
+                    slot_name=slot_reference.slot_name))