Checkpointable: Fix a bug where SaveableObjects in the Saver's var_list were stale...
authorAllen Lavoie <allenl@google.com>
Fri, 23 Mar 2018 00:29:47 +0000 (17:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Mar 2018 00:32:37 +0000 (17:32 -0700)
SaveableObjects were recreated each save(), but the Saver was not recreated and so had an old var_list.

PiperOrigin-RevId: 190153011

tensorflow/contrib/eager/python/checkpointable_utils.py
tensorflow/contrib/eager/python/checkpointable_utils_test.py

index adbb92e..91a7ade 100644 (file)
@@ -631,7 +631,12 @@ class CheckpointableSaver(object):
     named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
         tensor=object_graph_tensor,
         name=_OBJECT_GRAPH_PROTO_KEY)
-    if self._last_save_object_graph != graph_proto:
+    if (self._last_save_object_graph != graph_proto
+        # When executing eagerly, we need to re-create SaveableObjects each time
+        # save() is called so they pick up new Tensors passed to their
+        # constructors. That means the Saver needs to be copied with a new
+        # var_list.
+        or context.executing_eagerly()):
       if self._last_save_object_graph is not None:
         self._last_save_saver = _copy_saver_with_new_var_list(
             old_saver=self._last_save_saver, new_var_list=named_variables)
index 690f3ee..a4b215b 100644 (file)
@@ -154,14 +154,18 @@ class InterfaceTests(test.TestCase):
     self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
 
 
-class _MirroringSaveable(
-    core_saver.BaseSaverBuilder.ResourceVariableSaveable):
+class _MirroringSaveable(core_saver.BaseSaverBuilder.SaveableObject):
 
   def __init__(self, primary_variable, mirrored_variable, name):
     self._primary_variable = primary_variable
     self._mirrored_variable = mirrored_variable
+    tensor = self._primary_variable.read_value()
+    spec = core_saver.BaseSaverBuilder.SaveSpec(
+        tensor=tensor,
+        slice_spec="",
+        name=name)
     super(_MirroringSaveable, self).__init__(
-        self._primary_variable, "", name)
+        tensor, [spec], name)
 
   def restore(self, restored_tensors, restored_shapes):
     """Restore the same value into both variables."""
@@ -316,6 +320,12 @@ class CheckpointingTests(test.TestCase):
     checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
     self.assertEqual(42., self.evaluate(v.non_dep_variable))
     self.assertEqual(42., self.evaluate(v.mirrored))
+    self.evaluate(v.non_dep_variable.assign(44.))
+    save_path = checkpoint.save(prefix)
+    self.evaluate(v.non_dep_variable.assign(45.))
+    checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
+    self.assertEqual(44., self.evaluate(v.non_dep_variable))
+    self.assertEqual(44., self.evaluate(v.mirrored))
 
   @test_util.run_in_graph_and_eager_modes()
   def testMoreComplexSaveableReturnedWithGlobalName(self):