Simplify and extend the management of input-conditional losses and updates.
authorFrancois Chollet <fchollet@google.com>
Fri, 9 Feb 2018 19:28:21 +0000 (11:28 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Feb 2018 19:32:27 +0000 (11:32 -0800)
Instead of keeping track of dependencies manually, we rely on the TF graph structure to find dependencies. The resulting implementation is cleaner and more robust.

This does not change any existing behavior. It extends the current behavior by allowing `get_updates_for(inputs)` and `get_losses_for(inputs)` to be called from *any* tensors upstream of the layer, not just the immediate layer's inputs.

PiperOrigin-RevId: 185168680

21 files changed:
tensorflow/python/keras/_impl/keras/engine/topology_test.py
tensorflow/python/keras/_impl/keras/engine/training.py
tensorflow/python/keras/_impl/keras/layers/normalization_test.py
tensorflow/python/keras/_impl/keras/layers/recurrent.py
tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
tensorflow/python/keras/_impl/keras/layers/wrappers.py
tensorflow/python/keras/_impl/keras/models.py
tensorflow/python/layers/base.py
tensorflow/python/layers/base_test.py
tensorflow/python/layers/network.py
tensorflow/python/layers/network_test.py
tensorflow/python/layers/utils.py
tensorflow/python/layers/utils_test.py
tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt

index 85979d1..0673e42 100644 (file)
@@ -26,6 +26,8 @@ import numpy as np
 from tensorflow.python.framework import dtypes
 from tensorflow.python.keras._impl import keras
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.platform import test
 
 try:
@@ -42,22 +44,28 @@ except ImportError:
 class TopologyConstructionTest(test.TestCase):
 
   def test_get_updates_for(self):
-    a = keras.layers.Input(shape=(2,))
+    a = keras.layers.Input(shape=(1,))
     dense_layer = keras.layers.Dense(1)
-    dense_layer.add_update(0, inputs=a)
-    dense_layer.add_update(1, inputs=None)
+    dense_layer.build((None, 1))
+    update_1 = state_ops.assign_add(dense_layer.kernel, a)
+    update_2 = state_ops.assign_add(dense_layer.kernel, [[1.]])
+    dense_layer.add_update(update_1, inputs=a)
+    dense_layer.add_update(update_2, inputs=None)
 
-    self.assertListEqual(dense_layer.get_updates_for(a), [0])
-    self.assertListEqual(dense_layer.get_updates_for(None), [1])
+    self.assertListEqual(dense_layer.get_updates_for(a), [update_1])
+    self.assertListEqual(dense_layer.get_updates_for(None), [update_2])
 
   def test_get_losses_for(self):
-    a = keras.layers.Input(shape=(2,))
+    a = keras.layers.Input(shape=(1,))
     dense_layer = keras.layers.Dense(1)
-    dense_layer.add_loss(0, inputs=a)
-    dense_layer.add_loss(1, inputs=None)
-
-    self.assertListEqual(dense_layer.get_losses_for(a), [0])
-    self.assertListEqual(dense_layer.get_losses_for(None), [1])
+    dense_layer.build((None, 1))
+    loss_1 = math_ops.reduce_sum(a)
+    loss_2 = math_ops.reduce_sum(dense_layer.kernel)
+    dense_layer.add_loss(loss_1, inputs=a)
+    dense_layer.add_loss(loss_2, inputs=None)
+
+    self.assertListEqual(dense_layer.get_losses_for(a), [loss_1])
+    self.assertListEqual(dense_layer.get_losses_for(None), [loss_2])
 
   def test_trainable_weights(self):
     a = keras.layers.Input(shape=(2,))
index faca964..1185988 100644 (file)
@@ -1020,10 +1020,13 @@ class Model(Network):
 
       with K.name_scope('training'):
         with K.name_scope(self.optimizer.__class__.__name__):
-          training_updates = self.optimizer.get_updates(
+          # Training updates
+          updates = self.optimizer.get_updates(
               params=self._collected_trainable_weights, loss=self.total_loss)
-
-        updates = self.updates + training_updates
+        # Unconditional updates
+        updates += self.get_updates_for(None)
+        # Conditional updates relevant to this model
+        updates += self.get_updates_for(self._feed_inputs)
         # Gets loss and metrics. Updates weights at each call.
         self.train_function = K.function(
             inputs, [self.total_loss] + self.metrics_tensors,
index 39a90e5..2b3628c 100644 (file)
@@ -132,13 +132,19 @@ class NormalizationLayersTest(test.TestCase):
       model.compile('sgd', 'mse')
       model.train_on_batch(x, x)
 
-      assert len(model.updates) == 2
+      self.assertEqual(len(bn.updates), 4)
+      self.assertEqual(len(model.updates), 2)
+      self.assertEqual(len(model.get_updates_for(x1)), 0)
+      self.assertEqual(len(model.get_updates_for(x2)), 2)
 
       # Test model-level reuse
       x3 = keras.layers.Input(shape=(10,))
       y3 = model(x3)
-      new_model = keras.models.Model(x3, y3)
-      assert len(model.updates) == 2
+      new_model = keras.models.Model(x3, y3, name='new_model')
+
+      self.assertEqual(len(new_model.updates), 2)
+      self.assertEqual(len(model.updates), 4)
+      self.assertEqual(len(new_model.get_updates_for(x3)), 2)
       new_model.compile('sgd', 'mse')
       new_model.train_on_batch(x, x)
 
index 5c1b523..4bf6ae9 100644 (file)
@@ -202,17 +202,16 @@ class StackedRNNCells(Layer):
     losses = []
     for cell in self.cells:
       if isinstance(cell, Layer):
-        cell_losses = cell.losses
-        losses += cell_losses
-    return losses
+        losses += cell.losses
+    return losses + self._losses
 
-  def get_losses_for(self, inputs=None):
-    losses = []
+  @property
+  def updates(self):
+    updates = []
     for cell in self.cells:
       if isinstance(cell, Layer):
-        cell_losses = cell.get_losses_for(inputs)
-        losses += cell_losses
-    return losses
+        updates += cell.updates
+    return updates + self._updates
 
 
 @tf_export('keras.layers.RNN')
@@ -617,7 +616,7 @@ class RNN(Layer):
     if self.stateful:
       updates = []
       for i in range(len(states)):
-        updates.append((self.states[i], states[i]))
+        updates.append(K.update(self.states[i], states[i]))
       self.add_update(updates, inputs)
 
     if self.return_sequences:
@@ -777,15 +776,17 @@ class RNN(Layer):
 
   @property
   def losses(self):
+    losses = []
     if isinstance(self.cell, Layer):
-      return self.cell.losses
-    return []
+      losses += self.cell.losses
+    return losses + self._losses
 
-  def get_losses_for(self, inputs=None):
+  @property
+  def updates(self):
+    updates = []
     if isinstance(self.cell, Layer):
-      cell_losses = self.cell.get_losses_for(inputs)
-      return cell_losses + super(RNN, self).get_losses_for(inputs)
-    return super(RNN, self).get_losses_for(inputs)
+      updates += self.cell.updates
+    return updates + self._updates
 
 
 @tf_export('keras.layers.SimpleRNNCell')
@@ -2463,7 +2464,7 @@ class Recurrent(Layer):
     if self.stateful:
       updates = []
       for i in range(len(states)):
-        updates.append((self.states[i], states[i]))
+        updates.append(K.update(self.states[i], states[i]))
       self.add_update(updates, inputs)
 
     # Properly set learning phase
index a1407a2..ab48a63 100644 (file)
@@ -353,13 +353,10 @@ class RNNTest(test.TestCase):
       self.assertAllClose(y_np, y_np_3, atol=1e-4)
 
   def test_stacked_rnn_attributes(self):
-    cells = [keras.layers.LSTMCell(3),
-             keras.layers.LSTMCell(3, kernel_regularizer='l2')]
+    cells = [keras.layers.LSTMCell(1),
+             keras.layers.LSTMCell(1)]
     layer = keras.layers.RNN(cells)
-    layer.build((None, None, 5))
-
-    # Test regularization losses
-    self.assertEqual(len(layer.losses), 1)
+    layer.build((None, None, 1))
 
     # Test weights
     self.assertEqual(len(layer.trainable_weights), 6)
@@ -367,11 +364,32 @@ class RNNTest(test.TestCase):
     self.assertEqual(len(layer.trainable_weights), 3)
     self.assertEqual(len(layer.non_trainable_weights), 3)
 
-    # Test `get_losses_for`
-    x = keras.Input((None, 5))
-    y = keras.backend.sum(x)
-    cells[0].add_loss(y, inputs=x)
-    self.assertEqual(layer.get_losses_for(x), [y])
+    # Test `get_losses_for` and `losses`
+    x = keras.Input((None, 1))
+    loss_1 = keras.backend.sum(x)
+    loss_2 = keras.backend.sum(cells[0].kernel)
+    cells[0].add_loss(loss_1, inputs=x)
+    cells[0].add_loss(loss_2)
+    self.assertEqual(len(layer.losses), 2)
+    self.assertEqual(layer.get_losses_for(None), [loss_2])
+    self.assertEqual(layer.get_losses_for(x), [loss_1])
+
+    # Test `get_updates_for` and `updates`
+    cells = [keras.layers.LSTMCell(1),
+             keras.layers.LSTMCell(1)]
+    layer = keras.layers.RNN(cells)
+    layer.build((None, None, 1))
+
+    x = keras.Input((None, 1))
+    update_1 = keras.backend.update_add(
+        cells[0].kernel, x[0, 0, 0] * cells[0].kernel)
+    update_2 = keras.backend.update_add(
+        cells[0].kernel, keras.backend.ones_like(cells[0].kernel))
+    cells[0].add_update(update_1, inputs=x)
+    cells[0].add_update(update_2)
+    self.assertEqual(len(layer.updates), 2)
+    self.assertEqual(layer.get_updates_for(None), [update_2])
+    self.assertEqual(layer.get_updates_for(x), [update_1])
 
   def test_rnn_dynamic_trainability(self):
     layer_class = keras.layers.SimpleRNN
index c697bce..f053aa1 100644 (file)
@@ -71,34 +71,11 @@ class Wrapper(Layer):
 
   @property
   def updates(self):
-    if hasattr(self.layer, 'updates'):
-      return self.layer.updates
-    return []
-
-  def get_updates_for(self, inputs=None):
-    # If the wrapper modifies the inputs, use the modified inputs to
-    # get the updates from the inner layer.
-    inner_inputs = inputs
-    if inputs is not None:
-      uid = tf_layers_util.object_list_uid(inputs)
-      if uid in self._input_map:
-        inner_inputs = self._input_map[uid]
-
-    updates = self.layer.get_updates_for(inner_inputs)
-    updates += super(Wrapper, self).get_updates_for(inputs)
-    return updates
+    return self.layer.updates + self._updates
 
   @property
   def losses(self):
-    if hasattr(self.layer, 'losses'):
-      return self.layer.losses
-    return []
-
-  def get_losses_for(self, inputs=None):
-    if inputs is None:
-      losses = self.layer.get_losses_for(None)
-      return losses + super(Wrapper, self).get_losses_for(None)
-    return super(Wrapper, self).get_losses_for(inputs)
+    return self.layer.losses + self._losses
 
   def get_weights(self):
     return self.layer.get_weights()
index 20736d2..f5d44ef 100644 (file)
@@ -428,8 +428,6 @@ class Sequential(Model):
     # Used by Layer base class.
     self._dtype = None
     self._activity_regularizer = None
-    self._per_input_losses = {}
-    self._per_input_updates = {}
 
     # The following properties are not actually used by Keras;
     # they exist for compatibility with TF's variable scoping mechanism.
@@ -644,34 +642,6 @@ class Sequential(Model):
     return weights
 
   @property
-  def updates(self):
-    if not self.built:
-      self.build()
-    return self.model.updates
-
-  @property
-  def state_updates(self):
-    if not self.built:
-      self.build()
-    return self.model.state_updates
-
-  def get_updates_for(self, inputs):
-    if not self.built:
-      self.build()
-    return self.model.get_updates_for(inputs)
-
-  @property
-  def losses(self):
-    if not self.built:
-      self.build()
-    return self.model.losses
-
-  def get_losses_for(self, inputs):
-    if not self.built:
-      self.build()
-    return self.model.get_losses_for(inputs)
-
-  @property
   def regularizers(self):
     if not self.built:
       self.build()
index 3a3c559..0d78ef2 100644 (file)
@@ -127,8 +127,6 @@ class Layer(object):
     self._losses = []
     self._reuse = kwargs.get('_reuse')
     self._graph = ops.get_default_graph()
-    self._per_input_losses = {}
-    self._per_input_updates = {}
     self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
     call_fn_args = estimator_util.fn_args(self.call)
     self._compute_previous_mask = ('mask' in call_fn_args or
@@ -252,39 +250,32 @@ class Layer(object):
 
     Arguments:
       updates: Update op, or list/tuple of update ops.
-      inputs: Optional input tensor(s) that the update(s) depend on. Must
-        match the `inputs` argument passed to the `__call__` method at the time
-        the updates are created. If `None` is passed, the updates are assumed
-        to be unconditional, and will apply across all dataflows of the layer.
+      inputs: If anything other than None is passed, it signals the updates
+        are conditional on some of the layer's inputs,
+        and thus they should only be run where these inputs are available.
+        This is the case for BatchNormalization updates, for instance.
+        If None, the updates will be taken into account unconditionally,
+        and you are responsible for making sure that any dependency they might
+        have is available at runtime.
+        A step counter might fall into this category.
     """
     if context.in_eager_mode():
       return  # Updates already applied when in eager mode.
+
     updates = _to_list(updates)
-    if not updates:
-      return
     self._updates += updates
-    if inputs is not None:
-      inputs = nest.flatten(inputs)
-    if not inputs:
-      inputs = None
-    if inputs is not None:
-      # We compute an ID that uniquely identifies the list of tensors.
-      # This ID is order-sensitive.
-      inputs_hash = layers_util.object_list_uid(inputs)
+    if inputs is None:
+      for u in updates:
+        u._unconditional_update = True  # pylint: disable=protected-access
     else:
-      inputs_hash = None
-    if inputs_hash not in self._per_input_updates:
-      self._per_input_updates[inputs_hash] = []
-    self._per_input_updates[inputs_hash] += updates
+      for u in updates:
+        u._unconditional_update = False  # pylint: disable=protected-access
 
   def get_updates_for(self, inputs):
     """Retrieves updates relevant to a specific set of inputs.
 
     Arguments:
       inputs: Input tensor or list/tuple of input tensors.
-        Must match the `inputs` argument passed to the `__call__` method
-        at the time the updates were created.
-        If you pass `inputs=None`, unconditional updates are returned.
 
     Returns:
       List of update ops of the layer that depend on `inputs`.
@@ -293,18 +284,24 @@ class Layer(object):
       RuntimeError: If called in Eager mode.
     """
     if context.in_eager_mode():
-      raise RuntimeError('Layer.get_updates_for not supported in Eager mode.')
+      raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
+
+    # Updates disabled if layer is not trainable and not explicitly stateful.
     if not self.trainable and not self.stateful:
       return []
-    if inputs is not None:
-      inputs = nest.flatten(inputs)
-    if not inputs:
-      inputs = None
-    if inputs is not None:
-      inputs_hash = layers_util.object_list_uid(inputs)
-    else:
-      inputs_hash = None
-    return self._per_input_updates.get(inputs_hash, [])
+
+    if inputs is None:
+      # Requesting unconditional updates.
+      return [x for x in self.updates if x._unconditional_update]  # pylint: disable=protected-access
+
+    # Requesting input-conditional updates.
+    inputs = nest.flatten(inputs)
+    reachable = layers_util.get_reachable_from_inputs(inputs, self.updates)
+    updates = []
+    for update in self.updates:
+      if update in reachable:
+        updates.append(update)
+    return updates
 
   @property
   def losses(self):
@@ -344,9 +341,11 @@ class Layer(object):
 
     Arguments:
       losses: Loss tensor, or list/tuple of tensors.
-      inputs: Optional input tensor(s) that the loss(es) depend on. Must
-        match the `inputs` argument passed to the `__call__` method at the time
-        the losses are created. If `None` is passed, the losses are assumed
+      inputs: If anything other than None is passed, it signals the losses
+        are conditional on some of the layer's inputs,
+        and thus they should only be run where these inputs are available.
+        This is the case for activity regularization losses, for instance.
+        If `None` is passed, the losses are assumed
         to be unconditional, and will apply across all dataflows of the layer
         (e.g. weight regularization losses).
 
@@ -354,24 +353,25 @@ class Layer(object):
       RuntimeError: If called in Eager mode.
     """
     if context.in_eager_mode():
+      # TODO(fchollet): it should be possible (and highly desirable) to support
+      # `add_loss` in eager mode. This allows great convenience and flexibility
+      # in defining custom losses on the fly (e.g. in VAEs).
+      # Simply appending the loss value to `self._losses`
+      # is the correct behavior.
+      # The only caveat is that we need to force the user to only call
+      # `add_loss` from inside a model or Layer's `call` method
+      # (otherwise the loss computation cannot be backproped through).
       raise RuntimeError('Layer.add_loss not supported in Eager mode.')
+
     losses = _to_list(losses)
-    if not losses:
-      return
     self._losses += losses
-    if inputs is not None:
-      inputs = nest.flatten(inputs)
-    if not inputs:
-      inputs = None
-    if inputs is not None:
-      # We compute an ID that uniquely identifies the list of tensors.
-      # This ID is order-sensitive.
-      inputs_hash = layers_util.object_list_uid(inputs)
+    if inputs is None:
+      for loss in losses:
+        loss._unconditional_loss = True  # pylint: disable=protected-access
     else:
-      inputs_hash = None
-    if inputs_hash not in self._per_input_losses:
-      self._per_input_losses[inputs_hash] = []
-    self._per_input_losses[inputs_hash] += losses
+      for loss in losses:
+        loss._unconditional_loss = False  # pylint: disable=protected-access
+    # TODO(fchollet): deprecate collection below.
     _add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)
 
   def get_losses_for(self, inputs):
@@ -379,10 +379,6 @@ class Layer(object):
 
     Arguments:
       inputs: Input tensor or list/tuple of input tensors.
-        Must match the `inputs` argument passed to the `__call__`
-        method at the time the losses were created.
-        If you pass `inputs=None`, unconditional losses are returned,
-        such as weight regularization losses.
 
     Returns:
       List of loss tensors of the layer that depend on `inputs`.
@@ -392,15 +388,23 @@ class Layer(object):
     """
     if context.in_eager_mode():
       raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
-    if inputs is not None:
-      inputs = nest.flatten(inputs)
-    if not inputs:
-      inputs = None
-    if inputs is not None:
-      inputs_hash = layers_util.object_list_uid(inputs)
-    else:
-      inputs_hash = None
-    return self._per_input_losses.get(inputs_hash, [])
+
+    if inputs is None:
+      # Requesting unconditional losses.
+      return [x for x in self.losses if x._unconditional_loss]  # pylint: disable=protected-access
+
+    # Requesting input-conditional losses.
+    inputs = nest.flatten(inputs)
+    # Retrieve the set of tensors in the TF graph that depend on `inputs`.
+    # The losses we want to return will be part of this set.
+    # To avoid unnecessary work, we stop the search in case all of
+    # `self.losses` have been retrieved.
+    reachable = layers_util.get_reachable_from_inputs(inputs, self.losses)
+    losses = []
+    for loss in self.losses:
+      if loss in reachable:
+        losses.append(loss)
+    return losses
 
   def build(self, _):
     """Creates the variables of the layer."""
index 06ba214..91b8988 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import test
 
@@ -555,6 +556,93 @@ class BaseLayerTest(test.TestCase):
         self.assertEqual(len(layer.trainable_variables), 1)
         self.assertEqual(layer.variables[0].graph, outer_graph)
 
+  def testGetUpdateFor(self):
+
+    class MyLayer(base_layers.Layer):
+
+      def build(self, input_shape):
+        self.a = self.add_variable('a',
+                                   (),
+                                   dtypes.float32,
+                                   trainable=False)
+        self.b = self.add_variable('b',
+                                   (),
+                                   dtypes.float32,
+                                   trainable=False)
+        self.add_update(state_ops.assign_add(self.a, 1., name='b_update'))
+        self.built = True
+
+      def call(self, inputs):
+        self.add_update(state_ops.assign_add(self.a, inputs, name='a_update'),
+                        inputs=True)
+        return inputs + 1
+
+    layer = MyLayer()
+    inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+    intermediate_inputs = inputs + 1
+    outputs = layer.apply(intermediate_inputs)
+
+    self.assertEqual(len(layer.updates), 2)
+    self.assertEqual(len(layer.get_updates_for(None)), 1)
+    self.assertEqual(len(layer.get_updates_for([inputs])), 1)
+    self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1)
+    self.assertEqual(len(layer.get_updates_for([outputs])), 0)
+
+    # Call same layer on new input, creating one more conditional update
+    inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+    intermediate_inputs = inputs + 1
+    outputs = layer.apply(intermediate_inputs)
+
+    self.assertEqual(len(layer.updates), 3)
+    self.assertEqual(len(layer.get_updates_for(None)), 1)
+    # Check that we are successfully filtering out irrelevant updates
+    self.assertEqual(len(layer.get_updates_for([inputs])), 1)
+    self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1)
+    self.assertEqual(len(layer.get_updates_for([outputs])), 0)
+
+  def testGetLossesFor(self):
+
+    class MyLayer(base_layers.Layer):
+
+      def build(self, input_shape):
+        self.a = self.add_variable('a',
+                                   (),
+                                   dtypes.float32,
+                                   trainable=False)
+        self.b = self.add_variable('b',
+                                   (),
+                                   dtypes.float32,
+                                   trainable=False)
+        self.add_loss(self.a)
+        self.built = True
+
+      def call(self, inputs):
+        self.add_loss(inputs, inputs=True)
+        return inputs + 1
+
+    layer = MyLayer()
+    inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+    intermediate_inputs = inputs + 1
+    outputs = layer.apply(intermediate_inputs)
+
+    self.assertEqual(len(layer.losses), 2)
+    self.assertEqual(len(layer.get_losses_for(None)), 1)
+    self.assertEqual(len(layer.get_losses_for([inputs])), 1)
+    self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
+    self.assertEqual(len(layer.get_losses_for([outputs])), 0)
+
+    # Call same layer on new input, creating one more conditional loss
+    inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+    intermediate_inputs = inputs + 1
+    outputs = layer.apply(intermediate_inputs)
+
+    self.assertEqual(len(layer.losses), 3)
+    self.assertEqual(len(layer.get_losses_for(None)), 1)
+    # Check that we are successfully filtering out irrelevant losses
+    self.assertEqual(len(layer.get_losses_for([inputs])), 1)
+    self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
+    self.assertEqual(len(layer.get_losses_for([outputs])), 0)
+
 
 if __name__ == '__main__':
   test.main()
index 6de8f35..499f53d 100644 (file)
@@ -256,8 +256,6 @@ class GraphNetwork(base.Layer):
     # self.input_spec
 
     # Private attributes to implement compatibility with Layer.
-    self._per_input_losses = {}
-    self._per_input_updates = {}
     self._updates = []
     self._losses = []
     self._scope = None
@@ -587,28 +585,72 @@ class GraphNetwork(base.Layer):
 
     Will only include updates that are either
     unconditional, or conditional on inputs to this model
-    (e.g. will not include updates that depend on tensors
-    that aren't inputs to this model).
+    (e.g. will not include updates that were created by layers of this model
+    outside of the model).
+
+    Effectively, `network.updates` behaves like `layer.updates`.
+
+    Concrete example:
+
+    ```python
+      bn = keras.layers.BatchNormalization()
+      x1 = keras.layers.Input(shape=(10,))
+      _ = bn(x1)  # This creates 2 updates.
+
+      x2 = keras.layers.Input(shape=(10,))
+      y2 = bn(x2)  # This creates 2 more updates.
+
+      # The BN layer has now 4 updates.
+      self.assertEqual(len(bn.updates), 4)
+
+      # Let's create a model from x2 to y2.
+      model = keras.models.Model(x2, y2)
+
+      # The model does not list all updates from its underlying layers,
+      # but only the updates that are relevant to it. Updates created by layers
+      # outside of the model are discarded.
+      self.assertEqual(len(model.updates), 2)
+
+      # If you keep calling the model, you append to its updates, just like
+      # what happens for a layer.
+      x3 = keras.layers.Input(shape=(10,))
+      y3 = model(x3)
+      self.assertEqual(len(model.updates), 4)
+
+      # But if you call the inner BN layer independently, you don't affect
+      # the model's updates.
+      x4 = keras.layers.Input(shape=(10,))
+      _ = bn(x4)
+      self.assertEqual(len(model.updates), 4)
+    ```
 
     Returns:
         A list of update ops.
     """
     if not self.trainable and not self.stateful:
       return []
+
     updates = []
     for layer in self.layers:
-      if hasattr(layer, 'updates'):
-        # Collect updates that are dependent on inputs
-        # that are part of the model.
-        for node_index, node in enumerate(layer._inbound_nodes):  # pylint: disable=protected-access
-          node_key = _make_node_key(layer.name, node_index)
-          if node_key in self._network_nodes:
-            # The model owns this layer node.
-            inputs = node.input_tensors
-            updates += layer.get_updates_for(inputs)
-        # Collect unconditional updates.
-        updates += layer.get_updates_for(None)
-    return updates
+      updates += layer.updates
+
+    # `updates` might contain irrelevant updates, so it needs to be filtered
+    # with respect to inputs the model has been called on.
+    relevant_inputs = []
+    for i in range(len(self._inbound_nodes)):
+      inputs = self.get_input_at(i)
+      if isinstance(inputs, list):
+        relevant_inputs += inputs
+      else:
+        relevant_inputs.append(inputs)
+    reachable = layers_util.get_reachable_from_inputs(relevant_inputs, updates)
+    relevant_conditional_updates = [x for x in updates if x in reachable]
+    unconditional_updates = [
+        x for x in updates if x._unconditional_update]  # pylint: disable=protected-access
+    # A layer could be used multiple times in a nested structure,
+    # so the updates list must be de-duped.
+    return list(set(
+        relevant_conditional_updates + unconditional_updates + self._updates))
 
   @property
   def losses(self):
@@ -628,22 +670,22 @@ class GraphNetwork(base.Layer):
         losses += layer.losses
       return losses
 
-    # Retrieve losses for all internal layers.
     for layer in self.layers:
-      if hasattr(layer, 'losses'):
-        # Collect losses that are dependent on inputs
-        # that are part of the model.
-        for node_index, node in enumerate(layer._inbound_nodes):  # pylint: disable=protected-access
-          node_key = _make_node_key(layer.name, node_index)
-          if node_key in self._network_nodes:
-            # The model owns this layer node.
-            inputs = node.input_tensors
-            losses += layer.get_losses_for(inputs)
-        # Collect unconditional losses.
-        losses += layer.get_losses_for(None)
-    # Add any potential unconditional model-level loss.
-    losses += self.get_losses_for(None)
-    return losses
+      losses += layer.losses
+
+    relevant_inputs = []
+    for i in range(len(self._inbound_nodes)):
+      inputs = self.get_input_at(i)
+      if isinstance(inputs, list):
+        relevant_inputs += inputs
+      else:
+        relevant_inputs.append(inputs)
+    reachable = layers_util.get_reachable_from_inputs(relevant_inputs, losses)
+    relevant_conditional_losses = [x for x in losses if x in reachable]
+    unconditional_losses = [
+        x for x in losses if x._unconditional_loss]  # pylint: disable=protected-access
+    return list(set(
+        relevant_conditional_losses + unconditional_losses + self._losses))
 
   @property
   def trainable_weights(self):
@@ -805,7 +847,6 @@ class GraphNetwork(base.Layer):
           layer, node_index, tensor_index = self._output_coordinates[i]
           shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
           output_shapes.append(layers_to_output_shapes[shape_key])
-
         # Store in cache.
         self._output_shape_cache[cache_key] = output_shapes
     else:
@@ -915,20 +956,6 @@ class GraphNetwork(base.Layer):
                 # Apply activity regularizer if any:
                 layer.add_loss(regularization_losses, computed_tensors)
 
-          if context.in_graph_mode():
-            # Update model updates and losses:
-            # Keep track of updates that depend on the inputs
-            # (e.g. BN updates).
-            self.add_update(layer.get_updates_for(computed_tensors), inputs)
-            # Keep track of unconditional updates (e.g. a counter).
-            self.add_update(layer.get_updates_for(None), None)
-            # Keep track of losses that depend on the inputs
-            # (e.g. activity regularizers).
-            self.add_loss(layer.get_losses_for(computed_tensors), inputs)
-            # Keep track of unconditional losses
-            # (e.g. weight regularizers).
-            self.add_loss(layer.get_losses_for(None), None)
-
           # Update tensor_map.
           for x, y, mask in zip(reference_output_tensors, output_tensors,
                                 output_masks):
@@ -958,6 +985,7 @@ class GraphNetwork(base.Layer):
                    + '_' + layers_util.object_list_uid(masks))
       self._output_tensor_cache[cache_key] = output_tensors
       self._output_mask_cache[cache_key] = output_masks
+
       if output_shapes is not None:
         input_shapes = [layers_util.static_shape(x) for x in inputs]
         cache_key = layers_util.object_list_uid(input_shapes)
index 7a2c7fb..f46ebdf 100644 (file)
@@ -27,29 +27,137 @@ from tensorflow.python.layers import base as base_layers
 from tensorflow.python.layers import core as core_layers
 from tensorflow.python.layers import network as network_layers
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.platform import test
 
 
 class BaseLayerCompatibilityTest(test.TestCase):
 
-  def test_get_updates_for(self):
-    a = network_layers.Input(shape=(2,))
-    dense_layer = core_layers.Dense(1)
-    dense_layer.add_update(0, inputs=a)
-    dense_layer.add_update(1, inputs=None)
+  def test_get_updates(self):
 
-    self.assertEqual(dense_layer.get_updates_for(a), [0])
-    self.assertEqual(dense_layer.get_updates_for(None), [1])
+    class MyLayer(base_layers.Layer):
 
-  def test_get_losses_for(self):
-    a = network_layers.Input(shape=(2,))
-    dense_layer = core_layers.Dense(1)
-    dense_layer.add_loss(0, inputs=a)
-    dense_layer.add_loss(1, inputs=None)
+      def build(self, input_shape):
+        self.a = self.add_variable('a',
+                                   (1, 1),
+                                   'float32',
+                                   trainable=False)
+        self.b = self.add_variable('b',
+                                   (1, 1),
+                                   'float32',
+                                   trainable=False)
+        self.add_update(state_ops.assign_add(self.a, [[1.]]))
+        self.built = True
 
-    self.assertEqual(dense_layer.get_losses_for(a), [0])
-    self.assertEqual(dense_layer.get_losses_for(None), [1])
+      def call(self, inputs):
+        self.add_update(state_ops.assign_add(self.a, inputs),
+                        inputs=True)
+        return inputs + 1
+
+    x1 = network_layers.Input(shape=(1,))
+    layer = MyLayer()
+    _ = layer.apply(x1)
+
+    self.assertEqual(len(layer.updates), 2)
+    self.assertEqual(len(layer.get_updates_for(x1)), 1)
+    self.assertEqual(len(layer.get_updates_for(None)), 1)
+
+    x2 = network_layers.Input(shape=(1,))
+    y2 = layer.apply(x2)
+
+    self.assertEqual(len(layer.updates), 3)
+    self.assertEqual(len(layer.get_updates_for(x1)), 1)
+    self.assertEqual(len(layer.get_updates_for(x2)), 1)
+    self.assertEqual(len(layer.get_updates_for(None)), 1)
+
+    network = network_layers.GraphNetwork(x2, y2)
+    self.assertEqual(len(network.updates), 2)
+    self.assertEqual(len(network.get_updates_for(x1)), 0)
+    self.assertEqual(len(network.get_updates_for(x2)), 1)
+    self.assertEqual(len(network.get_updates_for(None)), 1)
+
+    x3 = network_layers.Input(shape=(1,))
+    _ = layer.apply(x3)
+    self.assertEqual(len(network.updates), 2)
+
+    x4 = network_layers.Input(shape=(1,))
+    _ = network(x4)
+    self.assertEqual(len(network.updates), 3)
+    self.assertEqual(len(network.get_updates_for(x2)), 1)
+    self.assertEqual(len(network.get_updates_for(x4)), 1)
+    self.assertEqual(len(network.get_updates_for(None)), 1)
+
+    network.add_update(state_ops.assign_add(layer.a, [[1]]))
+    self.assertEqual(len(network.updates), 4)
+    self.assertEqual(len(network.get_updates_for(None)), 2)
+
+    network.add_update(state_ops.assign_add(layer.a, x4), inputs=True)
+    self.assertEqual(len(network.updates), 5)
+    self.assertEqual(len(network.get_updates_for(x4)), 2)
+
+  def test_get_losses(self):
+
+    class MyLayer(base_layers.Layer):
+
+      def build(self, input_shape):
+        self.a = self.add_variable('a',
+                                   (1, 1),
+                                   'float32',
+                                   trainable=False)
+        self.b = self.add_variable('b',
+                                   (1, 1),
+                                   'float32',
+                                   trainable=False)
+        self.add_loss(math_ops.reduce_sum(self.a))
+        self.built = True
+
+      def call(self, inputs):
+        self.add_loss(math_ops.reduce_sum(inputs),
+                      inputs=True)
+        return inputs + 1
+
+    x1 = network_layers.Input(shape=(1,))
+    layer = MyLayer()
+    _ = layer.apply(x1)
+
+    self.assertEqual(len(layer.losses), 2)
+    self.assertEqual(len(layer.get_losses_for(x1)), 1)
+    self.assertEqual(len(layer.get_losses_for(None)), 1)
+
+    x2 = network_layers.Input(shape=(1,))
+    y2 = layer.apply(x2)
+
+    self.assertEqual(len(layer.losses), 3)
+    self.assertEqual(len(layer.get_losses_for(x1)), 1)
+    self.assertEqual(len(layer.get_losses_for(x2)), 1)
+    self.assertEqual(len(layer.get_losses_for(None)), 1)
+
+    network = network_layers.GraphNetwork(x2, y2)
+    self.assertEqual(len(network.losses), 2)
+    self.assertEqual(len(network.get_losses_for(x1)), 0)
+    self.assertEqual(len(network.get_losses_for(x2)), 1)
+    self.assertEqual(len(network.get_losses_for(None)), 1)
+
+    x3 = network_layers.Input(shape=(1,))
+    _ = layer.apply(x3)
+    self.assertEqual(len(network.losses), 2)
+
+    x4 = network_layers.Input(shape=(1,))
+    _ = network(x4)
+    self.assertEqual(len(network.losses), 3)
+    self.assertEqual(len(network.get_losses_for(x2)), 1)
+    self.assertEqual(len(network.get_losses_for(x4)), 1)
+    self.assertEqual(len(network.get_losses_for(None)), 1)
+
+    network.add_loss(math_ops.reduce_sum(layer.a))
+    self.assertEqual(len(network.losses), 4)
+    self.assertEqual(len(network.get_losses_for(None)), 2)
+
+    network.add_loss(math_ops.reduce_sum(x4), inputs=True)
+    self.assertEqual(len(network.losses), 5)
+    self.assertEqual(len(network.get_losses_for(x4)), 2)
 
   def testTopologicalAttributes(self):
     # test layer attributes / methods related to cross-layer connectivity.
@@ -299,9 +407,10 @@ class NetworkTest(test.TestCase):
 
   def testNetworkAttributes(self):
     x = network_layers.Input(shape=(32,))
-    z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x)
+    layer = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))
+    z = layer(x)
     dense = core_layers.Dense(2, name='dense')
-    dense.add_update(1)
+    dense.add_update(state_ops.assign_add(layer.kernel, layer.kernel * 2.))
     y = dense(z)
     net = network_layers.GraphNetwork(x, y)
 
index 7407d9a..1bbf4e6 100644 (file)
@@ -255,3 +255,45 @@ def static_shape(x):
     return tuple(x.get_shape().as_list())
   except ValueError:
     return None
+
+
+def get_reachable_from_inputs(inputs, targets=None):
+  """Returns the set of tensors reachable from `inputs`.
+
+  Stops if all targets have been found (target is optional).
+
+  Only valid in Symbolic mode, not Eager mode.
+
+  Args:
+    inputs: List of tensors.
+    targets: List of tensors.
+
+  Returns:
+    A set of tensors reachable from the inputs (includes the inputs themselves).
+  """
+  reachable = set(inputs)
+  if targets:
+    targets = set(targets)
+  queue = inputs[:]
+
+  while queue:
+    x = queue.pop()
+    outputs = []
+    try:
+      consumers = x.consumers()
+    except AttributeError:
+      # Case where x is a variable type
+      consumers = [x.op]
+    for z in consumers:
+      consumer_outputs = z.outputs
+      if consumer_outputs:  # May be None
+        outputs += consumer_outputs
+
+    for y in outputs:
+      if y not in reachable:
+        reachable.add(y)
+        queue.insert(0, y)
+
+    if targets and targets.issubset(reachable):
+      return reachable
+  return reachable
index a560f6b..c941aad 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.layers import utils
+from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
 
@@ -87,5 +88,34 @@ class ConvUtilsTest(test.TestCase):
     self.assertEqual(3, utils.deconv_output_length(4, 2, 'full', 1))
     self.assertEqual(6, utils.deconv_output_length(4, 2, 'full', 2))
 
+
+class GraphUtilsTest(test.TestCase):
+
+  def testGetReachableFromInputs(self):
+
+    with self.test_session():
+      pl_1 = array_ops.placeholder(shape=None, dtype='float32')
+      pl_2 = array_ops.placeholder(shape=None, dtype='float32')
+      pl_3 = array_ops.placeholder(shape=None, dtype='float32')
+      x_1 = pl_1 + pl_2
+      x_2 = pl_2 * 2
+      x_3 = pl_3 + 1
+      x_4 = x_1 + x_2
+      x_5 = x_3 * pl_1
+
+      self.assertEqual(
+          utils.get_reachable_from_inputs([pl_1]),
+          {pl_1, x_1, x_4, x_5})
+      self.assertEqual(
+          utils.get_reachable_from_inputs([pl_1, pl_2]),
+          {pl_1, pl_2, x_1, x_2, x_4, x_5})
+      self.assertEqual(
+          utils.get_reachable_from_inputs([pl_3]),
+          {pl_3, x_3, x_5})
+      self.assertEqual(
+          utils.get_reachable_from_inputs([x_3]),
+          {x_3, x_5})
+
+
 if __name__ == '__main__':
   test.main()
index 2f5e65a..db26c3e 100644 (file)
@@ -159,7 +159,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
@@ -175,7 +175,7 @@ tf_class {
   }
   member_method {
     name: "get_updates_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_weights"
index d0f6d2a..c741d4d 100644 (file)
@@ -227,7 +227,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
index 0036d68..29d9cf7 100644 (file)
@@ -231,7 +231,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
index b29f65d..ad539a7 100644 (file)
@@ -162,7 +162,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
index b875898..6fafc77 100644 (file)
@@ -219,7 +219,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
index db9f90c..90c37bd 100644 (file)
@@ -158,7 +158,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
index 2a7059d..40aa782 100644 (file)
@@ -155,7 +155,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
@@ -171,7 +171,7 @@ tf_class {
   }
   member_method {
     name: "get_updates_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_weights"
index 58bffa0..27a5438 100644 (file)
@@ -154,7 +154,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_output_at"
@@ -170,7 +170,7 @@ tf_class {
   }
   member_method {
     name: "get_updates_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "get_weights"