From ef58a46b730155717f1b03abb20767c1924ad05e Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 9 May 2018 15:56:43 -0700 Subject: [PATCH] Support saving Python state with object-based checkpoints Allows SaveableObjects to specify feed dict addition callbacks for object-based saving. For now just saves get_config() with Layers. Doesn't do any loading, and there isn't quite enough information to reconstruct a Model yet (needs topology). My plan is to get Models to the point where they can be reconstructed from object-based checkpoints (probably one more change), add in SavedModel export (assuming no dynamic control flow for now), then add this "SavedModel+Python" format to Model.save / load_model. PiperOrigin-RevId: 196043183 --- .../optimizer_v2/checkpointable_utils_test.py | 43 +++---- tensorflow/python/BUILD | 15 +++ .../python/keras/_impl/keras/engine/saving.py | 39 +----- tensorflow/python/training/checkpointable.py | 57 ++++++++- tensorflow/python/training/checkpointable_utils.py | 135 ++++++++++++++++----- .../python/training/checkpointable_utils_test.py | 103 ++++++++++++---- tensorflow/python/training/saver.py | 132 +++++++++++--------- tensorflow/python/util/serialization.py | 64 ++++++++++ tensorflow/python/util/serialization_test.py | 76 ++++++++++++ 9 files changed, 493 insertions(+), 171 deletions(-) create mode 100644 tensorflow/python/util/serialization.py create mode 100644 tensorflow/python/util/serialization_test.py diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 9e2858d..87b2ecf 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -31,7 +31,6 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras._impl.keras.engine import training @@ -139,8 +138,9 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers( root_checkpointable)) self.evaluate(train_op) - named_variables, serialized_graph = ( - checkpointable_utils._serialize_object_graph(root_checkpointable)) + named_variables, serialized_graph, _ = ( + checkpointable_utils._serialize_object_graph( + root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -163,24 +163,29 @@ class CheckpointingTests(test.TestCase): suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names] + # The Dense layers also save get_config() JSON + expected_checkpoint_names.extend( + ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON", + "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"]) + named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual( - "global_step:0", - named_variables["optimizer_step" + suffix].name) + "global_step", + named_variables["optimizer_step" + suffix].full_name) self.assertEqual( - "my_model/dense_1/kernel:0", - named_variables["model/_second/kernel" + suffix].name) + "my_model/dense_1/kernel", + named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( - "my_model/dense/kernel:0", - named_variables["model/_named_dense/kernel" + suffix].name) + "my_model/dense/kernel", + named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( - "beta1_power:0", - named_variables["optimizer/beta1_power" + suffix].name) + "beta1_power", + named_variables["optimizer/beta1_power" + suffix].full_name) self.assertEqual( - "beta2_power:0", - named_variables["optimizer/beta2_power" + suffix].name) + "beta2_power", + named_variables["optimizer/beta2_power" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) @@ -205,7 +210,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual( "my_model/dense/kernel/Adam:0", optimizer.get_slot( - var=named_variables["model/_named_dense/kernel" + suffix], + var=model._named_dense.kernel, name="m").name) self.assertEqual( "model/_named_dense/kernel" + suffix, @@ -417,16 +422,6 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: enable=cell-var-from-loop - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - checkpointable_utils.add_variable( - root, name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - def testAnonymousVarsInInit(self): class Model(training.Model): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f7cbaec..8b904a1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3036,9 +3036,12 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":constant_op", + ":control_flow_ops", ":dtypes", ":io_ops_gen", ":ops", + ":saveable_object", ":util", "//tensorflow/python/eager:context", ], @@ -3224,6 +3227,18 @@ py_test( ) py_test( + name = "util_serialization_test", + size = "small", + srcs = ["util/serialization_test.py"], + main = "util/serialization_test.py", + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + ], +) + +py_test( name = "future_api_test", size = "small", srcs = ["util/future_api_test.py"], diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py index a0b709a..ee6e320 100644 --- a/tensorflow/python/keras/_impl/keras/engine/saving.py +++ b/tensorflow/python/keras/_impl/keras/engine/saving.py @@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras import optimizers from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import serialization from tensorflow.python.util.tf_export import tf_export # pylint: disable=g-import-not-at-top @@ -74,40 +75,6 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): if h5py is None: raise ImportError('`save_model` requires h5py.') - def get_json_type(obj): - """Serializes any object to a JSON-serializable structure. - - Arguments: - obj: the object to serialize - - Returns: - JSON-serializable structure representing `obj`. - - Raises: - TypeError: if `obj` cannot be serialized. - """ - # if obj is a serializable Keras class instance - # e.g. optimizer, layer - if hasattr(obj, 'get_config'): - return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} - - # if obj is any numpy type - if type(obj).__module__ == np.__name__: - if isinstance(obj, np.ndarray): - return {'type': type(obj), 'value': obj.tolist()} - else: - return obj.item() - - # misc functions (e.g. loss function) - if callable(obj): - return obj.__name__ - - # if obj is a python 'type' - if type(obj).__name__ == type.__name__: - return obj.__name__ - - raise TypeError('Not JSON Serializable:', obj) - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top # If file exists and should not be overwritten. @@ -124,7 +91,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): 'class_name': model.__class__.__name__, 'config': model.get_config() }, - default=get_json_type).encode('utf8') + default=serialization.get_json_type).encode('utf8') model_weights_group = f.create_group('model_weights') model_layers = model.layers @@ -154,7 +121,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): 'sample_weight_mode': model.sample_weight_mode, 'loss_weights': model.loss_weights, }, - default=get_json_type).encode('utf8') + default=serialization.get_json_type).encode('utf8') # Save optimizer weights. symbolic_weights = getattr(model.optimizer, 'weights') diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py index d00312a..956dd66 100644 --- a/tensorflow/python/training/checkpointable.py +++ b/tensorflow/python/training/checkpointable.py @@ -18,14 +18,21 @@ from __future__ import division from __future__ import print_function import collections +import functools +import json +import weakref from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saveable_object from tensorflow.python.util import nest +from tensorflow.python.util import serialization # Key where the object graph proto is saved in a TensorBundle @@ -37,6 +44,7 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" # the object has no dependencies, then its value may be restored on object # creation (avoiding double assignment when executing eagerly). VARIABLE_VALUE_KEY = "VARIABLE_VALUE" +OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON" CheckpointableReference = collections.namedtuple( "CheckpointableReference", @@ -85,6 +93,35 @@ class CheckpointInitialValue(ops.Tensor): return self._checkpoint_position +class PythonStringStateSaveable(saveable_object.SaveableObject): + """Saves Python state in a checkpoint.""" + + def __init__(self, name, state_callback): + """Configure saving. + + Args: + name: The checkpoint key to write to. + state_callback: A function taking no arguments which returns a + string. This function is run every time a checkpoint is written. + """ + if context.executing_eagerly(): + self._save_string = ( + lambda: constant_op.constant(state_callback(), dtype=dtypes.string)) + else: + self._save_string = constant_op.constant("", dtype=dtypes.string) + self.feed_dict_additions = ( + lambda: {self._save_string: state_callback()}) + spec = saveable_object.SaveSpec( + self._save_string, "", name, dtype=dtypes.string) + super(PythonStringStateSaveable, self).__init__( + self._save_string, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + # TODO(allenl): Add a Python hook for state coming out of a checkpoint + # (currently PythonStringStateSaveable is write-only). + return control_flow_ops.no_op() + + class _CheckpointPosition(object): """Indicates a position within a `_Checkpoint`.""" @@ -604,7 +641,6 @@ class CheckpointableBase(object): # restoration on to our dependencies. if checkpoint.restore_uid > self._update_uid: restore_ops = checkpoint_position.restore_ops() - # TODO(allenl): Get a list of feeds for saving Python state self._update_uid = checkpoint.restore_uid else: restore_ops = () @@ -656,7 +692,24 @@ class CheckpointableBase(object): lambda name="global_name_for_this_object": SaveableObject(name=name, ...)} """ - return {} + if not hasattr(self, "get_config"): + return {} + try: + self.get_config() + except NotImplementedError: + return {} + weak_self = weakref.ref(self) + def _state_callback(): + dereferenced_self = weak_self() + if dereferenced_self: + return json.dumps(self, + default=serialization.get_json_type, + sort_keys=True).encode("utf8") + else: + return "" + return {OBJECT_CONFIG_JSON_KEY: functools.partial( + PythonStringStateSaveable, + state_callback=_state_callback)} class NoDependency(object): diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py index f2a2b41..1e69096 100644 --- a/tensorflow/python/training/checkpointable_utils.py +++ b/tensorflow/python/training/checkpointable_utils.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saveable_object from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -303,42 +304,93 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): def _serialize_checkpointables( - checkpointable_objects, node_ids, object_names, slot_variables): + checkpointable_objects, node_ids, object_names, slot_variables, + saveables_cache): """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - named_saveables = {} - + named_saveables = [] + feed_additions = {} for checkpoint_id, checkpointable in enumerate(checkpointable_objects): assert node_ids[checkpointable] == checkpoint_id object_proto = object_graph_proto.nodes.add() object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) object_name = object_names[checkpointable] + if saveables_cache is not None: + cached_attributes = saveables_cache.setdefault(checkpointable, {}) + else: + cached_attributes = None for name, saveable_factory in ( checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access attribute = object_proto.attributes.add() attribute.name = name attribute.checkpoint_key = "%s/%s/%s" % ( object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) - if callable(saveable_factory): - saveable = saveable_factory(name=attribute.checkpoint_key) + if cached_attributes is None: + saveables = None else: - saveable = saveable_factory - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [saveable], convert_variable_to_tensor=False) - attribute.full_name, = saver_dict.keys() - named_saveables[attribute.checkpoint_key] = saveable + saveables = cached_attributes.get(name, None) + if saveables is not None: + for saveable in saveables: + if attribute.checkpoint_key not in saveable.name: + # The checkpoint key for this SaveableObject is different. We need + # to re-create it. + saveables = None + del cached_attributes[name] + break + if saveables is None: + if callable(saveable_factory): + maybe_saveable = saveable_factory(name=attribute.checkpoint_key) + else: + maybe_saveable = saveable_factory + if isinstance(maybe_saveable, saveable_object.SaveableObject): + saveables = (maybe_saveable,) + else: + # Figure out the name-based Saver's name for this variable. If it's + # already a SaveableObject we'd just get the checkpoint key back, so + # we leave full_name blank. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [maybe_saveable], convert_variable_to_tensor=False) + full_name, = saver_dict.keys() + saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp( + op=maybe_saveable, name=attribute.checkpoint_key)) + for saveable in saveables: + saveable.full_name = full_name + for saveable in saveables: + if attribute.checkpoint_key not in saveable.name: + raise AssertionError( + ("The object %s produced a SaveableObject with name '%s' for " + "attribute '%s'. Expected a name containing '%s'.") + % (checkpointable, name, saveable.name, + attribute.checkpoint_key)) + if cached_attributes is not None: + cached_attributes[name] = saveables + + for saveable in saveables: + if hasattr(saveable, "full_name"): + attribute.full_name = saveable.full_name + saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None) + if saveable_feed_dict_fn is not None: + saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable + for new_feed_key in saveable_feed_dict.keys(): + if new_feed_key in feed_additions: + raise AssertionError( + ("The object %s tried to feed a value for the Tensor %s " + "when saving, but another object is already feeding a " + "value.") + % (checkpointable, new_feed_key)) + feed_additions.update(saveable_feed_dict) + named_saveables.extend(saveables) for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access child_proto = object_proto.children.add() child_proto.node_id = node_ids[child.ref] child_proto.local_name = child.name - return named_saveables, object_graph_proto + return named_saveables, object_graph_proto, feed_additions -def _serialize_object_graph(root_checkpointable): +def _serialize_object_graph(root_checkpointable, saveables_cache): """Determine checkpoint keys for variables and build a serialized graph. Non-slot variables are keyed based on a shortest path from the root saveable @@ -351,12 +403,17 @@ def _serialize_object_graph(root_checkpointable): Args: root_checkpointable: A `Checkpointable` object whose variables (including the variables of dependencies, recursively) should be saved. + saveables_cache: A dictionary mapping `Checkpointable` objects -> attribute + names -> SaveableObjects, used to avoid re-creating SaveableObjects when + graph building. Returns: - A tuple of (named_variables, object_graph_proto): + A tuple of (named_variables, object_graph_proto, feed_additions): named_variables: A dictionary mapping names to variable objects. object_graph_proto: A CheckpointableObjectGraph protocol buffer containing the serialized object graph and variable references. + feed_additions: A dictionary mapping from Tensors to values which should + be fed when saving. Raises: ValueError: If there are invalid characters in an optimizer's slot names. @@ -376,7 +433,8 @@ def _serialize_object_graph(root_checkpointable): checkpointable_objects=checkpointable_objects, node_ids=node_ids, object_names=object_names, - slot_variables=slot_variables) + slot_variables=slot_variables, + saveables_cache=saveables_cache) def list_objects(root_checkpointable): @@ -728,6 +786,14 @@ class CheckpointableSaver(object): self._last_restore_object_graph = None self._last_restore_checkpoint = None + if context.executing_eagerly(): + # SaveableObjects are always recreated when executing eagerly. + self._saveable_object_cache = None + else: + # Maps Checkpointable objects -> attribute names -> SaveableObjects, to + # avoid re-creating SaveableObjects when graph building. + self._saveable_object_cache = weakref.WeakKeyDictionary() + @property def _root_checkpointable(self): if isinstance(self._root_checkpointable_ref, weakref.ref): @@ -759,8 +825,9 @@ class CheckpointableSaver(object): Returns: The full path to the checkpoint. """ - named_variables, graph_proto = _serialize_object_graph( - self._root_checkpointable) + named_variables, graph_proto, feed_additions = _serialize_object_graph( + self._root_checkpointable, + saveables_cache=self._saveable_object_cache) if not context.executing_eagerly(): if session is None: session = ops.get_default_session() @@ -769,15 +836,15 @@ class CheckpointableSaver(object): self._object_graph_feed_tensor = constant_op.constant( "", dtype=dtypes.string) object_graph_tensor = self._object_graph_feed_tensor - feed_additions = {object_graph_tensor: graph_proto.SerializeToString()} + feed_additions.update( + {object_graph_tensor: graph_proto.SerializeToString()}) else: session = None with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string) - feed_additions = None assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables - named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = ( + named_variables.append( _NoRestoreSaveable( tensor=object_graph_tensor, name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) @@ -804,13 +871,23 @@ class CheckpointableSaver(object): def _global_variable_names(self): """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s.""" - named_saveables, graph_proto = _serialize_object_graph( - self._root_checkpointable) + named_saveables, graph_proto, _ = _serialize_object_graph( + self._root_checkpointable, + # We destructively modify SaveableObjects, so don't do any caching. + saveables_cache=None) + named_saveables = {v.name: v for v in named_saveables} saver_names = {} for object_proto in graph_proto.nodes: for attribute_proto in object_proto.attributes: - saver_names[attribute_proto.full_name] = named_saveables[ - attribute_proto.checkpoint_key] + if attribute_proto.full_name: + # Ignore attributes, such as Python object JSON, which don't have a + # name-based Saver name. + saveable = named_saveables[attribute_proto.checkpoint_key] + saveable.name = attribute_proto.full_name + for spec in saveable.specs: + spec.name = spec.name.replace(attribute_proto.checkpoint_key, + attribute_proto.full_name) + saver_names[attribute_proto.full_name] = saveable return saver_names def restore(self, save_path): @@ -1037,6 +1114,7 @@ class Checkpoint(checkpointable_lib.Checkpointable): % (v,)) setattr(self, k, v) self._save_counter = None # Created lazily for restore-on-create. + self._save_assign_op = None self._saver = CheckpointableSaver(weakref.ref(self)) def _maybe_create_save_counter(self): @@ -1089,10 +1167,13 @@ class Checkpoint(checkpointable_lib.Checkpointable): # needs to be initialized before assign_add. This is only an issue if # restore() has not been called first. session.run(self.save_counter.initializer) - with ops.colocate_with(self.save_counter): - assign_op = self.save_counter.assign_add(1) + if not in_graph_mode or self._save_assign_op is None: + with ops.colocate_with(self.save_counter): + assign_op = self.save_counter.assign_add(1, read_value=False) + if in_graph_mode: + self._save_assign_op = assign_op if in_graph_mode: - session.run(assign_op) + session.run(self._save_assign_op) return self._saver.save( file_prefix=file_prefix, checkpoint_number=self.save_counter, diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py index 3b8166b..dead8fd 100644 --- a/tensorflow/python/training/checkpointable_utils_test.py +++ b/tensorflow/python/training/checkpointable_utils_test.py @@ -17,10 +17,12 @@ from __future__ import division from __future__ import print_function import functools +import json import os import six +from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session as session_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -120,7 +122,8 @@ class InterfaceTests(test.TestCase): # The .name attribute may be globally influenced, but the checkpoint name # won't be (tested below). self.assertEqual("duplicate_1:0", duplicate.name) - named_variables, _ = checkpointable_utils._serialize_object_graph(obj) + named_variables, _, _ = checkpointable_utils._serialize_object_graph( + obj, saveables_cache=None) expected_checkpoint_names = ( "a_variable/.ATTRIBUTES/VARIABLE_VALUE", "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE", @@ -129,7 +132,7 @@ class InterfaceTests(test.TestCase): "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE", ) six.assertCountEqual( - self, expected_checkpoint_names, named_variables.keys()) + self, expected_checkpoint_names, [v.name for v in named_variables]) def testInitNotCalled(self): @@ -245,8 +248,9 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers( root_checkpointable)) self.evaluate(train_op) - named_variables, serialized_graph = ( - checkpointable_utils._serialize_object_graph(root_checkpointable)) + named_variables, serialized_graph, _ = ( + checkpointable_utils._serialize_object_graph( + root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -269,24 +273,29 @@ class CheckpointingTests(test.TestCase): suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names] + # The Dense layers also save get_config() JSON + expected_checkpoint_names.extend( + ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON", + "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"]) + named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual( - "global_step:0", - named_variables["optimizer_step" + suffix].name) + "global_step", + named_variables["optimizer_step" + suffix].full_name) self.assertEqual( - "my_model/dense_1/kernel:0", - named_variables["model/_second/kernel" + suffix].name) + "my_model/dense_1/kernel", + named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( - "my_model/dense/kernel:0", - named_variables["model/_named_dense/kernel" + suffix].name) + "my_model/dense/kernel", + named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( - "beta1_power:0", - named_variables["optimizer/beta1_power" + suffix].name) + "beta1_power", + named_variables["optimizer/beta1_power" + suffix].full_name) self.assertEqual( - "beta2_power:0", - named_variables["optimizer/beta2_power" + suffix].name) + "beta2_power", + named_variables["optimizer/beta2_power" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) @@ -311,7 +320,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual( "my_model/dense/kernel/Adam:0", optimizer.get_slot( - var=named_variables["model/_named_dense/kernel" + suffix], + var=model._named_dense.kernel, name="m").name) self.assertEqual( "model/_named_dense/kernel" + suffix, @@ -563,11 +572,11 @@ class CheckpointingTests(test.TestCase): root = checkpointable.Checkpointable() checkpointable_utils.add_variable( root, name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): + (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( + root, saveables_cache=None) + with ops.name_scope("root/" + named_variable.name): pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name + return named_variable.name @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableNameEscaping(self): @@ -585,9 +594,9 @@ class CheckpointingTests(test.TestCase): leaf = checkpointable.Checkpointable() root.leaf = leaf checkpointable_utils.add_variable(leaf, name="v", shape=[]) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - variable_name, = named_variables.keys() - self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name) + (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( + root, saveables_cache=None) + self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name) @test_util.run_in_graph_and_eager_modes() def testLocalNameValidation(self): @@ -596,9 +605,10 @@ class CheckpointingTests(test.TestCase): # Dots are escaped, which avoids conflicts with reserved names. root._track_checkpointable(leaf, name=".ATTRIBUTES") checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - name, = named_variables.keys() - self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE") + (named_variable,), _, _ = checkpointable_utils._serialize_object_graph( + root, saveables_cache=None) + self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE", + named_variable.name) def testAnonymousVarsInInit(self): @@ -1395,5 +1405,48 @@ class CheckpointCompatibilityTests(test.TestCase): root.restore(save_path).assert_consumed().run_restore_ops() self._check_sentinels(root) + +class PythonMetadataTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSaveLoad(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dense = core.Dense(1) + checkpoint = checkpointable_utils.Checkpoint(dense=dense) + dense(constant_op.constant([[1.]])) + checkpoint.restore(None).initialize_or_restore() + save_path = checkpoint.save(checkpoint_prefix) + + def _get_dense_node_from_object_graph(object_graph_proto): + root_node = object_graph_proto.nodes[0] + for child in root_node.children: + if child.local_name == "dense": + break + else: + raise AssertionError( + "Expected a 'dense' dependency of root, didn't find one.") + dense_node = object_graph_proto.nodes[child.node_id] # pylint: disable=undefined-loop-variable + self.assertEqual(1, len(dense_node.attributes)) + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + layer_json = reader.get_tensor(dense_node.attributes[0].checkpoint_key) + return json.loads(layer_json.decode("utf-8")) + + layer_data = _get_dense_node_from_object_graph( + checkpointable_utils.object_metadata(save_path)) + self.assertEqual("Dense", layer_data["class_name"]) + self.assertEqual(1, layer_data["config"]["units"]) + + # Check that no new ops are added to the graph the second time we save. + ops.get_default_graph().finalize() + + dense.units = 42 + save_path = checkpoint.save(checkpoint_prefix) + layer_data = _get_dense_node_from_object_graph( + checkpointable_utils.object_metadata(save_path)) + self.assertEqual("Dense", layer_data["class_name"]) + self.assertEqual(42, layer_data["config"]["units"]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 53e821c..98e79a4 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -569,6 +569,76 @@ class BaseSaverBuilder(object): # pylint: enable=protected-access return names_to_saveables + @staticmethod + def SaveableObjectsForOp(op, name): + """Create `SaveableObject`s from an operation. + + Args: + op: A variable, operation, or SaveableObject to coerce into a + SaveableObject. + name: A string name for the SaveableObject. + + Yields: + `SaveableObject`s which together save/restore `op`. + + Raises: + TypeError: If `name` is not a string. + ValueError: For operations with no known conversion to SaveableObject. + """ + if not isinstance(name, six.string_types): + raise TypeError( + "names_to_saveables must be a dict mapping string names to " + "checkpointable operations. Name is not a string: %s" % name) + if isinstance(op, BaseSaverBuilder.SaveableObject): + yield op + elif isinstance(op, (list, tuple, variables.PartitionedVariable)): + if isinstance(op, variables.PartitionedVariable): + op = list(op) + # A set of slices. + slice_name = None + # pylint: disable=protected-access + for variable in op: + if not isinstance(variable, variables.Variable): + raise ValueError("Slices must all be Variables: %s" % variable) + if not variable._save_slice_info: + raise ValueError("Slices must all be slices: %s" % variable) + if slice_name is None: + slice_name = variable._save_slice_info.full_name + elif slice_name != variable._save_slice_info.full_name: + raise ValueError( + "Slices must all be from the same tensor: %s != %s" % + (slice_name, variable._save_slice_info.full_name)) + if variable.op.type in ["Variable", "VariableV2", + "AutoReloadVariable"]: + yield BaseSaverBuilder.VariableSaveable( + variable, variable._save_slice_info.spec, name) + else: + yield BaseSaverBuilder.ResourceVariableSaveable( + variable, variable._save_slice_info.spec, name) + # pylint: enable=protected-access + else: + # A variable or tensor. + if context.executing_eagerly(): + if not isinstance(op, resource_variable_ops.ResourceVariable): + raise ValueError("Can only save/restore ResourceVariable eager " + "mode is enabled, type: %s." % type(op)) + yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name) + else: + if isinstance(op, resource_variable_ops.ResourceVariable): + variable = op._graph_element # pylint: disable=protected-access + else: + variable = ops.internal_convert_to_tensor(op, as_ref=True) + if not BaseSaverBuilder._IsVariable(variable): + raise TypeError("names_to_saveables must be a dict mapping string " + "names to Tensors/Variables. Not a variable: %s" % + variable) + if variable.op.type in ["Variable", "VariableV2", + "AutoReloadVariable"]: + yield BaseSaverBuilder.VariableSaveable(variable, "", name) + else: + yield BaseSaverBuilder.ResourceVariableSaveable( + variable, "", name) + def _ValidateAndSliceInputs(self, names_to_saveables): """Returns the variables and names that will be used for a Saver. @@ -590,63 +660,11 @@ class BaseSaverBuilder(object): saveables = [] seen_ops = set() - for name in sorted(names_to_saveables.keys()): - if not isinstance(name, six.string_types): - raise TypeError( - "names_to_saveables must be a dict mapping string names to " - "checkpointable operations. Name is not a string: %s" % name) - op = names_to_saveables[name] - if isinstance(op, BaseSaverBuilder.SaveableObject): - self._AddSaveable(saveables, seen_ops, op) - elif isinstance(op, (list, tuple, variables.PartitionedVariable)): - if isinstance(op, variables.PartitionedVariable): - op = list(op) - # A set of slices. - slice_name = None - # pylint: disable=protected-access - for variable in op: - if not isinstance(variable, variables.Variable): - raise ValueError("Slices must all be Variables: %s" % variable) - if not variable._save_slice_info: - raise ValueError("Slices must all be slices: %s" % variable) - if slice_name is None: - slice_name = variable._save_slice_info.full_name - elif slice_name != variable._save_slice_info.full_name: - raise ValueError( - "Slices must all be from the same tensor: %s != %s" % - (slice_name, variable._save_slice_info.full_name)) - if variable.op.type in ["Variable", "VariableV2", - "AutoReloadVariable"]: - saveable = BaseSaverBuilder.VariableSaveable( - variable, variable._save_slice_info.spec, name) - else: - saveable = BaseSaverBuilder.ResourceVariableSaveable( - variable, variable._save_slice_info.spec, name) - self._AddSaveable(saveables, seen_ops, saveable) - # pylint: enable=protected-access - else: - # A variable or tensor. - if context.executing_eagerly(): - if not isinstance(op, resource_variable_ops.ResourceVariable): - raise ValueError("Can only save/restore ResourceVariable eager " - "mode is enabled, type: %s." % type(op)) - saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name) - else: - if isinstance(op, resource_variable_ops.ResourceVariable): - variable = op._graph_element # pylint: disable=protected-access - else: - variable = ops.internal_convert_to_tensor(op, as_ref=True) - if not BaseSaverBuilder._IsVariable(variable): - raise TypeError("names_to_saveables must be a dict mapping string " - "names to Tensors/Variables. Not a variable: %s" % - variable) - if variable.op.type in ["Variable", "VariableV2", - "AutoReloadVariable"]: - saveable = BaseSaverBuilder.VariableSaveable(variable, "", name) - else: - saveable = BaseSaverBuilder.ResourceVariableSaveable( - variable, "", name) - self._AddSaveable(saveables, seen_ops, saveable) + for name, op in sorted(names_to_saveables.items(), + # Avoid comparing ops, sort only by name. + key=lambda x: x[0]): + for converted_saveable_object in self.SaveableObjectsForOp(op, name): + self._AddSaveable(saveables, seen_ops, converted_saveable_object) return saveables def _AddSaveable(self, saveables, seen_ops, saveable): diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py new file mode 100644 index 0000000..faf5164 --- /dev/null +++ b/tensorflow/python/util/serialization.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for serializing Python objects.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import tensor_shape + + +def get_json_type(obj): + """Serializes any object to a JSON-serializable structure. + + Arguments: + obj: the object to serialize + + Returns: + JSON-serializable structure representing `obj`. + + Raises: + TypeError: if `obj` cannot be serialized. + """ + # if obj is a serializable Keras class instance + # e.g. optimizer, layer + if hasattr(obj, 'get_config'): + return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} + + # if obj is any numpy type + if type(obj).__module__ == np.__name__: + if isinstance(obj, np.ndarray): + return {'type': type(obj), 'value': obj.tolist()} + else: + return obj.item() + + # misc functions (e.g. loss function) + if callable(obj): + return obj.__name__ + + # if obj is a python 'type' + if type(obj).__name__ == type.__name__: + return obj.__name__ + + if isinstance(obj, tensor_shape.Dimension): + return obj.value + + if isinstance(obj, tensor_shape.TensorShape): + return obj.as_list() + + raise TypeError('Not JSON Serializable:', obj) diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py new file mode 100644 index 0000000..f16fa53 --- /dev/null +++ b/tensorflow/python/util/serialization_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for serialization functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.keras._impl.keras.engine import input_layer +from tensorflow.python.keras._impl.keras.engine import sequential +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.platform import test +from tensorflow.python.util import serialization + + +class SerializationTests(test.TestCase): + + def test_serialize_dense(self): + dense = core.Dense(3) + dense(constant_op.constant([[4.]])) + round_trip = json.loads(json.dumps( + dense, default=serialization.get_json_type)) + self.assertEqual(3, round_trip["config"]["units"]) + + def test_serialize_shape(self): + round_trip = json.loads(json.dumps( + tensor_shape.TensorShape([None, 2, 3]), + default=serialization.get_json_type)) + self.assertIs(round_trip[0], None) + self.assertEqual(round_trip[1], 2) + + @test_util.run_in_graph_and_eager_modes() + def test_serialize_sequential(self): + model = sequential.Sequential() + model.add(core.Dense(4)) + model.add(core.Dense(5)) + model(constant_op.constant([[1.]])) + sequential_round_trip = json.loads( + json.dumps(model, default=serialization.get_json_type)) + self.assertEqual(5, sequential_round_trip["config"][1]["config"]["units"]) + input_round_trip = json.loads( + json.dumps(model._input_layers, default=serialization.get_json_type)) + self.assertAllEqual([1, 1], + input_round_trip[0]["config"]["batch_input_shape"]) + + @test_util.run_in_graph_and_eager_modes() + def test_serialize_model(self): + x = input_layer.Input(shape=[3]) + y = core.Dense(10)(x) + model = training.Model(x, y) + model(constant_op.constant([[1., 1., 1.]])) + model_round_trip = json.loads( + json.dumps(model, default=serialization.get_json_type)) + self.assertEqual( + 10, model_round_trip["config"]["layers"][1]["config"]["units"]) + +if __name__ == "__main__": + test.main() -- 2.7.4