Checkpointable: Remove overzealous error checking from tf.make_template
authorAllen Lavoie <allenl@google.com>
Fri, 11 May 2018 22:58:39 +0000 (15:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 23:01:24 +0000 (16:01 -0700)
It was checking that all variables in the Template's scope were dependencies, but Optimizer slot variables are created with the same prefix (and should not be dependencies).

Conversely, eager execution's eager slot variable creation meant that Templates create unnecessary/somewhat harmful dependencies on restored slot variables. Fixes that.

PiperOrigin-RevId: 196321999

tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
tensorflow/contrib/optimizer_v2/optimizer_v2.py
tensorflow/python/ops/template.py
tensorflow/python/training/checkpointable_utils_test.py
tensorflow/python/training/optimizer.py

index 87b2ecf..b1f2e9d 100644 (file)
@@ -36,8 +36,10 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.keras._impl.keras.engine import training
 from tensorflow.python.keras._impl.keras.layers import core
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.training import checkpointable
 from tensorflow.python.training import checkpointable_utils
@@ -612,6 +614,49 @@ class CheckpointingTests(test.TestCase):
         self.assertAllEqual(3., self.evaluate(beta1_power))
 
 
+class TemplateTests(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_checkpointable_save_restore(self):
+
+    def _templated():
+      v = variable_scope.get_variable(
+          "v", shape=[1], initializer=init_ops.zeros_initializer(),
+          use_resource=True)
+      v2 = variable_scope.get_variable(
+          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+          use_resource=True)
+      return v, v + 1., v2
+
+    save_template = template.make_template("s1", _templated)
+    v1_save, _, v2_save = save_template()
+    optimizer = adam.AdamOptimizer(0.0)
+    save_root = checkpointable_utils.Checkpoint(
+        my_template=save_template, optimizer=optimizer)
+    optimizer.minimize(v1_save.read_value)
+    self.evaluate([v.initializer for v in optimizer.variables()])
+    self.evaluate(v1_save.assign([12.]))
+    self.evaluate(v2_save.assign([14.]))
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    save_path = save_root.save(checkpoint_prefix)
+
+    load_template = template.make_template("s2", _templated)
+    load_optimizer = adam.AdamOptimizer(0.0)
+    load_root = checkpointable_utils.Checkpoint(
+        my_template=load_template, optimizer=load_optimizer)
+    status = load_root.restore(save_path)
+    var, var_plus_one, var2 = load_template()
+    load_optimizer.minimize(var.read_value)
+    self.assertEqual(2, len(load_template._checkpoint_dependencies))
+    self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+    self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+    status.assert_consumed().run_restore_ops()
+    self.assertAllEqual([12.], self.evaluate(var))
+    self.assertAllEqual([13.], self.evaluate(var_plus_one))
+    self.assertAllEqual([14.], self.evaluate(var2))
+
+
 class CheckpointCompatibilityTests(test.TestCase):
 
   def _initialized_model(self):
index 46bfbb7..694a3ce 100644 (file)
@@ -360,7 +360,16 @@ class _OptimizerV2State(object):
     """
     slot_variable = self.get_slot(var=variable, name=slot_name)
     if (slot_variable is None and context.executing_eagerly() and
-        slot_variable_position.is_simple_variable()):
+        slot_variable_position.is_simple_variable()
+        # Defer slot variable creation if there is an active variable creator
+        # scope. Generally we'd like to eagerly create/restore slot variables
+        # when possible, but this may mean that scopes intended to catch
+        # `variable` also catch its eagerly created slot variable
+        # unintentionally (specifically make_template would add a dependency on
+        # a slot variable if not for this case). Deferring is mostly harmless
+        # (aside from double initialization), and makes variable creator scopes
+        # behave the same way they do when graph building.
+        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
       initializer = checkpointable.CheckpointInitialValue(
           checkpoint_position=slot_variable_position)
       slot_variable = self.create_slot(
index 9b6b8c5..b46c46d 100644 (file)
@@ -295,42 +295,6 @@ class Template(checkpointable.CheckpointableBase):
     # which is not the same as whether the scope has been created.
     self._variables_created = False
 
-  @property
-  def _checkpoint_dependencies(self):
-    """Sanity checking for object-based saving.
-
-    Does not override Checkpointable dependency tracking, but checks that
-    variables accessible through Checkpointable dependencies on other `Template`
-    objects include all of the variable_scope-filtered `Template.variables`.
-
-    Returns:
-      A list of checkpointable.CheckpointableReference objects.
-    Raises:
-      ValueError: If this object is not compatible with object-based saving.
-    """
-    dependencies = super(Template, self)._checkpoint_dependencies
-    dependency_variables = []
-    for _, dependency in dependencies:
-      if isinstance(dependency, Template):
-        dependency_variables.extend(dependency.variables)
-      else:
-        dependency_variables.append(dependency)
-    dependency_variables = set(dependency_variables)
-    not_included_variables = []
-    for expected_variable in sorted(self.variables, key=lambda v: v.name):
-      if expected_variable not in dependency_variables:
-        not_included_variables.append(expected_variable)
-    if not_included_variables:
-      # Trying to save a Template which improperly tracks its variables.
-      raise ValueError(
-          ("The Template '%s' references variables which are not included via "
-           "object-based dependency tracking. Most likely a custom "
-           "getter/creator was registered which does not call Template's "
-           "custom variable creator (which is responsible for tracking "
-           "dependencies).\n\nExpected these variables to be dependencies: %s")
-          % (self, not_included_variables))
-    return dependencies
-
   def _checkpointable_custom_creator(self, next_creator, name, initial_value,
                                      checkpointable_parent=None, **kwargs):
     """A variable creation hook which adds Checkpointable dependencies.
index 84cacb6..d94cdcf 100644 (file)
@@ -1250,14 +1250,20 @@ class TemplateTests(test.TestCase):
 
     def _templated():
       v = variable_scope.get_variable(
-          "v", shape=[1], initializer=init_ops.zeros_initializer())
+          "v", shape=[1], initializer=init_ops.zeros_initializer(),
+          use_resource=True)
       v2 = variable_scope.get_variable(
-          "v2", shape=[1], initializer=init_ops.zeros_initializer())
+          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+          use_resource=True)
       return v, v + 1., v2
 
     save_template = template.make_template("s1", _templated)
-    save_root = checkpointable_utils.Checkpoint(my_template=save_template)
     v1_save, _, v2_save = save_template()
+    optimizer = adam.AdamOptimizer(0.0)
+    save_root = checkpointable_utils.Checkpoint(
+        my_template=save_template, optimizer=optimizer)
+    optimizer.minimize(v1_save.read_value)
+    self.evaluate([v.initializer for v in optimizer.variables()])
     self.evaluate(v1_save.assign([12.]))
     self.evaluate(v2_save.assign([14.]))
     checkpoint_directory = self.get_temp_dir()
@@ -1265,9 +1271,12 @@ class TemplateTests(test.TestCase):
     save_path = save_root.save(checkpoint_prefix)
 
     load_template = template.make_template("s2", _templated)
-    load_root = checkpointable_utils.Checkpoint(my_template=load_template)
+    load_optimizer = adam.AdamOptimizer(0.0)
+    load_root = checkpointable_utils.Checkpoint(
+        my_template=load_template, optimizer=load_optimizer)
     status = load_root.restore(save_path)
     var, var_plus_one, var2 = load_template()
+    load_optimizer.minimize(var.read_value)
     self.assertEqual(2, len(load_template._checkpoint_dependencies))
     self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
     self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
index 66914ba..a676ef9 100644 (file)
@@ -1175,7 +1175,16 @@ class Optimizer(
     variable_key = _var_key(variable)
     slot_variable = named_slots.get(variable_key, None)
     if (slot_variable is None and context.executing_eagerly() and
-        slot_variable_position.is_simple_variable()):
+        slot_variable_position.is_simple_variable()
+        # Defer slot variable creation if there is an active variable creator
+        # scope. Generally we'd like to eagerly create/restore slot variables
+        # when possible, but this may mean that scopes intended to catch
+        # `variable` also catch its eagerly created slot variable
+        # unintentionally (specifically make_template would add a dependency on
+        # a slot variable if not for this case). Deferring is mostly harmless
+        # (aside from double initialization), and makes variable creator scopes
+        # behave the same way they do when graph building.
+        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
       initializer = checkpointable.CheckpointInitialValue(
           checkpoint_position=slot_variable_position)
       slot_variable = self._get_or_make_slot(