import abc
import collections
+import functools
import weakref
from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
# initialization when executing eagerly.
self._maybe_create_save_counter()
return status
+
+
+class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """Wraps save and restore callbacks as a `SaveableObject`."""
+
+ def __init__(self, name, dtype, save_callback, restore_callback):
+ self._restore_callback = restore_callback
+ spec = saver_lib.BaseSaverBuilder.SaveSpec(
+ tensor=save_callback,
+ slice_spec="",
+ name=name,
+ dtype=dtype)
+ super(_CallbackSaveable, self).__init__(
+ save_callback, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return self._restore_callback(tensor)
+
+
+class _SplitDependency(core_checkpointable.CheckpointableBase):
+ """Looks like a regular variable while synchronizing save/restores."""
+
+ def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
+ fill_save_buffer_fn, consume_restore_buffer_fn):
+ self._save_buffer = save_buffer
+ self._restore_buffer = restore_buffer
+ self._name = name
+ self._dtype = dtype
+ self._num_components = num_components
+ self._fill_save_buffer_fn = fill_save_buffer_fn
+ self._consume_restore_buffer_fn = consume_restore_buffer_fn
+
+ def _save(self):
+ """Pull from the shared buffer, populating it if necessary."""
+ if self._name not in self._save_buffer:
+ if self._save_buffer:
+ raise AssertionError(
+ ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+ "be saved together.") % (self._name, self))
+ self._fill_save_buffer_fn(self._save_buffer)
+ return self._save_buffer.pop(self._name)
+
+ def _restore(self, tensor):
+ """Push into the shared buffer, flushing it if necessary."""
+ if self._name in self._restore_buffer:
+ raise AssertionError(
+ ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+ "be restored together.") % (self._name, self))
+ self._restore_buffer[self._name] = tensor
+ if len(self._restore_buffer) == self._num_components:
+ op = self._consume_restore_buffer_fn(self._restore_buffer)
+ self._restore_buffer.clear()
+ return op
+ else:
+ return control_flow_ops.no_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Looks to Checkpointable like a regular variable."""
+ return {
+ core_checkpointable.VARIABLE_VALUE_KEY:
+ functools.partial(_CallbackSaveable,
+ dtype=self._dtype,
+ save_callback=self._save,
+ restore_callback=self._restore)
+ }
+
+
+def split_dependency(component_names, component_dtypes,
+ fill_save_buffer_fn, consume_restore_buffer_fn):
+ """Creates multiple dependencies with a synchronized save/restore.
+
+ Useful when a single op produces `Tensor`s which should each be saved under
+ different objects, or when `Tensor`s saved with many different objects need to
+ be restored together as inputs to a single op (i.e. an object which uses a
+ single fused op may be swapped out for a subgraph of objects, and these two
+ programs are checkpoint compatible).
+
+ Args:
+ component_names: A sequence of names for the split
+ dependencies. `fill_save_buffer_fn` must add these keys to the dictionary
+ it is passed, and `consume_restore_buffer_fn` will receive a dictionary
+ with these keys.
+ component_dtypes: Data types for the `Tensor`s being saved and restored, a
+ sequence corresponding to `component_names`.
+ fill_save_buffer_fn: A function which takes an empty dictionary as an
+ argument and adds `Tensor`s with `component_names` as keys. These
+ `Tensor`s will be saved as if they were individual variables.
+ consume_restore_buffer_fn: A function which takes a dictionary with
+ `component_names` as keys mapping to restored individual `Tensor`s and
+ returns a restore op (or if executing eagerly, runs the restoration and
+ may return `None`).
+
+ Returns:
+ A dictionary mapping from names to Checkpointable objects. If one is
+ reachable from an object as a dependency, the others should be too; adding
+ dependencies on some but not all of the objects will result in errors.
+ """
+ save_buffer = {}
+ restore_buffer = {}
+ split_dependencies = {}
+ for name, dtype in zip(component_names, component_dtypes):
+ split_dependencies[name] = _SplitDependency(
+ save_buffer=save_buffer,
+ restore_buffer=restore_buffer,
+ name=name,
+ dtype=dtype,
+ num_components=len(component_names),
+ fill_save_buffer_fn=fill_save_buffer_fn,
+ consume_restore_buffer_fn=consume_restore_buffer_fn)
+ return split_dependencies
from tensorflow.python.keras._impl.keras.engine import sequential
from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
return ret
+def _split_variable_closure(variable):
+ def _fill_save_buffer_fn(save_buffer):
+ save_buffer["first_half"] = variable[:2]
+ save_buffer["second_half"] = variable[2:]
+ return _fill_save_buffer_fn
+
+
+def _combine_variable_closure(variable):
+ def _consume_restore_buffer_fn(restore_buffer):
+ return variable.assign(
+ array_ops.concat([restore_buffer["first_half"],
+ restore_buffer["second_half"]],
+ axis=0))
+ return _consume_restore_buffer_fn
+
+
+class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
+
+ def __init__(self):
+ self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
+ split_dependencies = checkpointable_utils.split_dependency(
+ component_names=("first_half", "second_half"),
+ component_dtypes=(self.combined.dtype,) * 2,
+ fill_save_buffer_fn=_split_variable_closure(
+ self.combined),
+ consume_restore_buffer_fn=_combine_variable_closure(
+ self.combined))
+ for name, dep in split_dependencies.items():
+ self._track_checkpointable(dep, name=name)
+
+
+class HasRegularDeps(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+ self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class OnlyOneDep(checkpointable.Checkpointable):
+
+ def __init__(self):
+ self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class SplitTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testSaveRestoreSplitDep(self):
+ save_checkpoint = checkpointable_utils.Checkpoint(
+ dep=SaveTensorSlicesAsDeps())
+ self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_checkpoint.save(checkpoint_prefix)
+
+ regular_deps = HasRegularDeps()
+ regular_restore_checkpoint = checkpointable_utils.Checkpoint(
+ dep=regular_deps)
+ regular_restore_checkpoint.restore(
+ save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half))
+ self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
+
+ one_dep = OnlyOneDep()
+ one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep)
+ status = one_dep_restore_checkpoint.restore(save_path)
+ with self.assertRaises(AssertionError):
+ # Missing the second dependency.
+ status.assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
+
+ restore_checkpoint = checkpointable_utils.Checkpoint()
+ status = restore_checkpoint.restore(save_path)
+ restore_checkpoint.dep = SaveTensorSlicesAsDeps()
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual(
+ [1., 2., 3., 4.],
+ self.evaluate(restore_checkpoint.dep.combined))
+
+
class InterfaceTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)