from tensorflow.python.framework import dtypes
from tensorflow.python.keras._impl import keras
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
try:
class TopologyConstructionTest(test.TestCase):
def test_get_updates_for(self):
- a = keras.layers.Input(shape=(2,))
+ a = keras.layers.Input(shape=(1,))
dense_layer = keras.layers.Dense(1)
- dense_layer.add_update(0, inputs=a)
- dense_layer.add_update(1, inputs=None)
+ dense_layer.build((None, 1))
+ update_1 = state_ops.assign_add(dense_layer.kernel, a)
+ update_2 = state_ops.assign_add(dense_layer.kernel, [[1.]])
+ dense_layer.add_update(update_1, inputs=a)
+ dense_layer.add_update(update_2, inputs=None)
- self.assertListEqual(dense_layer.get_updates_for(a), [0])
- self.assertListEqual(dense_layer.get_updates_for(None), [1])
+ self.assertListEqual(dense_layer.get_updates_for(a), [update_1])
+ self.assertListEqual(dense_layer.get_updates_for(None), [update_2])
def test_get_losses_for(self):
- a = keras.layers.Input(shape=(2,))
+ a = keras.layers.Input(shape=(1,))
dense_layer = keras.layers.Dense(1)
- dense_layer.add_loss(0, inputs=a)
- dense_layer.add_loss(1, inputs=None)
-
- self.assertListEqual(dense_layer.get_losses_for(a), [0])
- self.assertListEqual(dense_layer.get_losses_for(None), [1])
+ dense_layer.build((None, 1))
+ loss_1 = math_ops.reduce_sum(a)
+ loss_2 = math_ops.reduce_sum(dense_layer.kernel)
+ dense_layer.add_loss(loss_1, inputs=a)
+ dense_layer.add_loss(loss_2, inputs=None)
+
+ self.assertListEqual(dense_layer.get_losses_for(a), [loss_1])
+ self.assertListEqual(dense_layer.get_losses_for(None), [loss_2])
def test_trainable_weights(self):
a = keras.layers.Input(shape=(2,))
with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
- training_updates = self.optimizer.get_updates(
+ # Training updates
+ updates = self.optimizer.get_updates(
params=self._collected_trainable_weights, loss=self.total_loss)
-
- updates = self.updates + training_updates
+ # Unconditional updates
+ updates += self.get_updates_for(None)
+ # Conditional updates relevant to this model
+ updates += self.get_updates_for(self._feed_inputs)
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(
inputs, [self.total_loss] + self.metrics_tensors,
model.compile('sgd', 'mse')
model.train_on_batch(x, x)
- assert len(model.updates) == 2
+ self.assertEqual(len(bn.updates), 4)
+ self.assertEqual(len(model.updates), 2)
+ self.assertEqual(len(model.get_updates_for(x1)), 0)
+ self.assertEqual(len(model.get_updates_for(x2)), 2)
# Test model-level reuse
x3 = keras.layers.Input(shape=(10,))
y3 = model(x3)
- new_model = keras.models.Model(x3, y3)
- assert len(model.updates) == 2
+ new_model = keras.models.Model(x3, y3, name='new_model')
+
+ self.assertEqual(len(new_model.updates), 2)
+ self.assertEqual(len(model.updates), 4)
+ self.assertEqual(len(new_model.get_updates_for(x3)), 2)
new_model.compile('sgd', 'mse')
new_model.train_on_batch(x, x)
losses = []
for cell in self.cells:
if isinstance(cell, Layer):
- cell_losses = cell.losses
- losses += cell_losses
- return losses
+ losses += cell.losses
+ return losses + self._losses
- def get_losses_for(self, inputs=None):
- losses = []
+ @property
+ def updates(self):
+ updates = []
for cell in self.cells:
if isinstance(cell, Layer):
- cell_losses = cell.get_losses_for(inputs)
- losses += cell_losses
- return losses
+ updates += cell.updates
+ return updates + self._updates
@tf_export('keras.layers.RNN')
if self.stateful:
updates = []
for i in range(len(states)):
- updates.append((self.states[i], states[i]))
+ updates.append(K.update(self.states[i], states[i]))
self.add_update(updates, inputs)
if self.return_sequences:
@property
def losses(self):
+ losses = []
if isinstance(self.cell, Layer):
- return self.cell.losses
- return []
+ losses += self.cell.losses
+ return losses + self._losses
- def get_losses_for(self, inputs=None):
+ @property
+ def updates(self):
+ updates = []
if isinstance(self.cell, Layer):
- cell_losses = self.cell.get_losses_for(inputs)
- return cell_losses + super(RNN, self).get_losses_for(inputs)
- return super(RNN, self).get_losses_for(inputs)
+ updates += self.cell.updates
+ return updates + self._updates
@tf_export('keras.layers.SimpleRNNCell')
if self.stateful:
updates = []
for i in range(len(states)):
- updates.append((self.states[i], states[i]))
+ updates.append(K.update(self.states[i], states[i]))
self.add_update(updates, inputs)
# Properly set learning phase
self.assertAllClose(y_np, y_np_3, atol=1e-4)
def test_stacked_rnn_attributes(self):
- cells = [keras.layers.LSTMCell(3),
- keras.layers.LSTMCell(3, kernel_regularizer='l2')]
+ cells = [keras.layers.LSTMCell(1),
+ keras.layers.LSTMCell(1)]
layer = keras.layers.RNN(cells)
- layer.build((None, None, 5))
-
- # Test regularization losses
- self.assertEqual(len(layer.losses), 1)
+ layer.build((None, None, 1))
# Test weights
self.assertEqual(len(layer.trainable_weights), 6)
self.assertEqual(len(layer.trainable_weights), 3)
self.assertEqual(len(layer.non_trainable_weights), 3)
- # Test `get_losses_for`
- x = keras.Input((None, 5))
- y = keras.backend.sum(x)
- cells[0].add_loss(y, inputs=x)
- self.assertEqual(layer.get_losses_for(x), [y])
+ # Test `get_losses_for` and `losses`
+ x = keras.Input((None, 1))
+ loss_1 = keras.backend.sum(x)
+ loss_2 = keras.backend.sum(cells[0].kernel)
+ cells[0].add_loss(loss_1, inputs=x)
+ cells[0].add_loss(loss_2)
+ self.assertEqual(len(layer.losses), 2)
+ self.assertEqual(layer.get_losses_for(None), [loss_2])
+ self.assertEqual(layer.get_losses_for(x), [loss_1])
+
+ # Test `get_updates_for` and `updates`
+ cells = [keras.layers.LSTMCell(1),
+ keras.layers.LSTMCell(1)]
+ layer = keras.layers.RNN(cells)
+ layer.build((None, None, 1))
+
+ x = keras.Input((None, 1))
+ update_1 = keras.backend.update_add(
+ cells[0].kernel, x[0, 0, 0] * cells[0].kernel)
+ update_2 = keras.backend.update_add(
+ cells[0].kernel, keras.backend.ones_like(cells[0].kernel))
+ cells[0].add_update(update_1, inputs=x)
+ cells[0].add_update(update_2)
+ self.assertEqual(len(layer.updates), 2)
+ self.assertEqual(layer.get_updates_for(None), [update_2])
+ self.assertEqual(layer.get_updates_for(x), [update_1])
def test_rnn_dynamic_trainability(self):
layer_class = keras.layers.SimpleRNN
@property
def updates(self):
- if hasattr(self.layer, 'updates'):
- return self.layer.updates
- return []
-
- def get_updates_for(self, inputs=None):
- # If the wrapper modifies the inputs, use the modified inputs to
- # get the updates from the inner layer.
- inner_inputs = inputs
- if inputs is not None:
- uid = tf_layers_util.object_list_uid(inputs)
- if uid in self._input_map:
- inner_inputs = self._input_map[uid]
-
- updates = self.layer.get_updates_for(inner_inputs)
- updates += super(Wrapper, self).get_updates_for(inputs)
- return updates
+ return self.layer.updates + self._updates
@property
def losses(self):
- if hasattr(self.layer, 'losses'):
- return self.layer.losses
- return []
-
- def get_losses_for(self, inputs=None):
- if inputs is None:
- losses = self.layer.get_losses_for(None)
- return losses + super(Wrapper, self).get_losses_for(None)
- return super(Wrapper, self).get_losses_for(inputs)
+ return self.layer.losses + self._losses
def get_weights(self):
return self.layer.get_weights()
# Used by Layer base class.
self._dtype = None
self._activity_regularizer = None
- self._per_input_losses = {}
- self._per_input_updates = {}
# The following properties are not actually used by Keras;
# they exist for compatibility with TF's variable scoping mechanism.
return weights
@property
- def updates(self):
- if not self.built:
- self.build()
- return self.model.updates
-
- @property
- def state_updates(self):
- if not self.built:
- self.build()
- return self.model.state_updates
-
- def get_updates_for(self, inputs):
- if not self.built:
- self.build()
- return self.model.get_updates_for(inputs)
-
- @property
- def losses(self):
- if not self.built:
- self.build()
- return self.model.losses
-
- def get_losses_for(self, inputs):
- if not self.built:
- self.build()
- return self.model.get_losses_for(inputs)
-
- @property
def regularizers(self):
if not self.built:
self.build()
self._losses = []
self._reuse = kwargs.get('_reuse')
self._graph = ops.get_default_graph()
- self._per_input_losses = {}
- self._per_input_updates = {}
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
call_fn_args = estimator_util.fn_args(self.call)
self._compute_previous_mask = ('mask' in call_fn_args or
Arguments:
updates: Update op, or list/tuple of update ops.
- inputs: Optional input tensor(s) that the update(s) depend on. Must
- match the `inputs` argument passed to the `__call__` method at the time
- the updates are created. If `None` is passed, the updates are assumed
- to be unconditional, and will apply across all dataflows of the layer.
+ inputs: If anything other than None is passed, it signals the updates
+ are conditional on some of the layer's inputs,
+ and thus they should only be run where these inputs are available.
+ This is the case for BatchNormalization updates, for instance.
+ If None, the updates will be taken into account unconditionally,
+ and you are responsible for making sure that any dependency they might
+ have is available at runtime.
+ A step counter might fall into this category.
"""
if context.in_eager_mode():
return # Updates already applied when in eager mode.
+
updates = _to_list(updates)
- if not updates:
- return
self._updates += updates
- if inputs is not None:
- inputs = nest.flatten(inputs)
- if not inputs:
- inputs = None
- if inputs is not None:
- # We compute an ID that uniquely identifies the list of tensors.
- # This ID is order-sensitive.
- inputs_hash = layers_util.object_list_uid(inputs)
+ if inputs is None:
+ for u in updates:
+ u._unconditional_update = True # pylint: disable=protected-access
else:
- inputs_hash = None
- if inputs_hash not in self._per_input_updates:
- self._per_input_updates[inputs_hash] = []
- self._per_input_updates[inputs_hash] += updates
+ for u in updates:
+ u._unconditional_update = False # pylint: disable=protected-access
def get_updates_for(self, inputs):
"""Retrieves updates relevant to a specific set of inputs.
Arguments:
inputs: Input tensor or list/tuple of input tensors.
- Must match the `inputs` argument passed to the `__call__` method
- at the time the updates were created.
- If you pass `inputs=None`, unconditional updates are returned.
Returns:
List of update ops of the layer that depend on `inputs`.
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
- raise RuntimeError('Layer.get_updates_for not supported in Eager mode.')
+ raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
+
+ # Updates disabled if layer is not trainable and not explicitly stateful.
if not self.trainable and not self.stateful:
return []
- if inputs is not None:
- inputs = nest.flatten(inputs)
- if not inputs:
- inputs = None
- if inputs is not None:
- inputs_hash = layers_util.object_list_uid(inputs)
- else:
- inputs_hash = None
- return self._per_input_updates.get(inputs_hash, [])
+
+ if inputs is None:
+ # Requesting unconditional updates.
+ return [x for x in self.updates if x._unconditional_update] # pylint: disable=protected-access
+
+ # Requesting input-conditional updates.
+ inputs = nest.flatten(inputs)
+ reachable = layers_util.get_reachable_from_inputs(inputs, self.updates)
+ updates = []
+ for update in self.updates:
+ if update in reachable:
+ updates.append(update)
+ return updates
@property
def losses(self):
Arguments:
losses: Loss tensor, or list/tuple of tensors.
- inputs: Optional input tensor(s) that the loss(es) depend on. Must
- match the `inputs` argument passed to the `__call__` method at the time
- the losses are created. If `None` is passed, the losses are assumed
+ inputs: If anything other than None is passed, it signals the losses
+ are conditional on some of the layer's inputs,
+ and thus they should only be run where these inputs are available.
+ This is the case for activity regularization losses, for instance.
+ If `None` is passed, the losses are assumed
to be unconditional, and will apply across all dataflows of the layer
(e.g. weight regularization losses).
RuntimeError: If called in Eager mode.
"""
if context.in_eager_mode():
+ # TODO(fchollet): it should be possible (and highly desirable) to support
+ # `add_loss` in eager mode. This allows great convenience and flexibility
+ # in defining custom losses on the fly (e.g. in VAEs).
+ # Simply appending the loss value to `self._losses`
+ # is the correct behavior.
+ # The only caveat is that we need to force the user to only call
+ # `add_loss` from inside a model or Layer's `call` method
+ # (otherwise the loss computation cannot be backproped through).
raise RuntimeError('Layer.add_loss not supported in Eager mode.')
+
losses = _to_list(losses)
- if not losses:
- return
self._losses += losses
- if inputs is not None:
- inputs = nest.flatten(inputs)
- if not inputs:
- inputs = None
- if inputs is not None:
- # We compute an ID that uniquely identifies the list of tensors.
- # This ID is order-sensitive.
- inputs_hash = layers_util.object_list_uid(inputs)
+ if inputs is None:
+ for loss in losses:
+ loss._unconditional_loss = True # pylint: disable=protected-access
else:
- inputs_hash = None
- if inputs_hash not in self._per_input_losses:
- self._per_input_losses[inputs_hash] = []
- self._per_input_losses[inputs_hash] += losses
+ for loss in losses:
+ loss._unconditional_loss = False # pylint: disable=protected-access
+ # TODO(fchollet): deprecate collection below.
_add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)
def get_losses_for(self, inputs):
Arguments:
inputs: Input tensor or list/tuple of input tensors.
- Must match the `inputs` argument passed to the `__call__`
- method at the time the losses were created.
- If you pass `inputs=None`, unconditional losses are returned,
- such as weight regularization losses.
Returns:
List of loss tensors of the layer that depend on `inputs`.
"""
if context.in_eager_mode():
raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
- if inputs is not None:
- inputs = nest.flatten(inputs)
- if not inputs:
- inputs = None
- if inputs is not None:
- inputs_hash = layers_util.object_list_uid(inputs)
- else:
- inputs_hash = None
- return self._per_input_losses.get(inputs_hash, [])
+
+ if inputs is None:
+ # Requesting unconditional losses.
+ return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access
+
+ # Requesting input-conditional losses.
+ inputs = nest.flatten(inputs)
+ # Retrieve the set of tensors in the TF graph that depend on `inputs`.
+ # The losses we want to return will be part of this set.
+ # To avoid unnecessary work, we stop the search in case all of
+ # `self.losses` have been retrieved.
+ reachable = layers_util.get_reachable_from_inputs(inputs, self.losses)
+ losses = []
+ for loss in self.losses:
+ if loss in reachable:
+ losses.append(loss)
+ return losses
def build(self, _):
"""Creates the variables of the layer."""
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
self.assertEqual(len(layer.trainable_variables), 1)
self.assertEqual(layer.variables[0].graph, outer_graph)
+ def testGetUpdateFor(self):
+
+ class MyLayer(base_layers.Layer):
+
+ def build(self, input_shape):
+ self.a = self.add_variable('a',
+ (),
+ dtypes.float32,
+ trainable=False)
+ self.b = self.add_variable('b',
+ (),
+ dtypes.float32,
+ trainable=False)
+ self.add_update(state_ops.assign_add(self.a, 1., name='b_update'))
+ self.built = True
+
+ def call(self, inputs):
+ self.add_update(state_ops.assign_add(self.a, inputs, name='a_update'),
+ inputs=True)
+ return inputs + 1
+
+ layer = MyLayer()
+ inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+ intermediate_inputs = inputs + 1
+ outputs = layer.apply(intermediate_inputs)
+
+ self.assertEqual(len(layer.updates), 2)
+ self.assertEqual(len(layer.get_updates_for(None)), 1)
+ self.assertEqual(len(layer.get_updates_for([inputs])), 1)
+ self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1)
+ self.assertEqual(len(layer.get_updates_for([outputs])), 0)
+
+ # Call same layer on new input, creating one more conditional update
+ inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+ intermediate_inputs = inputs + 1
+ outputs = layer.apply(intermediate_inputs)
+
+ self.assertEqual(len(layer.updates), 3)
+ self.assertEqual(len(layer.get_updates_for(None)), 1)
+ # Check that we are successfully filtering out irrelevant updates
+ self.assertEqual(len(layer.get_updates_for([inputs])), 1)
+ self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1)
+ self.assertEqual(len(layer.get_updates_for([outputs])), 0)
+
+ def testGetLossesFor(self):
+
+ class MyLayer(base_layers.Layer):
+
+ def build(self, input_shape):
+ self.a = self.add_variable('a',
+ (),
+ dtypes.float32,
+ trainable=False)
+ self.b = self.add_variable('b',
+ (),
+ dtypes.float32,
+ trainable=False)
+ self.add_loss(self.a)
+ self.built = True
+
+ def call(self, inputs):
+ self.add_loss(inputs, inputs=True)
+ return inputs + 1
+
+ layer = MyLayer()
+ inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+ intermediate_inputs = inputs + 1
+ outputs = layer.apply(intermediate_inputs)
+
+ self.assertEqual(len(layer.losses), 2)
+ self.assertEqual(len(layer.get_losses_for(None)), 1)
+ self.assertEqual(len(layer.get_losses_for([inputs])), 1)
+ self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
+ self.assertEqual(len(layer.get_losses_for([outputs])), 0)
+
+ # Call same layer on new input, creating one more conditional loss
+ inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+ intermediate_inputs = inputs + 1
+ outputs = layer.apply(intermediate_inputs)
+
+ self.assertEqual(len(layer.losses), 3)
+ self.assertEqual(len(layer.get_losses_for(None)), 1)
+ # Check that we are successfully filtering out irrelevant losses
+ self.assertEqual(len(layer.get_losses_for([inputs])), 1)
+ self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
+ self.assertEqual(len(layer.get_losses_for([outputs])), 0)
+
if __name__ == '__main__':
test.main()
# self.input_spec
# Private attributes to implement compatibility with Layer.
- self._per_input_losses = {}
- self._per_input_updates = {}
self._updates = []
self._losses = []
self._scope = None
Will only include updates that are either
unconditional, or conditional on inputs to this model
- (e.g. will not include updates that depend on tensors
- that aren't inputs to this model).
+ (e.g. will not include updates that were created by layers of this model
+ outside of the model).
+
+ Effectively, `network.updates` behaves like `layer.updates`.
+
+ Concrete example:
+
+ ```python
+ bn = keras.layers.BatchNormalization()
+ x1 = keras.layers.Input(shape=(10,))
+ _ = bn(x1) # This creates 2 updates.
+
+ x2 = keras.layers.Input(shape=(10,))
+ y2 = bn(x2) # This creates 2 more updates.
+
+ # The BN layer has now 4 updates.
+ self.assertEqual(len(bn.updates), 4)
+
+ # Let's create a model from x2 to y2.
+ model = keras.models.Model(x2, y2)
+
+ # The model does not list all updates from its underlying layers,
+ # but only the updates that are relevant to it. Updates created by layers
+ # outside of the model are discarded.
+ self.assertEqual(len(model.updates), 2)
+
+ # If you keep calling the model, you append to its updates, just like
+ # what happens for a layer.
+ x3 = keras.layers.Input(shape=(10,))
+ y3 = model(x3)
+ self.assertEqual(len(model.updates), 4)
+
+ # But if you call the inner BN layer independently, you don't affect
+ # the model's updates.
+ x4 = keras.layers.Input(shape=(10,))
+ _ = bn(x4)
+ self.assertEqual(len(model.updates), 4)
+ ```
Returns:
A list of update ops.
"""
if not self.trainable and not self.stateful:
return []
+
updates = []
for layer in self.layers:
- if hasattr(layer, 'updates'):
- # Collect updates that are dependent on inputs
- # that are part of the model.
- for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access
- node_key = _make_node_key(layer.name, node_index)
- if node_key in self._network_nodes:
- # The model owns this layer node.
- inputs = node.input_tensors
- updates += layer.get_updates_for(inputs)
- # Collect unconditional updates.
- updates += layer.get_updates_for(None)
- return updates
+ updates += layer.updates
+
+ # `updates` might contain irrelevant updates, so it needs to be filtered
+ # with respect to inputs the model has been called on.
+ relevant_inputs = []
+ for i in range(len(self._inbound_nodes)):
+ inputs = self.get_input_at(i)
+ if isinstance(inputs, list):
+ relevant_inputs += inputs
+ else:
+ relevant_inputs.append(inputs)
+ reachable = layers_util.get_reachable_from_inputs(relevant_inputs, updates)
+ relevant_conditional_updates = [x for x in updates if x in reachable]
+ unconditional_updates = [
+ x for x in updates if x._unconditional_update] # pylint: disable=protected-access
+ # A layer could be used multiple times in a nested structure,
+ # so the updates list must be de-duped.
+ return list(set(
+ relevant_conditional_updates + unconditional_updates + self._updates))
@property
def losses(self):
losses += layer.losses
return losses
- # Retrieve losses for all internal layers.
for layer in self.layers:
- if hasattr(layer, 'losses'):
- # Collect losses that are dependent on inputs
- # that are part of the model.
- for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access
- node_key = _make_node_key(layer.name, node_index)
- if node_key in self._network_nodes:
- # The model owns this layer node.
- inputs = node.input_tensors
- losses += layer.get_losses_for(inputs)
- # Collect unconditional losses.
- losses += layer.get_losses_for(None)
- # Add any potential unconditional model-level loss.
- losses += self.get_losses_for(None)
- return losses
+ losses += layer.losses
+
+ relevant_inputs = []
+ for i in range(len(self._inbound_nodes)):
+ inputs = self.get_input_at(i)
+ if isinstance(inputs, list):
+ relevant_inputs += inputs
+ else:
+ relevant_inputs.append(inputs)
+ reachable = layers_util.get_reachable_from_inputs(relevant_inputs, losses)
+ relevant_conditional_losses = [x for x in losses if x in reachable]
+ unconditional_losses = [
+ x for x in losses if x._unconditional_loss] # pylint: disable=protected-access
+ return list(set(
+ relevant_conditional_losses + unconditional_losses + self._losses))
@property
def trainable_weights(self):
layer, node_index, tensor_index = self._output_coordinates[i]
shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
output_shapes.append(layers_to_output_shapes[shape_key])
-
# Store in cache.
self._output_shape_cache[cache_key] = output_shapes
else:
# Apply activity regularizer if any:
layer.add_loss(regularization_losses, computed_tensors)
- if context.in_graph_mode():
- # Update model updates and losses:
- # Keep track of updates that depend on the inputs
- # (e.g. BN updates).
- self.add_update(layer.get_updates_for(computed_tensors), inputs)
- # Keep track of unconditional updates (e.g. a counter).
- self.add_update(layer.get_updates_for(None), None)
- # Keep track of losses that depend on the inputs
- # (e.g. activity regularizers).
- self.add_loss(layer.get_losses_for(computed_tensors), inputs)
- # Keep track of unconditional losses
- # (e.g. weight regularizers).
- self.add_loss(layer.get_losses_for(None), None)
-
# Update tensor_map.
for x, y, mask in zip(reference_output_tensors, output_tensors,
output_masks):
+ '_' + layers_util.object_list_uid(masks))
self._output_tensor_cache[cache_key] = output_tensors
self._output_mask_cache[cache_key] = output_masks
+
if output_shapes is not None:
input_shapes = [layers_util.static_shape(x) for x in inputs]
cache_key = layers_util.object_list_uid(input_shapes)
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import network as network_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
class BaseLayerCompatibilityTest(test.TestCase):
- def test_get_updates_for(self):
- a = network_layers.Input(shape=(2,))
- dense_layer = core_layers.Dense(1)
- dense_layer.add_update(0, inputs=a)
- dense_layer.add_update(1, inputs=None)
+ def test_get_updates(self):
- self.assertEqual(dense_layer.get_updates_for(a), [0])
- self.assertEqual(dense_layer.get_updates_for(None), [1])
+ class MyLayer(base_layers.Layer):
- def test_get_losses_for(self):
- a = network_layers.Input(shape=(2,))
- dense_layer = core_layers.Dense(1)
- dense_layer.add_loss(0, inputs=a)
- dense_layer.add_loss(1, inputs=None)
+ def build(self, input_shape):
+ self.a = self.add_variable('a',
+ (1, 1),
+ 'float32',
+ trainable=False)
+ self.b = self.add_variable('b',
+ (1, 1),
+ 'float32',
+ trainable=False)
+ self.add_update(state_ops.assign_add(self.a, [[1.]]))
+ self.built = True
- self.assertEqual(dense_layer.get_losses_for(a), [0])
- self.assertEqual(dense_layer.get_losses_for(None), [1])
+ def call(self, inputs):
+ self.add_update(state_ops.assign_add(self.a, inputs),
+ inputs=True)
+ return inputs + 1
+
+ x1 = network_layers.Input(shape=(1,))
+ layer = MyLayer()
+ _ = layer.apply(x1)
+
+ self.assertEqual(len(layer.updates), 2)
+ self.assertEqual(len(layer.get_updates_for(x1)), 1)
+ self.assertEqual(len(layer.get_updates_for(None)), 1)
+
+ x2 = network_layers.Input(shape=(1,))
+ y2 = layer.apply(x2)
+
+ self.assertEqual(len(layer.updates), 3)
+ self.assertEqual(len(layer.get_updates_for(x1)), 1)
+ self.assertEqual(len(layer.get_updates_for(x2)), 1)
+ self.assertEqual(len(layer.get_updates_for(None)), 1)
+
+ network = network_layers.GraphNetwork(x2, y2)
+ self.assertEqual(len(network.updates), 2)
+ self.assertEqual(len(network.get_updates_for(x1)), 0)
+ self.assertEqual(len(network.get_updates_for(x2)), 1)
+ self.assertEqual(len(network.get_updates_for(None)), 1)
+
+ x3 = network_layers.Input(shape=(1,))
+ _ = layer.apply(x3)
+ self.assertEqual(len(network.updates), 2)
+
+ x4 = network_layers.Input(shape=(1,))
+ _ = network(x4)
+ self.assertEqual(len(network.updates), 3)
+ self.assertEqual(len(network.get_updates_for(x2)), 1)
+ self.assertEqual(len(network.get_updates_for(x4)), 1)
+ self.assertEqual(len(network.get_updates_for(None)), 1)
+
+ network.add_update(state_ops.assign_add(layer.a, [[1]]))
+ self.assertEqual(len(network.updates), 4)
+ self.assertEqual(len(network.get_updates_for(None)), 2)
+
+ network.add_update(state_ops.assign_add(layer.a, x4), inputs=True)
+ self.assertEqual(len(network.updates), 5)
+ self.assertEqual(len(network.get_updates_for(x4)), 2)
+
+ def test_get_losses(self):
+
+ class MyLayer(base_layers.Layer):
+
+ def build(self, input_shape):
+ self.a = self.add_variable('a',
+ (1, 1),
+ 'float32',
+ trainable=False)
+ self.b = self.add_variable('b',
+ (1, 1),
+ 'float32',
+ trainable=False)
+ self.add_loss(math_ops.reduce_sum(self.a))
+ self.built = True
+
+ def call(self, inputs):
+ self.add_loss(math_ops.reduce_sum(inputs),
+ inputs=True)
+ return inputs + 1
+
+ x1 = network_layers.Input(shape=(1,))
+ layer = MyLayer()
+ _ = layer.apply(x1)
+
+ self.assertEqual(len(layer.losses), 2)
+ self.assertEqual(len(layer.get_losses_for(x1)), 1)
+ self.assertEqual(len(layer.get_losses_for(None)), 1)
+
+ x2 = network_layers.Input(shape=(1,))
+ y2 = layer.apply(x2)
+
+ self.assertEqual(len(layer.losses), 3)
+ self.assertEqual(len(layer.get_losses_for(x1)), 1)
+ self.assertEqual(len(layer.get_losses_for(x2)), 1)
+ self.assertEqual(len(layer.get_losses_for(None)), 1)
+
+ network = network_layers.GraphNetwork(x2, y2)
+ self.assertEqual(len(network.losses), 2)
+ self.assertEqual(len(network.get_losses_for(x1)), 0)
+ self.assertEqual(len(network.get_losses_for(x2)), 1)
+ self.assertEqual(len(network.get_losses_for(None)), 1)
+
+ x3 = network_layers.Input(shape=(1,))
+ _ = layer.apply(x3)
+ self.assertEqual(len(network.losses), 2)
+
+ x4 = network_layers.Input(shape=(1,))
+ _ = network(x4)
+ self.assertEqual(len(network.losses), 3)
+ self.assertEqual(len(network.get_losses_for(x2)), 1)
+ self.assertEqual(len(network.get_losses_for(x4)), 1)
+ self.assertEqual(len(network.get_losses_for(None)), 1)
+
+ network.add_loss(math_ops.reduce_sum(layer.a))
+ self.assertEqual(len(network.losses), 4)
+ self.assertEqual(len(network.get_losses_for(None)), 2)
+
+ network.add_loss(math_ops.reduce_sum(x4), inputs=True)
+ self.assertEqual(len(network.losses), 5)
+ self.assertEqual(len(network.get_losses_for(x4)), 2)
def testTopologicalAttributes(self):
# test layer attributes / methods related to cross-layer connectivity.
def testNetworkAttributes(self):
x = network_layers.Input(shape=(32,))
- z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x)
+ layer = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))
+ z = layer(x)
dense = core_layers.Dense(2, name='dense')
- dense.add_update(1)
+ dense.add_update(state_ops.assign_add(layer.kernel, layer.kernel * 2.))
y = dense(z)
net = network_layers.GraphNetwork(x, y)
return tuple(x.get_shape().as_list())
except ValueError:
return None
+
+
+def get_reachable_from_inputs(inputs, targets=None):
+ """Returns the set of tensors reachable from `inputs`.
+
+ Stops if all targets have been found (target is optional).
+
+ Only valid in Symbolic mode, not Eager mode.
+
+ Args:
+ inputs: List of tensors.
+ targets: List of tensors.
+
+ Returns:
+ A set of tensors reachable from the inputs (includes the inputs themselves).
+ """
+ reachable = set(inputs)
+ if targets:
+ targets = set(targets)
+ queue = inputs[:]
+
+ while queue:
+ x = queue.pop()
+ outputs = []
+ try:
+ consumers = x.consumers()
+ except AttributeError:
+ # Case where x is a variable type
+ consumers = [x.op]
+ for z in consumers:
+ consumer_outputs = z.outputs
+ if consumer_outputs: # May be None
+ outputs += consumer_outputs
+
+ for y in outputs:
+ if y not in reachable:
+ reachable.add(y)
+ queue.insert(0, y)
+
+ if targets and targets.issubset(reachable):
+ return reachable
+ return reachable
from __future__ import print_function
from tensorflow.python.layers import utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
self.assertEqual(3, utils.deconv_output_length(4, 2, 'full', 1))
self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2))
+
+class GraphUtilsTest(test.TestCase):
+
+ def testGetReachableFromInputs(self):
+
+ with self.test_session():
+ pl_1 = array_ops.placeholder(shape=None, dtype='float32')
+ pl_2 = array_ops.placeholder(shape=None, dtype='float32')
+ pl_3 = array_ops.placeholder(shape=None, dtype='float32')
+ x_1 = pl_1 + pl_2
+ x_2 = pl_2 * 2
+ x_3 = pl_3 + 1
+ x_4 = x_1 + x_2
+ x_5 = x_3 * pl_1
+
+ self.assertEqual(
+ utils.get_reachable_from_inputs([pl_1]),
+ {pl_1, x_1, x_4, x_5})
+ self.assertEqual(
+ utils.get_reachable_from_inputs([pl_1, pl_2]),
+ {pl_1, pl_2, x_1, x_2, x_4, x_5})
+ self.assertEqual(
+ utils.get_reachable_from_inputs([pl_3]),
+ {pl_3, x_3, x_5})
+ self.assertEqual(
+ utils.get_reachable_from_inputs([x_3]),
+ {x_3, x_5})
+
+
if __name__ == '__main__':
test.main()
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
}
member_method {
name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
}
member_method {
name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"