From: Allen Lavoie Date: Fri, 23 Mar 2018 03:02:45 +0000 (-0700) Subject: Checkpointable: Add a way for subclasses to manage deferred restorations of condition... X-Git-Tag: tflite-v0.1.7~115^2~2^2~9 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=31741634e73d7e75a0285f6977e1f61cd38f754a;p=platform%2Fupstream%2Ftensorflow.git Checkpointable: Add a way for subclasses to manage deferred restorations of conditional dependencies. PiperOrigin-RevId: 190166571 --- diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index a4b215b..a8c47d7 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -378,7 +378,11 @@ class CheckpointingTests(test.TestCase): if not context.executing_eagerly(): return # Restore-on-create is only supported when executing eagerly on_create_model = MyModel() - on_create_optimizer = adam.AdamOptimizer(0.001) + on_create_optimizer = adam.AdamOptimizer( + 0.001, + # Preserve beta1_power and beta2_power when appying gradients so we can + # test that they've been restored correctly. + beta1=1.0, beta2=1.0) on_create_root = checkpointable_utils.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration @@ -395,8 +399,8 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) self.assertAllEqual(optimizer_variables[2:], self.evaluate(on_create_optimizer.variables())) - on_create_optimizer._create_slots( - [resource_variable_ops.ResourceVariable([1.])]) + dummy_var = resource_variable_ops.ResourceVariable([1.]) + on_create_optimizer.minimize(loss=dummy_var.read_value) status.assert_consumed() beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index 96e3c48..d0650eb 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -321,8 +321,10 @@ class CheckpointableBase(object): # Maps names -> Checkpointable objects self._unconditional_dependency_names = {} # Restorations for other Checkpointable objects on which this object may - # eventually depend. - self._deferred_dependencies = {} # local name -> _CheckpointPosition list + # eventually depend. Maps local name -> _CheckpointPosition list. Optimizers + # tack on conditional dependencies, and so need separate management of + # deferred dependencies too. + self._unconditional_deferred_dependencies = {} # The UID of the highest assignment to this object. Used to ensure that the # last requested assignment determines the final value of an object. if hasattr(self, "_update_uid"): @@ -344,6 +346,21 @@ class CheckpointableBase(object): """ return self._unconditional_checkpoint_dependencies + @property + def _deferred_dependencies(self): + """A dictionary with deferred dependencies. + + Stores restorations for other Checkpointable objects on which this object + may eventually depend. May be overridden by sub-classes (e.g. Optimizers use + conditional dependencies based the current graph, and so need separate + management of deferred dependencies too). + + Returns: + A dictionary mapping from local name to a list of _CheckpointPosition + objects. + """ + return self._unconditional_deferred_dependencies + def _lookup_dependency(self, name): """Look up a dependency by name.