Deprecate tfe.Network and associated utilities in favor of tf.keras.Model.
authorAllen Lavoie <allenl@google.com>
Thu, 26 Apr 2018 18:24:36 +0000 (11:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 18:29:11 +0000 (11:29 -0700)
Also throws an error rather than silently saving incorrectly with tf.train.Checkpoint.

(In response to confusion over tf.train.Checkpoint with tfe.Network)

PiperOrigin-RevId: 194426679

tensorflow/contrib/eager/python/network.py
tensorflow/contrib/eager/python/network_test.py

index 2f87213..44828be 100644 (file)
@@ -28,9 +28,11 @@ from tensorflow.python.framework import ops
 from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
 from tensorflow.python.layers import base
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import checkpoint_utils
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training import training_util
+from tensorflow.python.util import deprecation
 
 # pylint: disable=protected-access
 # Explanation for protected-access disable: Network has lots of same-class and
@@ -52,9 +54,40 @@ def _network_name_scope_naming(current_variable_scope):
   return current_variable_scope.name + "/"
 
 
+_NETWORK_DEPRECATION_MESSAGE = (
+    "Please inherit from `tf.keras.Model`, and see its documentation for "
+    "details. `tf.keras.Model` should be a drop-in replacement for "
+    "`tfe.Network` in most cases, but note that `track_layer` is no longer "
+    "necessary or supported. Instead, `Layer` instances are tracked on "
+    "attribute assignment (see the section of `tf.keras.Model`'s documentation "
+    "on subclassing). Since the output of `track_layer` is often assigned to "
+    "an attribute anyway, most code can be ported by simply removing the "
+    "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow "
+    "`Layer` instances, including those from `tf.layers`, but switching to "
+    "the `tf.keras.layers` versions along with the migration to "
+    "`tf.keras.Model` is recommended, since it will preserve variable names. "
+    "Feel free to import it with an alias to avoid excess typing :)."
+)
+
+
 class Network(base.Layer):
   """Represents the composition of a set of Layers.
 
+  *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation
+  for details. `tf.keras.Model` should be a drop-in replacement for
+  `tfe.Network` in most cases, but note that `track_layer` is no longer
+  necessary or supported. Instead, `Layer` instances are tracked on attribute
+  assignment (see the section of `tf.keras.Model`'s documentation on
+  subclassing). Since the output of `track_layer` is often assigned to an
+  attribute anyway, most code can be ported by simply removing the `track_layer`
+  calls.
+
+  `tf.keras.Model` works with all TensorFlow `Layer` instances, including those
+  from `tf.layers`, but switching to the `tf.keras.layers` versions along with
+  the migration to `tf.keras.Model` is recommended, since it will preserve
+  variable names.  Feel free to import it with an alias to avoid excess typing
+  :).
+
   `Network` implements the `Layer` interface and adds convenience methods for
   managing sub-`Layer`s, such as listing variables.
 
@@ -112,6 +145,7 @@ class Network(base.Layer):
   # - Detect layers used in __call__ that weren't registered with track_layer.
   # - Convert inputs to __call__ to tensors.
 
+  @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE)
   def __init__(self, name=None):
     """Configure the `Network`.
 
@@ -130,6 +164,10 @@ class Network(base.Layer):
       ValueError: If `name` is not valid. Note that some naming errors will
         instead be raised when the `Network` is called.
     """
+    if context.executing_eagerly():
+      logging.warning(
+          ("** tfe.Network is deprecated and will be removed in a future "
+           "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE)
     if isinstance(name, variable_scope.VariableScope):
       raise ValueError("VariableScopes are not valid Network names.")
     if name is not None and "/" in name:
@@ -152,6 +190,11 @@ class Network(base.Layer):
     self._variable_scope_counts_on_init = (
         variable_scope.get_variable_scope_store().variable_scopes_count)
 
+  def _gather_saveables_for_checkpoint(self):
+    raise NotImplementedError(
+        "tfe.Network does not support object-based checkpointing.\n\n%s"
+        % _NETWORK_DEPRECATION_MESSAGE)
+
   def _name_scope_name(self, current_variable_scope):
     """Overrides Layer op naming to match variable naming."""
     return _network_name_scope_naming(
@@ -706,6 +749,9 @@ def _make_prefix_stripping_map_fn(scope_name):
   return _strip_variable_prefix
 
 
+@deprecation.deprecated(date=None, instructions=(
+    "Please inherit from tf.keras.Model instead of tfe.Network, and use "
+    "tf.keras.Model.save_weights."))
 def save_network_checkpoint(
     network, save_path, global_step=None, map_func=None):
   """Save variables from the Network to a checkpoint.
@@ -905,6 +951,9 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func,
     _add_deferred_restoration(network, deferred_restoration)
 
 
+@deprecation.deprecated(date=None, instructions=(
+    "Please inherit from tf.keras.Model instead of tfe.Network, and use "
+    "tf.keras.Model.load_weights."))
 def restore_network_checkpoint(network, save_path, map_func=None):
   """Restore the Network from a checkpoint.
 
index f43376d..6a51d03 100644 (file)
@@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpointable_utils
 from tensorflow.python.training import training_util
 
 
@@ -62,6 +63,12 @@ class RegularizedNetwork(network.Network):
 
 class NetworkTest(test.TestCase):
 
+  def test_checkpointing_not_implemented(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork())
+    with self.assertRaises(NotImplementedError):
+      checkpoint.save(checkpoint_directory)
+
   def _save_modify_load_network_built(self, net, global_step=None):
     checkpoint_directory = self.get_temp_dir()
     checkpoint_path = network.save_network_checkpoint(