Checkpointable: Utility to split a dependency for saving/loading
authorAllen Lavoie <allenl@google.com>
Tue, 3 Apr 2018 22:28:27 +0000 (15:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 22:31:14 +0000 (15:31 -0700)
Useful when a single op produces Tensors which should each be saved under
different objects (or when Tensors saved with many different objects need to be
restored together as inputs to a single op).

I plan to use this for cuDNN RNN saving, which currently relies heavily on name
matching to allow cuDNN cells to be swapped out for CPU-compatible RNN
cells. This change has just the utility and some tests for it.

PiperOrigin-RevId: 191509664

tensorflow/contrib/eager/python/checkpointable_utils.py
tensorflow/contrib/eager/python/checkpointable_utils_test.py

index 91a7ade..34cb8d0 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import abc
 import collections
+import functools
 import weakref
 
 from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2
@@ -867,3 +868,115 @@ class Checkpoint(core_checkpointable.Checkpointable):
     # initialization when executing eagerly.
     self._maybe_create_save_counter()
     return status
+
+
+class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
+  """Wraps save and restore callbacks as a `SaveableObject`."""
+
+  def __init__(self, name, dtype, save_callback, restore_callback):
+    self._restore_callback = restore_callback
+    spec = saver_lib.BaseSaverBuilder.SaveSpec(
+        tensor=save_callback,
+        slice_spec="",
+        name=name,
+        dtype=dtype)
+    super(_CallbackSaveable, self).__init__(
+        save_callback, [spec], name)
+
+  def restore(self, restored_tensors, restored_shapes):
+    """Restore the same value into both variables."""
+    tensor, = restored_tensors
+    return self._restore_callback(tensor)
+
+
+class _SplitDependency(core_checkpointable.CheckpointableBase):
+  """Looks like a regular variable while synchronizing save/restores."""
+
+  def __init__(self, save_buffer, restore_buffer, name, dtype, num_components,
+               fill_save_buffer_fn, consume_restore_buffer_fn):
+    self._save_buffer = save_buffer
+    self._restore_buffer = restore_buffer
+    self._name = name
+    self._dtype = dtype
+    self._num_components = num_components
+    self._fill_save_buffer_fn = fill_save_buffer_fn
+    self._consume_restore_buffer_fn = consume_restore_buffer_fn
+
+  def _save(self):
+    """Pull from the shared buffer, populating it if necessary."""
+    if self._name not in self._save_buffer:
+      if self._save_buffer:
+        raise AssertionError(
+            ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+             "be saved together.") % (self._name, self))
+      self._fill_save_buffer_fn(self._save_buffer)
+    return self._save_buffer.pop(self._name)
+
+  def _restore(self, tensor):
+    """Push into the shared buffer, flushing it if necessary."""
+    if self._name in self._restore_buffer:
+      raise AssertionError(
+          ("Split dependency %s (%s) unsynchronized. Split dependencies must "
+           "be restored together.") % (self._name, self))
+    self._restore_buffer[self._name] = tensor
+    if len(self._restore_buffer) == self._num_components:
+      op = self._consume_restore_buffer_fn(self._restore_buffer)
+      self._restore_buffer.clear()
+      return op
+    else:
+      return control_flow_ops.no_op()
+
+  def _gather_saveables_for_checkpoint(self):
+    """Looks to Checkpointable like a regular variable."""
+    return {
+        core_checkpointable.VARIABLE_VALUE_KEY:
+        functools.partial(_CallbackSaveable,
+                          dtype=self._dtype,
+                          save_callback=self._save,
+                          restore_callback=self._restore)
+    }
+
+
+def split_dependency(component_names, component_dtypes,
+                     fill_save_buffer_fn, consume_restore_buffer_fn):
+  """Creates multiple dependencies with a synchronized save/restore.
+
+  Useful when a single op produces `Tensor`s which should each be saved under
+  different objects, or when `Tensor`s saved with many different objects need to
+  be restored together as inputs to a single op (i.e. an object which uses a
+  single fused op may be swapped out for a subgraph of objects, and these two
+  programs are checkpoint compatible).
+
+  Args:
+    component_names: A sequence of names for the split
+      dependencies. `fill_save_buffer_fn` must add these keys to the dictionary
+      it is passed, and `consume_restore_buffer_fn` will receive a dictionary
+      with these keys.
+    component_dtypes: Data types for the `Tensor`s being saved and restored, a
+      sequence corresponding to `component_names`.
+    fill_save_buffer_fn: A function which takes an empty dictionary as an
+      argument and adds `Tensor`s with `component_names` as keys. These
+      `Tensor`s will be saved as if they were individual variables.
+    consume_restore_buffer_fn: A function which takes a dictionary with
+      `component_names` as keys mapping to restored individual `Tensor`s and
+      returns a restore op (or if executing eagerly, runs the restoration and
+      may return `None`).
+
+  Returns:
+    A dictionary mapping from names to Checkpointable objects. If one is
+    reachable from an object as a dependency, the others should be too; adding
+    dependencies on some but not all of the objects will result in errors.
+  """
+  save_buffer = {}
+  restore_buffer = {}
+  split_dependencies = {}
+  for name, dtype in zip(component_names, component_dtypes):
+    split_dependencies[name] = _SplitDependency(
+        save_buffer=save_buffer,
+        restore_buffer=restore_buffer,
+        name=name,
+        dtype=dtype,
+        num_components=len(component_names),
+        fill_save_buffer_fn=fill_save_buffer_fn,
+        consume_restore_buffer_fn=consume_restore_buffer_fn)
+  return split_dependencies
index 5e1b647..891c093 100644 (file)
@@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.keras._impl.keras.engine import sequential
 from tensorflow.python.keras._impl.keras.engine import training
 from tensorflow.python.layers import core
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -69,6 +70,87 @@ class MyModel(training.Model):
     return ret
 
 
+def _split_variable_closure(variable):
+  def _fill_save_buffer_fn(save_buffer):
+    save_buffer["first_half"] = variable[:2]
+    save_buffer["second_half"] = variable[2:]
+  return _fill_save_buffer_fn
+
+
+def _combine_variable_closure(variable):
+  def _consume_restore_buffer_fn(restore_buffer):
+    return variable.assign(
+        array_ops.concat([restore_buffer["first_half"],
+                          restore_buffer["second_half"]],
+                         axis=0))
+  return _consume_restore_buffer_fn
+
+
+class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
+
+  def __init__(self):
+    self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
+    split_dependencies = checkpointable_utils.split_dependency(
+        component_names=("first_half", "second_half"),
+        component_dtypes=(self.combined.dtype,) * 2,
+        fill_save_buffer_fn=_split_variable_closure(
+            self.combined),
+        consume_restore_buffer_fn=_combine_variable_closure(
+            self.combined))
+    for name, dep in split_dependencies.items():
+      self._track_checkpointable(dep, name=name)
+
+
+class HasRegularDeps(checkpointable.Checkpointable):
+
+  def __init__(self):
+    self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+    self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class OnlyOneDep(checkpointable.Checkpointable):
+
+  def __init__(self):
+    self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
+
+
+class SplitTests(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+  def testSaveRestoreSplitDep(self):
+    save_checkpoint = checkpointable_utils.Checkpoint(
+        dep=SaveTensorSlicesAsDeps())
+    self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    save_path = save_checkpoint.save(checkpoint_prefix)
+
+    regular_deps = HasRegularDeps()
+    regular_restore_checkpoint = checkpointable_utils.Checkpoint(
+        dep=regular_deps)
+    regular_restore_checkpoint.restore(
+        save_path).assert_consumed().run_restore_ops()
+    self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half))
+    self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
+
+    one_dep = OnlyOneDep()
+    one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep)
+    status = one_dep_restore_checkpoint.restore(save_path)
+    with self.assertRaises(AssertionError):
+      # Missing the second dependency.
+      status.assert_consumed()
+    status.run_restore_ops()
+    self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
+
+    restore_checkpoint = checkpointable_utils.Checkpoint()
+    status = restore_checkpoint.restore(save_path)
+    restore_checkpoint.dep = SaveTensorSlicesAsDeps()
+    status.assert_consumed().run_restore_ops()
+    self.assertAllEqual(
+        [1., 2., 3., 4.],
+        self.evaluate(restore_checkpoint.dep.combined))
+
+
 class InterfaceTests(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)