Allow tf.train.Saver to load object-based checkpoints (using names)
authorAllen Lavoie <allenl@google.com>
Fri, 13 Apr 2018 21:32:45 +0000 (14:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 21:35:26 +0000 (14:35 -0700)
This is the second part of the compatibility story. Object-based checkpointing APIs can already read name-based checkpoints, and now the name-based APIs can read object-based checkpoints by looking up the modified keys in the object graph proto.

PiperOrigin-RevId: 192824907

tensorflow/python/training/checkpointable.py
tensorflow/python/training/checkpointable_utils.py
tensorflow/python/training/checkpointable_utils_test.py
tensorflow/python/training/saver.py
tensorflow/python/training/saver_test.py

index 9bf48df..0b84737 100644 (file)
@@ -26,6 +26,11 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_io_ops as io_ops
 from tensorflow.python.util import nest
 
+
+# Key where the object graph proto is saved in a TensorBundle
+OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
+
+
 # A key indicating a variable's value in an object's checkpointed Tensors
 # (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and
 # the object has no dependencies, then its value may be restored on object
index da99d2e..2c4677a 100644 (file)
@@ -54,8 +54,6 @@ _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
 # attribute in checkpoint names. Used like:
 #   <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
 _OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
-# Key where the object graph proto is saved in a TensorBundle
-_OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
 
 
 class _CheckpointRestoreCoordinator(object):
@@ -680,10 +678,11 @@ class CheckpointableSaver(object):
         object_graph_tensor = constant_op.constant(
             graph_proto.SerializeToString(), dtype=dtypes.string)
       feed_additions = None
-    assert _OBJECT_GRAPH_PROTO_KEY not in named_variables
-    named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
-        tensor=object_graph_tensor,
-        name=_OBJECT_GRAPH_PROTO_KEY)
+    assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables
+    named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = (
+        _NoRestoreSaveable(
+            tensor=object_graph_tensor,
+            name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
     if (self._last_save_object_graph != graph_proto
         # When executing eagerly, we need to re-create SaveableObjects each time
         # save() is called so they pick up new Tensors passed to their
@@ -786,7 +785,8 @@ class CheckpointableSaver(object):
       file_prefix_feed_dict = None
     reader = pywrap_tensorflow.NewCheckpointReader(save_path)
     try:
-      object_graph_string = reader.get_tensor(_OBJECT_GRAPH_PROTO_KEY)
+      object_graph_string = reader.get_tensor(
+          checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
     except errors_impl.NotFoundError:
       # The object graph proto does not exist in this checkpoint. Try again with
       # name-based saving.
index ddf9820..29fcdb7 100644 (file)
@@ -1268,9 +1268,6 @@ class CheckpointCompatibilityTests(test.TestCase):
       status.initialize_or_restore()
       self._check_sentinels(root)
 
-  # TODO(allenl): Test for the core name-based saver loading object-based
-  # checkpoints once object-based checkpointing is in core.
-
   def testSaveGraphLoadEager(self):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
index e40b8d2..79d278c 100644 (file)
@@ -22,6 +22,7 @@ from __future__ import print_function
 import collections
 import os.path
 import re
+import sys
 import time
 import uuid
 
@@ -30,8 +31,10 @@ import six
 
 from google.protobuf import text_format
 
+from tensorflow.core.protobuf import checkpointable_object_graph_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.client import session
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
@@ -1340,6 +1343,9 @@ class Saver(object):
       self._check_saver_def()
       self._write_version = self.saver_def.version
     self._save_relative_paths = save_relative_paths
+    # For compatibility with object-based checkpoints, we may build a second
+    # Saver to read the renamed keys.
+    self._object_restore_saver = None
 
   def build(self):
     if context.executing_eagerly():
@@ -1795,11 +1801,65 @@ class Saver(object):
     if save_path is None:
       raise ValueError("Can't load save_path when it is None.")
     logging.info("Restoring parameters from %s", save_path)
-    if context.executing_eagerly():
-      self._build_eager(save_path, build_save=False, build_restore=True)
-    else:
-      sess.run(self.saver_def.restore_op_name,
-               {self.saver_def.filename_tensor_name: save_path})
+    try:
+      if context.executing_eagerly():
+        self._build_eager(save_path, build_save=False, build_restore=True)
+      else:
+        sess.run(self.saver_def.restore_op_name,
+                 {self.saver_def.filename_tensor_name: save_path})
+    except errors.NotFoundError:
+      exception_type, exception_value, exception_traceback = sys.exc_info()
+      # The checkpoint would not be loaded successfully as is. Try to parse it
+      # as an object-based checkpoint.
+      try:
+        reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+        object_graph_string = reader.get_tensor(
+            checkpointable.OBJECT_GRAPH_PROTO_KEY)
+      except errors.NotFoundError:
+        # This is not an object-based checkpoint, or the checkpoint doesn't
+        # exist. Re-raise the original exception.
+        six.reraise(exception_type, exception_value, exception_traceback)
+      del exception_traceback  # avoid reference cycles
+
+      # This is an object-based checkpoint. We'll print a warning and then do
+      # the restore.
+      logging.warning(
+          # TODO(allenl): Modify instructions for using the object-based saver
+          # once that's in core.
+          "Restoring an object-based checkpoint using a name-based saver. This "
+          "may be somewhat fragile, and will re-build the Saver. Instead, "
+          "consider loading object-based checkpoints using "
+          "tf.contrib.eager.Checkpoint().")
+      self._restore_from_object_based_checkpoint(
+          sess=sess, save_path=save_path,
+          object_graph_string=object_graph_string)
+
+  def _restore_from_object_based_checkpoint(self, sess, save_path,
+                                            object_graph_string):
+    """A compatibility mode for reading object-based checkpoints."""
+    object_graph_proto = (
+        checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+    object_graph_proto.ParseFromString(object_graph_string)
+    names_to_keys = {}
+    for node in object_graph_proto.nodes:
+      for attribute in node.attributes:
+        names_to_keys[attribute.full_name] = attribute.checkpoint_key
+    saveables = self._builder._ValidateAndSliceInputs(self._var_list)  # pylint: disable=protected-access
+    for saveable in saveables:
+      for spec in saveable.specs:
+        if spec.name not in names_to_keys:
+          raise errors.NotFoundError(
+              None, None,
+              message=("Attempting to load an object-based checkpoint using "
+                       "variable names, but could not find %s in the "
+                       "checkpoint.") % spec.name)
+        spec.name = names_to_keys[spec.name]
+    if self._object_restore_saver is None:
+      # Cache the Saver so multiple restore() calls don't pollute the graph when
+      # graph building. This assumes keys are consistent (i.e. this is the same
+      # type of object-based checkpoint we saw previously).
+      self._object_restore_saver = Saver(saveables)
+    self._object_restore_saver.restore(sess=sess, save_path=save_path)
 
   @staticmethod
   def _add_collection_def(meta_graph_def, key, export_scope=None):
index 14dda79..3867c0d 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import contextlib
+import functools
 import math
 import os
 import random
@@ -50,6 +51,8 @@ from tensorflow.python.framework import graph_io
 from tensorflow.python.framework import meta_graph
 from tensorflow.python.framework import ops as ops_lib
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -69,10 +72,12 @@ from tensorflow.python.platform import test
 from tensorflow.python.summary import summary
 from tensorflow.python.training import adam
 from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
 from tensorflow.python.training import gradient_descent
 from tensorflow.python.training import queue_runner_impl
 from tensorflow.python.training import saver as saver_module
 from tensorflow.python.training import saver_test_utils
+from tensorflow.python.training import training_util
 from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
 from tensorflow.python.util import compat
 
@@ -2948,6 +2953,29 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
     return self.non_dep_variable.name
 
 
+class NonLayerCheckpointable(checkpointable.Checkpointable):
+
+  def __init__(self):
+    super(NonLayerCheckpointable, self).__init__()
+    self.a_variable = checkpointable_utils.add_variable(
+        self, name="a_variable", shape=[])
+
+
+class MyModel(training.Model):
+  """A concrete Model for testing."""
+
+  def __init__(self):
+    super(MyModel, self).__init__()
+    self._named_dense = core.Dense(1, use_bias=True)
+    self._second = core.Dense(1, use_bias=False)
+    # We can still track Checkpointables which aren't Layers.
+    self._non_layer = NonLayerCheckpointable()
+
+  def call(self, values):
+    ret = self._second(self._named_dense(values))
+    return ret
+
+
 @test_util.with_c_api
 class CheckpointableCompatibilityTests(test.TestCase):
 
@@ -3011,6 +3039,128 @@ class CheckpointableCompatibilityTests(test.TestCase):
         saver.restore(sess, save_path)
         self.assertEqual(1, v.eval_count)
 
+  def _initialized_model(self):
+    input_value = constant_op.constant([[3.]])
+    model = MyModel()
+    optimizer = adam.AdamOptimizer(0.001)
+    optimizer_step = training_util.get_or_create_global_step()
+    root_checkpointable = checkpointable_utils.Checkpoint(
+        optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+    train_op = optimizer.minimize(
+        functools.partial(model, input_value),
+        global_step=optimizer_step)
+    self.evaluate(checkpointable_utils.gather_initializers(
+        root_checkpointable))
+    self.evaluate(train_op)
+    # A regular variable, a slot variable, and a non-slot Optimizer variable
+    # with known values to check when loading.
+    self.evaluate(model._named_dense.bias.assign([1.]))
+    self.evaluate(optimizer.get_slot(
+        var=model._named_dense.bias, name="m").assign([2.]))
+    beta1_power, _ = optimizer._get_beta_accumulators()
+    self.evaluate(beta1_power.assign(3.))
+    return root_checkpointable
+
+  def _set_sentinels(self, root_checkpointable):
+    self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+    self.evaluate(
+        root_checkpointable.optimizer.get_slot(
+            var=root_checkpointable.model._named_dense.bias, name="m")
+        .assign([102.]))
+    beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+    self.evaluate(beta1_power.assign(103.))
+
+  def _check_sentinels(self, root_checkpointable):
+    self.assertAllEqual(
+        [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+    self.assertAllEqual([2.], self.evaluate(
+        root_checkpointable.optimizer.get_slot(
+            var=root_checkpointable.model._named_dense.bias, name="m")))
+    beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+    self.assertAllEqual(3., self.evaluate(beta1_power))
+
+  def testVariableNotFoundErrorRaised(self):
+    # Restore does some tricky exception handling to figure out if it should
+    # load an object-based checkpoint. Tests that the exception handling isn't
+    # too broad.
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+    a = resource_variable_ops.ResourceVariable(1., name="a")
+    b = resource_variable_ops.ResourceVariable(1., name="b")
+    a_saver = saver_module.Saver([a])
+    b_saver = saver_module.Saver([b])
+    with self.test_session() as sess:
+      sess.run(a.initializer)
+      save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
+      with self.assertRaisesRegexp(
+          errors.NotFoundError, "Key b not found in checkpoint"):
+        b_saver.restore(sess=sess, save_path=save_path)
+
+  def testCheckpointNotFoundErrorRaised(self):
+    # Restore does some tricky exception handling to figure out if it should
+    # load an object-based checkpoint. Tests that the exception handling isn't
+    # too broad.
+    a = resource_variable_ops.ResourceVariable(1., name="a")
+    saver = saver_module.Saver([a])
+    with self.test_session() as sess:
+      with self.assertRaisesRegexp(
+          errors.NotFoundError,
+          "Failed to find any matching files for path_which_does_not_exist"):
+        saver.restore(sess=sess, save_path="path_which_does_not_exist")
+
+  def testLoadFromObjectBasedGraph(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+    save_graph = ops_lib.Graph()
+    with save_graph.as_default(), self.test_session(graph=save_graph) as sess:
+      root = self._initialized_model()
+      object_saver = checkpointable_utils.CheckpointableSaver(root)
+      save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+      # An incompatible object-based checkpoint to check error messages
+      var = resource_variable_ops.ResourceVariable(1., name="a")
+      self.evaluate(var.initializer)
+      second_saver = checkpointable_utils.CheckpointableSaver(var)
+      second_path = second_saver.save(file_prefix=os.path.join(
+          checkpoint_directory, "second"))
+
+    restore_graph = ops_lib.Graph()
+    with restore_graph.as_default(), self.test_session(
+        graph=restore_graph) as sess:
+      root = self._initialized_model()
+      self._set_sentinels(root)
+      saver = saver_module.Saver()
+      saver.restore(sess=sess, save_path=save_path)
+      self._check_sentinels(root)
+      before_second_restore_ops = restore_graph.get_operations()
+      # Test that multiple restores do not pollute the graph
+      saver.restore(sess=sess, save_path=save_path)
+      self.assertEqual(before_second_restore_ops,
+                       restore_graph.get_operations())
+      with self.assertRaisesRegexp(errors.NotFoundError,
+                                   "could not find a_variable"):
+        saver.restore(sess=sess, save_path=second_path)
+
+  def testLoadFromObjectBasedEager(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+    save_graph = ops_lib.Graph()
+    with save_graph.as_default(), self.test_session(graph=save_graph):
+      root = self._initialized_model()
+      object_saver = checkpointable_utils.CheckpointableSaver(root)
+      save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+    with context.eager_mode():
+      root = self._initialized_model()
+      self._set_sentinels(root)
+      saver = saver_module.Saver(
+          root.model.variables + root.optimizer.variables())
+      saver.restore(sess=None, save_path=save_path)
+      self._check_sentinels(root)
+
 
 if __name__ == "__main__":
   test.main()