--- /dev/null
- # Used in symbolic mode only, only in conjonction with graph-networks
+ # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+ # ==============================================================================
+ # pylint: disable=protected-access
+ """A `Network` is way to compose layers: the topological form of a `Model`.
+ """
+ from __future__ import absolute_import
+ from __future__ import division
+ from __future__ import print_function
+
+ import copy
+ import json
+ import os
+ import weakref
+
+ import numpy as np
+ from six.moves import zip # pylint: disable=redefined-builtin
+
+ from tensorflow.python import pywrap_tensorflow
+ from tensorflow.python.eager import context
+ from tensorflow.python.framework import errors_impl
+ from tensorflow.python.framework import ops
+ from tensorflow.python.framework import tensor_shape
+ from tensorflow.python.keras import backend
+ from tensorflow.python.keras.engine import base_layer
+ from tensorflow.python.keras.engine import saving
+ from tensorflow.python.keras.utils import generic_utils
+ from tensorflow.python.keras.utils import tf_utils
+ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
+ from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
+ from tensorflow.python.platform import tf_logging as logging
+ from tensorflow.python.training.checkpointable import base as checkpointable
+ from tensorflow.python.training.checkpointable import util as checkpointable_utils
+ from tensorflow.python.util import nest
+ from tensorflow.python.util import tf_inspect
+
+
+ # pylint: disable=g-import-not-at-top
+ try:
+ import h5py
+ except ImportError:
+ h5py = None
+
+ try:
+ import yaml
+ except ImportError:
+ yaml = None
+ # pylint: enable=g-import-not-at-top
+
+
+ class Network(base_layer.Layer):
+ """A `Network` is a composition of layers.
+
+ It is the topological form of a "model". A `Model`
+ is simply a `Network` with added training routines.
+ """
+
+ def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called
+ # Signature detection
+ if (len(args) == 2 or
+ len(args) == 1 and 'outputs' in kwargs or
+ 'inputs' in kwargs and 'outputs' in kwargs):
+ # Graph network
+ self._init_graph_network(*args, **kwargs)
+ else:
+ # Subclassed network
+ self._init_subclassed_network(**kwargs)
+
+ def _base_init(self, name=None):
+ # The following are implemented as property functions:
+ # self.trainable_weights
+ # self.non_trainable_weights
+ # self.input_spec
+ # self.losses
+ # self.updates
+
+ self._init_set_name(name, zero_based=True)
+ self._activity_regularizer = None
+ # This acts just like the `trainable` attribute of any layer instance.
+ # It does not affect users of the underlying layers, only users of the
+ # Network instance.
+ self.trainable = True
+ self._is_compiled = False
+ self._expects_training_arg = False
+
+ self.supports_masking = False
+ if not hasattr(self, 'optimizer'):
+ # Don't reset optimizer if already set.
+ self.optimizer = None
+
+ # Private attributes to implement compatibility with Layer.
+ self._updates = [] # Used in symbolic mode only.
+ self._losses = [] # Used in symbolic mode only.
+ self._scope = None # Never used.
+ self._reuse = None # Never used.
+ if context.executing_eagerly():
+ self._graph = None
+ else:
+ self._graph = ops.get_default_graph() # Used in symbolic mode only.
+ # A Network does not create weights of its own, thus has no dtype.
+ self._dtype = None
+
+ # All layers in order of horizontal graph traversal.
+ # Entries are unique. Includes input and output layers.
+ self._layers = []
+
++ # Used in symbolic mode only, only in conjunction with graph-networks
+ self._outbound_nodes = []
+ self._inbound_nodes = []
+
+ self._checkpointable_saver = checkpointable_utils.CheckpointableSaver(
+ weakref.ref(self))
+ # A zero-argument function which should be called and set back to None as
+ # soon as the network is built (only applicable to subclassed Models). Runs
+ # restore operations when graph building.
+ self._in_progress_restore_finalizer = None
+
+ def _init_graph_network(self, inputs, outputs, name=None):
+ self._uses_inputs_arg = True
+ # Normalize and set self.inputs, self.outputs.
+ if isinstance(inputs, (list, tuple)):
+ self.inputs = list(inputs) # Tensor or list of tensors.
+ else:
+ self.inputs = [inputs]
+ if isinstance(outputs, (list, tuple)):
+ self.outputs = list(outputs)
+ else:
+ self.outputs = [outputs]
+
+ # User-provided argument validation.
+ if context.executing_eagerly():
+ # Check that all inputs/outputs are DeferredTensors.
+ for tensor in self.inputs:
+ if not isinstance(tensor, base_layer.DeferredTensor): # pylint: disable=protected-access
+ raise TypeError('When eager execution is enabled, '
+ 'inputs must come from a call to '
+ '`tf.keras.Input` (called after '
+ 'tfe.enable_eager_execution()). '
+ 'Received invalid input: ' + str(tensor))
+ for tensor in self.outputs:
+ if not isinstance(tensor, base_layer.DeferredTensor): # pylint: disable=protected-access
+ raise TypeError('When eager execution is enabled, '
+ 'outputs must come from a call to '
+ 'a layer (called after '
+ 'tfe.enable_eager_execution()). '
+ 'Received invalid output: ' + str(tensor))
+ # Check for redundancy in inputs.
+ if len(set(self.inputs)) != len(self.inputs):
+ raise ValueError('The list of inputs passed to the model '
+ 'is redundant. '
+ 'All inputs should only appear once.'
+ ' Found: ' + str(self.inputs))
+ for x in self.inputs:
+ # Check that x has appropriate `_keras_history` metadata.
+ if not hasattr(x, '_keras_history'):
+ cls_name = self.__class__.__name__
+ raise ValueError('Input tensors to a ' + cls_name + ' ' +
+ 'must come from `tf.layers.Input`. '
+ 'Received: ' + str(x) +
+ ' (missing previous layer metadata).')
+ # Check that x is an input tensor.
+ # pylint: disable=protected-access
+ layer, node_index, tensor_index = x._keras_history
+ if len(layer._inbound_nodes) > 1 or (
+ layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers):
+ cls_name = self.__class__.__name__
+ logging.warning(cls_name + ' inputs must come from '
+ '`tf.layers.Input` (thus holding past layer metadata), '
+ 'they cannot be the output of '
+ 'a previous non-Input layer. '
+ 'Here, a tensor specified as '
+ 'input to "' + self.name + '" was not an Input tensor, '
+ 'it was generated by layer ' + layer.name + '.\n'
+ 'Note that input tensors are '
+ 'instantiated via `tensor = tf.layers.Input(shape)`.\n'
+ 'The tensor that caused the issue was: ' + str(x.name))
+ for x in self.outputs:
+ if not hasattr(x, '_keras_history'):
+ cls_name = self.__class__.__name__
+ raise ValueError('Output tensors to a ' + cls_name + ' must be '
+ 'the output of a TensorFlow `Layer` '
+ '(thus holding past layer metadata). Found: ' + str(x))
+
+ self._base_init(name=name)
+ self._compute_previous_mask = (
+ 'mask' in tf_inspect.getargspec(self.call).args or
+ hasattr(self, 'compute_mask'))
+ # A Network does not create weights of its own, thus it is already
+ # built.
+ self.built = True
+ self._is_graph_network = True
+
+ self._input_layers = []
+ self._output_layers = []
+ self._input_coordinates = []
+ self._output_coordinates = []
+
+ # This is for performance optimization when calling the Network on new
+ # inputs. Every time the Network is called on a set on input tensors,
+ # we compute the output tensors, output masks and output shapes in one pass,
+ # then cache them here. When any of these outputs is queried later, we
+ # retrieve it from there instead of recomputing it.
+ self._output_mask_cache = {}
+ self._output_tensor_cache = {}
+ self._output_shape_cache = {}
+
+ # Build self._output_layers:
+ for x in self.outputs:
+ layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
+ self._output_layers.append(layer)
+ self._output_coordinates.append((layer, node_index, tensor_index))
+
+ # Build self._input_layers:
+ for x in self.inputs:
+ layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
+ # It's supposed to be an input layer, so only one node
+ # and one tensor output.
+ assert node_index == 0
+ assert tensor_index == 0
+ self._input_layers.append(layer)
+ self._input_coordinates.append((layer, node_index, tensor_index))
+
+ # Keep track of the network's nodes and layers.
+ nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network(
+ self.inputs, self.outputs)
+ self._network_nodes = nodes
+ self._nodes_by_depth = nodes_by_depth
+ self._layers = layers
+ self._layers_by_depth = layers_by_depth
+
+ self._track_layers(layers)
+
+ # Create the node linking internal inputs to internal outputs.
+ base_layer.Node(
+ outbound_layer=self,
+ inbound_layers=[],
+ node_indices=[],
+ tensor_indices=[],
+ input_tensors=self.inputs,
+ output_tensors=self.outputs)
+
+ # Fill in the output mask cache.
+ masks = []
+ for x in self.inputs:
+ mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access
+ masks.append(mask)
+ mask_cache_key = (generic_utils.object_list_uid(self.inputs) + '_' +
+ generic_utils.object_list_uid(masks))
+ masks = []
+ for x in self.outputs:
+ mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access
+ masks.append(mask)
+ if len(masks) == 1:
+ mask = masks[0]
+ else:
+ mask = masks
+ self._output_mask_cache[mask_cache_key] = mask
+
+ # Build self.input_names and self.output_names.
+ self.input_names = []
+ self.output_names = []
+ self._feed_input_names = []
+ self._feed_inputs = []
+ self._feed_input_shapes = []
+ for i, layer in enumerate(self._input_layers):
+ self.input_names.append(layer.name)
+ if layer.is_placeholder:
+ self._feed_input_names.append(layer.name)
+ self._feed_input_shapes.append(backend.int_shape(self.inputs[i]))
+ # layer.input gives an error in eager mode
+ if not context.executing_eagerly():
+ self._feed_inputs.append(layer.input)
+ for layer in self._output_layers:
+ self.output_names.append(layer.name)
+
+ def _init_subclassed_network(self, name=None):
+ self._base_init(name=name)
+ self._is_graph_network = False
+ call_args = tf_inspect.getargspec(self.call).args
+ if 'training' in call_args:
+ self._expects_training_arg = True
+ else:
+ self._expects_training_arg = False
+ if 'inputs' in call_args:
+ self._uses_inputs_arg = True
+ else:
+ self._uses_inputs_arg = False
+ self.outputs = None
+ self.inputs = None
+ self.built = False
+
+ def _track_layers(self, layers):
+ """Add Checkpointable dependencies on a list of Layers."""
+ weight_layer_index = 0
+ for layer_index, layer in enumerate(layers):
+ if layer.weights:
+ # Keep a separate index for layers which have weights. This allows users
+ # to insert Layers without weights anywhere in the network without
+ # breaking checkpoints.
+ self._track_checkpointable(
+ layer, name='layer_with_weights-%d' % weight_layer_index,
+ overwrite=True)
+ weight_layer_index += 1
+ # Even if it doesn't have weights, we should still track everything in
+ # case it has/will have Checkpointable dependencies.
+ self._track_checkpointable(
+ layer, name='layer-%d' % layer_index, overwrite=True)
+
+ def __setattr__(self, name, value):
+ no_dependency = isinstance(value, checkpointable.NoDependency)
+ if no_dependency:
+ value = value.value
+ if isinstance(value, (base_layer.Layer, Network)):
+ try:
+ is_graph_network = self._is_graph_network
+ except AttributeError:
+ raise RuntimeError('It looks like you are subclassing `Model` and you '
+ 'forgot to call `super(YourClass, self).__init__()`.'
+ ' Always start with this line.')
+ if not is_graph_network:
+ if value not in self._layers:
+ self._layers.append(value)
+ if hasattr(value, '_use_resource_variables'):
+ # In subclassed models, legacy layers (tf.layers) must always use
+ # resource variables.
+ value._use_resource_variables = True
+ if (not no_dependency
+ and isinstance(value, checkpointable.CheckpointableBase)):
+ # Layer (and therefore Network/Model) inherit from CheckpointableBase
+ # rather than Checkpointable, which means there is no Checkpointable
+ # __setattr__ override (it would be a performance issue for functional
+ # layers). Therefore Model tracks Checkpointable objects itself.
+ self._track_checkpointable(
+ checkpointable=value, name=name, overwrite=True)
+ super(Network, self).__setattr__(name, value)
+
+ def add_variable(self, name, shape, dtype=None, initializer=None,
+ regularizer=None, trainable=True, constraint=None):
+ raise NotImplementedError('`add_variable` is not supported on Networks.')
+
+ def add_loss(self, *args, **kwargs):
+ if context.executing_eagerly():
+ raise NotImplementedError('`add_loss` is not supported on Networks '
+ 'when eager execution is enabled.')
+ super(Network, self).add_loss(*args, **kwargs)
+
+ @property
+ def uses_learning_phase(self):
+ return any(
+ [getattr(x, '_uses_learning_phase', False) for x in self.outputs])
+
+ @property
+ def stateful(self):
+ return any([(hasattr(layer, 'stateful') and layer.stateful)
+ for layer in self.layers])
+
+ def reset_states(self):
+ for layer in self.layers:
+ if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
+ layer.reset_states()
+
+ @property
+ def state_updates(self):
+ """Returns the `updates` from all layers that are stateful.
+
+ This is useful for separating training updates and
+ state updates, e.g. when we need to update a layer's internal state
+ during prediction.
+
+ Returns:
+ A list of update ops.
+ """
+ state_updates = []
+ for layer in self.layers:
+ if getattr(layer, 'stateful', False):
+ if hasattr(layer, 'updates'):
+ state_updates += layer.updates
+ return state_updates
+
+ def get_weights(self):
+ """Retrieves the weights of the model.
+
+ Returns:
+ A flat list of Numpy arrays.
+ """
+ weights = []
+ for layer in self.layers:
+ weights += layer.weights
+ return backend.batch_get_value(weights)
+
+ def set_weights(self, weights):
+ """Sets the weights of the model.
+
+ Arguments:
+ weights: A list of Numpy arrays with shapes and types matching
+ the output of `model.get_weights()`.
+ """
+ tuples = []
+ for layer in self.layers:
+ num_param = len(layer.weights)
+ layer_weights = weights[:num_param]
+ for sw, w in zip(layer.weights, layer_weights):
+ tuples.append((sw, w))
+ weights = weights[num_param:]
+ backend.batch_set_value(tuples)
+
+ def compute_mask(self, inputs, mask):
+ if not self._is_graph_network:
+ return None
+
+ inputs = generic_utils.to_list(inputs)
+ if mask is None:
+ masks = [None for _ in range(len(inputs))]
+ else:
+ masks = generic_utils.to_list(mask)
+ cache_key = (generic_utils.object_list_uid(inputs)
+ + '_' + generic_utils.object_list_uid(masks))
+ if cache_key in self._output_mask_cache:
+ return self._output_mask_cache[cache_key]
+ else:
+ _, output_masks = self._run_internal_graph(inputs, mask=masks)
+ return output_masks
+
+ @property
+ def layers(self):
+ return self._layers
+
+ def get_layer(self, name=None, index=None):
+ """Retrieves a layer based on either its name (unique) or index.
+
+ If `name` and `index` are both provided, `index` will take precedence.
+ Indices are based on order of horizontal graph traversal (bottom-up).
+
+ Arguments:
+ name: String, name of layer.
+ index: Integer, index of layer.
+
+ Returns:
+ A layer instance.
+
+ Raises:
+ ValueError: In case of invalid layer name or index.
+ """
+ # TODO(fchollet): We could build a dictionary based on layer names
+ # since they are constant, but we have not done that yet.
+ if index is not None:
+ if len(self.layers) <= index:
+ raise ValueError('Was asked to retrieve layer at index ' + str(index) +
+ ' but model only has ' + str(len(self.layers)) +
+ ' layers.')
+ else:
+ return self.layers[index]
+ else:
+ if not name:
+ raise ValueError('Provide either a layer name or layer index.')
+ for layer in self.layers:
+ if layer.name == name:
+ return layer
+ raise ValueError('No such layer: ' + name)
+
+ @property
+ def updates(self):
+ """Retrieves the network's updates.
+
+ Will only include updates that are either
+ unconditional, or conditional on inputs to this model
+ (e.g. will not include updates that were created by layers of this model
+ outside of the model).
+
+ Effectively, `network.updates` behaves like `layer.updates`.
+
+ Concrete example:
+
+ ```python
+ bn = keras.layers.BatchNormalization()
+ x1 = keras.layers.Input(shape=(10,))
+ _ = bn(x1) # This creates 2 updates.
+
+ x2 = keras.layers.Input(shape=(10,))
+ y2 = bn(x2) # This creates 2 more updates.
+
+ # The BN layer has now 4 updates.
+ self.assertEqual(len(bn.updates), 4)
+
+ # Let's create a model from x2 to y2.
+ model = keras.models.Model(x2, y2)
+
+ # The model does not list all updates from its underlying layers,
+ # but only the updates that are relevant to it. Updates created by layers
+ # outside of the model are discarded.
+ self.assertEqual(len(model.updates), 2)
+
+ # If you keep calling the model, you append to its updates, just like
+ # what happens for a layer.
+ x3 = keras.layers.Input(shape=(10,))
+ y3 = model(x3)
+ self.assertEqual(len(model.updates), 4)
+
+ # But if you call the inner BN layer independently, you don't affect
+ # the model's updates.
+ x4 = keras.layers.Input(shape=(10,))
+ _ = bn(x4)
+ self.assertEqual(len(model.updates), 4)
+ ```
+
+ Returns:
+ A list of update ops.
+ """
+ if context.executing_eagerly():
+ return []
+
+ if not self.trainable and not self.stateful:
+ return []
+
+ updates = []
+ for layer in self.layers:
+ updates += layer.updates
+
+ # `updates` might contain irrelevant updates, so it needs to be filtered
+ # with respect to inputs the model has been called on.
+ if self.inputs:
+ relevant_inputs = self.inputs[:]
+ else:
+ relevant_inputs = []
+ for i in range(1, len(self._inbound_nodes)):
+ inputs = self.get_input_at(i)
+ if isinstance(inputs, list):
+ relevant_inputs += inputs
+ else:
+ relevant_inputs.append(inputs)
+ reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates)
+ relevant_conditional_updates = [x for x in updates if x in reachable]
+ unconditional_updates = [
+ x for x in updates if x._unconditional_update] # pylint: disable=protected-access
+ # A layer could be used multiple times in a nested structure,
+ # so the updates list must be de-duped.
+ return list(set(
+ relevant_conditional_updates + unconditional_updates + self._updates))
+
+ @property
+ def losses(self):
+ """Retrieves the network's losses.
+
+ Will only include losses that are either
+ unconditional, or conditional on inputs to this model
+ (e.g. will not include losses that depend on tensors
+ that aren't inputs to this model).
+
+ Returns:
+ A list of loss tensors.
+ """
+ losses = []
+ for layer in self.layers:
+ losses += layer.losses
+ if context.executing_eagerly():
+ return losses
+
+ if self.inputs:
+ relevant_inputs = self.inputs[:]
+ else:
+ relevant_inputs = []
+ for i in range(1, len(self._inbound_nodes)):
+ inputs = self.get_input_at(i)
+ if isinstance(inputs, list):
+ relevant_inputs += inputs
+ else:
+ relevant_inputs.append(inputs)
+ reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses)
+ relevant_conditional_losses = [x for x in losses if x in reachable]
+ unconditional_losses = [
+ x for x in losses if x._unconditional_loss] # pylint: disable=protected-access
+ return list(set(
+ relevant_conditional_losses + unconditional_losses + self._losses))
+
+ @property
+ def trainable_weights(self):
+ if not self.trainable:
+ return []
+ weights = []
+ for layer in self.layers:
+ weights += layer.trainable_weights
+ return weights
+
+ @property
+ def non_trainable_weights(self):
+ weights = []
+ for layer in self.layers:
+ weights += layer.non_trainable_weights
+ if not self.trainable:
+ trainable_weights = []
+ for layer in self.layers:
+ trainable_weights += layer.trainable_weights
+ return trainable_weights + weights
+ return weights
+
+ @property
+ def input_spec(self):
+ """Gets the network's input specs.
+
+ Returns:
+ A list of `InputSpec` instances (one per input to the model)
+ or a single instance if the model has only one input.
+ """
+ # If not a graph network, can't assume anything.
+ if not self._is_graph_network:
+ return None
+
+ specs = []
+ for layer in self._input_layers:
+ if layer.input_spec is None:
+ specs.append(None)
+ else:
+ if not isinstance(layer.input_spec, list):
+ raise TypeError('Layer ' + layer.name +
+ ' has an input_spec attribute that '
+ 'is not a list. We expect a list. '
+ 'Found input_spec = ' + str(layer.input_spec))
+ specs += layer.input_spec
+ if len(specs) == 1:
+ return specs[0]
+ return specs
+
+ def call(self, inputs, training=None, mask=None):
+ """Calls the model on new inputs.
+
+ In this case `call` just reapplies
+ all ops in the graph to the new inputs
+ (e.g. build a new computational graph from the provided inputs).
+
+ Arguments:
+ inputs: A tensor or list of tensors.
+ training: Boolean or boolean scalar tensor, indicating whether to run
+ the `Network` in training mode or inference mode.
+ mask: A mask or list of masks. A mask can be
+ either a tensor or None (no mask).
+
+ Returns:
+ A tensor if there is a single output, or
+ a list of tensors if there are more than one outputs.
+ """
+ inputs = nest.flatten(inputs)
+ if mask is None:
+ masks = [None for _ in range(len(inputs))]
+ else:
+ masks = nest.flatten(mask)
+
+ if not context.executing_eagerly():
+ # Try to retrieve cached outputs if the layer has already been called
+ # on these exact inputs.
+ cache_key = (generic_utils.object_list_uid(inputs)
+ + '_' + generic_utils.object_list_uid(masks))
+ if cache_key in self._output_tensor_cache:
+ # Cache hit.
+ return self._output_tensor_cache[cache_key]
+ # Actually apply the network graph to the new inputs.
+ outputs, _ = self._run_internal_graph(inputs,
+ training=training,
+ mask=masks)
+ return outputs
+
+ def compute_output_shape(self, input_shape):
+ if not self._is_graph_network:
+ raise NotImplementedError
+
+ if isinstance(input_shape, list):
+ input_shapes = []
+ for shape in input_shape:
+ if shape is not None:
+ input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list()))
+ else:
+ input_shapes.append(None)
+ else:
+ if input_shape is not None:
+ input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())]
+ else:
+ input_shapes = [None]
+
+ if len(input_shapes) != len(self._input_layers):
+ raise ValueError('Invalid input_shape argument ' + str(input_shape) +
+ ': model has ' + str(len(self._input_layers)) +
+ ' tensor inputs.')
+
+ cache_key = generic_utils.object_list_uid(input_shapes)
+ if cache_key not in self._output_shape_cache:
+ # Cache miss. We have to run the network graph manually (recursive calls
+ # to `compute_output_shape`).
+ layers_to_output_shapes = {}
+ for i in range(len(input_shapes)):
+ layer = self._input_layers[i]
+ input_shape = input_shapes[i]
+ # It's an input layer: then `compute_output_shape` is identity,
+ # and there is only one node and one tensor output.
+ shape_key = layer.name + '_0_0'
+ layers_to_output_shapes[shape_key] = input_shape
+
+ depth_keys = list(self._nodes_by_depth.keys())
+ depth_keys.sort(reverse=True)
+ # Iterate over nodes, by depth level.
+ if len(depth_keys) > 1:
+ for depth in depth_keys:
+ nodes = self._nodes_by_depth[depth]
+ for node in nodes:
+ # This is always a single layer, never a list.
+ layer = node.outbound_layer
+ if layer in self._input_layers:
+ # We've already covered the input layers
+ # a few lines above.
+ continue
+ # Potentially redundant list,
+ # same size as node.input_tensors.
+ input_shapes = []
+ for j in range(len(node.inbound_layers)):
+ inbound_layer = node.inbound_layers[j]
+ node_index = node.node_indices[j]
+ tensor_index = node.tensor_indices[j]
+ shape_key = inbound_layer.name + '_%s_%s' % (node_index,
+ tensor_index)
+ input_shape = layers_to_output_shapes[shape_key]
+ input_shapes.append(input_shape)
+
+ if len(input_shapes) == 1:
+ output_shape = layer.compute_output_shape(input_shapes[0])
+ else:
+ output_shape = layer.compute_output_shape(input_shapes)
+ if isinstance(output_shape, list):
+ output_shapes = [
+ tuple(tensor_shape.TensorShape(shape).as_list())
+ for shape in output_shape
+ ]
+ else:
+ output_shapes = [
+ tuple(tensor_shape.TensorShape(output_shape).as_list())
+ ]
+
+ node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access
+ for j in range(len(output_shapes)):
+ shape_key = layer.name + '_%s_%s' % (node_index, j)
+ layers_to_output_shapes[shape_key] = output_shapes[j]
+
+ # Read final output shapes from layers_to_output_shapes.
+ output_shapes = []
+ for i in range(len(self._output_layers)):
+ layer, node_index, tensor_index = self._output_coordinates[i]
+ shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
+ output_shapes.append(layers_to_output_shapes[shape_key])
+ # Store in cache.
+ self._output_shape_cache[cache_key] = output_shapes
+ else:
+ # Cache hit.
+ output_shapes = self._output_shape_cache[cache_key]
+
+ if isinstance(output_shapes, list):
+ if len(output_shapes) == 1:
+ return tensor_shape.TensorShape(output_shapes[0])
+ else:
+ return [tensor_shape.TensorShape(shape) for shape in output_shapes]
+ else:
+ return tensor_shape.TensorShape(output_shapes)
+
+ def _run_internal_graph(self, inputs, training=None, mask=None):
+ """Computes output tensors for new inputs.
+
+ # Note:
+ - Expects `inputs` to be a list (potentially with 1 element).
+ - Can be run on non-Keras tensors.
+
+ Arguments:
+ inputs: List of tensors
+ training: Boolean learning phase.
+ mask: List of masks (tensors or None).
+
+ Returns:
+ Three lists: output_tensors, output_masks, output_shapes
+ """
+ # Note: masking support is relevant mainly for Keras.
+ # It cannot be factored out without having the fully reimplement the network
+ # calling logic on the Keras side. We choose to incorporate it in
+ # Network because 1) it may be useful to fully support in tf.layers in
+ # the future and 2) Keras is a major user of Network. If you don't
+ # use masking, it does not interfere with regular behavior at all and you
+ # can ignore it.
+ if mask is None:
+ masks = [None for _ in range(len(inputs))]
+ else:
+ masks = mask
+
+ # Dictionary mapping reference tensors to tuples
+ # (computed tensor, compute mask)
+ # we assume a 1:1 mapping from tensor to mask
+ # TODO(fchollet): raise exception when a `.compute_mask()` call
+ # does not return a list the same size as `call`
+ tensor_map = {}
+ for x, y, mask in zip(self.inputs, inputs, masks):
+ tensor_map[str(id(x))] = (y, mask)
+
+ depth_keys = list(self._nodes_by_depth.keys())
+ depth_keys.sort(reverse=True)
+ for depth in depth_keys:
+ nodes = self._nodes_by_depth[depth]
+ for node in nodes:
+ # This is always a single layer, never a list.
+ layer = node.outbound_layer
+ reference_input_tensors = node.input_tensors
+ reference_output_tensors = node.output_tensors
+
+ # If all previous input tensors are available in tensor_map,
+ # then call node.inbound_layer on them.
+ computed_data = [] # List of tuples (input, mask).
+ for x in reference_input_tensors:
+ if str(id(x)) in tensor_map:
+ computed_data.append(tensor_map[str(id(x))])
+
+ if len(computed_data) == len(reference_input_tensors):
+ # Call layer (reapplying ops to new inputs).
+ with ops.name_scope(layer.name):
+ if node.arguments:
+ kwargs = node.arguments
+ else:
+ kwargs = {}
+ if len(computed_data) == 1:
+ computed_tensor, computed_mask = computed_data[0]
+ # Ensure mask propagation if applicable.
+ if 'mask' in tf_inspect.getargspec(layer.call).args:
+ kwargs.setdefault('mask', computed_mask)
+ if 'training' in tf_inspect.getargspec(layer.call).args:
+ kwargs.setdefault('training', training)
+
+ output_tensors = nest.flatten(
+ layer.call(computed_tensor, **kwargs))
+ if hasattr(layer, 'compute_mask'):
+ output_masks = layer.compute_mask(computed_tensor,
+ computed_mask)
+ if output_masks is None:
+ output_masks = [None for _ in output_tensors]
+ else:
+ output_masks = nest.flatten(output_masks)
+ else:
+ output_masks = [None for _ in output_tensors]
+ computed_tensors = [computed_tensor]
+ computed_masks = [computed_mask]
+ else:
+ computed_tensors = [x[0] for x in computed_data]
+ computed_masks = [x[1] for x in computed_data]
+ if 'mask' in tf_inspect.getargspec(layer.call).args:
+ kwargs.setdefault('mask', computed_masks)
+ if 'training' in tf_inspect.getargspec(layer.call).args:
+ kwargs.setdefault('training', training)
+
+ output_tensors = nest.flatten(
+ layer.call(computed_tensors, **kwargs))
+
+ if hasattr(layer, 'compute_mask'):
+ output_masks = layer.compute_mask(computed_tensors,
+ computed_masks)
+ if output_masks is None:
+ output_masks = [None for _ in output_tensors]
+ else:
+ output_masks = nest.flatten(output_masks)
+ else:
+ output_masks = [None for _ in output_tensors]
+
+ if not context.executing_eagerly():
+ if layer.activity_regularizer is not None:
+ regularization_losses = [
+ layer.activity_regularizer(x) for x in output_tensors
+ ]
+ # Apply activity regularizer if any:
+ layer.add_loss(regularization_losses, computed_tensors)
+
+ # Update tensor_map.
+ for x, y, mask in zip(reference_output_tensors, output_tensors,
+ output_masks):
+ tensor_map[str(id(x))] = (y, mask)
+
+ output_tensors = []
+ output_masks = []
+ output_shapes = []
+ for x in self.outputs:
+ assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
+ tensor, mask = tensor_map[str(id(x))]
+ output_shapes.append(backend.int_shape(x))
+ output_tensors.append(tensor)
+ output_masks.append(mask)
+
+ if len(output_tensors) == 1:
+ output_tensors = output_tensors[0]
+ if output_shapes is not None:
+ output_shapes = output_shapes[0]
+ if output_masks is not None:
+ output_masks = output_masks[0]
+
+ if not context.executing_eagerly():
+ # Update cache;
+ # keys are based on ids on input tensors and inputs masks.
+ cache_key = (generic_utils.object_list_uid(inputs)
+ + '_' + generic_utils.object_list_uid(masks))
+ self._output_tensor_cache[cache_key] = output_tensors
+ self._output_mask_cache[cache_key] = output_masks
+
+ if output_shapes is not None:
+ input_shapes = [backend.int_shape(x) for x in inputs]
+ cache_key = generic_utils.object_list_uid(input_shapes)
+ self._output_shape_cache[cache_key] = output_shapes
+
+ return output_tensors, output_masks
+
+ def get_config(self):
+ if not self._is_graph_network:
+ raise NotImplementedError
+
+ config = {
+ 'name': self.name,
+ }
+ node_conversion_map = {}
+ for layer in self.layers:
+ if issubclass(layer.__class__, Network):
+ # Networks start with a pre-existing node
+ # linking their input to output.
+ kept_nodes = 1
+ else:
+ kept_nodes = 0
+ for original_node_index, node in enumerate(layer._inbound_nodes):
+ node_key = _make_node_key(layer.name, original_node_index)
+ if node_key in self._network_nodes:
+ node_conversion_map[node_key] = kept_nodes
+ kept_nodes += 1
+ layer_configs = []
+ for layer in self.layers: # From the earliest layers on.
+ layer_class_name = layer.__class__.__name__
+ layer_config = layer.get_config()
+ filtered_inbound_nodes = []
+ for original_node_index, node in enumerate(layer._inbound_nodes):
+ node_key = _make_node_key(layer.name, original_node_index)
+ if node_key in self._network_nodes:
+ # The node is relevant to the model:
+ # add to filtered_inbound_nodes.
+ if node.arguments:
+ try:
+ json.dumps(node.arguments)
+ kwargs = node.arguments
+ except TypeError:
+ logging.warning(
+ 'Layer ' + layer.name +
+ ' was passed non-serializable keyword arguments: ' +
+ str(node.arguments) + '. They will not be included '
+ 'in the serialized model (and thus will be missing '
+ 'at deserialization time).')
+ kwargs = {}
+ else:
+ kwargs = {}
+ if node.inbound_layers:
+ node_data = []
+ for i in range(len(node.inbound_layers)):
+ inbound_layer = node.inbound_layers[i]
+ node_index = node.node_indices[i]
+ tensor_index = node.tensor_indices[i]
+ node_key = _make_node_key(inbound_layer.name, node_index)
+ new_node_index = node_conversion_map.get(node_key, 0)
+ node_data.append(
+ [inbound_layer.name, new_node_index, tensor_index, kwargs])
+ filtered_inbound_nodes.append(node_data)
+ layer_configs.append({
+ 'name': layer.name,
+ 'class_name': layer_class_name,
+ 'config': layer_config,
+ 'inbound_nodes': filtered_inbound_nodes,
+ })
+ config['layers'] = layer_configs
+
+ # Gather info about inputs and outputs.
+ model_inputs = []
+ for i in range(len(self._input_layers)):
+ layer, node_index, tensor_index = self._input_coordinates[i]
+ node_key = _make_node_key(layer.name, node_index)
+ if node_key not in self._network_nodes:
+ continue
+ new_node_index = node_conversion_map[node_key]
+ model_inputs.append([layer.name, new_node_index, tensor_index])
+ config['input_layers'] = model_inputs
+ model_outputs = []
+ for i in range(len(self._output_layers)):
+ layer, node_index, tensor_index = self._output_coordinates[i]
+ node_key = _make_node_key(layer.name, node_index)
+ if node_key not in self._network_nodes:
+ continue
+ new_node_index = node_conversion_map[node_key]
+ model_outputs.append([layer.name, new_node_index, tensor_index])
+ config['output_layers'] = model_outputs
+ return copy.deepcopy(config)
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ """Instantiates a Model from its config (output of `get_config()`).
+
+ Arguments:
+ config: Model config dictionary.
+ custom_objects: Optional dictionary mapping names
+ (strings) to custom classes or functions to be
+ considered during deserialization.
+
+ Returns:
+ A model instance.
+
+ Raises:
+ ValueError: In case of improperly formatted config dict.
+ """
+ # Layer instances created during
+ # the graph reconstruction process
+ created_layers = {}
+
+ # Dictionary mapping layer instances to
+ # node data that specifies a layer call.
+ # It acts as a queue that maintains any unprocessed
+ # layer call until it becomes possible to process it
+ # (i.e. until the input tensors to the call all exist).
+ unprocessed_nodes = {}
+
+ def add_unprocessed_node(layer, node_data):
+ if layer not in unprocessed_nodes:
+ unprocessed_nodes[layer] = [node_data]
+ else:
+ unprocessed_nodes[layer].append(node_data)
+
+ def process_node(layer, node_data):
+ """Deserialize a node.
+
+ Arguments:
+ layer: layer instance.
+ node_data: node config dict.
+
+ Raises:
+ ValueError: In case of improperly formatted `node_data` dict.
+ """
+ input_tensors = []
+ for input_data in node_data:
+ inbound_layer_name = input_data[0]
+ inbound_node_index = input_data[1]
+ inbound_tensor_index = input_data[2]
+ if len(input_data) == 3:
+ kwargs = {}
+ elif len(input_data) == 4:
+ kwargs = input_data[3]
+ else:
+ raise ValueError('Improperly formatted model config.')
+ if inbound_layer_name not in created_layers:
+ add_unprocessed_node(layer, node_data)
+ return
+ inbound_layer = created_layers[inbound_layer_name]
+ if len(inbound_layer._inbound_nodes) <= inbound_node_index:
+ add_unprocessed_node(layer, node_data)
+ return
+ inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
+ input_tensors.append(inbound_node.output_tensors[inbound_tensor_index])
+ # Call layer on its inputs, thus creating the node
+ # and building the layer if needed.
+ if input_tensors:
+ if len(input_tensors) == 1:
+ layer(input_tensors[0], **kwargs)
+ else:
+ layer(input_tensors, **kwargs)
+
+ def process_layer(layer_data):
+ """Deserializes a layer, then call it on appropriate inputs.
+
+ Arguments:
+ layer_data: layer config dict.
+
+ Raises:
+ ValueError: In case of improperly formatted `layer_data` dict.
+ """
+ layer_name = layer_data['name']
+
+ # Instantiate layer.
+ from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
+
+ layer = deserialize_layer(layer_data, custom_objects=custom_objects)
+ created_layers[layer_name] = layer
+
+ # Gather layer inputs.
+ inbound_nodes_data = layer_data['inbound_nodes']
+ for node_data in inbound_nodes_data:
+ # We don't process nodes (i.e. make layer calls)
+ # on the fly because the inbound node may not yet exist,
+ # in case of layer shared at different topological depths
+ # (e.g. a model such as A(B(A(B(x)))))
+ add_unprocessed_node(layer, node_data)
+
+ # First, we create all layers and enqueue nodes to be processed
+ for layer_data in config['layers']:
+ process_layer(layer_data)
+ # Then we process nodes in order of layer depth.
+ # Nodes that cannot yet be processed (if the inbound node
+ # does not yet exist) are re-enqueued, and the process
+ # is repeated until all nodes are processed.
+ while unprocessed_nodes:
+ for layer_data in config['layers']:
+ layer = created_layers[layer_data['name']]
+ if layer in unprocessed_nodes:
+ for node_data in unprocessed_nodes.pop(layer):
+ process_node(layer, node_data)
+
+ name = config.get('name')
+ input_tensors = []
+ output_tensors = []
+ for layer_data in config['input_layers']:
+ layer_name, node_index, tensor_index = layer_data
+ assert layer_name in created_layers
+ layer = created_layers[layer_name]
+ layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
+ input_tensors.append(layer_output_tensors[tensor_index])
+ for layer_data in config['output_layers']:
+ layer_name, node_index, tensor_index = layer_data
+ assert layer_name in created_layers
+ layer = created_layers[layer_name]
+ layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
+ output_tensors.append(layer_output_tensors[tensor_index])
+ return cls(inputs=input_tensors, outputs=output_tensors, name=name)
+
+ def save(self, filepath, overwrite=True, include_optimizer=True):
+ """Saves the model to a single HDF5 file.
+
+ The savefile includes:
+ - The model architecture, allowing to re-instantiate the model.
+ - The model weights.
+ - The state of the optimizer, allowing to resume training
+ exactly where you left off.
+
+ This allows you to save the entirety of the state of a model
+ in a single file.
+
+ Saved models can be reinstantiated via `keras.models.load_model`.
+ The model returned by `load_model`
+ is a compiled model ready to be used (unless the saved model
+ was never compiled in the first place).
+
+ Arguments:
+ filepath: String, path to the file to save the weights to.
+ overwrite: Whether to silently overwrite any existing file at the
+ target location, or provide the user with a manual prompt.
+ include_optimizer: If True, save optimizer's state together.
+
+ Example:
+
+ ```python
+ from keras.models import load_model
+
+ model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
+ del model # deletes the existing model
+
+ # returns a compiled model
+ # identical to the previous one
+ model = load_model('my_model.h5')
+ ```
+ """
+ if not self._is_graph_network:
+ raise NotImplementedError
+
+ from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top
+ save_model(self, filepath, overwrite, include_optimizer)
+
+ def save_weights(self, filepath, overwrite=True, save_format=None):
+ """Saves all layer weights.
+
+ Either saves in HDF5 or in TensorFlow format based on the `save_format`
+ argument.
+
+ When saving in HDF5 format, the weight file has:
+ - `layer_names` (attribute), a list of strings
+ (ordered names of model layers).
+ - For every layer, a `group` named `layer.name`
+ - For every such layer group, a group attribute `weight_names`,
+ a list of strings
+ (ordered names of weights tensor of the layer).
+ - For every weight in the layer, a dataset
+ storing the weight value, named after the weight tensor.
+
+ When saving in TensorFlow format, all objects referenced by the network are
+ saved in the same format as `tf.train.Checkpoint`, including any `Layer`
+ instances or `Optimizer` instances assigned to object attributes. For
+ networks constructed from inputs and outputs using `tf.keras.Model(inputs,
+ outputs)`, `Layer` instances used by the network are tracked/saved
+ automatically. For user-defined classes which inherit from `tf.keras.Model`,
+ `Layer` instances must be assigned to object attributes, typically in the
+ constructor. See the documentation of `tf.train.Checkpoint` and
+ `tf.keras.Model` for details.
+
+ Arguments:
+ filepath: String, path to the file to save the weights to. When saving
+ in TensorFlow format, this is the prefix used for checkpoint files
+ (multiple files are generated). Note that the '.h5' suffix causes
+ weights to be saved in HDF5 format.
+ overwrite: Whether to silently overwrite any existing file at the
+ target location, or provide the user with a manual prompt.
+ save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
+ '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
+ `None` defaults to 'tf'.
+
+ Raises:
+ ImportError: If h5py is not available when attempting to save in HDF5
+ format.
+ ValueError: For invalid/unknown format arguments.
+ """
+ filepath_is_h5 = _is_hdf5_filepath(filepath)
+ if save_format is None:
+ if filepath_is_h5:
+ save_format = 'h5'
+ else:
+ save_format = 'tf'
+ else:
+ user_format = save_format.lower().strip()
+ if user_format in ('tensorflow', 'tf'):
+ save_format = 'tf'
+ elif user_format in ('hdf5', 'h5', 'keras'):
+ save_format = 'h5'
+ else:
+ raise ValueError(
+ 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
+ save_format,))
+ if save_format == 'tf' and filepath_is_h5:
+ raise ValueError(
+ ('save_weights got save_format="tf"/"tensorflow", but the '
+ 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
+ 'when saving in TensorFlow format.')
+ % filepath)
+
+ if save_format == 'h5' and h5py is None:
+ raise ImportError(
+ '`save_weights` requires h5py when saving in hdf5.')
+ if save_format == 'tf':
+ check_filepath = filepath + '.index'
+ else:
+ check_filepath = filepath
+ # If file exists and should not be overwritten:
+ if not overwrite and os.path.isfile(check_filepath):
+ proceed = ask_to_proceed_with_overwrite(check_filepath)
+ if not proceed:
+ return
+ if save_format == 'h5':
+ with h5py.File(filepath, 'w') as f:
+ saving.save_weights_to_hdf5_group(f, self.layers)
+ else:
+ self._checkpointable_saver.save(filepath)
+
+ def load_weights(self, filepath, by_name=False):
+ """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
+
+ If `by_name` is False weights are loaded based on the network's
+ topology. This means the architecture should be the same as when the weights
+ were saved. Note that layers that don't have weights are not taken into
+ account in the topological ordering, so adding or removing layers is fine as
+ long as they don't have weights.
+
+ If `by_name` is True, weights are loaded into layers only if they share the
+ same name. This is useful for fine-tuning or transfer-learning models where
+ some of the layers have changed.
+
+ Only topological loading (`by_name=False`) is supported when loading weights
+ from the TensorFlow format. Note that topological loading differs slightly
+ between TensorFlow and HDF5 formats for user-defined classes inheriting from
+ `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
+ TensorFlow format loads based on the object-local names of attributes to
+ which layers are assigned in the `Model`'s constructor.
+
+ Arguments:
+ filepath: String, path to the weights file to load. For weight files in
+ TensorFlow format, this is the file prefix (the same as was passed
+ to `save_weights`).
+ by_name: Boolean, whether to load weights by name or by topological
+ order. Only topological loading is supported for weight files in
+ TensorFlow format.
+
+ Returns:
+ When loading a weight file in TensorFlow format, returns the same status
+ object as `tf.train.Checkpoint.restore`. When graph building, restore
+ ops are run automatically as soon as the network is built (on first call
+ for user-defined classes inheriting from `Model`, immediately if it is
+ already built).
+
+ When loading weights in HDF5 format, returns `None`.
+
+ Raises:
+ ImportError: If h5py is not available and the weight file is in HDF5
+ format.
+ """
+ if _is_hdf5_filepath(filepath):
+ save_format = 'h5'
+ else:
+ try:
+ pywrap_tensorflow.NewCheckpointReader(filepath)
+ save_format = 'tf'
+ except errors_impl.DataLossError:
+ # The checkpoint is not readable in TensorFlow format. Try HDF5.
+ save_format = 'h5'
+ if save_format == 'tf':
+ status = self._checkpointable_saver.restore(filepath)
+ if by_name:
+ raise NotImplementedError(
+ 'Weights may only be loaded based on topology into Models when '
+ 'loading TensorFlow-formatted weights (got by_name=True to '
+ 'load_weights).')
+ if not context.executing_eagerly():
+ finalizer = status.run_restore_ops
+ if self.built:
+ finalizer()
+ else:
+ # Hold on to this status object until the network is built (for
+ # subclassed Models). Then we'll run restore ops if necessary.
+ self._in_progress_restore_finalizer = finalizer
+ return status
+ if h5py is None:
+ raise ImportError(
+ '`load_weights` requires h5py when loading weights from HDF5.')
+ if self._is_graph_network and not self.built:
+ raise NotImplementedError(
+ 'Unable to load weights saved in HDF5 format into a subclassed '
+ 'Model which has not created its variables yet. Call the Model '
+ 'first, then load the weights.')
+ with h5py.File(filepath, 'r') as f:
+ if 'layer_names' not in f.attrs and 'model_weights' in f:
+ f = f['model_weights']
+ if by_name:
+ saving.load_weights_from_hdf5_group_by_name(f, self.layers)
+ else:
+ saving.load_weights_from_hdf5_group(f, self.layers)
+
+ def _post_build_cleanup(self):
+ super(Network, self)._post_build_cleanup()
+ if self._in_progress_restore_finalizer is not None:
+ # Runs queued restore operations left over from load_weights when graph
+ # building.
+ self._in_progress_restore_finalizer()
+ self._in_progress_restore_finalizer = None
+
+ def _updated_config(self):
+ """Util shared between different serialization methods.
+
+ Returns:
+ Model config with Keras version information added.
+ """
+ from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
+
+ config = self.get_config()
+ model_config = {
+ 'class_name': self.__class__.__name__,
+ 'config': config,
+ 'keras_version': keras_version,
+ 'backend': backend.backend()
+ }
+ return model_config
+
+ def to_json(self, **kwargs):
+ """Returns a JSON string containing the network configuration.
+
+ To load a network from a JSON save file, use
+ `keras.models.model_from_json(json_string, custom_objects={})`.
+
+ Arguments:
+ **kwargs: Additional keyword arguments
+ to be passed to `json.dumps()`.
+
+ Returns:
+ A JSON string.
+ """
+ def get_json_type(obj):
+ # If obj is any numpy type
+ if type(obj).__module__ == np.__name__:
+ return obj.item()
+
+ # If obj is a python 'type'
+ if type(obj).__name__ == type.__name__:
+ return obj.__name__
+
+ raise TypeError('Not JSON Serializable:', obj)
+
+ model_config = self._updated_config()
+ return json.dumps(model_config, default=get_json_type, **kwargs)
+
+ def to_yaml(self, **kwargs):
+ """Returns a yaml string containing the network configuration.
+
+ To load a network from a yaml save file, use
+ `keras.models.model_from_yaml(yaml_string, custom_objects={})`.
+
+ `custom_objects` should be a dictionary mapping
+ the names of custom losses / layers / etc to the corresponding
+ functions / classes.
+
+ Arguments:
+ **kwargs: Additional keyword arguments
+ to be passed to `yaml.dump()`.
+
+ Returns:
+ A YAML string.
+
+ Raises:
+ ImportError: if yaml module is not found.
+ """
+ if yaml is None:
+ raise ImportError('Requires yaml module installed.')
+ return yaml.dump(self._updated_config(), **kwargs)
+
+ def summary(self, line_length=None, positions=None, print_fn=None):
+ """Prints a string summary of the network.
+
+ Arguments:
+ line_length: Total length of printed lines
+ (e.g. set this to adapt the display to different
+ terminal window sizes).
+ positions: Relative or absolute positions of log elements
+ in each line. If not provided,
+ defaults to `[.33, .55, .67, 1.]`.
+ print_fn: Print function to use. Defaults to `print`.
+ It will be called on each line of the summary.
+ You can set it to a custom function
+ in order to capture the string summary.
+ """
+ print_layer_summary(self,
+ line_length=line_length,
+ positions=positions,
+ print_fn=print_fn)
+
+
+ def get_source_inputs(tensor, layer=None, node_index=None):
+ """Returns the list of input tensors necessary to compute `tensor`.
+
+ Output will always be a list of tensors
+ (potentially with 1 element).
+
+ Arguments:
+ tensor: The tensor to start from.
+ layer: Origin layer of the tensor. Will be
+ determined via tensor._keras_history if not provided.
+ node_index: Origin node index of the tensor.
+
+ Returns:
+ List of input tensors.
+ """
+ if not hasattr(tensor, '_keras_history'):
+ return tensor
+
+ if layer is None or node_index:
+ layer, node_index, _ = tensor._keras_history
+ if not layer._inbound_nodes:
+ return [tensor]
+ else:
+ node = layer._inbound_nodes[node_index]
+ if not node.inbound_layers:
+ # Reached an Input layer, stop recursion.
+ return node.input_tensors
+ else:
+ source_tensors = []
+ for i in range(len(node.inbound_layers)):
+ x = node.input_tensors[i]
+ layer = node.inbound_layers[i]
+ node_index = node.node_indices[i]
+ previous_sources = get_source_inputs(x, layer, node_index)
+ # Avoid input redundancy.
+ for x in previous_sources:
+ if x not in source_tensors:
+ source_tensors.append(x)
+ return source_tensors
+
+
+ def _is_hdf5_filepath(filepath):
+ return filepath.endswith('.h5') or filepath.endswith('.keras')
+
+
+ def _make_node_key(layer_name, node_index):
+ return layer_name + '_ib-' + str(node_index)
+
+
+ def _map_graph_network(inputs, outputs):
+ """Validates a network's topology and gather its layers and nodes.
+
+ Arguments:
+ inputs: List of input tensors.
+ outputs: List of outputs tensors.
+
+ Returns:
+ A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
+ - nodes: list of Node instances.
+ - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
+ - layers: list of Layer instances.
+ - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
+
+ Raises:
+ ValueError: In case the network is not valid (e.g. disconnected graph).
+ """
+ # Network_nodes: set of nodes included in the graph of layers
+ # (not all nodes included in the layers are relevant to the current graph).
+ network_nodes = set() # ids of all nodes relevant to the Network
+ nodes_depths = {} # dict {node: depth value}
+ layers_depths = {} # dict {layer: depth value}
+ layer_indices = {} # dict {layer: index in traversal}
+ nodes_in_decreasing_depth = []
+
+ def build_map(tensor,
+ finished_nodes,
+ nodes_in_progress,
+ layer,
+ node_index,
+ tensor_index):
+ """Builds a map of the graph of layers.
+
+ This recursively updates the map `layer_indices`,
+ the list `nodes_in_decreasing_depth` and the set `network_nodes`.
+
+ Arguments:
+ tensor: Some tensor in a graph.
+ finished_nodes: Set of nodes whose subgraphs have been traversed
+ completely. Useful to prevent duplicated work.
+ nodes_in_progress: Set of nodes that are currently active on the
+ recursion stack. Useful to detect cycles.
+ layer: Layer from which `tensor` comes from. If not provided,
+ will be obtained from `tensor._keras_history`.
+ node_index: Node index from which `tensor` comes from.
+ tensor_index: Tensor_index from which `tensor` comes from.
+
+ Raises:
+ ValueError: if a cycle is detected.
+ """
+ node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
+
+ # Prevent cycles.
+ if node in nodes_in_progress:
+ raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
+ layer.name + '" is part of a cycle.')
+
+ # Don't repeat work for shared subgraphs
+ if node in finished_nodes:
+ return
+
+ node_key = _make_node_key(layer.name, node_index)
+ # Update network_nodes.
+ network_nodes.add(node_key)
+
+ # Store the traversal order for layer sorting.
+ if layer not in layer_indices:
+ layer_indices[layer] = len(layer_indices)
+
+ nodes_in_progress.add(node)
+
+ # Propagate to all previous tensors connected to this node.
+ for i in range(len(node.inbound_layers)):
+ x = node.input_tensors[i]
+ layer = node.inbound_layers[i]
+ node_index = node.node_indices[i]
+ tensor_index = node.tensor_indices[i]
+ build_map(x, finished_nodes, nodes_in_progress, layer,
+ node_index, tensor_index)
+
+ finished_nodes.add(node)
+ nodes_in_progress.remove(node)
+ nodes_in_decreasing_depth.append(node)
+
+ finished_nodes = set()
+ nodes_in_progress = set()
+ for x in outputs:
+ layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
+ build_map(x, finished_nodes, nodes_in_progress,
+ layer=layer,
+ node_index=node_index,
+ tensor_index=tensor_index)
+
+ for node in reversed(nodes_in_decreasing_depth):
+ # If the depth is not set, the node has no outbound nodes (depth 0).
+ depth = nodes_depths.setdefault(node, 0)
+
+ # Update the depth of the corresponding layer
+ previous_depth = layers_depths.get(node.outbound_layer, 0)
+ # If we've seen this layer before at a higher depth,
+ # we should use that depth instead of the node depth.
+ # This is necessary for shared layers that have inputs at different
+ # depth levels in the graph.
+ depth = max(depth, previous_depth)
+ layers_depths[node.outbound_layer] = depth
+ nodes_depths[node] = depth
+
+ # Update the depth of inbound nodes.
+ # The "depth" of a node is the max of the depths
+ # of all layers it is connected to.
+ for i in range(len(node.inbound_layers)):
+ inbound_layer = node.inbound_layers[i]
+ node_index = node.node_indices[i]
+ inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access
+ previous_depth = nodes_depths.get(inbound_node, 0)
+ nodes_depths[inbound_node] = max(depth + 1, previous_depth)
+
+ # Build a dict {depth: list of nodes with this depth}
+ nodes_by_depth = {}
+ for node, depth in nodes_depths.items():
+ if depth not in nodes_by_depth:
+ nodes_by_depth[depth] = []
+ nodes_by_depth[depth].append(node)
+
+ # Build a dict {depth: list of layers with this depth}
+ layers_by_depth = {}
+ for layer, depth in layers_depths.items():
+ if depth not in layers_by_depth:
+ layers_by_depth[depth] = []
+ layers_by_depth[depth].append(layer)
+
+ # Get sorted list of layer depths.
+ depth_keys = list(layers_by_depth.keys())
+ depth_keys.sort(reverse=True)
+
+ # Set self.layers and self._layers_by_depth.
+ layers = []
+ for depth in depth_keys:
+ layers_for_depth = layers_by_depth[depth]
+ # Network.layers needs to have a deterministic order:
+ # here we order them by traversal order.
+ layers_for_depth.sort(key=lambda x: layer_indices[x])
+ layers.extend(layers_for_depth)
+
+ # Get sorted list of node depths.
+ depth_keys = list(nodes_by_depth.keys())
+ depth_keys.sort(reverse=True)
+
+ # Check that all tensors required are computable.
+ # computable_tensors: all tensors in the graph
+ # that can be computed from the inputs provided.
+ computable_tensors = []
+ for x in inputs:
+ computable_tensors.append(x)
+
+ layers_with_complete_input = [] # To provide a better error msg.
+ for depth in depth_keys:
+ for node in nodes_by_depth[depth]:
+ layer = node.outbound_layer
+ if layer:
+ for x in node.input_tensors:
+ if x not in computable_tensors:
+ raise ValueError('Graph disconnected: '
+ 'cannot obtain value for tensor ' + str(x) +
+ ' at layer "' + layer.name + '". '
+ 'The following previous layers '
+ 'were accessed without issue: ' +
+ str(layers_with_complete_input))
+ for x in node.output_tensors:
+ computable_tensors.append(x)
+ layers_with_complete_input.append(layer.name)
+
+ # Ensure name unicity, which will be crucial for serialization
+ # (since serialized nodes refer to layers by their name).
+ all_names = [layer.name for layer in layers]
+ for name in all_names:
+ if all_names.count(name) != 1:
+ raise ValueError('The name "' + name + '" is used ' +
+ str(all_names.count(name)) + ' times in the model. '
+ 'All layer names should be unique.')
+ return network_nodes, nodes_by_depth, layers, layers_by_depth
--- /dev/null
- # The chunking of layer names array should have happend.
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+ #,============================================================================
+ """Tests for model saving."""
+
+ from __future__ import absolute_import
+ from __future__ import division
+ from __future__ import print_function
+
+ import os
+ import shutil
+ import tempfile
+
+ from absl.testing import parameterized
+ import numpy as np
+
+ from tensorflow.python import keras
+ from tensorflow.python.eager import context
+ from tensorflow.python.framework import constant_op
+ from tensorflow.python.framework import dtypes
+ from tensorflow.python.framework import ops
+ from tensorflow.python.framework import test_util
+ from tensorflow.python.keras.engine import training
+ from tensorflow.python.ops import array_ops
+ from tensorflow.python.ops import random_ops
+ from tensorflow.python.platform import test
+ from tensorflow.python.training import training as training_module
+
+ try:
+ import h5py # pylint:disable=g-import-not-at-top
+ except ImportError:
+ h5py = None
+
+
+ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
+
+ def test_weight_loading(self):
+ with self.test_session():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3)(a)
+ b = keras.layers.Dense(1)(x)
+ model = keras.models.Model(a, b)
+
+ x = np.random.random((3, 2))
+ ref_y = model.predict(x)
+ weights = model.get_weights()
+ model.set_weights(weights)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ with self.assertRaises(ValueError):
+ model.set_weights(weights[1:])
+ with self.assertRaises(ValueError):
+ model.set_weights(weights[::-1])
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ no_extension_path = os.path.join(temp_dir, 'test')
+ model.save_weights(no_extension_path, save_format='tf')
+ model.load_weights(no_extension_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ if h5py is None:
+ return # Skip rest of test if H5py isn't available.
+
+ h5_path = os.path.join(temp_dir, 'test.h5')
+ model.save_weights(h5_path)
+ model.load_weights(h5_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ model.load_weights(h5_path, by_name=True)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ model.save_weights(no_extension_path, save_format='hdf5')
+ model.load_weights(no_extension_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ def test_weight_preprocessing(self):
+ input_dim = 3
+ output_dim = 3
+ size = 2
+ cases = [
+ [
+ (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))),
+ [np.random.random((2, 1)), np.random.random((2, 1))],
+ (None, 3, 2),
+ ],
+ [
+ (keras.layers.TimeDistributed(keras.layers.Dense(1))),
+ [np.random.random((2, 1)), np.random.random((1,))],
+ (None, 3, 2),
+ ],
+ [
+ (keras.layers.Conv1D(output_dim, size, use_bias=False)),
+ [np.random.random((output_dim, input_dim, size, 1))],
+ (None, 4, input_dim),
+ ],
+ [
+ (keras.layers.Conv2D(output_dim, size,
+ use_bias=False, data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size))],
+ (None, input_dim, 4, 4),
+ ],
+ [
+ (keras.layers.Conv2DTranspose(output_dim, size,
+ use_bias=False,
+ data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size))],
+ (None, input_dim, 4, 4),
+ ],
+ [
+ (keras.layers.Conv2DTranspose(output_dim, size,
+ use_bias=False,
+ data_format='channels_last')),
+ [np.random.random((size, size, input_dim, output_dim))],
+ (None, 4, 4, input_dim),
+ ],
+ [
+ (keras.layers.Conv3D(output_dim, size,
+ use_bias=False, data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size, size))],
+ (None, input_dim, 4, 4, 4),
+ ],
+ [
+ (keras.layers.GRU(output_dim)),
+ [np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,))],
+ (None, 4, input_dim),
+ ],
+ [
+ (keras.layers.LSTM(output_dim)),
+ [np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,))],
+ (None, 4, input_dim),
+ ],
+ ]
+ for layer, weights, input_shape in cases:
+ layer.build(input_shape)
+ _ = keras.engine.saving.preprocess_weights_for_loading(
+ layer, weights, original_keras_version='1')
+
+ model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)])
+ _ = keras.engine.saving.preprocess_weights_for_loading(
+ model, model.weights, original_keras_version='1')
+
+ x = keras.Input((2,))
+ y = keras.layers.Dense(2)(x)
+ model = keras.models.Model(x, y)
+ _ = keras.engine.saving.preprocess_weights_for_loading(
+ model, model.weights, original_keras_version='1')
+
+ @parameterized.named_parameters(
+ ('gru', keras.layers.GRU, {
+ 'units': 2,
+ 'input_shape': (3, 5)
+ }),
+ ('gru_with_reset_after', keras.layers.GRU, {
+ 'units': 2,
+ 'input_shape': (3, 5),
+ 'reset_after': True
+ }),
+ ('lstm', keras.layers.LSTM, {
+ 'units': 2,
+ 'input_shape': (3, 5)
+ }),
+ ('cudnngru', keras.layers.CuDNNGRU, {
+ 'units': 2,
+ 'input_shape': (3, 5)
+ }),
+ ('cudnnlstm', keras.layers.CuDNNLSTM, {
+ 'units': 2,
+ 'input_shape': (3, 5)
+ }))
+ def test_preprocess_weights_for_loading_rnn_should_be_idempotent(
+ self, layer_class, layer_args):
+ with self.test_session():
+ layer = layer_class(**layer_args)
+ layer.build(input_shape=layer_args.get('input_shape'))
+ weights1 = layer.get_weights()
+ weights2 = keras.engine.saving.preprocess_weights_for_loading(
+ layer, weights1)
+ _ = [
+ self.assertAllClose(x, y, rtol=1e-05)
+ for (x, y) in zip(weights1, weights2)
+ ]
+
+ def test_sequential_weight_loading(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ batch_size = 5
+ num_classes = 2
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+
+ x = np.random.random((batch_size, input_dim))
+ ref_y = model.predict(x)
+
+ model.save_weights(h5_path)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+ model.load_weights(h5_path)
+ y = model.predict(x)
+
+ self.assertAllClose(y, ref_y)
+
+
+ class TestWholeModelSaving(test.TestCase):
+
+ def test_sequential_model_saving(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ new_model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = new_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # test that new updates are the same with both models
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ new_model.train_on_batch(x, y)
+ out = model.predict(x)
+ out2 = new_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ def test_sequential_model_saving_2(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ # test with custom optimizer, loss
+
+ class CustomOp(keras.optimizers.RMSprop):
+ pass
+
+ def custom_loss(y_true, y_pred):
+ return keras.losses.mse(y_true, y_pred)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ model = keras.models.load_model(
+ fname,
+ custom_objects={'CustomOp': CustomOp,
+ 'custom_loss': custom_loss})
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ def test_functional_model_saving(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ def test_saving_without_compilation(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_with_tf_optimizer(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse',
+ optimizer=training_module.AdadeltaOptimizer(0.1),
+ metrics=['acc'])
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_right_after_compilation(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+ model._make_train_function()
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_lambda_numpy_array_arguments(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ mean = np.random.random((4, 2, 3))
+ std = np.abs(np.random.random((4, 2, 3))) + 1e-5
+ inputs = keras.layers.Input(shape=(4, 2, 3))
+ output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
+ arguments={'mu': mean, 'std': std})(inputs)
+ model = keras.models.Model(inputs, output)
+ model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+
+ model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ self.assertAllClose(mean, model.layers[1].arguments['mu'])
+ self.assertAllClose(std, model.layers[1].arguments['std'])
+
+ def test_saving_model_with_long_layer_names(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ # This layer name will make the `layers_name` HDF5 attribute blow
+ # out of proportion. Note that it fits into the internal HDF5
+ # attribute memory limit on its own but because h5py converts
+ # the list of layer names into numpy array, which uses the same
+ # amout of memory for every item, it increases the memory
+ # requirements substantially.
+ x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))
+ f = x
+ for i in range(4):
+ f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
+ model = keras.Model(inputs=[x], outputs=[f])
+ model.compile(loss='mse', optimizer='adam', metrics=['acc'])
+
+ x = np.random.random((1, 2))
+ y = np.random.random((1, 2))
+ model.train_on_batch(x, y)
+ out = model.predict(x)
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+
+ # Check that the HDF5 files contains chunked array
+ # of layer names.
+ with h5py.File(fname, 'r') as h5file:
+ num_names_arrays = len([attr for attr in h5file['model_weights'].attrs
+ if attr.startswith('layer_names')])
- # The chunking of layer names array should have happend.
++ # The chunking of layer names array should have happened.
+ self.assertGreater(num_names_arrays, 0)
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Cleanup
+ os.close(fd)
+ os.remove(fname)
+
+ def test_saving_model_with_long_weights_names(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ x = keras.Input(shape=(2,), name='nested_model_input')
+ f = x
+ for i in range(4):
+ f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)
+ # This layer name will make the `weights_name`
+ # HDF5 attribute blow out of proportion.
+ f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f)
+ nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')
+
+ x = keras.Input(shape=(2,), name='outer_model_input')
+ f = nested_model(x)
+ f = keras.layers.Dense(2, name='outer_model_output')(f)
+
+ model = keras.Model(inputs=[x], outputs=[f])
+ model.compile(loss='mse', optimizer='adam', metrics=['acc'])
+
+ x = np.random.random((1, 2))
+ y = np.random.random((1, 2))
+ model.train_on_batch(x, y)
+ out = model.predict(x)
+
+ fd, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+
+ # Check that the HDF5 files contains chunked array
+ # of weight names.
+ with h5py.File(fname, 'r') as h5file:
+ num_weight_arrays = len(
+ [attr for attr in h5file['model_weights']['nested_model'].attrs
+ if attr.startswith('weight_names')])
++ # The chunking of layer names array should have happened.
+ self.assertGreater(num_weight_arrays, 0)
+ out2 = model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Cleanup
+ os.close(fd)
+ os.remove(fname)
+
+ def test_model_saving_to_pre_created_h5py_file(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ inputs = keras.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ outputs = keras.layers.Dense(3)(x)
+
+ model = keras.Model(inputs, outputs)
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.Adam(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+ with h5py.File(fname, mode='r+') as h5file:
+ keras.models.save_model(model, h5file)
+ loaded_model = keras.models.load_model(h5file)
+ out2 = loaded_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Test non-default options in h5
+ with h5py.File('_', driver='core',
+ backing_store=False) as h5file:
+ keras.models.save_model(model, h5file)
+ loaded_model = keras.models.load_model(h5file)
+ out2 = loaded_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
+ # Cleanup
+ os.close(fd)
+ os.remove(fname)
+
+
+ class SubclassedModel(training.Model):
+
+ def __init__(self):
+ super(SubclassedModel, self).__init__()
+ self.x_layer = keras.layers.Dense(3)
+ self.b_layer = keras.layers.Dense(1)
+
+ def call(self, a):
+ return self.b_layer(self.x_layer(a))
+
+
+ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_tensorflow_format_overwrite(self):
+ with self.test_session() as session:
+ model = SubclassedModel()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ executing_eagerly = context.executing_eagerly()
+ model(x) # pylint: disable=not-callable
+ if not executing_eagerly:
+ session.run([v.initializer for v in model.variables])
+ model.save_weights(prefix, save_format='tensorflow')
+ model.save_weights(prefix, save_format='tensorflow', overwrite=True)
+ with self.assertRaises(EOFError):
+ # Indirectly tests that the user is prompted
+ model.save_weights(prefix, save_format='tensorflow', overwrite=False)
+
+ def test_no_graph_pollution(self):
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.test_session(graph) as session:
+ model = SubclassedModel()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ model(x) # pylint: disable=not-callable
+ session.run([v.initializer for v in model.variables])
+ model.save_weights(prefix, save_format='tensorflow')
+ op_count = len(graph.get_operations())
+ model.save_weights(prefix, save_format='tensorflow')
+ self.assertEqual(len(graph.get_operations()), op_count)
+
+ model.load_weights(prefix)
+ op_count = len(graph.get_operations())
+ model.load_weights(prefix)
+ self.assertEqual(len(graph.get_operations()), op_count)
+
+ def _weight_loading_test_template(self, make_model_fn):
+ with self.test_session() as session:
+ model = make_model_fn()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ executing_eagerly = context.executing_eagerly()
+ ref_y_tensor = model(x)
+ if not executing_eagerly:
+ session.run([v.initializer for v in model.variables])
+ ref_y = self.evaluate(ref_y_tensor)
+ model.save_weights(prefix, save_format='tf')
+ for v in model.variables:
+ self.evaluate(
+ v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
+
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ model.load_weights(prefix)
+ y = self.evaluate(model(x))
+ self.assertAllClose(ref_y, y)
+
+ # Test restore-on-create if this is a subclassed Model (graph Networks
+ # will have already created their variables).
+ load_model = make_model_fn()
+ load_model.load_weights(prefix)
+ restore_on_create_y_tensor = load_model(x)
+ restore_on_create_y = self.evaluate(restore_on_create_y_tensor)
+ self.assertAllClose(ref_y, restore_on_create_y)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_graph_model(self):
+ def _make_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3)(a)
+ b = keras.layers.Dense(1)(x)
+ return keras.models.Model(a, b)
+
+ self._weight_loading_test_template(_make_graph_model)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_subclassed_model(self):
+ self._weight_loading_test_template(SubclassedModel)
+
+ def _new_layer_weight_loading_test_template(
+ self, first_model_fn, second_model_fn, restore_init_fn):
+ with self.test_session() as session:
+ model = first_model_fn()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ executing_eagerly = context.executing_eagerly()
+ ref_y_tensor = model(x)
+ if not executing_eagerly:
+ session.run([v.initializer for v in model.variables])
+ ref_y = self.evaluate(ref_y_tensor)
+ model.save_weights(prefix)
+ for v in model.variables:
+ self.evaluate(
+ v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
+
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ second_model = second_model_fn()
+ second_model.load_weights(prefix)
+ second_model(x)
+ self.evaluate(restore_init_fn(second_model))
+ second_model.save_weights(prefix)
+ # Check that the second model's checkpoint loads into the original model
+ model.load_weights(prefix)
+ y = self.evaluate(model(x))
+ self.assertAllClose(ref_y, y)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_graph_model_added_layer(self):
+ def _save_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ b = keras.layers.Dense(1, name='second')(x)
+ return keras.models.Model(a, b)
+ def _restore_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ y = keras.layers.Dense(1, name='second')(x)
+ b = keras.layers.Dense(3, name='secondjr')(y)
+ return keras.models.Model(a, b)
+ def _restore_init_fn(restore_model):
+ return [v.initializer for v in restore_model.layers[-1].variables]
+
+ self._new_layer_weight_loading_test_template(
+ _save_graph_model, _restore_graph_model,
+ _restore_init_fn)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_graph_model_added_no_weight_layer(self):
+ def _save_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ b = keras.layers.Dense(1, name='second')(x)
+ return keras.models.Model(a, b)
+ def _restore_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ y = keras.layers.Dropout(rate=0.1)(x)
+ b = keras.layers.Dense(1, name='second')(y)
+ return keras.models.Model(a, b)
+ def _restore_init_fn(restore_model):
+ del restore_model # unused
+ return []
+
+ self._new_layer_weight_loading_test_template(
+ _save_graph_model, _restore_graph_model,
+ _restore_init_fn)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_subclassed_model_added_layer(self):
+
+ class SubclassedModelRestore(training.Model):
+
+ def __init__(self):
+ super(SubclassedModelRestore, self).__init__()
+ self.x_layer = keras.layers.Dense(3)
+ self.y_layer = keras.layers.Dense(3)
+ self.b_layer = keras.layers.Dense(1)
+
+ def call(self, a):
+ return self.b_layer(self.y_layer(self.x_layer(a)))
+
+ def _restore_init_fn(restore_model):
+ return [v.initializer for v in restore_model.y_layer.variables]
+
+ self._new_layer_weight_loading_test_template(
+ SubclassedModel, SubclassedModelRestore,
+ _restore_init_fn)
+
+ if __name__ == '__main__':
+ test.main()
--- /dev/null
- Numpy arrays (they folow the same slicing behavior as symbolic TF tensors),
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+ # ==============================================================================
+ """Keras training and evaluation routines for eager execution.
+ """
+ # pylint: disable=protected-access
+ from __future__ import absolute_import
+ from __future__ import division
+ from __future__ import print_function
+
+ import copy
+
+ import numpy as np
+
+ from tensorflow.python.data.ops import iterator_ops
+ from tensorflow.python.eager.backprop import GradientTape
+ from tensorflow.python.framework import errors
+ from tensorflow.python.framework import ops
+ from tensorflow.python.framework import tensor_util
+ from tensorflow.python.keras import backend
+ from tensorflow.python.keras import callbacks as cbks
+ from tensorflow.python.keras import losses
+ from tensorflow.python.keras import metrics as metrics_module
+ from tensorflow.python.keras.engine import training_utils
+ from tensorflow.python.keras.utils import generic_utils
+ from tensorflow.python.ops import array_ops
+ from tensorflow.python.platform import tf_logging as logging
+
+
+ def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None):
+ if metric == 'accuracy' or metric == 'acc':
+ # custom handling of accuracy
+ # (because of class mode duality)
+ output_shape = internal_output_shapes
+ if output_shape[-1] == 1 or loss_func == losses.binary_crossentropy:
+ # case: binary accuracy
+ acc_fn = metrics_module.binary_accuracy
+ elif loss_func == losses.sparse_categorical_crossentropy:
+ # case: categorical accuracy with sparse targets
+ acc_fn = metrics_module.sparse_categorical_accuracy
+ else:
+ acc_fn = metrics_module.categorical_accuracy
+
+ metric_name = 'acc'
+ return metric_name, acc_fn
+ else:
+ metric_fn = metrics_module.get(metric)
+ metric_name = metric_fn.__name__
+ return metric_name, metric_fn
+
+
+ def _eager_loss_fn(outputs, targets, loss_fn, output_name):
+ with backend.name_scope(output_name + '_loss'):
+ loss = loss_fn(targets, outputs)
+ return loss
+
+
+ def _eager_metrics_fn(model, outputs, targets):
+ """Calculates the metrics for each output of the given model.
+
+ Arguments:
+ model: The model on which metrics are being calculated.
+ outputs: The outputs of the given model.
+ targets: The predictions or targets of the given model.
+
+ Returns:
+ Returns the metric names and metric results for each output of the model.
+ """
+ metric_names = []
+ metric_results = []
+ if not isinstance(outputs, list):
+ outputs = [outputs]
+
+ if not isinstance(targets, list):
+ targets = [targets]
+
+ for i in range(len(model.outputs)):
+ output_metrics = model.nested_metrics[i]
+ for nested_output_metric in output_metrics:
+ metric_name, metric_fn = _get_metrics_info(
+ nested_output_metric, backend.int_shape(model.outputs[i]),
+ model.loss_functions[i])
+
+ if len(model.output_names) > 1:
+ metric_name = model.output_names[i] + '_' + metric_name
+ if metric_name not in model.metrics_names:
+ model.metrics_names.append(metric_name)
+
+ with backend.name_scope(metric_name):
+ metric_result = metric_fn(targets[i], outputs[i])
+ metric_names.append(metric_name)
+ metric_results.append(backend.mean(metric_result))
+
+ return metric_results
+
+
+ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
+ """Calculates the loss for a given model.
+
+ Arguments:
+ model: The model on which metrics are being calculated.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ sample_weights: Optional list of sample weight arrays.
+ training: Whether the model should be run in inference or training mode.
+
+ Returns:
+ Returns the model output, total loss and loss value calculated using the
+ specified loss function. The total loss includes regularization losses and
+ applies masking and sample weighting to the loss value.
+ """
+ total_loss = 0
+ if len(inputs) == 1:
+ if model._expects_training_arg:
+ outs = model.call(inputs[0], training=training)
+ else:
+ outs = model.call(inputs[0])
+ else:
+ if model._expects_training_arg:
+ outs = model.call(inputs, training=training)
+ else:
+ outs = model.call(inputs)
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ if not isinstance(targets, list):
+ targets = [targets]
+
+ loss_metrics = []
+ with backend.name_scope('loss'):
+ for i, loss_fn in enumerate(model.loss_functions):
+ if sample_weights:
+ weights = sample_weights[i]
+ else:
+ weights = None
+
+ # TODO(fchollet): support masking; in practice `_keras_mask` is never
+ # set in this context currently.
+ mask = outs[i]._keras_mask
+
+ weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
+ with backend.name_scope(model.output_names[i] + '_loss'):
+ output_loss = weighted_masked_fn(
+ targets[i], outs[i], weights, mask=mask)
+ # If the number of outputs is 1 then we don't append the loss metric
+ # associated with each model output. When there are multiple outputs
+ # associated with a model, each output's loss is calculated and returned
+ # as part of the loss_metrics.
+ if len(model.outputs) > 1:
+ loss_metrics.append(backend.mean(output_loss))
+
+ loss_weight = model.loss_weights_list[i]
+ if total_loss is None:
+ total_loss = loss_weight * output_loss
+ else:
+ total_loss += loss_weight * output_loss
+
+ total_loss = backend.mean(total_loss)
+ # Add regularization losses
+ custom_losses = []
+ for layer in model.layers:
+ if layer.losses:
+ custom_losses += layer.losses
+
+ if custom_losses:
+ total_loss += sum(custom_losses)
+
+ return outs, total_loss, loss_metrics
+
+
+ def iterator_fit_loop(model,
+ inputs,
+ class_weight,
+ steps_per_epoch,
+ callback_model,
+ out_labels,
+ epoch_logs,
+ val_inputs=None,
+ val_targets=None,
+ val_sample_weights=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ callback_metrics=None,
+ validation_steps=None,
+ do_validation=False):
+ """Fit function for eager execution when input is given as dataset iterator.
+
+ Updates the given epoch logs.
+
+ Arguments:
+ model: Instance of the `Model`.
+ inputs: Input dataset iterator.
+ class_weight: Optional class-weight array to weight the importance of
+ samples in `inputs` based on the class they belong to, as conveyed by
+ the targets from the `inputs` iterator.
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch.
+ callback_model: Instance of `Model` to callback.
+ out_labels: Output labels generated from model metric names.
+ epoch_logs: Dictionary of logs from every epoch.
+ val_inputs: Input data for validation.
+ val_targets: Target data for validation.
+ val_sample_weights: Sample weight data for validation.
+ epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
+ callback_metrics: List of strings, the display names of the metrics
+ passed to the callbacks. They should be the
+ concatenation of list the display names of the outputs of
+ `f` and the list of display names of the outputs of `f_val`.
+ validation_steps: Number of steps to run validation for (only if doing
+ validation from data tensors). Ignored with default value of `None`.
+ do_validation: Boolean value indicating whether we should do validation.
+
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
+ """
+ assert isinstance(inputs, iterator_ops.EagerIterator)
+ for step_index in range(steps_per_epoch):
+ batch_logs = {}
+ batch_logs['batch'] = step_index
+ batch_logs['size'] = 1
+ callbacks.on_batch_begin(step_index, batch_logs)
+
+ # Get data from the iterator.
+ try:
+ next_element = inputs.get_next()
+ except errors.OutOfRangeError:
+ logging.warning(
+ 'Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset'
+ ' can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' % steps_per_epoch * epochs)
+ break
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s' % next_element)
+ x, y = next_element
+
+ # Validate and standardize data.
+ x, y, sample_weights = model._standardize_user_data(
+ x, y, class_weight=class_weight)
+ if sample_weights:
+ sample_weights = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights
+ ]
+
+ if step_index == 0 and not callback_metrics:
+ out_labels = model.metrics_names
+ if do_validation:
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ callback_metrics = copy.copy(out_labels)
+ callbacks.set_params({
+ 'epochs': epochs,
+ 'steps': steps_per_epoch,
+ 'verbose': verbose,
+ 'do_validation': do_validation,
+ 'metrics': callback_metrics or [],
+ })
+
+ # Train model.
+ outs, loss, loss_metrics = _process_single_batch(
+ model, x, y, sample_weights=sample_weights, training=True)
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ # Calculate metrics.
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ # Required for eager execution
+ metrics_results = _eager_metrics_fn(model, outs, y)
+ batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
+
+ for k, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_logs[k] = tensor_util.constant_value(v)
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callback_model.stop_training:
+ break
+
+ if step_index == steps_per_epoch - 1:
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+
+ def batch_fit_loop(model,
+ inputs,
+ targets,
+ epoch_logs,
+ index_array,
+ out_labels,
+ callback_model,
+ batch_size,
+ sample_weights=None,
+ val_inputs=None,
+ val_targets=None,
+ val_sample_weights=None,
+ callbacks=None,
+ shuffle=True,
+ num_train_samples=None,
+ do_validation=False):
+ """Fit function for eager execution when input is given as arrays or tensors.
+
+ Updates the given epoch logs.
+
+ Arguments:
+ model: Instance of the `Model`.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ epoch_logs: Dictionary of logs from every epoch.
+ index_array: Index array generated from number of training samples.
+ out_labels: Output labels generated from model metric names.
+ callback_model: Instance of `Model` to callback.
+ batch_size: Integer batch size or None if unknown.
+ sample_weights: Optional list of sample weight arrays.
+ val_inputs: Input data for validation.
+ val_targets: Target data for validation.
+ val_sample_weights: Sample weight data for validation.
+ callbacks: List of callbacks to be called during training.
+ shuffle: Whether to shuffle the data at the beginning of each epoch.
+ num_train_samples: Integer number of training samples.
+ do_validation: Boolean value indicating whether we should do validation.
+ """
+ # TODO(psv): Create a dataset iterator instead of manually creating batches
+ # here and in batch_test_loop, batch_predict_loop.
+ if shuffle == 'batch':
+ index_array = model._batch_shuffle(index_array, batch_size)
+ elif shuffle:
+ np.random.shuffle(index_array)
+
+ batches = generic_utils.make_batches(num_train_samples, batch_size)
+
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ inputs_batch = slice_arrays(inputs, batch_ids, contiguous=not shuffle)
+ targets_batch = slice_arrays(targets, batch_ids, contiguous=not shuffle)
+ if sample_weights:
+ sample_weights_batch = slice_arrays(
+ sample_weights, batch_ids, contiguous=not shuffle)
+ else:
+ sample_weights_batch = None
+ batch_logs = {}
+ batch_logs['batch'] = batch_index
+ batch_logs['size'] = len(batch_ids)
+
+ callbacks.on_batch_begin(batch_index, batch_logs)
+
+ inputs_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in inputs_batch
+ ]
+ targets_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in targets_batch
+ ]
+ if sample_weights:
+ sample_weights_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights_batch
+ ]
+
+ outs, loss, loss_metrics = _process_single_batch(
+ model,
+ inputs_batch,
+ targets_batch,
+ sample_weights=sample_weights_batch,
+ training=True)
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ # Required for eager execution
+ metrics_results = _eager_metrics_fn(model, outs, targets_batch)
+ batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
+
+ for k, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_logs[k] = tensor_util.constant_value(v)
+ callbacks.on_batch_end(batch_index, batch_logs)
+ if callback_model.stop_training:
+ break
+
+ if batch_index == len(batches) - 1: # Last batch.
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+
+ def iterator_test_loop(model, inputs, steps, verbose=0):
+ """Test function for eager execution when input is given as dataset iterator.
+
+ Arguments:
+ model: Model instance that is being evaluated in Eager mode.
+ inputs: Input dataset iterator.
+ steps: Total number of steps (batches of samples) before declaring
+ predictions finished.
+ verbose: Verbosity mode.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
+ """
+ assert isinstance(inputs, iterator_ops.EagerIterator)
+ outs = []
+ num_samples = 0
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=steps)
+ for step_index in range(steps):
+ # Get data from the iterator.
+ try:
+ next_element = inputs.get_next()
+ except errors.OutOfRangeError:
+ logging.warning(
+ 'Your dataset iterator ran out of data interrupting testing. '
+ 'Make sure that your dataset can generate at least `steps` batches '
+ '(in this case, %d batches).', steps)
+ break
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError('Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s' % next_element)
+ x, y = next_element
+
+ # Validate and standardize data.
+ x, y, sample_weights = model._standardize_user_data(x, y)
+
+ # Calculate model output, loss values.
+ loss_outs, loss, loss_metrics = _model_loss(
+ model, x, y, sample_weights=sample_weights, training=False)
+ metrics_results = _eager_metrics_fn(model, loss_outs, y)
+ batch_outs = []
+ for _, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_outs.append(tensor_util.constant_value(v))
+
+ # Get current step size.
+ if isinstance(x, list):
+ step_size = x[0].get_shape().as_list()[0]
+ else:
+ step_size = x.get_shape().as_list()[0]
+
+ # Accumulate results in output array.
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step_index == 0:
+ for _ in enumerate(batch_outs):
+ outs.append(0.)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out * step_size
+
+ # Calculate sample size.
+ num_samples += step_size
+ if verbose == 1:
+ progbar.update(step_index + 1)
+
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+ def batch_test_loop(model,
+ inputs,
+ targets,
+ batch_size,
+ sample_weights=None,
+ verbose=0):
+ """Test function for eager execution when input is given as arrays or tensors.
+
+ Arguments:
+ model: Model instance that is being evaluated in Eager mode.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ batch_size: Integer batch size.
+ sample_weights: Optional list of sample weight arrays.
+ verbose: Verbosity mode.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+ """
+ outs = []
+ feed_data = inputs + targets
+ if sample_weights:
+ feed_data += sample_weights
+ num_samples = training_utils.check_num_samples(
+ feed_data, batch_size=batch_size)
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=num_samples)
+ batches = generic_utils.make_batches(num_samples, batch_size)
+ index_array = np.arange(num_samples)
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ inputs_batch = slice_arrays(inputs, batch_ids)
+ targets_batch = slice_arrays(targets, batch_ids)
+ if sample_weights:
+ sample_weights_batch = slice_arrays(sample_weights, batch_ids)
+ else:
+ sample_weights_batch = None
+
+ inputs_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in inputs_batch
+ ]
+ targets_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in targets_batch
+ ]
+ if sample_weights:
+ sample_weights_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights_batch
+ ]
+
+ loss_outs, loss, loss_metrics = _model_loss(
+ model,
+ inputs_batch,
+ targets_batch,
+ sample_weights=sample_weights_batch,
+ training=False)
+ metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch)
+ batch_outs = []
+ for _, v in zip(model.metrics_names,
+ [backend.mean(loss)] + loss_metrics + metrics_results):
+ batch_outs.append(tensor_util.constant_value(v))
+
+ if isinstance(batch_outs, list):
+ if batch_index == 0:
+ for _ in enumerate(batch_outs):
+ outs.append(0.)
+ for i, batch_out in enumerate(batch_outs):
+ outs[i] += batch_out * len(batch_ids)
+ else:
+ if batch_index == 0:
+ outs.append(0.)
+ outs[0] += batch_outs * len(batch_ids)
+
+ if verbose == 1:
+ progbar.update(batch_end)
+
+ for i in range(len(outs)):
+ outs[i] /= num_samples
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+ def iterator_predict_loop(model, inputs, steps, verbose=0):
+ """Predict function for eager execution when input is dataset iterator.
+
+ Arguments:
+ model: Instance of `Model`.
+ inputs: Input dataset iterator.
+ steps: Total number of steps (batches of samples) before declaring
+ `_predict_loop` finished.
+ verbose: Verbosity mode.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions (if the model has multiple outputs).
+
+ Raises:
+ ValueError: In case of mismatch between given number of inputs and
+ expectations of the model.
+ """
+ assert isinstance(inputs, iterator_ops.EagerIterator)
+ outs = []
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=steps)
+ for step_index in range(steps):
+ # Get data from the iterator.
+ try:
+ next_element = inputs.get_next()
+ except errors.OutOfRangeError:
+ logging.warning(
+ 'Your dataset iterator ran out of data; '
+ 'interrupting prediction. Make sure that your '
+ 'dataset can generate at least `steps` '
+ 'batches (in this case, %d batches).', steps)
+ break
+
+ if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
+ raise ValueError(
+ 'Please provide data as a list or tuple of 2 elements '
+ ' - input and target pair. Received %s. We do not use the '
+ '`target` value here.' % next_element)
+ x, _ = next_element
+
+ # Validate and standardize data.
+ x, _, _ = model._standardize_user_data(x)
+
+ if model._expects_training_arg:
+ batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
+ else:
+ batch_outs = model.call(x[0] if len(x) == 1 else x)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+
+ # We collect the results from every step and then concatenate them once
+ # in the end. This is an expensive process. We are doing this because we
+ # do not know the number of samples beforehand.
+ if step_index == 0:
+ for _ in batch_outs:
+ outs.append([])
+ for i, batch_out in enumerate(batch_outs):
+ outs[i].append(backend.get_value(batch_out))
+
+ if verbose == 1:
+ progbar.update(step_index + 1)
+ for i, out in enumerate(outs):
+ outs[i] = np.concatenate(tuple(out), axis=0)
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+ def batch_predict_loop(model, inputs, batch_size, verbose=0):
+ """Predict function for eager execution when input is arrays or tensors.
+
+ Arguments:
+ model: Instance of `Model`.
+ inputs: List of input arrays.
+ batch_size: Integer batch size.
+ verbose: Verbosity mode.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions (if the model has multiple outputs).
+ """
+ outs = []
+ num_samples = training_utils.check_num_samples(inputs, batch_size)
+ if verbose == 1:
+ progbar = generic_utils.Progbar(target=num_samples)
+ batches = generic_utils.make_batches(num_samples, batch_size)
+ index_array = np.arange(num_samples)
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ inputs_batch = slice_arrays(inputs, batch_ids)
+
+ inputs_batch = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ for val in inputs_batch
+ ]
+
+ if len(inputs_batch) == 1:
+ if model._expects_training_arg:
+ batch_outs = model.call(inputs_batch[0], training=False)
+ else:
+ batch_outs = model.call(inputs_batch[0])
+ else:
+ if model._expects_training_arg:
+ batch_outs = model.call(inputs_batch, training=False)
+ else:
+ batch_outs = model.call(inputs_batch)
+
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if batch_index == 0:
+ # Pre-allocate the results arrays.
+ for batch_out in batch_outs:
+ dims = batch_out.shape[1:].dims
+ dims_list = [d.value for d in dims]
+ shape = (num_samples,) + tuple(dims_list)
+ outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype))
+ for i, batch_out in enumerate(batch_outs):
+ outs[i][batch_start:batch_end] = batch_out
+ if verbose == 1:
+ progbar.update(batch_end)
+
+ if len(outs) == 1:
+ return outs[0]
+ return outs
+
+
+ def slice_arrays(arrays, indices, contiguous=True):
+ """Slices batches out of provided arrays (workaround for eager tensors).
+
+ Unfortunately eager tensors don't have the same slicing behavior as
++ Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
+ hence we cannot use `generic_utils.slice_arrays` directly
+ and we have to implement this workaround based on `concat`. This has a
+ performance cost.
+
+ Arguments:
+ arrays: Single array or list of arrays.
+ indices: List of indices in the array that should be included in the output
+ batch.
+ contiguous: Boolean flag indicating whether the indices are contiguous.
+
+ Returns:
+ Slice of data (either single array or list of arrays).
+ """
+ if any(tensor_util.is_tensor(x) for x in arrays):
+ converted_to_list = False
+ if not isinstance(arrays, list):
+ converted_to_list = True
+ arrays = [arrays]
+ if not contiguous:
+ entries = [[x[i:i + 1] for i in indices] for x in arrays]
+ slices = [array_ops.concat(x, axis=0) for x in entries]
+ else:
+ slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
+ if converted_to_list:
+ slices = slices[0]
+ return slices
+ else:
+ return generic_utils.slice_arrays(arrays, indices)
+
+
+ def _process_single_batch(model,
+ inputs,
+ targets,
+ sample_weights=None,
+ training=False):
+ """Calculate the loss and gradient for one input batch.
+
+ The model weights are updated if training is set to True.
+
+ Arguments:
+ model: Model whose loss has to be calculated.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ sample_weights: Optional list of sample weight arrays.
+ training: The boolean represents if the weights of the model are updated.
+ 'fit' methods will set this to True while 'evaluate' methods will
+ set this to False.
+
+ Returns:
+ output of the model, total loss and the loss associated with each output.
+
+ Raises:
+ ValueError: If the model has no loss to optimize.
+ """
+ with backend.learning_phase_scope(1 if training else 0):
+ with GradientTape() as tape:
+ outs, loss, loss_metrics = _model_loss(model, inputs, targets,
+ sample_weights=sample_weights,
+ training=training)
+ if loss is None:
+ raise ValueError('The model cannot be run '
+ 'because it has no loss to optimize.')
+ if training:
+ if not model._collected_trainable_weights:
+ logging.warning('The list of trainable weights is empty. Make sure that'
+ ' you are not setting model.trainable to False before '
+ 'compiling the model.')
+ else:
+ grads = tape.gradient(loss, model._collected_trainable_weights)
+ model.optimizer.apply_gradients(zip(grads,
+ model._collected_trainable_weights))
+ return outs, loss, loss_metrics
+
+
+ def train_on_batch(model, inputs, targets, sample_weights=None):
+ """Calculates the loss and gradient updates for one input batch.
+
+ Arguments:
+ model: Model whose loss has to be calculated.
+ inputs: Input batch data.
+ targets: Target batch data.
+ sample_weights: Sample weight batch data.
+
+ Returns:
+ total loss and the loss associated with each output.
+ """
+ if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ inputs = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
+ ]
+ targets = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets
+ ]
+ if sample_weights:
+ sample_weights = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights
+ ]
+
+ outs, loss, _ = _process_single_batch(
+ model, inputs, targets, sample_weights=sample_weights, training=True)
+ if not isinstance(outs, list):
+ outs = [outs]
+ metrics_results = _eager_metrics_fn(model, outs, targets)
+ if not isinstance(loss, list):
+ loss = [loss]
+ return loss + metrics_results
+
+
+ def test_on_batch(model, inputs, targets, sample_weights=None):
+ """Calculates the loss for one input batch.
+
+ Arguments:
+ model: Model whose loss has to be calculated.
+ inputs: Input batch data.
+ targets: Target batch data.
+ sample_weights: Sample weight batch data.
+
+ Returns:
+ total loss, loss and metrics associated with each output.
+ """
+ if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ inputs = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
+ ]
+ targets = [
+ ops.convert_to_tensor(val, dtype=backend.floatx()) for val in targets
+ ]
+ if sample_weights:
+ sample_weights = [
+ ops.convert_to_tensor(val, dtype=backend.floatx())
+ if val is not None else None for val in sample_weights
+ ]
+ outs, loss, loss_metrics = _model_loss(
+ model, inputs, targets, sample_weights=sample_weights, training=False)
+ if not isinstance(outs, list):
+ outs = [outs]
+ metrics_results = _eager_metrics_fn(model, outs, targets)
+ if not isinstance(loss, list):
+ loss = [loss]
+ return loss + loss_metrics + metrics_results
+
+
+ def fit_loop(model,
+ inputs,
+ targets,
+ sample_weights=None,
+ class_weight=None,
+ val_inputs=None,
+ val_targets=None,
+ val_sample_weights=None,
+ batch_size=None,
+ epochs=1,
+ verbose=1,
+ callbacks=None,
+ shuffle=True,
+ callback_metrics=None,
+ initial_epoch=0,
+ steps_per_epoch=None,
+ validation_steps=None):
+ """Fit function for eager execution.
+
+ Arguments:
+ model: Instance of the model that is being executed in Eager mode.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ sample_weights: Optional list of sample weight arrays.
+ class_weight: Optional class-weight array to weight the importance of
+ samples in `inputs` based on the class they belong to, as conveyed by
+ `targets`.
+ val_inputs: Input data for validation.
+ val_targets: Target data for validation.
+ val_sample_weights: Sample weight data for validation.
+ batch_size: Integer batch size or None if unknown.
+ epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
+ shuffle: Whether to shuffle the data at the beginning of each epoch
+ callback_metrics: List of strings, the display names of the metrics
+ passed to the callbacks. They should be the
+ concatenation of list the display names of the outputs of
+ `f` and the list of display names of the outputs of `f_val`.
+ initial_epoch: Epoch at which to start training
+ (useful for resuming a previous training run)
+ steps_per_epoch: Total number of steps (batches of samples)
+ before declaring one epoch finished and starting the
+ next epoch. Ignored with the default value of `None`.
+ validation_steps: Number of steps to run validation for (only if doing
+ validation from data tensors). Ignored with default value of `None`.
+
+ Returns:
+ `History` object.
+
+ Raises:
+ ValueError: In case of invalid argument values.
+ """
+ # Required for eager execution
+ with backend.learning_phase_scope(1):
+ do_validation = False
+ if val_inputs:
+ do_validation = True
+ if (steps_per_epoch is None and verbose and inputs and
+ hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
+ print('Train on %d samples, validate on %d samples' %
+ (inputs[0].shape[0], val_inputs[0].shape[0]))
+
+ num_train_samples = None
+ out_labels = None
+ if steps_per_epoch is None or model._is_compiled:
+ out_labels = model.metrics_names
+ if do_validation:
+ callback_metrics = copy.copy(out_labels) + [
+ 'val_' + n for n in out_labels
+ ]
+ else:
+ callback_metrics = copy.copy(out_labels)
+
+ if steps_per_epoch is None:
+ if sample_weights:
+ feed_data = inputs + targets + sample_weights
+ else:
+ feed_data = inputs + targets
+ num_train_samples = training_utils.check_num_samples(
+ feed_data,
+ batch_size=batch_size,
+ steps=steps_per_epoch,
+ steps_name='steps_per_epoch')
+
+ if num_train_samples is not None:
+ index_array = np.arange(num_train_samples)
+
+ model.history = cbks.History()
+ callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
+ if verbose:
+ if steps_per_epoch is not None:
+ count_mode = 'steps'
+ else:
+ count_mode = 'samples'
+ callbacks += [cbks.ProgbarLogger(count_mode)]
+ callbacks = cbks.CallbackList(callbacks)
+
+ # it's possible to callback a different model than self
+ # (used by Sequential models)
+ if hasattr(model, 'callback_model') and model.callback_model:
+ callback_model = model.callback_model
+ else:
+ callback_model = model
+
+ callbacks.set_model(callback_model)
+
+ callbacks.set_params({
+ 'batch_size': batch_size,
+ 'epochs': epochs,
+ 'steps': steps_per_epoch,
+ 'samples': num_train_samples,
+ 'verbose': verbose,
+ 'do_validation': do_validation,
+ 'metrics': callback_metrics or [],
+ })
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
+ for cbk in callbacks:
+ if not val_inputs:
+ cbk.validation_data = []
+ elif isinstance(val_inputs, iterator_ops.EagerIterator):
+ cbk.validation_data = val_inputs
+ elif val_sample_weights:
+ cbk.validation_data = val_inputs + val_targets + val_sample_weights
+ else:
+ cbk.validation_data = val_inputs + val_targets
+
+ for epoch in range(initial_epoch, epochs):
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+
+ if steps_per_epoch is not None:
+ iterator_fit_loop(
+ model,
+ inputs,
+ class_weight,
+ steps_per_epoch=steps_per_epoch,
+ callback_model=callback_model,
+ out_labels=out_labels,
+ epoch_logs=epoch_logs,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ callback_metrics=callback_metrics,
+ validation_steps=validation_steps,
+ do_validation=do_validation)
+ else:
+ batch_fit_loop(
+ model,
+ inputs,
+ targets,
+ epoch_logs=epoch_logs,
+ index_array=index_array,
+ out_labels=out_labels,
+ callback_model=callback_model,
+ batch_size=batch_size,
+ sample_weights=sample_weights,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ callbacks=callbacks,
+ shuffle=shuffle,
+ num_train_samples=num_train_samples,
+ do_validation=do_validation)
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callback_model.stop_training:
+ break
+ callbacks.on_train_end()
+ return model.history
+
+
+ def test_loop(model, inputs, targets,
+ sample_weights=None,
+ batch_size=None,
+ verbose=0,
+ steps=None):
+ """Test function for eager execution.
+
+ Arguments:
+ model: Model instance that is being evaluated in Eager mode.
+ inputs: List of input arrays.
+ targets: List of target arrays.
+ sample_weights: Optional list of sample weight arrays.
+ batch_size: integer batch size or `None`.
+ verbose: verbosity mode.
+ steps: Total number of steps (batches of samples)
+ before declaring predictions finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Scalar loss (if the model has a single output and no metrics)
+ or list of scalars (if the model has multiple outputs
+ and/or metrics). The attribute `model.metrics_names` will give you
+ the display labels for the scalar outputs.
+ """
+ with backend.learning_phase_scope(0):
+ if steps is not None:
+ return iterator_test_loop(model, inputs, steps, verbose=verbose)
+ else:
+ return batch_test_loop(
+ model,
+ inputs,
+ targets,
+ batch_size=batch_size,
+ sample_weights=sample_weights,
+ verbose=verbose)
+
+
+ def predict_loop(model, inputs,
+ batch_size=32,
+ verbose=0,
+ steps=None):
+ """Predict function for eager execution.
+
+ Arguments:
+ model: Instance of `Model`.
+ inputs: List of input arrays.
+ batch_size: integer batch size.
+ verbose: verbosity mode.
+ steps: Total number of steps (batches of samples)
+ before declaring `_predict_loop` finished.
+ Ignored with the default value of `None`.
+
+ Returns:
+ Array of predictions (if the model has a single output)
+ or list of arrays of predictions
+ (if the model has multiple outputs).
+ """
+ with backend.learning_phase_scope(0):
+ if steps is not None:
+ return iterator_predict_loop(model, inputs, steps, verbose=verbose)
+ else:
+ return batch_predict_loop(
+ model, inputs, batch_size=batch_size, verbose=verbose)
from __future__ import division
from __future__ import print_function
- from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
- from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
- from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer
- from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
- from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer
- from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope
- from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
- from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
- from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects
- from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
- from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
- from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix
- from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model
- from tensorflow.python.keras._impl.keras.utils.multi_gpu_utils import multi_gpu_model
- from tensorflow.python.keras._impl.keras.utils.np_utils import normalize
- from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
- from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model
+ from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
+ from tensorflow.python.keras.utils.data_utils import get_file
++from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
+ from tensorflow.python.keras.utils.data_utils import Sequence
+ from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer
+ from tensorflow.python.keras.utils.generic_utils import custom_object_scope
+ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
+ from tensorflow.python.keras.utils.generic_utils import get_custom_objects
+ from tensorflow.python.keras.utils.generic_utils import Progbar
+ from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+ from tensorflow.python.keras.utils.io_utils import HDF5Matrix
+ from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
+ from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
+ from tensorflow.python.keras.utils.np_utils import normalize
+ from tensorflow.python.keras.utils.np_utils import to_categorical
+ from tensorflow.python.keras.utils.vis_utils import plot_model
del absolute_import
del division