named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
tensor=object_graph_tensor,
name=_OBJECT_GRAPH_PROTO_KEY)
- if self._last_save_object_graph != graph_proto:
+ if (self._last_save_object_graph != graph_proto
+ # When executing eagerly, we need to re-create SaveableObjects each time
+ # save() is called so they pick up new Tensors passed to their
+ # constructors. That means the Saver needs to be copied with a new
+ # var_list.
+ or context.executing_eagerly()):
if self._last_save_object_graph is not None:
self._last_save_saver = _copy_saver_with_new_var_list(
old_saver=self._last_save_saver, new_var_list=named_variables)
self.assertAllEqual([1., 1., 1.], self.evaluate(v2))
-class _MirroringSaveable(
- core_saver.BaseSaverBuilder.ResourceVariableSaveable):
+class _MirroringSaveable(core_saver.BaseSaverBuilder.SaveableObject):
def __init__(self, primary_variable, mirrored_variable, name):
self._primary_variable = primary_variable
self._mirrored_variable = mirrored_variable
+ tensor = self._primary_variable.read_value()
+ spec = core_saver.BaseSaverBuilder.SaveSpec(
+ tensor=tensor,
+ slice_spec="",
+ name=name)
super(_MirroringSaveable, self).__init__(
- self._primary_variable, "", name)
+ tensor, [spec], name)
def restore(self, restored_tensors, restored_shapes):
"""Restore the same value into both variables."""
checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
self.assertEqual(42., self.evaluate(v.non_dep_variable))
self.assertEqual(42., self.evaluate(v.mirrored))
+ self.evaluate(v.non_dep_variable.assign(44.))
+ save_path = checkpoint.save(prefix)
+ self.evaluate(v.non_dep_variable.assign(45.))
+ checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
+ self.assertEqual(44., self.evaluate(v.non_dep_variable))
+ self.assertEqual(44., self.evaluate(v.mirrored))
@test_util.run_in_graph_and_eager_modes()
def testMoreComplexSaveableReturnedWithGlobalName(self):