Checkpointable: Add a way for subclasses to manage deferred restorations of condition...
authorAllen Lavoie <allenl@google.com>
Fri, 23 Mar 2018 03:02:45 +0000 (20:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Mar 2018 03:05:18 +0000 (20:05 -0700)
PiperOrigin-RevId: 190166571

tensorflow/contrib/eager/python/checkpointable_utils_test.py
tensorflow/python/training/checkpointable.py

index a4b215b..a8c47d7 100644 (file)
@@ -378,7 +378,11 @@ class CheckpointingTests(test.TestCase):
     if not context.executing_eagerly():
       return  # Restore-on-create is only supported when executing eagerly
     on_create_model = MyModel()
-    on_create_optimizer = adam.AdamOptimizer(0.001)
+    on_create_optimizer = adam.AdamOptimizer(
+        0.001,
+        # Preserve beta1_power and beta2_power when appying gradients so we can
+        # test that they've been restored correctly.
+        beta1=1.0, beta2=1.0)
     on_create_root = checkpointable_utils.Checkpoint(
         optimizer=on_create_optimizer, model=on_create_model)
     # Deferred restoration
@@ -395,8 +399,8 @@ class CheckpointingTests(test.TestCase):
     self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
     self.assertAllEqual(optimizer_variables[2:],
                         self.evaluate(on_create_optimizer.variables()))
-    on_create_optimizer._create_slots(
-        [resource_variable_ops.ResourceVariable([1.])])
+    dummy_var = resource_variable_ops.ResourceVariable([1.])
+    on_create_optimizer.minimize(loss=dummy_var.read_value)
     status.assert_consumed()
     beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
     self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
index 96e3c48..d0650eb 100644 (file)
@@ -321,8 +321,10 @@ class CheckpointableBase(object):
     # Maps names -> Checkpointable objects
     self._unconditional_dependency_names = {}
     # Restorations for other Checkpointable objects on which this object may
-    # eventually depend.
-    self._deferred_dependencies = {}  # local name -> _CheckpointPosition list
+    # eventually depend. Maps local name -> _CheckpointPosition list. Optimizers
+    # tack on conditional dependencies, and so need separate management of
+    # deferred dependencies too.
+    self._unconditional_deferred_dependencies = {}
     # The UID of the highest assignment to this object. Used to ensure that the
     # last requested assignment determines the final value of an object.
     if hasattr(self, "_update_uid"):
@@ -344,6 +346,21 @@ class CheckpointableBase(object):
     """
     return self._unconditional_checkpoint_dependencies
 
+  @property
+  def _deferred_dependencies(self):
+    """A dictionary with deferred dependencies.
+
+    Stores restorations for other Checkpointable objects on which this object
+    may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
+    conditional dependencies based the current graph, and so need separate
+    management of deferred dependencies too).
+
+    Returns:
+      A dictionary mapping from local name to a list of _CheckpointPosition
+      objects.
+    """
+    return self._unconditional_deferred_dependencies
+
   def _lookup_dependency(self, name):
     """Look up a dependency by name.