from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
from tensorflow.python.layers import base
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
+from tensorflow.python.util import deprecation
# pylint: disable=protected-access
# Explanation for protected-access disable: Network has lots of same-class and
return current_variable_scope.name + "/"
+_NETWORK_DEPRECATION_MESSAGE = (
+ "Please inherit from `tf.keras.Model`, and see its documentation for "
+ "details. `tf.keras.Model` should be a drop-in replacement for "
+ "`tfe.Network` in most cases, but note that `track_layer` is no longer "
+ "necessary or supported. Instead, `Layer` instances are tracked on "
+ "attribute assignment (see the section of `tf.keras.Model`'s documentation "
+ "on subclassing). Since the output of `track_layer` is often assigned to "
+ "an attribute anyway, most code can be ported by simply removing the "
+ "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow "
+ "`Layer` instances, including those from `tf.layers`, but switching to "
+ "the `tf.keras.layers` versions along with the migration to "
+ "`tf.keras.Model` is recommended, since it will preserve variable names. "
+ "Feel free to import it with an alias to avoid excess typing :)."
+)
+
+
class Network(base.Layer):
"""Represents the composition of a set of Layers.
+ *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation
+ for details. `tf.keras.Model` should be a drop-in replacement for
+ `tfe.Network` in most cases, but note that `track_layer` is no longer
+ necessary or supported. Instead, `Layer` instances are tracked on attribute
+ assignment (see the section of `tf.keras.Model`'s documentation on
+ subclassing). Since the output of `track_layer` is often assigned to an
+ attribute anyway, most code can be ported by simply removing the `track_layer`
+ calls.
+
+ `tf.keras.Model` works with all TensorFlow `Layer` instances, including those
+ from `tf.layers`, but switching to the `tf.keras.layers` versions along with
+ the migration to `tf.keras.Model` is recommended, since it will preserve
+ variable names. Feel free to import it with an alias to avoid excess typing
+ :).
+
`Network` implements the `Layer` interface and adds convenience methods for
managing sub-`Layer`s, such as listing variables.
# - Detect layers used in __call__ that weren't registered with track_layer.
# - Convert inputs to __call__ to tensors.
+ @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE)
def __init__(self, name=None):
"""Configure the `Network`.
ValueError: If `name` is not valid. Note that some naming errors will
instead be raised when the `Network` is called.
"""
+ if context.executing_eagerly():
+ logging.warning(
+ ("** tfe.Network is deprecated and will be removed in a future "
+ "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE)
if isinstance(name, variable_scope.VariableScope):
raise ValueError("VariableScopes are not valid Network names.")
if name is not None and "/" in name:
self._variable_scope_counts_on_init = (
variable_scope.get_variable_scope_store().variable_scopes_count)
+ def _gather_saveables_for_checkpoint(self):
+ raise NotImplementedError(
+ "tfe.Network does not support object-based checkpointing.\n\n%s"
+ % _NETWORK_DEPRECATION_MESSAGE)
+
def _name_scope_name(self, current_variable_scope):
"""Overrides Layer op naming to match variable naming."""
return _network_name_scope_naming(
return _strip_variable_prefix
+@deprecation.deprecated(date=None, instructions=(
+ "Please inherit from tf.keras.Model instead of tfe.Network, and use "
+ "tf.keras.Model.save_weights."))
def save_network_checkpoint(
network, save_path, global_step=None, map_func=None):
"""Save variables from the Network to a checkpoint.
_add_deferred_restoration(network, deferred_restoration)
+@deprecation.deprecated(date=None, instructions=(
+ "Please inherit from tf.keras.Model instead of tfe.Network, and use "
+ "tf.keras.Model.load_weights."))
def restore_network_checkpoint(network, save_path, map_func=None):
"""Restore the Network from a checkpoint.