From: Allen Lavoie Date: Fri, 13 Apr 2018 21:32:45 +0000 (-0700) Subject: Allow tf.train.Saver to load object-based checkpoints (using names) X-Git-Tag: upstream/v1.9.0_rc1~322^2^2~10 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8600d918a63c658b9b79ba96ee821c903ba3ee94;p=platform%2Fupstream%2Ftensorflow.git Allow tf.train.Saver to load object-based checkpoints (using names) This is the second part of the compatibility story. Object-based checkpointing APIs can already read name-based checkpoints, and now the name-based APIs can read object-based checkpoints by looking up the modified keys in the object graph proto. PiperOrigin-RevId: 192824907 --- diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index 9bf48df..0b84737 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -26,6 +26,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.util import nest + +# Key where the object graph proto is saved in a TensorBundle +OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" + + # A key indicating a variable's value in an object's checkpointed Tensors # (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and # the object has no dependencies, then its value may be restored on object diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py index da99d2e..2c4677a 100644 --- a/tensorflow/python/training/checkpointable_utils.py +++ b/tensorflow/python/training/checkpointable_utils.py @@ -54,8 +54,6 @@ _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" # attribute in checkpoint names. Used like: # /<_OBJECT_ATTRIBUTES_NAME>/ _OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" -# Key where the object graph proto is saved in a TensorBundle -_OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" class _CheckpointRestoreCoordinator(object): @@ -680,10 +678,11 @@ class CheckpointableSaver(object): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) feed_additions = None - assert _OBJECT_GRAPH_PROTO_KEY not in named_variables - named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( - tensor=object_graph_tensor, - name=_OBJECT_GRAPH_PROTO_KEY) + assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables + named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = ( + _NoRestoreSaveable( + tensor=object_graph_tensor, + name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) if (self._last_save_object_graph != graph_proto # When executing eagerly, we need to re-create SaveableObjects each time # save() is called so they pick up new Tensors passed to their @@ -786,7 +785,8 @@ class CheckpointableSaver(object): file_prefix_feed_dict = None reader = pywrap_tensorflow.NewCheckpointReader(save_path) try: - object_graph_string = reader.get_tensor(_OBJECT_GRAPH_PROTO_KEY) + object_graph_string = reader.get_tensor( + checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try again with # name-based saving. diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py index ddf9820..29fcdb7 100644 --- a/tensorflow/python/training/checkpointable_utils_test.py +++ b/tensorflow/python/training/checkpointable_utils_test.py @@ -1268,9 +1268,6 @@ class CheckpointCompatibilityTests(test.TestCase): status.initialize_or_restore() self._check_sentinels(root) - # TODO(allenl): Test for the core name-based saver loading object-based - # checkpoints once object-based checkpointing is in core. - def testSaveGraphLoadEager(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index e40b8d2..79d278c 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -22,6 +22,7 @@ from __future__ import print_function import collections import os.path import re +import sys import time import uuid @@ -30,8 +31,10 @@ import six from google.protobuf import text_format +from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -1340,6 +1343,9 @@ class Saver(object): self._check_saver_def() self._write_version = self.saver_def.version self._save_relative_paths = save_relative_paths + # For compatibility with object-based checkpoints, we may build a second + # Saver to read the renamed keys. + self._object_restore_saver = None def build(self): if context.executing_eagerly(): @@ -1795,11 +1801,65 @@ class Saver(object): if save_path is None: raise ValueError("Can't load save_path when it is None.") logging.info("Restoring parameters from %s", save_path) - if context.executing_eagerly(): - self._build_eager(save_path, build_save=False, build_restore=True) - else: - sess.run(self.saver_def.restore_op_name, - {self.saver_def.filename_tensor_name: save_path}) + try: + if context.executing_eagerly(): + self._build_eager(save_path, build_save=False, build_restore=True) + else: + sess.run(self.saver_def.restore_op_name, + {self.saver_def.filename_tensor_name: save_path}) + except errors.NotFoundError: + exception_type, exception_value, exception_traceback = sys.exc_info() + # The checkpoint would not be loaded successfully as is. Try to parse it + # as an object-based checkpoint. + try: + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + object_graph_string = reader.get_tensor( + checkpointable.OBJECT_GRAPH_PROTO_KEY) + except errors.NotFoundError: + # This is not an object-based checkpoint, or the checkpoint doesn't + # exist. Re-raise the original exception. + six.reraise(exception_type, exception_value, exception_traceback) + del exception_traceback # avoid reference cycles + + # This is an object-based checkpoint. We'll print a warning and then do + # the restore. + logging.warning( + # TODO(allenl): Modify instructions for using the object-based saver + # once that's in core. + "Restoring an object-based checkpoint using a name-based saver. This " + "may be somewhat fragile, and will re-build the Saver. Instead, " + "consider loading object-based checkpoints using " + "tf.contrib.eager.Checkpoint().") + self._restore_from_object_based_checkpoint( + sess=sess, save_path=save_path, + object_graph_string=object_graph_string) + + def _restore_from_object_based_checkpoint(self, sess, save_path, + object_graph_string): + """A compatibility mode for reading object-based checkpoints.""" + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + object_graph_proto.ParseFromString(object_graph_string) + names_to_keys = {} + for node in object_graph_proto.nodes: + for attribute in node.attributes: + names_to_keys[attribute.full_name] = attribute.checkpoint_key + saveables = self._builder._ValidateAndSliceInputs(self._var_list) # pylint: disable=protected-access + for saveable in saveables: + for spec in saveable.specs: + if spec.name not in names_to_keys: + raise errors.NotFoundError( + None, None, + message=("Attempting to load an object-based checkpoint using " + "variable names, but could not find %s in the " + "checkpoint.") % spec.name) + spec.name = names_to_keys[spec.name] + if self._object_restore_saver is None: + # Cache the Saver so multiple restore() calls don't pollute the graph when + # graph building. This assumes keys are consistent (i.e. this is the same + # type of object-based checkpoint we saw previously). + self._object_restore_saver = Saver(saveables) + self._object_restore_saver.restore(sess=sess, save_path=save_path) @staticmethod def _add_collection_def(meta_graph_def, key, export_scope=None): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 14dda79..3867c0d 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import contextlib +import functools import math import os import random @@ -50,6 +51,8 @@ from tensorflow.python.framework import graph_io from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras._impl.keras.layers import core from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -69,10 +72,12 @@ from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import adam from tensorflow.python.training import checkpointable +from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module from tensorflow.python.training import saver_test_utils +from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat @@ -2948,6 +2953,29 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase): return self.non_dep_variable.name +class NonLayerCheckpointable(checkpointable.Checkpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = checkpointable_utils.add_variable( + self, name="a_variable", shape=[]) + + +class MyModel(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(MyModel, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + @test_util.with_c_api class CheckpointableCompatibilityTests(test.TestCase): @@ -3011,6 +3039,128 @@ class CheckpointableCompatibilityTests(test.TestCase): saver.restore(sess, save_path) self.assertEqual(1, v.eval_count) + def _initialized_model(self): + input_value = constant_op.constant([[3.]]) + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + train_op = optimizer.minimize( + functools.partial(model, input_value), + global_step=optimizer_step) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + # A regular variable, a slot variable, and a non-slot Optimizer variable + # with known values to check when loading. + self.evaluate(model._named_dense.bias.assign([1.])) + self.evaluate(optimizer.get_slot( + var=model._named_dense.bias, name="m").assign([2.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(3.)) + return root_checkpointable + + def _set_sentinels(self, root_checkpointable): + self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.])) + self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.model._named_dense.bias, name="m") + .assign([102.])) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(103.)) + + def _check_sentinels(self, root_checkpointable): + self.assertAllEqual( + [1.], self.evaluate(root_checkpointable.model._named_dense.bias)) + self.assertAllEqual([2.], self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.model._named_dense.bias, name="m"))) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.assertAllEqual(3., self.evaluate(beta1_power)) + + def testVariableNotFoundErrorRaised(self): + # Restore does some tricky exception handling to figure out if it should + # load an object-based checkpoint. Tests that the exception handling isn't + # too broad. + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + a = resource_variable_ops.ResourceVariable(1., name="a") + b = resource_variable_ops.ResourceVariable(1., name="b") + a_saver = saver_module.Saver([a]) + b_saver = saver_module.Saver([b]) + with self.test_session() as sess: + sess.run(a.initializer) + save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) + with self.assertRaisesRegexp( + errors.NotFoundError, "Key b not found in checkpoint"): + b_saver.restore(sess=sess, save_path=save_path) + + def testCheckpointNotFoundErrorRaised(self): + # Restore does some tricky exception handling to figure out if it should + # load an object-based checkpoint. Tests that the exception handling isn't + # too broad. + a = resource_variable_ops.ResourceVariable(1., name="a") + saver = saver_module.Saver([a]) + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.NotFoundError, + "Failed to find any matching files for path_which_does_not_exist"): + saver.restore(sess=sess, save_path="path_which_does_not_exist") + + def testLoadFromObjectBasedGraph(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + save_graph = ops_lib.Graph() + with save_graph.as_default(), self.test_session(graph=save_graph) as sess: + root = self._initialized_model() + object_saver = checkpointable_utils.CheckpointableSaver(root) + save_path = object_saver.save(file_prefix=checkpoint_prefix) + + # An incompatible object-based checkpoint to check error messages + var = resource_variable_ops.ResourceVariable(1., name="a") + self.evaluate(var.initializer) + second_saver = checkpointable_utils.CheckpointableSaver(var) + second_path = second_saver.save(file_prefix=os.path.join( + checkpoint_directory, "second")) + + restore_graph = ops_lib.Graph() + with restore_graph.as_default(), self.test_session( + graph=restore_graph) as sess: + root = self._initialized_model() + self._set_sentinels(root) + saver = saver_module.Saver() + saver.restore(sess=sess, save_path=save_path) + self._check_sentinels(root) + before_second_restore_ops = restore_graph.get_operations() + # Test that multiple restores do not pollute the graph + saver.restore(sess=sess, save_path=save_path) + self.assertEqual(before_second_restore_ops, + restore_graph.get_operations()) + with self.assertRaisesRegexp(errors.NotFoundError, + "could not find a_variable"): + saver.restore(sess=sess, save_path=second_path) + + def testLoadFromObjectBasedEager(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + save_graph = ops_lib.Graph() + with save_graph.as_default(), self.test_session(graph=save_graph): + root = self._initialized_model() + object_saver = checkpointable_utils.CheckpointableSaver(root) + save_path = object_saver.save(file_prefix=checkpoint_prefix) + + with context.eager_mode(): + root = self._initialized_model() + self._set_sentinels(root) + saver = saver_module.Saver( + root.model.variables + root.optimizer.variables()) + saver.restore(sess=None, save_path=save_path) + self._check_sentinels(root) + if __name__ == "__main__": test.main()