from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl.keras.engine import training
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph = (
- checkpointable_utils._serialize_object_graph(root_checkpointable))
+ named_variables, serialized_graph, _ = (
+ checkpointable_utils._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
# Check that we've mapped to the right variable objects (not exhaustive)
self.assertEqual(
- "global_step:0",
- named_variables["optimizer_step" + suffix].name)
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
self.assertEqual(
- "my_model/dense_1/kernel:0",
- named_variables["model/_second/kernel" + suffix].name)
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
self.assertEqual(
- "my_model/dense/kernel:0",
- named_variables["model/_named_dense/kernel" + suffix].name)
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power:0",
- named_variables["optimizer/beta1_power" + suffix].name)
+ "beta1_power",
+ named_variables["optimizer/beta1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power:0",
- named_variables["optimizer/beta2_power" + suffix].name)
+ "beta2_power",
+ named_variables["optimizer/beta2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
self.assertEqual(
"my_model/dense/kernel/Adam:0",
optimizer.get_slot(
- var=named_variables["model/_named_dense/kernel" + suffix],
+ var=model._named_dense.kernel,
name="m").name)
self.assertEqual(
"model/_named_dense/kernel" + suffix,
self.evaluate(root.save_counter))
# pylint: enable=cell-var-from-loop
- def _get_checkpoint_name(self, name):
- root = checkpointable.Checkpointable()
- checkpointable_utils.add_variable(
- root, name=name, shape=[1, 2], dtype=dtypes.float64)
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- checkpoint_name, = named_variables.keys()
- with ops.name_scope("root/" + checkpoint_name):
- pass # Make sure we can use this as an op name if we prefix it.
- return checkpoint_name
-
def testAnonymousVarsInInit(self):
class Model(training.Model):
srcs_version = "PY2AND3",
deps = [
":array_ops",
+ ":constant_op",
+ ":control_flow_ops",
":dtypes",
":io_ops_gen",
":ops",
+ ":saveable_object",
":util",
"//tensorflow/python/eager:context",
],
)
py_test(
+ name = "util_serialization_test",
+ size = "small",
+ srcs = ["util/serialization_test.py"],
+ main = "util/serialization_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":util",
+ ],
+)
+
+py_test(
name = "future_api_test",
size = "small",
srcs = ["util/future_api_test.py"],
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import serialization
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=g-import-not-at-top
if h5py is None:
raise ImportError('`save_model` requires h5py.')
- def get_json_type(obj):
- """Serializes any object to a JSON-serializable structure.
-
- Arguments:
- obj: the object to serialize
-
- Returns:
- JSON-serializable structure representing `obj`.
-
- Raises:
- TypeError: if `obj` cannot be serialized.
- """
- # if obj is a serializable Keras class instance
- # e.g. optimizer, layer
- if hasattr(obj, 'get_config'):
- return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
-
- # if obj is any numpy type
- if type(obj).__module__ == np.__name__:
- if isinstance(obj, np.ndarray):
- return {'type': type(obj), 'value': obj.tolist()}
- else:
- return obj.item()
-
- # misc functions (e.g. loss function)
- if callable(obj):
- return obj.__name__
-
- # if obj is a python 'type'
- if type(obj).__name__ == type.__name__:
- return obj.__name__
-
- raise TypeError('Not JSON Serializable:', obj)
-
from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
# If file exists and should not be overwritten.
'class_name': model.__class__.__name__,
'config': model.get_config()
},
- default=get_json_type).encode('utf8')
+ default=serialization.get_json_type).encode('utf8')
model_weights_group = f.create_group('model_weights')
model_layers = model.layers
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
},
- default=get_json_type).encode('utf8')
+ default=serialization.get_json_type).encode('utf8')
# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
from __future__ import print_function
import collections
+import functools
+import json
+import weakref
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import saveable_object
from tensorflow.python.util import nest
+from tensorflow.python.util import serialization
# Key where the object graph proto is saved in a TensorBundle
# the object has no dependencies, then its value may be restored on object
# creation (avoiding double assignment when executing eagerly).
VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
+OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
CheckpointableReference = collections.namedtuple(
"CheckpointableReference",
return self._checkpoint_position
+class PythonStringStateSaveable(saveable_object.SaveableObject):
+ """Saves Python state in a checkpoint."""
+
+ def __init__(self, name, state_callback):
+ """Configure saving.
+
+ Args:
+ name: The checkpoint key to write to.
+ state_callback: A function taking no arguments which returns a
+ string. This function is run every time a checkpoint is written.
+ """
+ if context.executing_eagerly():
+ self._save_string = (
+ lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
+ else:
+ self._save_string = constant_op.constant("", dtype=dtypes.string)
+ self.feed_dict_additions = (
+ lambda: {self._save_string: state_callback()})
+ spec = saveable_object.SaveSpec(
+ self._save_string, "", name, dtype=dtypes.string)
+ super(PythonStringStateSaveable, self).__init__(
+ self._save_string, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ # TODO(allenl): Add a Python hook for state coming out of a checkpoint
+ # (currently PythonStringStateSaveable is write-only).
+ return control_flow_ops.no_op()
+
+
class _CheckpointPosition(object):
"""Indicates a position within a `_Checkpoint`."""
# restoration on to our dependencies.
if checkpoint.restore_uid > self._update_uid:
restore_ops = checkpoint_position.restore_ops()
- # TODO(allenl): Get a list of feeds for saving Python state
self._update_uid = checkpoint.restore_uid
else:
restore_ops = ()
lambda name="global_name_for_this_object":
SaveableObject(name=name, ...)}
"""
- return {}
+ if not hasattr(self, "get_config"):
+ return {}
+ try:
+ self.get_config()
+ except NotImplementedError:
+ return {}
+ weak_self = weakref.ref(self)
+ def _state_callback():
+ dereferenced_self = weak_self()
+ if dereferenced_self:
+ return json.dumps(self,
+ default=serialization.get_json_type,
+ sort_keys=True).encode("utf8")
+ else:
+ return ""
+ return {OBJECT_CONFIG_JSON_KEY: functools.partial(
+ PythonStringStateSaveable,
+ state_callback=_state_callback)}
class NoDependency(object):
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable as checkpointable_lib
from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.training import saveable_object
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
def _serialize_checkpointables(
- checkpointable_objects, node_ids, object_names, slot_variables):
+ checkpointable_objects, node_ids, object_names, slot_variables,
+ saveables_cache):
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
- named_saveables = {}
-
+ named_saveables = []
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
object_name = object_names[checkpointable]
+ if saveables_cache is not None:
+ cached_attributes = saveables_cache.setdefault(checkpointable, {})
+ else:
+ cached_attributes = None
for name, saveable_factory in (
checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
attribute = object_proto.attributes.add()
attribute.name = name
attribute.checkpoint_key = "%s/%s/%s" % (
object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
- if callable(saveable_factory):
- saveable = saveable_factory(name=attribute.checkpoint_key)
+ if cached_attributes is None:
+ saveables = None
else:
- saveable = saveable_factory
- # Figure out the name-based Saver's name for this variable.
- saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
- [saveable], convert_variable_to_tensor=False)
- attribute.full_name, = saver_dict.keys()
- named_saveables[attribute.checkpoint_key] = saveable
+ saveables = cached_attributes.get(name, None)
+ if saveables is not None:
+ for saveable in saveables:
+ if attribute.checkpoint_key not in saveable.name:
+ # The checkpoint key for this SaveableObject is different. We need
+ # to re-create it.
+ saveables = None
+ del cached_attributes[name]
+ break
+ if saveables is None:
+ if callable(saveable_factory):
+ maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
+ else:
+ maybe_saveable = saveable_factory
+ if isinstance(maybe_saveable, saveable_object.SaveableObject):
+ saveables = (maybe_saveable,)
+ else:
+ # Figure out the name-based Saver's name for this variable. If it's
+ # already a SaveableObject we'd just get the checkpoint key back, so
+ # we leave full_name blank.
+ saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
+ [maybe_saveable], convert_variable_to_tensor=False)
+ full_name, = saver_dict.keys()
+ saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
+ op=maybe_saveable, name=attribute.checkpoint_key))
+ for saveable in saveables:
+ saveable.full_name = full_name
+ for saveable in saveables:
+ if attribute.checkpoint_key not in saveable.name:
+ raise AssertionError(
+ ("The object %s produced a SaveableObject with name '%s' for "
+ "attribute '%s'. Expected a name containing '%s'.")
+ % (checkpointable, name, saveable.name,
+ attribute.checkpoint_key))
+ if cached_attributes is not None:
+ cached_attributes[name] = saveables
+
+ for saveable in saveables:
+ if hasattr(saveable, "full_name"):
+ attribute.full_name = saveable.full_name
+ saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
+ if saveable_feed_dict_fn is not None:
+ saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.extend(saveables)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
child_proto.node_id = node_ids[child.ref]
child_proto.local_name = child.name
- return named_saveables, object_graph_proto
+ return named_saveables, object_graph_proto, feed_additions
-def _serialize_object_graph(root_checkpointable):
+def _serialize_object_graph(root_checkpointable, saveables_cache):
"""Determine checkpoint keys for variables and build a serialized graph.
Non-slot variables are keyed based on a shortest path from the root saveable
Args:
root_checkpointable: A `Checkpointable` object whose variables (including
the variables of dependencies, recursively) should be saved.
+ saveables_cache: A dictionary mapping `Checkpointable` objects -> attribute
+ names -> SaveableObjects, used to avoid re-creating SaveableObjects when
+ graph building.
Returns:
- A tuple of (named_variables, object_graph_proto):
+ A tuple of (named_variables, object_graph_proto, feed_additions):
named_variables: A dictionary mapping names to variable objects.
object_graph_proto: A CheckpointableObjectGraph protocol buffer containing
the serialized object graph and variable references.
+ feed_additions: A dictionary mapping from Tensors to values which should
+ be fed when saving.
Raises:
ValueError: If there are invalid characters in an optimizer's slot names.
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names,
- slot_variables=slot_variables)
+ slot_variables=slot_variables,
+ saveables_cache=saveables_cache)
def list_objects(root_checkpointable):
self._last_restore_object_graph = None
self._last_restore_checkpoint = None
+ if context.executing_eagerly():
+ # SaveableObjects are always recreated when executing eagerly.
+ self._saveable_object_cache = None
+ else:
+ # Maps Checkpointable objects -> attribute names -> SaveableObjects, to
+ # avoid re-creating SaveableObjects when graph building.
+ self._saveable_object_cache = weakref.WeakKeyDictionary()
+
@property
def _root_checkpointable(self):
if isinstance(self._root_checkpointable_ref, weakref.ref):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto = _serialize_object_graph(
- self._root_checkpointable)
+ named_variables, graph_proto, feed_additions = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=self._saveable_object_cache)
if not context.executing_eagerly():
if session is None:
session = ops.get_default_session()
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions = {object_graph_tensor: graph_proto.SerializeToString()}
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
else:
session = None
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
- feed_additions = None
assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = (
+ named_variables.append(
_NoRestoreSaveable(
tensor=object_graph_tensor,
name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
def _global_variable_names(self):
"""Generate a `tf.train.Saver`-style `var_list` using `variable.name`s."""
- named_saveables, graph_proto = _serialize_object_graph(
- self._root_checkpointable)
+ named_saveables, graph_proto, _ = _serialize_object_graph(
+ self._root_checkpointable,
+ # We destructively modify SaveableObjects, so don't do any caching.
+ saveables_cache=None)
+ named_saveables = {v.name: v for v in named_saveables}
saver_names = {}
for object_proto in graph_proto.nodes:
for attribute_proto in object_proto.attributes:
- saver_names[attribute_proto.full_name] = named_saveables[
- attribute_proto.checkpoint_key]
+ if attribute_proto.full_name:
+ # Ignore attributes, such as Python object JSON, which don't have a
+ # name-based Saver name.
+ saveable = named_saveables[attribute_proto.checkpoint_key]
+ saveable.name = attribute_proto.full_name
+ for spec in saveable.specs:
+ spec.name = spec.name.replace(attribute_proto.checkpoint_key,
+ attribute_proto.full_name)
+ saver_names[attribute_proto.full_name] = saveable
return saver_names
def restore(self, save_path):
% (v,))
setattr(self, k, v)
self._save_counter = None # Created lazily for restore-on-create.
+ self._save_assign_op = None
self._saver = CheckpointableSaver(weakref.ref(self))
def _maybe_create_save_counter(self):
# needs to be initialized before assign_add. This is only an issue if
# restore() has not been called first.
session.run(self.save_counter.initializer)
- with ops.colocate_with(self.save_counter):
- assign_op = self.save_counter.assign_add(1)
+ if not in_graph_mode or self._save_assign_op is None:
+ with ops.colocate_with(self.save_counter):
+ assign_op = self.save_counter.assign_add(1, read_value=False)
+ if in_graph_mode:
+ self._save_assign_op = assign_op
if in_graph_mode:
- session.run(assign_op)
+ session.run(self._save_assign_op)
return self._saver.save(
file_prefix=file_prefix,
checkpoint_number=self.save_counter,
from __future__ import print_function
import functools
+import json
import os
import six
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
- named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
+ named_variables, _, _ = checkpointable_utils._serialize_object_graph(
+ obj, saveables_cache=None)
expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
"bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
"ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
)
six.assertCountEqual(
- self, expected_checkpoint_names, named_variables.keys())
+ self, expected_checkpoint_names, [v.name for v in named_variables])
def testInitNotCalled(self):
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph = (
- checkpointable_utils._serialize_object_graph(root_checkpointable))
+ named_variables, serialized_graph, _ = (
+ checkpointable_utils._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
# Check that we've mapped to the right variable objects (not exhaustive)
self.assertEqual(
- "global_step:0",
- named_variables["optimizer_step" + suffix].name)
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
self.assertEqual(
- "my_model/dense_1/kernel:0",
- named_variables["model/_second/kernel" + suffix].name)
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
self.assertEqual(
- "my_model/dense/kernel:0",
- named_variables["model/_named_dense/kernel" + suffix].name)
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power:0",
- named_variables["optimizer/beta1_power" + suffix].name)
+ "beta1_power",
+ named_variables["optimizer/beta1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power:0",
- named_variables["optimizer/beta2_power" + suffix].name)
+ "beta2_power",
+ named_variables["optimizer/beta2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
self.assertEqual(
"my_model/dense/kernel/Adam:0",
optimizer.get_slot(
- var=named_variables["model/_named_dense/kernel" + suffix],
+ var=model._named_dense.kernel,
name="m").name)
self.assertEqual(
"model/_named_dense/kernel" + suffix,
root = checkpointable.Checkpointable()
checkpointable_utils.add_variable(
root, name=name, shape=[1, 2], dtype=dtypes.float64)
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- checkpoint_name, = named_variables.keys()
- with ops.name_scope("root/" + checkpoint_name):
+ (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+ root, saveables_cache=None)
+ with ops.name_scope("root/" + named_variable.name):
pass # Make sure we can use this as an op name if we prefix it.
- return checkpoint_name
+ return named_variable.name
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testVariableNameEscaping(self):
leaf = checkpointable.Checkpointable()
root.leaf = leaf
checkpointable_utils.add_variable(leaf, name="v", shape=[])
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- variable_name, = named_variables.keys()
- self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name)
+ (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+ root, saveables_cache=None)
+ self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name)
@test_util.run_in_graph_and_eager_modes()
def testLocalNameValidation(self):
# Dots are escaped, which avoids conflicts with reserved names.
root._track_checkpointable(leaf, name=".ATTRIBUTES")
checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[])
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- name, = named_variables.keys()
- self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE")
+ (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+ root, saveables_cache=None)
+ self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE",
+ named_variable.name)
def testAnonymousVarsInInit(self):
root.restore(save_path).assert_consumed().run_restore_ops()
self._check_sentinels(root)
+
+class PythonMetadataTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testSaveLoad(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dense = core.Dense(1)
+ checkpoint = checkpointable_utils.Checkpoint(dense=dense)
+ dense(constant_op.constant([[1.]]))
+ checkpoint.restore(None).initialize_or_restore()
+ save_path = checkpoint.save(checkpoint_prefix)
+
+ def _get_dense_node_from_object_graph(object_graph_proto):
+ root_node = object_graph_proto.nodes[0]
+ for child in root_node.children:
+ if child.local_name == "dense":
+ break
+ else:
+ raise AssertionError(
+ "Expected a 'dense' dependency of root, didn't find one.")
+ dense_node = object_graph_proto.nodes[child.node_id] # pylint: disable=undefined-loop-variable
+ self.assertEqual(1, len(dense_node.attributes))
+ reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+ layer_json = reader.get_tensor(dense_node.attributes[0].checkpoint_key)
+ return json.loads(layer_json.decode("utf-8"))
+
+ layer_data = _get_dense_node_from_object_graph(
+ checkpointable_utils.object_metadata(save_path))
+ self.assertEqual("Dense", layer_data["class_name"])
+ self.assertEqual(1, layer_data["config"]["units"])
+
+ # Check that no new ops are added to the graph the second time we save.
+ ops.get_default_graph().finalize()
+
+ dense.units = 42
+ save_path = checkpoint.save(checkpoint_prefix)
+ layer_data = _get_dense_node_from_object_graph(
+ checkpointable_utils.object_metadata(save_path))
+ self.assertEqual("Dense", layer_data["class_name"])
+ self.assertEqual(42, layer_data["config"]["units"])
+
+
if __name__ == "__main__":
test.main()
# pylint: enable=protected-access
return names_to_saveables
+ @staticmethod
+ def SaveableObjectsForOp(op, name):
+ """Create `SaveableObject`s from an operation.
+
+ Args:
+ op: A variable, operation, or SaveableObject to coerce into a
+ SaveableObject.
+ name: A string name for the SaveableObject.
+
+ Yields:
+ `SaveableObject`s which together save/restore `op`.
+
+ Raises:
+ TypeError: If `name` is not a string.
+ ValueError: For operations with no known conversion to SaveableObject.
+ """
+ if not isinstance(name, six.string_types):
+ raise TypeError(
+ "names_to_saveables must be a dict mapping string names to "
+ "checkpointable operations. Name is not a string: %s" % name)
+ if isinstance(op, BaseSaverBuilder.SaveableObject):
+ yield op
+ elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
+ if isinstance(op, variables.PartitionedVariable):
+ op = list(op)
+ # A set of slices.
+ slice_name = None
+ # pylint: disable=protected-access
+ for variable in op:
+ if not isinstance(variable, variables.Variable):
+ raise ValueError("Slices must all be Variables: %s" % variable)
+ if not variable._save_slice_info:
+ raise ValueError("Slices must all be slices: %s" % variable)
+ if slice_name is None:
+ slice_name = variable._save_slice_info.full_name
+ elif slice_name != variable._save_slice_info.full_name:
+ raise ValueError(
+ "Slices must all be from the same tensor: %s != %s" %
+ (slice_name, variable._save_slice_info.full_name))
+ if variable.op.type in ["Variable", "VariableV2",
+ "AutoReloadVariable"]:
+ yield BaseSaverBuilder.VariableSaveable(
+ variable, variable._save_slice_info.spec, name)
+ else:
+ yield BaseSaverBuilder.ResourceVariableSaveable(
+ variable, variable._save_slice_info.spec, name)
+ # pylint: enable=protected-access
+ else:
+ # A variable or tensor.
+ if context.executing_eagerly():
+ if not isinstance(op, resource_variable_ops.ResourceVariable):
+ raise ValueError("Can only save/restore ResourceVariable eager "
+ "mode is enabled, type: %s." % type(op))
+ yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
+ else:
+ if isinstance(op, resource_variable_ops.ResourceVariable):
+ variable = op._graph_element # pylint: disable=protected-access
+ else:
+ variable = ops.internal_convert_to_tensor(op, as_ref=True)
+ if not BaseSaverBuilder._IsVariable(variable):
+ raise TypeError("names_to_saveables must be a dict mapping string "
+ "names to Tensors/Variables. Not a variable: %s" %
+ variable)
+ if variable.op.type in ["Variable", "VariableV2",
+ "AutoReloadVariable"]:
+ yield BaseSaverBuilder.VariableSaveable(variable, "", name)
+ else:
+ yield BaseSaverBuilder.ResourceVariableSaveable(
+ variable, "", name)
+
def _ValidateAndSliceInputs(self, names_to_saveables):
"""Returns the variables and names that will be used for a Saver.
saveables = []
seen_ops = set()
- for name in sorted(names_to_saveables.keys()):
- if not isinstance(name, six.string_types):
- raise TypeError(
- "names_to_saveables must be a dict mapping string names to "
- "checkpointable operations. Name is not a string: %s" % name)
- op = names_to_saveables[name]
- if isinstance(op, BaseSaverBuilder.SaveableObject):
- self._AddSaveable(saveables, seen_ops, op)
- elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
- if isinstance(op, variables.PartitionedVariable):
- op = list(op)
- # A set of slices.
- slice_name = None
- # pylint: disable=protected-access
- for variable in op:
- if not isinstance(variable, variables.Variable):
- raise ValueError("Slices must all be Variables: %s" % variable)
- if not variable._save_slice_info:
- raise ValueError("Slices must all be slices: %s" % variable)
- if slice_name is None:
- slice_name = variable._save_slice_info.full_name
- elif slice_name != variable._save_slice_info.full_name:
- raise ValueError(
- "Slices must all be from the same tensor: %s != %s" %
- (slice_name, variable._save_slice_info.full_name))
- if variable.op.type in ["Variable", "VariableV2",
- "AutoReloadVariable"]:
- saveable = BaseSaverBuilder.VariableSaveable(
- variable, variable._save_slice_info.spec, name)
- else:
- saveable = BaseSaverBuilder.ResourceVariableSaveable(
- variable, variable._save_slice_info.spec, name)
- self._AddSaveable(saveables, seen_ops, saveable)
- # pylint: enable=protected-access
- else:
- # A variable or tensor.
- if context.executing_eagerly():
- if not isinstance(op, resource_variable_ops.ResourceVariable):
- raise ValueError("Can only save/restore ResourceVariable eager "
- "mode is enabled, type: %s." % type(op))
- saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
- else:
- if isinstance(op, resource_variable_ops.ResourceVariable):
- variable = op._graph_element # pylint: disable=protected-access
- else:
- variable = ops.internal_convert_to_tensor(op, as_ref=True)
- if not BaseSaverBuilder._IsVariable(variable):
- raise TypeError("names_to_saveables must be a dict mapping string "
- "names to Tensors/Variables. Not a variable: %s" %
- variable)
- if variable.op.type in ["Variable", "VariableV2",
- "AutoReloadVariable"]:
- saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
- else:
- saveable = BaseSaverBuilder.ResourceVariableSaveable(
- variable, "", name)
- self._AddSaveable(saveables, seen_ops, saveable)
+ for name, op in sorted(names_to_saveables.items(),
+ # Avoid comparing ops, sort only by name.
+ key=lambda x: x[0]):
+ for converted_saveable_object in self.SaveableObjectsForOp(op, name):
+ self._AddSaveable(saveables, seen_ops, converted_saveable_object)
return saveables
def _AddSaveable(self, saveables, seen_ops, saveable):
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for serializing Python objects."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import tensor_shape
+
+
+def get_json_type(obj):
+ """Serializes any object to a JSON-serializable structure.
+
+ Arguments:
+ obj: the object to serialize
+
+ Returns:
+ JSON-serializable structure representing `obj`.
+
+ Raises:
+ TypeError: if `obj` cannot be serialized.
+ """
+ # if obj is a serializable Keras class instance
+ # e.g. optimizer, layer
+ if hasattr(obj, 'get_config'):
+ return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
+
+ # if obj is any numpy type
+ if type(obj).__module__ == np.__name__:
+ if isinstance(obj, np.ndarray):
+ return {'type': type(obj), 'value': obj.tolist()}
+ else:
+ return obj.item()
+
+ # misc functions (e.g. loss function)
+ if callable(obj):
+ return obj.__name__
+
+ # if obj is a python 'type'
+ if type(obj).__name__ == type.__name__:
+ return obj.__name__
+
+ if isinstance(obj, tensor_shape.Dimension):
+ return obj.value
+
+ if isinstance(obj, tensor_shape.TensorShape):
+ return obj.as_list()
+
+ raise TypeError('Not JSON Serializable:', obj)
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for serialization functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import input_layer
+from tensorflow.python.keras._impl.keras.engine import sequential
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
+from tensorflow.python.platform import test
+from tensorflow.python.util import serialization
+
+
+class SerializationTests(test.TestCase):
+
+ def test_serialize_dense(self):
+ dense = core.Dense(3)
+ dense(constant_op.constant([[4.]]))
+ round_trip = json.loads(json.dumps(
+ dense, default=serialization.get_json_type))
+ self.assertEqual(3, round_trip["config"]["units"])
+
+ def test_serialize_shape(self):
+ round_trip = json.loads(json.dumps(
+ tensor_shape.TensorShape([None, 2, 3]),
+ default=serialization.get_json_type))
+ self.assertIs(round_trip[0], None)
+ self.assertEqual(round_trip[1], 2)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_serialize_sequential(self):
+ model = sequential.Sequential()
+ model.add(core.Dense(4))
+ model.add(core.Dense(5))
+ model(constant_op.constant([[1.]]))
+ sequential_round_trip = json.loads(
+ json.dumps(model, default=serialization.get_json_type))
+ self.assertEqual(5, sequential_round_trip["config"][1]["config"]["units"])
+ input_round_trip = json.loads(
+ json.dumps(model._input_layers, default=serialization.get_json_type))
+ self.assertAllEqual([1, 1],
+ input_round_trip[0]["config"]["batch_input_shape"])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_serialize_model(self):
+ x = input_layer.Input(shape=[3])
+ y = core.Dense(10)(x)
+ model = training.Model(x, y)
+ model(constant_op.constant([[1., 1., 1.]]))
+ model_round_trip = json.loads(
+ json.dumps(model, default=serialization.get_json_type))
+ self.assertEqual(
+ 10, model_round_trip["config"]["layers"][1]["config"]["units"])
+
+if __name__ == "__main__":
+ test.main()