Checkpointable: allow using Checkpointable objects in a tf.train.Saver()
authorAllen Lavoie <allenl@google.com>
Sat, 24 Feb 2018 00:53:21 +0000 (16:53 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 24 Feb 2018 01:03:48 +0000 (17:03 -0800)
Checkpointable objects in a Saver's var_list will be unpacked into their
SaveableObjects, possibly running some Python logic along the way.

This should help keep the transition from name-based saving smooth: to save
either way, just override CheckpointableBase._gather_saveables_for_checkpoint.

PiperOrigin-RevId: 186843857

tensorflow/python/training/saver.py
tensorflow/python/training/saver_test.py

index 83e848d..9afd1e6 100644 (file)
@@ -50,6 +50,7 @@ from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpointable
 from tensorflow.python.training import training_util
 from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
 from tensorflow.python.util import compat
@@ -577,6 +578,11 @@ class BaseSaverBuilder(object):
           names_to_saveables[name].append(var)
         else:
           names_to_saveables[name] = [var]
+      elif (isinstance(var, checkpointable.CheckpointableBase)
+            and not isinstance(var, variables.Variable)):
+        names_to_saveables.update(
+            BaseSaverBuilder.OpListToDict(
+                list(var._gather_saveables_for_checkpoint().values())))
       else:
         if context.in_graph_mode():
           if convert_variable_to_tensor:
index c5a6f49..f00f98d 100644 (file)
@@ -66,6 +66,7 @@ from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.summary import summary
 from tensorflow.python.training import adam
+from tensorflow.python.training import checkpointable
 from tensorflow.python.training import gradient_descent
 from tensorflow.python.training import queue_runner_impl
 from tensorflow.python.training import saver as saver_module
@@ -2660,5 +2661,92 @@ class ScopedGraphTest(test.TestCase):
       self.assertEqual(2.0, var_dict2["variable2:0"].eval())
 
 
+class _OwnsAVariableSimple(checkpointable.CheckpointableBase):
+  """A Checkpointable object which can be saved using a tf.train.Saver."""
+
+  def __init__(self):
+    self.non_dep_variable = variable_scope.get_variable(
+        name="non_dep_variable", initializer=6., use_resource=True)
+
+  def _gather_saveables_for_checkpoint(self):
+    return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable}
+
+  # The Saver sorts by name before parsing, so we need a name property.
+  @property
+  def name(self):
+    return self.non_dep_variable.name
+
+
+class _MirroringSaveable(
+    saver_module.BaseSaverBuilder.ResourceVariableSaveable):
+
+  def __init__(self, primary_variable, mirrored_variable):
+    self._primary_variable = primary_variable
+    self._mirrored_variable = mirrored_variable
+    super(_MirroringSaveable, self).__init__(
+        self._primary_variable, "", self._primary_variable.name)
+
+  def restore(self, restored_tensors, restored_shapes):
+    """Restore the same value into both variables."""
+    tensor, = restored_tensors
+    return control_flow_ops.group(
+        self._primary_variable.assign(tensor),
+        self._mirrored_variable.assign(tensor))
+
+
+class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
+  """A Checkpointable object which returns a more complex SaveableObject."""
+
+  def __init__(self):
+    self.non_dep_variable = variable_scope.get_variable(
+        name="non_dep_variable", initializer=6., use_resource=True)
+    self.mirrored = variable_scope.get_variable(
+        name="mirrored", initializer=15., use_resource=True)
+
+  def _gather_saveables_for_checkpoint(self):
+    saveable = _MirroringSaveable(
+        primary_variable=self.non_dep_variable,
+        mirrored_variable=self.mirrored)
+    return {checkpointable.VARIABLE_VALUE_KEY: saveable}
+
+  # The Saver sorts by name before parsing, so we need a name property.
+  @property
+  def name(self):
+    return self.non_dep_variable.name
+
+
+@test_util.with_c_api
+class CheckpointableCompatibilityTests(test.TestCase):
+
+  # TODO(allenl): Track down python3 reference cycles in these tests.
+  @test_util.run_in_graph_and_eager_modes()
+  def testNotSaveableButIsCheckpointable(self):
+    v = _OwnsAVariableSimple()
+    saver = saver_module.Saver(var_list=[v])
+    test_dir = self.get_temp_dir()
+    prefix = os.path.join(test_dir, "ckpt")
+    self.evaluate(v.non_dep_variable.assign(42.))
+    with self.test_session() as sess:
+      save_path = saver.save(sess, prefix)
+      self.evaluate(v.non_dep_variable.assign(43.))
+      saver.restore(sess, save_path)
+      self.assertEqual(42., self.evaluate(v.non_dep_variable))
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testMoreComplexSaveableReturned(self):
+    v = _OwnsMirroredVariables()
+    saver = saver_module.Saver(var_list=[v])
+    test_dir = self.get_temp_dir()
+    prefix = os.path.join(test_dir, "ckpt")
+    self.evaluate(v.non_dep_variable.assign(42.))
+    with self.test_session() as sess:
+      save_path = saver.save(sess, prefix)
+      self.evaluate(v.non_dep_variable.assign(43.))
+      self.evaluate(v.mirrored.assign(44.))
+      saver.restore(sess, save_path)
+      self.assertEqual(42., self.evaluate(v.non_dep_variable))
+      self.assertEqual(42., self.evaluate(v.mirrored))
+
+
 if __name__ == "__main__":
   test.main()