from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.keras._impl.keras.layers import core
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable
from tensorflow.python.training import checkpointable_utils
self.assertAllEqual(3., self.evaluate(beta1_power))
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.AdamOptimizer(0.0)
+ save_root = checkpointable_utils.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.AdamOptimizer(0.0)
+ load_root = checkpointable_utils.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
class CheckpointCompatibilityTests(test.TestCase):
def _initialized_model(self):
"""
slot_variable = self.get_slot(var=variable, name=slot_name)
if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()):
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self.create_slot(
# which is not the same as whether the scope has been created.
self._variables_created = False
- @property
- def _checkpoint_dependencies(self):
- """Sanity checking for object-based saving.
-
- Does not override Checkpointable dependency tracking, but checks that
- variables accessible through Checkpointable dependencies on other `Template`
- objects include all of the variable_scope-filtered `Template.variables`.
-
- Returns:
- A list of checkpointable.CheckpointableReference objects.
- Raises:
- ValueError: If this object is not compatible with object-based saving.
- """
- dependencies = super(Template, self)._checkpoint_dependencies
- dependency_variables = []
- for _, dependency in dependencies:
- if isinstance(dependency, Template):
- dependency_variables.extend(dependency.variables)
- else:
- dependency_variables.append(dependency)
- dependency_variables = set(dependency_variables)
- not_included_variables = []
- for expected_variable in sorted(self.variables, key=lambda v: v.name):
- if expected_variable not in dependency_variables:
- not_included_variables.append(expected_variable)
- if not_included_variables:
- # Trying to save a Template which improperly tracks its variables.
- raise ValueError(
- ("The Template '%s' references variables which are not included via "
- "object-based dependency tracking. Most likely a custom "
- "getter/creator was registered which does not call Template's "
- "custom variable creator (which is responsible for tracking "
- "dependencies).\n\nExpected these variables to be dependencies: %s")
- % (self, not_included_variables))
- return dependencies
-
def _checkpointable_custom_creator(self, next_creator, name, initial_value,
checkpointable_parent=None, **kwargs):
"""A variable creation hook which adds Checkpointable dependencies.
def _templated():
v = variable_scope.get_variable(
- "v", shape=[1], initializer=init_ops.zeros_initializer())
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
v2 = variable_scope.get_variable(
- "v2", shape=[1], initializer=init_ops.zeros_initializer())
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
return v, v + 1., v2
save_template = template.make_template("s1", _templated)
- save_root = checkpointable_utils.Checkpoint(my_template=save_template)
v1_save, _, v2_save = save_template()
+ optimizer = adam.AdamOptimizer(0.0)
+ save_root = checkpointable_utils.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
checkpoint_directory = self.get_temp_dir()
save_path = save_root.save(checkpoint_prefix)
load_template = template.make_template("s2", _templated)
- load_root = checkpointable_utils.Checkpoint(my_template=load_template)
+ load_optimizer = adam.AdamOptimizer(0.0)
+ load_root = checkpointable_utils.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
self.assertEqual(2, len(load_template._checkpoint_dependencies))
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
variable_key = _var_key(variable)
slot_variable = named_slots.get(variable_key, None)
if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()):
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self._get_or_make_slot(