Support saving Python state with object-based checkpoints
authorAllen Lavoie <allenl@google.com>
Wed, 9 May 2018 22:56:43 +0000 (15:56 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 22:59:21 +0000 (15:59 -0700)
Allows SaveableObjects to specify feed dict addition callbacks for object-based saving.

For now just saves get_config() with Layers. Doesn't do any loading, and there isn't quite enough information to reconstruct a Model yet (needs topology).

My plan is to get Models to the point where they can be reconstructed from object-based checkpoints (probably one more change), add in SavedModel export (assuming no dynamic control flow for now), then add this "SavedModel+Python" format to Model.save / load_model.

PiperOrigin-RevId: 196043183

tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
tensorflow/python/BUILD
tensorflow/python/keras/_impl/keras/engine/saving.py
tensorflow/python/training/checkpointable.py
tensorflow/python/training/checkpointable_utils.py
tensorflow/python/training/checkpointable_utils_test.py
tensorflow/python/training/saver.py
tensorflow/python/util/serialization.py [new file with mode: 0644]
tensorflow/python/util/serialization_test.py [new file with mode: 0644]

index 9e2858d..87b2ecf 100644 (file)
@@ -31,7 +31,6 @@ from tensorflow.python.eager import context
 from tensorflow.python.eager import function
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras._impl.keras.engine import training
@@ -139,8 +138,9 @@ class CheckpointingTests(test.TestCase):
       self.evaluate(checkpointable_utils.gather_initializers(
           root_checkpointable))
       self.evaluate(train_op)
-    named_variables, serialized_graph = (
-        checkpointable_utils._serialize_object_graph(root_checkpointable))
+    named_variables, serialized_graph, _ = (
+        checkpointable_utils._serialize_object_graph(
+            root_checkpointable, saveables_cache=None))
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
@@ -163,24 +163,29 @@ class CheckpointingTests(test.TestCase):
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names]
+    # The Dense layers also save get_config() JSON
+    expected_checkpoint_names.extend(
+        ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+         "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+    named_variables = {v.name: v for v in named_variables}
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual(
-        "global_step:0",
-        named_variables["optimizer_step" + suffix].name)
+        "global_step",
+        named_variables["optimizer_step" + suffix].full_name)
     self.assertEqual(
-        "my_model/dense_1/kernel:0",
-        named_variables["model/_second/kernel" + suffix].name)
+        "my_model/dense_1/kernel",
+        named_variables["model/_second/kernel" + suffix].full_name)
     self.assertEqual(
-        "my_model/dense/kernel:0",
-        named_variables["model/_named_dense/kernel" + suffix].name)
+        "my_model/dense/kernel",
+        named_variables["model/_named_dense/kernel" + suffix].full_name)
     self.assertEqual(
-        "beta1_power:0",
-        named_variables["optimizer/beta1_power" + suffix].name)
+        "beta1_power",
+        named_variables["optimizer/beta1_power" + suffix].full_name)
     self.assertEqual(
-        "beta2_power:0",
-        named_variables["optimizer/beta2_power" + suffix].name)
+        "beta2_power",
+        named_variables["optimizer/beta2_power" + suffix].full_name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
@@ -205,7 +210,7 @@ class CheckpointingTests(test.TestCase):
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(
-            var=named_variables["model/_named_dense/kernel" + suffix],
+            var=model._named_dense.kernel,
             name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
@@ -417,16 +422,6 @@ class CheckpointingTests(test.TestCase):
                          self.evaluate(root.save_counter))
   # pylint: enable=cell-var-from-loop
 
-  def _get_checkpoint_name(self, name):
-    root = checkpointable.Checkpointable()
-    checkpointable_utils.add_variable(
-        root, name=name, shape=[1, 2], dtype=dtypes.float64)
-    named_variables, _ = checkpointable_utils._serialize_object_graph(root)
-    checkpoint_name, = named_variables.keys()
-    with ops.name_scope("root/" + checkpoint_name):
-      pass  # Make sure we can use this as an op name if we prefix it.
-    return checkpoint_name
-
   def testAnonymousVarsInInit(self):
 
     class Model(training.Model):
index f7cbaec..8b904a1 100644 (file)
@@ -3036,9 +3036,12 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
+        ":constant_op",
+        ":control_flow_ops",
         ":dtypes",
         ":io_ops_gen",
         ":ops",
+        ":saveable_object",
         ":util",
         "//tensorflow/python/eager:context",
     ],
@@ -3224,6 +3227,18 @@ py_test(
 )
 
 py_test(
+    name = "util_serialization_test",
+    size = "small",
+    srcs = ["util/serialization_test.py"],
+    main = "util/serialization_test.py",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":client_testlib",
+        ":util",
+    ],
+)
+
+py_test(
     name = "future_api_test",
     size = "small",
     srcs = ["util/future_api_test.py"],
index a0b709a..ee6e320 100644 (file)
@@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras import optimizers
 from tensorflow.python.keras._impl.keras.utils import conv_utils
 from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import serialization
 from tensorflow.python.util.tf_export import tf_export
 
 # pylint: disable=g-import-not-at-top
@@ -74,40 +75,6 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
   if h5py is None:
     raise ImportError('`save_model` requires h5py.')
 
-  def get_json_type(obj):
-    """Serializes any object to a JSON-serializable structure.
-
-    Arguments:
-        obj: the object to serialize
-
-    Returns:
-        JSON-serializable structure representing `obj`.
-
-    Raises:
-        TypeError: if `obj` cannot be serialized.
-    """
-    # if obj is a serializable Keras class instance
-    # e.g. optimizer, layer
-    if hasattr(obj, 'get_config'):
-      return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
-
-    # if obj is any numpy type
-    if type(obj).__module__ == np.__name__:
-      if isinstance(obj, np.ndarray):
-        return {'type': type(obj), 'value': obj.tolist()}
-      else:
-        return obj.item()
-
-    # misc functions (e.g. loss function)
-    if callable(obj):
-      return obj.__name__
-
-    # if obj is a python 'type'
-    if type(obj).__name__ == type.__name__:
-      return obj.__name__
-
-    raise TypeError('Not JSON Serializable:', obj)
-
   from tensorflow.python.keras._impl.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
 
   # If file exists and should not be overwritten.
@@ -124,7 +91,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
             'class_name': model.__class__.__name__,
             'config': model.get_config()
         },
-        default=get_json_type).encode('utf8')
+        default=serialization.get_json_type).encode('utf8')
 
     model_weights_group = f.create_group('model_weights')
     model_layers = model.layers
@@ -154,7 +121,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
                 'sample_weight_mode': model.sample_weight_mode,
                 'loss_weights': model.loss_weights,
             },
-            default=get_json_type).encode('utf8')
+            default=serialization.get_json_type).encode('utf8')
 
         # Save optimizer weights.
         symbolic_weights = getattr(model.optimizer, 'weights')
index d00312a..956dd66 100644 (file)
@@ -18,14 +18,21 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
+import functools
+import json
+import weakref
 
 from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_io_ops as io_ops
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import saveable_object
 from tensorflow.python.util import nest
+from tensorflow.python.util import serialization
 
 
 # Key where the object graph proto is saved in a TensorBundle
@@ -37,6 +44,7 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
 # the object has no dependencies, then its value may be restored on object
 # creation (avoiding double assignment when executing eagerly).
 VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
+OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
 
 CheckpointableReference = collections.namedtuple(
     "CheckpointableReference",
@@ -85,6 +93,35 @@ class CheckpointInitialValue(ops.Tensor):
     return self._checkpoint_position
 
 
+class PythonStringStateSaveable(saveable_object.SaveableObject):
+  """Saves Python state in a checkpoint."""
+
+  def __init__(self, name, state_callback):
+    """Configure saving.
+
+    Args:
+      name: The checkpoint key to write to.
+      state_callback: A function taking no arguments which returns a
+        string. This function is run every time a checkpoint is written.
+    """
+    if context.executing_eagerly():
+      self._save_string = (
+          lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
+    else:
+      self._save_string = constant_op.constant("", dtype=dtypes.string)
+      self.feed_dict_additions = (
+          lambda: {self._save_string: state_callback()})
+    spec = saveable_object.SaveSpec(
+        self._save_string, "", name, dtype=dtypes.string)
+    super(PythonStringStateSaveable, self).__init__(
+        self._save_string, [spec], name)
+
+  def restore(self, restored_tensors, restored_shapes):
+    # TODO(allenl): Add a Python hook for state coming out of a checkpoint
+    # (currently PythonStringStateSaveable is write-only).
+    return control_flow_ops.no_op()
+
+
 class _CheckpointPosition(object):
   """Indicates a position within a `_Checkpoint`."""
 
@@ -604,7 +641,6 @@ class CheckpointableBase(object):
     # restoration on to our dependencies.
     if checkpoint.restore_uid > self._update_uid:
       restore_ops = checkpoint_position.restore_ops()
-      # TODO(allenl): Get a list of feeds for saving Python state
       self._update_uid = checkpoint.restore_uid
     else:
       restore_ops = ()
@@ -656,7 +692,24 @@ class CheckpointableBase(object):
        lambda name="global_name_for_this_object":
        SaveableObject(name=name, ...)}
     """
-    return {}
+    if not hasattr(self, "get_config"):
+      return {}
+    try:
+      self.get_config()
+    except NotImplementedError:
+      return {}
+    weak_self = weakref.ref(self)
+    def _state_callback():
+      dereferenced_self = weak_self()
+      if dereferenced_self:
+        return json.dumps(self,
+                          default=serialization.get_json_type,
+                          sort_keys=True).encode("utf8")
+      else:
+        return ""
+    return {OBJECT_CONFIG_JSON_KEY: functools.partial(
+        PythonStringStateSaveable,
+        state_callback=_state_callback)}
 
 
 class NoDependency(object):
index f2a2b41..1e69096 100644 (file)
@@ -36,6 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.training import checkpointable as checkpointable_lib
 from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.training import saveable_object
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
@@ -303,42 +304,93 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
 
 
 def _serialize_checkpointables(
-    checkpointable_objects, node_ids, object_names, slot_variables):
+    checkpointable_objects, node_ids, object_names, slot_variables,
+    saveables_cache):
   """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
   object_graph_proto = (
       checkpointable_object_graph_pb2.CheckpointableObjectGraph())
-  named_saveables = {}
-
+  named_saveables = []
+  feed_additions = {}
   for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
     assert node_ids[checkpointable] == checkpoint_id
     object_proto = object_graph_proto.nodes.add()
     object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
     object_name = object_names[checkpointable]
+    if saveables_cache is not None:
+      cached_attributes = saveables_cache.setdefault(checkpointable, {})
+    else:
+      cached_attributes = None
     for name, saveable_factory in (
         checkpointable._gather_saveables_for_checkpoint().items()):  # pylint: disable=protected-access
       attribute = object_proto.attributes.add()
       attribute.name = name
       attribute.checkpoint_key = "%s/%s/%s" % (
           object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
-      if callable(saveable_factory):
-        saveable = saveable_factory(name=attribute.checkpoint_key)
+      if cached_attributes is None:
+        saveables = None
       else:
-        saveable = saveable_factory
-      # Figure out the name-based Saver's name for this variable.
-      saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
-          [saveable], convert_variable_to_tensor=False)
-      attribute.full_name, = saver_dict.keys()
-      named_saveables[attribute.checkpoint_key] = saveable
+        saveables = cached_attributes.get(name, None)
+        if saveables is not None:
+          for saveable in saveables:
+            if attribute.checkpoint_key not in saveable.name:
+              # The checkpoint key for this SaveableObject is different. We need
+              # to re-create it.
+              saveables = None
+              del cached_attributes[name]
+              break
+      if saveables is None:
+        if callable(saveable_factory):
+          maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
+        else:
+          maybe_saveable = saveable_factory
+        if isinstance(maybe_saveable, saveable_object.SaveableObject):
+          saveables = (maybe_saveable,)
+        else:
+          # Figure out the name-based Saver's name for this variable. If it's
+          # already a SaveableObject we'd just get the checkpoint key back, so
+          # we leave full_name blank.
+          saver_dict = saver_lib.BaseSaverBuilder.OpListToDict(
+              [maybe_saveable], convert_variable_to_tensor=False)
+          full_name, = saver_dict.keys()
+          saveables = tuple(saver_lib.BaseSaverBuilder.SaveableObjectsForOp(
+              op=maybe_saveable, name=attribute.checkpoint_key))
+          for saveable in saveables:
+            saveable.full_name = full_name
+        for saveable in saveables:
+          if attribute.checkpoint_key not in saveable.name:
+            raise AssertionError(
+                ("The object %s produced a SaveableObject with name '%s' for "
+                 "attribute '%s'. Expected a name containing '%s'.")
+                % (checkpointable, name, saveable.name,
+                   attribute.checkpoint_key))
+        if cached_attributes is not None:
+          cached_attributes[name] = saveables
+
+      for saveable in saveables:
+        if hasattr(saveable, "full_name"):
+          attribute.full_name = saveable.full_name
+        saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
+        if saveable_feed_dict_fn is not None:
+          saveable_feed_dict = saveable_feed_dict_fn()  # pylint: disable=not-callable
+          for new_feed_key in saveable_feed_dict.keys():
+            if new_feed_key in feed_additions:
+              raise AssertionError(
+                  ("The object %s tried to feed a value for the Tensor %s "
+                   "when saving, but another object is already feeding a "
+                   "value.")
+                  % (checkpointable, new_feed_key))
+          feed_additions.update(saveable_feed_dict)
+      named_saveables.extend(saveables)
 
     for child in checkpointable._checkpoint_dependencies:  # pylint: disable=protected-access
       child_proto = object_proto.children.add()
       child_proto.node_id = node_ids[child.ref]
       child_proto.local_name = child.name
 
-  return named_saveables, object_graph_proto
+  return named_saveables, object_graph_proto, feed_additions
 
 
-def _serialize_object_graph(root_checkpointable):
+def _serialize_object_graph(root_checkpointable, saveables_cache):
   """Determine checkpoint keys for variables and build a serialized graph.
 
   Non-slot variables are keyed based on a shortest path from the root saveable
@@ -351,12 +403,17 @@ def _serialize_object_graph(root_checkpointable):
   Args:
     root_checkpointable: A `Checkpointable` object whose variables (including
       the variables of dependencies, recursively) should be saved.
+    saveables_cache: A dictionary mapping `Checkpointable` objects -> attribute
+      names -> SaveableObjects, used to avoid re-creating SaveableObjects when
+      graph building.
 
   Returns:
-    A tuple of (named_variables, object_graph_proto):
+    A tuple of (named_variables, object_graph_proto, feed_additions):
       named_variables: A dictionary mapping names to variable objects.
       object_graph_proto: A CheckpointableObjectGraph protocol buffer containing
         the serialized object graph and variable references.
+      feed_additions: A dictionary mapping from Tensors to values which should
+        be fed when saving.
 
   Raises:
     ValueError: If there are invalid characters in an optimizer's slot names.
@@ -376,7 +433,8 @@ def _serialize_object_graph(root_checkpointable):
       checkpointable_objects=checkpointable_objects,
       node_ids=node_ids,
       object_names=object_names,
-      slot_variables=slot_variables)
+      slot_variables=slot_variables,
+      saveables_cache=saveables_cache)
 
 
 def list_objects(root_checkpointable):
@@ -728,6 +786,14 @@ class CheckpointableSaver(object):
     self._last_restore_object_graph = None
     self._last_restore_checkpoint = None
 
+    if context.executing_eagerly():
+      # SaveableObjects are always recreated when executing eagerly.
+      self._saveable_object_cache = None
+    else:
+      # Maps Checkpointable objects -> attribute names -> SaveableObjects, to
+      # avoid re-creating SaveableObjects when graph building.
+      self._saveable_object_cache = weakref.WeakKeyDictionary()
+
   @property
   def _root_checkpointable(self):
     if isinstance(self._root_checkpointable_ref, weakref.ref):
@@ -759,8 +825,9 @@ class CheckpointableSaver(object):
     Returns:
       The full path to the checkpoint.
     """
-    named_variables, graph_proto = _serialize_object_graph(
-        self._root_checkpointable)
+    named_variables, graph_proto, feed_additions = _serialize_object_graph(
+        self._root_checkpointable,
+        saveables_cache=self._saveable_object_cache)
     if not context.executing_eagerly():
       if session is None:
         session = ops.get_default_session()
@@ -769,15 +836,15 @@ class CheckpointableSaver(object):
           self._object_graph_feed_tensor = constant_op.constant(
               "", dtype=dtypes.string)
       object_graph_tensor = self._object_graph_feed_tensor
-      feed_additions = {object_graph_tensor: graph_proto.SerializeToString()}
+      feed_additions.update(
+          {object_graph_tensor: graph_proto.SerializeToString()})
     else:
       session = None
       with ops.device("/cpu:0"):
         object_graph_tensor = constant_op.constant(
             graph_proto.SerializeToString(), dtype=dtypes.string)
-      feed_additions = None
     assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables
-    named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = (
+    named_variables.append(
         _NoRestoreSaveable(
             tensor=object_graph_tensor,
             name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
@@ -804,13 +871,23 @@ class CheckpointableSaver(object):
 
   def _global_variable_names(self):
     """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s."""
-    named_saveables, graph_proto = _serialize_object_graph(
-        self._root_checkpointable)
+    named_saveables, graph_proto, _ = _serialize_object_graph(
+        self._root_checkpointable,
+        # We destructively modify SaveableObjects, so don't do any caching.
+        saveables_cache=None)
+    named_saveables = {v.name: v for v in named_saveables}
     saver_names = {}
     for object_proto in graph_proto.nodes:
       for attribute_proto in object_proto.attributes:
-        saver_names[attribute_proto.full_name] = named_saveables[
-            attribute_proto.checkpoint_key]
+        if attribute_proto.full_name:
+          # Ignore attributes, such as Python object JSON, which don't have a
+          # name-based Saver name.
+          saveable = named_saveables[attribute_proto.checkpoint_key]
+          saveable.name = attribute_proto.full_name
+          for spec in saveable.specs:
+            spec.name = spec.name.replace(attribute_proto.checkpoint_key,
+                                          attribute_proto.full_name)
+          saver_names[attribute_proto.full_name] = saveable
     return saver_names
 
   def restore(self, save_path):
@@ -1037,6 +1114,7 @@ class Checkpoint(checkpointable_lib.Checkpointable):
             % (v,))
       setattr(self, k, v)
     self._save_counter = None  # Created lazily for restore-on-create.
+    self._save_assign_op = None
     self._saver = CheckpointableSaver(weakref.ref(self))
 
   def _maybe_create_save_counter(self):
@@ -1089,10 +1167,13 @@ class Checkpoint(checkpointable_lib.Checkpointable):
         # needs to be initialized before assign_add. This is only an issue if
         # restore() has not been called first.
         session.run(self.save_counter.initializer)
-    with ops.colocate_with(self.save_counter):
-      assign_op = self.save_counter.assign_add(1)
+    if not in_graph_mode or self._save_assign_op is None:
+      with ops.colocate_with(self.save_counter):
+        assign_op = self.save_counter.assign_add(1, read_value=False)
+      if in_graph_mode:
+        self._save_assign_op = assign_op
     if in_graph_mode:
-      session.run(assign_op)
+      session.run(self._save_assign_op)
     return self._saver.save(
         file_prefix=file_prefix,
         checkpoint_number=self.save_counter,
index 3b8166b..dead8fd 100644 (file)
@@ -17,10 +17,12 @@ from __future__ import division
 from __future__ import print_function
 
 import functools
+import json
 import os
 
 import six
 
+from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.client import session as session_lib
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
@@ -120,7 +122,8 @@ class InterfaceTests(test.TestCase):
       # The .name attribute may be globally influenced, but the checkpoint name
       # won't be (tested below).
       self.assertEqual("duplicate_1:0", duplicate.name)
-    named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
+    named_variables, _, _ = checkpointable_utils._serialize_object_graph(
+        obj, saveables_cache=None)
     expected_checkpoint_names = (
         "a_variable/.ATTRIBUTES/VARIABLE_VALUE",
         "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
@@ -129,7 +132,7 @@ class InterfaceTests(test.TestCase):
         "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
     )
     six.assertCountEqual(
-        self, expected_checkpoint_names, named_variables.keys())
+        self, expected_checkpoint_names, [v.name for v in named_variables])
 
   def testInitNotCalled(self):
 
@@ -245,8 +248,9 @@ class CheckpointingTests(test.TestCase):
       self.evaluate(checkpointable_utils.gather_initializers(
           root_checkpointable))
       self.evaluate(train_op)
-    named_variables, serialized_graph = (
-        checkpointable_utils._serialize_object_graph(root_checkpointable))
+    named_variables, serialized_graph, _ = (
+        checkpointable_utils._serialize_object_graph(
+            root_checkpointable, saveables_cache=None))
     expected_checkpoint_names = (
         # Created in the root node, so no prefix.
         "optimizer_step",
@@ -269,24 +273,29 @@ class CheckpointingTests(test.TestCase):
     suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
     expected_checkpoint_names = [
         name + suffix for name in expected_checkpoint_names]
+    # The Dense layers also save get_config() JSON
+    expected_checkpoint_names.extend(
+        ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+         "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+    named_variables = {v.name: v for v in named_variables}
     six.assertCountEqual(self, expected_checkpoint_names,
                          named_variables.keys())
     # Check that we've mapped to the right variable objects (not exhaustive)
     self.assertEqual(
-        "global_step:0",
-        named_variables["optimizer_step" + suffix].name)
+        "global_step",
+        named_variables["optimizer_step" + suffix].full_name)
     self.assertEqual(
-        "my_model/dense_1/kernel:0",
-        named_variables["model/_second/kernel" + suffix].name)
+        "my_model/dense_1/kernel",
+        named_variables["model/_second/kernel" + suffix].full_name)
     self.assertEqual(
-        "my_model/dense/kernel:0",
-        named_variables["model/_named_dense/kernel" + suffix].name)
+        "my_model/dense/kernel",
+        named_variables["model/_named_dense/kernel" + suffix].full_name)
     self.assertEqual(
-        "beta1_power:0",
-        named_variables["optimizer/beta1_power" + suffix].name)
+        "beta1_power",
+        named_variables["optimizer/beta1_power" + suffix].full_name)
     self.assertEqual(
-        "beta2_power:0",
-        named_variables["optimizer/beta2_power" + suffix].name)
+        "beta2_power",
+        named_variables["optimizer/beta2_power" + suffix].full_name)
     # Spot check the generated protocol buffers.
     self.assertEqual("optimizer",
                      serialized_graph.nodes[0].children[1].local_name)
@@ -311,7 +320,7 @@ class CheckpointingTests(test.TestCase):
     self.assertEqual(
         "my_model/dense/kernel/Adam:0",
         optimizer.get_slot(
-            var=named_variables["model/_named_dense/kernel" + suffix],
+            var=model._named_dense.kernel,
             name="m").name)
     self.assertEqual(
         "model/_named_dense/kernel" + suffix,
@@ -563,11 +572,11 @@ class CheckpointingTests(test.TestCase):
     root = checkpointable.Checkpointable()
     checkpointable_utils.add_variable(
         root, name=name, shape=[1, 2], dtype=dtypes.float64)
-    named_variables, _ = checkpointable_utils._serialize_object_graph(root)
-    checkpoint_name, = named_variables.keys()
-    with ops.name_scope("root/" + checkpoint_name):
+    (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+        root, saveables_cache=None)
+    with ops.name_scope("root/" + named_variable.name):
       pass  # Make sure we can use this as an op name if we prefix it.
-    return checkpoint_name
+    return named_variable.name
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testVariableNameEscaping(self):
@@ -585,9 +594,9 @@ class CheckpointingTests(test.TestCase):
     leaf = checkpointable.Checkpointable()
     root.leaf = leaf
     checkpointable_utils.add_variable(leaf, name="v", shape=[])
-    named_variables, _ = checkpointable_utils._serialize_object_graph(root)
-    variable_name, = named_variables.keys()
-    self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name)
+    (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+        root, saveables_cache=None)
+    self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name)
 
   @test_util.run_in_graph_and_eager_modes()
   def testLocalNameValidation(self):
@@ -596,9 +605,10 @@ class CheckpointingTests(test.TestCase):
     # Dots are escaped, which avoids conflicts with reserved names.
     root._track_checkpointable(leaf, name=".ATTRIBUTES")
     checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[])
-    named_variables, _ = checkpointable_utils._serialize_object_graph(root)
-    name, = named_variables.keys()
-    self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE")
+    (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
+        root, saveables_cache=None)
+    self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE",
+                     named_variable.name)
 
   def testAnonymousVarsInInit(self):
 
@@ -1395,5 +1405,48 @@ class CheckpointCompatibilityTests(test.TestCase):
         root.restore(save_path).assert_consumed().run_restore_ops()
         self._check_sentinels(root)
 
+
+class PythonMetadataTests(test.TestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testSaveLoad(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    dense = core.Dense(1)
+    checkpoint = checkpointable_utils.Checkpoint(dense=dense)
+    dense(constant_op.constant([[1.]]))
+    checkpoint.restore(None).initialize_or_restore()
+    save_path = checkpoint.save(checkpoint_prefix)
+
+    def _get_dense_node_from_object_graph(object_graph_proto):
+      root_node = object_graph_proto.nodes[0]
+      for child in root_node.children:
+        if child.local_name == "dense":
+          break
+      else:
+        raise AssertionError(
+            "Expected a 'dense' dependency of root, didn't find one.")
+      dense_node = object_graph_proto.nodes[child.node_id]  # pylint: disable=undefined-loop-variable
+      self.assertEqual(1, len(dense_node.attributes))
+      reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+      layer_json = reader.get_tensor(dense_node.attributes[0].checkpoint_key)
+      return json.loads(layer_json.decode("utf-8"))
+
+    layer_data = _get_dense_node_from_object_graph(
+        checkpointable_utils.object_metadata(save_path))
+    self.assertEqual("Dense", layer_data["class_name"])
+    self.assertEqual(1, layer_data["config"]["units"])
+
+    # Check that no new ops are added to the graph the second time we save.
+    ops.get_default_graph().finalize()
+
+    dense.units = 42
+    save_path = checkpoint.save(checkpoint_prefix)
+    layer_data = _get_dense_node_from_object_graph(
+        checkpointable_utils.object_metadata(save_path))
+    self.assertEqual("Dense", layer_data["class_name"])
+    self.assertEqual(42, layer_data["config"]["units"])
+
+
 if __name__ == "__main__":
   test.main()
index 53e821c..98e79a4 100644 (file)
@@ -569,6 +569,76 @@ class BaseSaverBuilder(object):
       # pylint: enable=protected-access
     return names_to_saveables
 
+  @staticmethod
+  def SaveableObjectsForOp(op, name):
+    """Create `SaveableObject`s from an operation.
+
+    Args:
+      op: A variable, operation, or SaveableObject to coerce into a
+        SaveableObject.
+      name: A string name for the SaveableObject.
+
+    Yields:
+      `SaveableObject`s which together save/restore `op`.
+
+    Raises:
+      TypeError: If `name` is not a string.
+      ValueError: For operations with no known conversion to SaveableObject.
+    """
+    if not isinstance(name, six.string_types):
+      raise TypeError(
+          "names_to_saveables must be a dict mapping string names to "
+          "checkpointable operations. Name is not a string: %s" % name)
+    if isinstance(op, BaseSaverBuilder.SaveableObject):
+      yield op
+    elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
+      if isinstance(op, variables.PartitionedVariable):
+        op = list(op)
+      # A set of slices.
+      slice_name = None
+      # pylint: disable=protected-access
+      for variable in op:
+        if not isinstance(variable, variables.Variable):
+          raise ValueError("Slices must all be Variables: %s" % variable)
+        if not variable._save_slice_info:
+          raise ValueError("Slices must all be slices: %s" % variable)
+        if slice_name is None:
+          slice_name = variable._save_slice_info.full_name
+        elif slice_name != variable._save_slice_info.full_name:
+          raise ValueError(
+              "Slices must all be from the same tensor: %s != %s" %
+              (slice_name, variable._save_slice_info.full_name))
+        if variable.op.type in ["Variable", "VariableV2",
+                                "AutoReloadVariable"]:
+          yield BaseSaverBuilder.VariableSaveable(
+              variable, variable._save_slice_info.spec, name)
+        else:
+          yield BaseSaverBuilder.ResourceVariableSaveable(
+              variable, variable._save_slice_info.spec, name)
+      # pylint: enable=protected-access
+    else:
+      # A variable or tensor.
+      if context.executing_eagerly():
+        if not isinstance(op, resource_variable_ops.ResourceVariable):
+          raise ValueError("Can only save/restore ResourceVariable eager "
+                           "mode is enabled, type: %s." % type(op))
+        yield BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
+      else:
+        if isinstance(op, resource_variable_ops.ResourceVariable):
+          variable = op._graph_element  # pylint: disable=protected-access
+        else:
+          variable = ops.internal_convert_to_tensor(op, as_ref=True)
+        if not BaseSaverBuilder._IsVariable(variable):
+          raise TypeError("names_to_saveables must be a dict mapping string "
+                          "names to Tensors/Variables. Not a variable: %s" %
+                          variable)
+        if variable.op.type in ["Variable", "VariableV2",
+                                "AutoReloadVariable"]:
+          yield BaseSaverBuilder.VariableSaveable(variable, "", name)
+        else:
+          yield BaseSaverBuilder.ResourceVariableSaveable(
+              variable, "", name)
+
   def _ValidateAndSliceInputs(self, names_to_saveables):
     """Returns the variables and names that will be used for a Saver.
 
@@ -590,63 +660,11 @@ class BaseSaverBuilder(object):
 
     saveables = []
     seen_ops = set()
-    for name in sorted(names_to_saveables.keys()):
-      if not isinstance(name, six.string_types):
-        raise TypeError(
-            "names_to_saveables must be a dict mapping string names to "
-            "checkpointable operations. Name is not a string: %s" % name)
-      op = names_to_saveables[name]
-      if isinstance(op, BaseSaverBuilder.SaveableObject):
-        self._AddSaveable(saveables, seen_ops, op)
-      elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
-        if isinstance(op, variables.PartitionedVariable):
-          op = list(op)
-        # A set of slices.
-        slice_name = None
-        # pylint: disable=protected-access
-        for variable in op:
-          if not isinstance(variable, variables.Variable):
-            raise ValueError("Slices must all be Variables: %s" % variable)
-          if not variable._save_slice_info:
-            raise ValueError("Slices must all be slices: %s" % variable)
-          if slice_name is None:
-            slice_name = variable._save_slice_info.full_name
-          elif slice_name != variable._save_slice_info.full_name:
-            raise ValueError(
-                "Slices must all be from the same tensor: %s != %s" %
-                (slice_name, variable._save_slice_info.full_name))
-          if variable.op.type in ["Variable", "VariableV2",
-                                  "AutoReloadVariable"]:
-            saveable = BaseSaverBuilder.VariableSaveable(
-                variable, variable._save_slice_info.spec, name)
-          else:
-            saveable = BaseSaverBuilder.ResourceVariableSaveable(
-                variable, variable._save_slice_info.spec, name)
-          self._AddSaveable(saveables, seen_ops, saveable)
-        # pylint: enable=protected-access
-      else:
-        # A variable or tensor.
-        if context.executing_eagerly():
-          if not isinstance(op, resource_variable_ops.ResourceVariable):
-            raise ValueError("Can only save/restore ResourceVariable eager "
-                             "mode is enabled, type: %s." % type(op))
-          saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
-        else:
-          if isinstance(op, resource_variable_ops.ResourceVariable):
-            variable = op._graph_element  # pylint: disable=protected-access
-          else:
-            variable = ops.internal_convert_to_tensor(op, as_ref=True)
-          if not BaseSaverBuilder._IsVariable(variable):
-            raise TypeError("names_to_saveables must be a dict mapping string "
-                            "names to Tensors/Variables. Not a variable: %s" %
-                            variable)
-          if variable.op.type in ["Variable", "VariableV2",
-                                  "AutoReloadVariable"]:
-            saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
-          else:
-            saveable = BaseSaverBuilder.ResourceVariableSaveable(
-                variable, "", name)
-        self._AddSaveable(saveables, seen_ops, saveable)
+    for name, op in sorted(names_to_saveables.items(),
+                           # Avoid comparing ops, sort only by name.
+                           key=lambda x: x[0]):
+      for converted_saveable_object in self.SaveableObjectsForOp(op, name):
+        self._AddSaveable(saveables, seen_ops, converted_saveable_object)
     return saveables
 
   def _AddSaveable(self, saveables, seen_ops, saveable):
diff --git a/tensorflow/python/util/serialization.py b/tensorflow/python/util/serialization.py
new file mode 100644 (file)
index 0000000..faf5164
--- /dev/null
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for serializing Python objects."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import tensor_shape
+
+
+def get_json_type(obj):
+  """Serializes any object to a JSON-serializable structure.
+
+  Arguments:
+      obj: the object to serialize
+
+  Returns:
+      JSON-serializable structure representing `obj`.
+
+  Raises:
+      TypeError: if `obj` cannot be serialized.
+  """
+  # if obj is a serializable Keras class instance
+  # e.g. optimizer, layer
+  if hasattr(obj, 'get_config'):
+    return {'class_name': obj.__class__.__name__, 'config': obj.get_config()}
+
+  # if obj is any numpy type
+  if type(obj).__module__ == np.__name__:
+    if isinstance(obj, np.ndarray):
+      return {'type': type(obj), 'value': obj.tolist()}
+    else:
+      return obj.item()
+
+  # misc functions (e.g. loss function)
+  if callable(obj):
+    return obj.__name__
+
+  # if obj is a python 'type'
+  if type(obj).__name__ == type.__name__:
+    return obj.__name__
+
+  if isinstance(obj, tensor_shape.Dimension):
+    return obj.value
+
+  if isinstance(obj, tensor_shape.TensorShape):
+    return obj.as_list()
+
+  raise TypeError('Not JSON Serializable:', obj)
diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py
new file mode 100644 (file)
index 0000000..f16fa53
--- /dev/null
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for serialization functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import input_layer
+from tensorflow.python.keras._impl.keras.engine import sequential
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
+from tensorflow.python.platform import test
+from tensorflow.python.util import serialization
+
+
+class SerializationTests(test.TestCase):
+
+  def test_serialize_dense(self):
+    dense = core.Dense(3)
+    dense(constant_op.constant([[4.]]))
+    round_trip = json.loads(json.dumps(
+        dense, default=serialization.get_json_type))
+    self.assertEqual(3, round_trip["config"]["units"])
+
+  def test_serialize_shape(self):
+    round_trip = json.loads(json.dumps(
+        tensor_shape.TensorShape([None, 2, 3]),
+        default=serialization.get_json_type))
+    self.assertIs(round_trip[0], None)
+    self.assertEqual(round_trip[1], 2)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_serialize_sequential(self):
+    model = sequential.Sequential()
+    model.add(core.Dense(4))
+    model.add(core.Dense(5))
+    model(constant_op.constant([[1.]]))
+    sequential_round_trip = json.loads(
+        json.dumps(model, default=serialization.get_json_type))
+    self.assertEqual(5, sequential_round_trip["config"][1]["config"]["units"])
+    input_round_trip = json.loads(
+        json.dumps(model._input_layers, default=serialization.get_json_type))
+    self.assertAllEqual([1, 1],
+                        input_round_trip[0]["config"]["batch_input_shape"])
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_serialize_model(self):
+    x = input_layer.Input(shape=[3])
+    y = core.Dense(10)(x)
+    model = training.Model(x, y)
+    model(constant_op.constant([[1., 1., 1.]]))
+    model_round_trip = json.loads(
+        json.dumps(model, default=serialization.get_json_type))
+    self.assertEqual(
+        10, model_round_trip["config"]["layers"][1]["config"]["units"])
+
+if __name__ == "__main__":
+  test.main()