if not context.executing_eagerly():
return # Restore-on-create is only supported when executing eagerly
on_create_model = MyModel()
- on_create_optimizer = adam.AdamOptimizer(0.001)
+ on_create_optimizer = adam.AdamOptimizer(
+ 0.001,
+ # Preserve beta1_power and beta2_power when appying gradients so we can
+ # test that they've been restored correctly.
+ beta1=1.0, beta2=1.0)
on_create_root = checkpointable_utils.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
self.assertAllEqual(optimizer_variables[2:],
self.evaluate(on_create_optimizer.variables()))
- on_create_optimizer._create_slots(
- [resource_variable_ops.ResourceVariable([1.])])
+ dummy_var = resource_variable_ops.ResourceVariable([1.])
+ on_create_optimizer.minimize(loss=dummy_var.read_value)
status.assert_consumed()
beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
# Maps names -> Checkpointable objects
self._unconditional_dependency_names = {}
# Restorations for other Checkpointable objects on which this object may
- # eventually depend.
- self._deferred_dependencies = {} # local name -> _CheckpointPosition list
+ # eventually depend. Maps local name -> _CheckpointPosition list. Optimizers
+ # tack on conditional dependencies, and so need separate management of
+ # deferred dependencies too.
+ self._unconditional_deferred_dependencies = {}
# The UID of the highest assignment to this object. Used to ensure that the
# last requested assignment determines the final value of an object.
if hasattr(self, "_update_uid"):
"""
return self._unconditional_checkpoint_dependencies
+ @property
+ def _deferred_dependencies(self):
+ """A dictionary with deferred dependencies.
+
+ Stores restorations for other Checkpointable objects on which this object
+ may eventually depend. May be overridden by sub-classes (e.g. Optimizers use
+ conditional dependencies based the current graph, and so need separate
+ management of deferred dependencies too).
+
+ Returns:
+ A dictionary mapping from local name to a list of _CheckpointPosition
+ objects.
+ """
+ return self._unconditional_deferred_dependencies
+
def _lookup_dependency(self, name):
"""Look up a dependency by name.