Add support for explicit `training` argument in subclassed models.
authorFrancois Chollet <fchollet@google.com>
Fri, 16 Feb 2018 23:02:22 +0000 (15:02 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Feb 2018 23:05:56 +0000 (15:05 -0800)
PiperOrigin-RevId: 186051752

tensorflow/python/keras/_impl/keras/engine/topology.py
tensorflow/python/keras/_impl/keras/engine/training.py
tensorflow/python/keras/_impl/keras/engine/training_eager.py
tensorflow/python/keras/_impl/keras/model_subclassing_test.py

index dd7436e3d00f5dfa736b8d058316918cb5ef51e4..7de5af41c5e04e046e7d6798706f630374d5640f 100644 (file)
@@ -39,6 +39,7 @@ from tensorflow.python.layers import base as tf_base_layers
 from tensorflow.python.layers import network as tf_network
 from tensorflow.python.layers import utils as tf_layers_util
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -263,7 +264,7 @@ class Layer(tf_base_layers.Layer):
 
     # Un-built subclassed network: build it
     if isinstance(self, Network) and not self.inputs:
-      self._set_inputs(inputs)
+      self._set_inputs(inputs, training=kwargs.get('training'))
 
     # Update learning phase info.
     output_tensors = _to_list(output)
@@ -702,6 +703,8 @@ class Network(tf_network.GraphNetwork, Layer):
     super(Network, self).__init__(inputs, outputs, name=name)
 
     self._is_compiled = False
+    self._expects_training_arg = False
+
     self.supports_masking = False
     self.optimizer = None
 
@@ -744,6 +747,11 @@ class Network(tf_network.GraphNetwork, Layer):
     self._layers = []
     self._is_graph_network = False
     self._is_compiled = False
+    if 'training' in tf_inspect.getargspec(self.call).args:
+      self._expects_training_arg = True
+    else:
+      self._expects_training_arg = False
+
     self.outputs = None
     self.inputs = None
     self.trainable = True
index fd14bf3d05f13bba3b5cfdc15d3add3c0e48138f..d8ea2fe3db500d3b52d80e46b0cff22a3d1c5915 100644 (file)
@@ -515,7 +515,65 @@ def _standardize_weights(y,
 
 @tf_export('keras.models.Model', 'keras.Model')
 class Model(Network):
-  """The `Model` class adds training & evaluation routines to a `Network`.
+  """`Model` groups layers into an object with training and inference features.
+
+  There are two ways to instantiate a `Model`:
+
+  1 - With the "functional API", where you start from `Input`,
+  you chain layer calls to specify the model's forward pass,
+  and finally you create your model from inputs and outputs:
+
+  ```python
+  import tensorflow as tf
+
+  inputs = tf.keras.Input(shape=(3,))
+  x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
+  outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
+  model = tf.keras.Model(inputs=inputs, outputs=outputs)
+  ```
+
+  2 - By subclassing the `Model` class: in that case, you should define your
+  layers in `__init__` and you should implement the model's forward pass
+  in `call`.
+
+  ```python
+  import tensorflow as tf
+
+  class MyModel(tf.keras.Model):
+
+    def __init__(self):
+      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
+      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
+
+    def call(self, inputs):
+      x = self.dense1(inputs)
+      return self.dense2(x)
+
+  model = MyModel()
+  ```
+
+  If you subclass `Model`, you can optionally have
+  a `training` argument (boolean) in `call`, which you can use to specify
+  a different behavior in training and inference:
+
+  ```python
+  import tensorflow as tf
+
+  class MyModel(tf.keras.Model):
+
+    def __init__(self):
+      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
+      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
+      self.dropout = tf.keras.layers.Dropout(0.5)
+
+    def call(self, inputs, training=False):
+      x = self.dense1(inputs)
+      if training:
+        x = self.dropout(x, training=training)
+      return self.dense2(x)
+
+  model = MyModel()
+  ```
   """
 
   def compile(self,
@@ -1709,7 +1767,7 @@ class Model(Network):
                          str(x[0].shape[0]) + ' samples')
     return x, y, sample_weights
 
-  def _set_inputs(self, inputs):
+  def _set_inputs(self, inputs, training=None):
     """Set model's input and output specs based on the input data received.
 
     This is to be used for Model subclasses, which do not know at instantiation
@@ -1725,11 +1783,14 @@ class Model(Network):
           when calling `fit`/etc.
         - if data tensors: the model is built on top of these tensors.
           We do not expect any Numpy data to be provided when calling `fit`/etc.
+      training: Boolean or None. Only relevant in symbolic mode. Specifies
+        whether to build the model's graph in inference mode (False), training
+        mode (True), or using the Keras learning phase (None).
     """
     if context.in_eager_mode():
       self._eager_set_inputs(inputs)
     else:
-      self._symbolic_set_inputs(inputs)
+      self._symbolic_set_inputs(inputs, training=training)
 
   def _eager_set_inputs(self, inputs):
     """Set model's input and output specs based on the input data received.
@@ -1775,7 +1836,7 @@ class Model(Network):
         'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
     self.built = True
 
-  def _symbolic_set_inputs(self, inputs):
+  def _symbolic_set_inputs(self, inputs, training=None):
     """Set model's inputs based on the input data received from the user.
 
     This is to be used for Model subclasses, which do not know at instantiation
@@ -1783,6 +1844,9 @@ class Model(Network):
 
     Args:
       inputs: Argument `x` (input data) passed by the user upon first model use.
+      training: Boolean or None. Only relevant in symbolic mode. Specifies
+        whether to build the model's graph in inference mode (False), training
+        mode (True), or using the Keras learning phase (None).
 
     Raises:
       ValueError: If the model's inputs are already set.
@@ -1831,9 +1895,15 @@ class Model(Network):
 
     # Obtain symbolic outputs by calling the model.
     if len(self.inputs) == 1:
-      outputs = self.call(self.inputs[0])
+      if self._expects_training_arg:
+        outputs = self.call(self.inputs[0], training=training)
+      else:
+        outputs = self.call(self.inputs[0])
     else:
-      outputs = self.call(self.inputs)
+      if self._expects_training_arg:
+        outputs = self.call(self.inputs, training=training)
+      else:
+        outputs = self.call(self.inputs)
     if isinstance(outputs, (list, tuple)):
       outputs = list(outputs)
     else:
index 477bb2fe7ac44f1f52191a113c495360400b8d75..3507f36e14de28e1049895da5cbfd036dbb414f7 100644 (file)
@@ -98,7 +98,7 @@ def _eager_metrics_fn(model, outputs, targets):
   return metric_names, metric_results
 
 
-def _model_loss(model, inputs, targets):
+def _model_loss(model, inputs, targets, training=False):
   """Calculates the loss for a given model.
 
   Arguments:
@@ -106,6 +106,7 @@ def _model_loss(model, inputs, targets):
      inputs: The inputs of the given model. This is typically the mini batch of
               data that is fed to the model.
      targets: The predictions or targets of the given model.
+     training: Whether the model should be run in inference or training mode.
 
   Returns:
      Returns the model output, total loss and loss value calculated using the
@@ -114,9 +115,15 @@ def _model_loss(model, inputs, targets):
   """
   total_loss = 0
   if len(inputs) == 1:
-    outs = model.call(inputs[0])
+    if model._expects_training_arg:
+      outs = model.call(inputs[0], training=training)
+    else:
+      outs = model.call(inputs[0])
   else:
-    outs = model.call(inputs)
+    if model._expects_training_arg:
+      outs = model.call(inputs, training=training)
+    else:
+      outs = model.call(inputs)
   if not isinstance(outs, list):
     outs = [outs]
 
@@ -172,7 +179,7 @@ def _model_loss(model, inputs, targets):
 
 
 def _process_single_batch(eager_model_inputs, eager_model_outputs, model,
-                          training=True):
+                          training=False):
   """Calculate the loss and gradient for one input batch.
 
      The model weights are updated if training is set to True.
@@ -195,7 +202,8 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model,
   K.set_learning_phase(training)
   with GradientTape() as tape:
     outs, loss, loss_metrics = _model_loss(model, eager_model_inputs,
-                                           eager_model_outputs)
+                                           eager_model_outputs,
+                                           training=training)
     if loss is None:
       raise ValueError('The model cannot be run '
                        'because it has no loss to optimize.')
@@ -230,7 +238,7 @@ def train_on_batch(model, ins):
   for i in range(len(model.inputs), len(ins_batch_converted)):
     eager_model_outputs.append(ins_batch_converted[i])
   outs, loss, _ = _process_single_batch(
-      eager_model_inputs, eager_model_outputs, model)
+      eager_model_inputs, eager_model_outputs, model, training=True)
   if not isinstance(outs, list):
     outs = [outs]
   _, metrics_results = _eager_metrics_fn(
@@ -415,7 +423,8 @@ def fit_loop(
 
       outs, loss, loss_metrics = _process_single_batch(eager_model_inputs,
                                                        eager_model_outputs,
-                                                       model)
+                                                       model,
+                                                       training=True)
 
       if not isinstance(outs, list):
         outs = [outs]
@@ -517,7 +526,8 @@ def test_loop(model, ins, batch_size=None, verbose=0, steps=None):
       eager_model_outputs.append(ins_batch_converted[i])
 
     loss_outs, loss, loss_metrics = _model_loss(model, eager_model_inputs,
-                                                eager_model_outputs)
+                                                eager_model_outputs,
+                                                training=False)
     _, metrics_results = _eager_metrics_fn(model, loss_outs,
                                            eager_model_outputs)
     batch_outs = []
@@ -590,9 +600,15 @@ def predict_loop(model, ins, batch_size=32, verbose=0, steps=None):
       eager_model_inputs.append(ins_batch_converted[i])
 
     if len(eager_model_inputs) == 1:
-      batch_outs = model.call(eager_model_inputs[0])
+      if model._expects_training_arg:
+        batch_outs = model.call(eager_model_inputs[0], training=False)
+      else:
+        batch_outs = model.call(eager_model_inputs[0])
     else:
-      batch_outs = model.call(eager_model_inputs)
+      if model._expects_training_arg:
+        batch_outs = model.call(eager_model_inputs, training=False)
+      else:
+        batch_outs = model.call(eager_model_inputs)
 
     if not isinstance(batch_outs, list):
       batch_outs = [batch_outs]
index 275985aa36fc6d85768ae05f14cf65e710ad7353..3d71a620fcb34d21c41f920eed99b1fe22668899 100644 (file)
@@ -376,11 +376,11 @@ class ModelSubclassingTest(test.TestCase):
     with self.test_session():
       model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
       model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
-      model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32)
+      model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
       model.fit({'input_1': x1, 'input_2': x2},
                 {'output_1': y1, 'output_2': y2},
                 epochs=2, batch_size=32)
-      model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32,
+      model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0,
                 validation_data=([x1, x2], [y1, y2]))
 
       model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
@@ -438,7 +438,7 @@ class ModelSubclassingTest(test.TestCase):
     with self.test_session():
       model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
       model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
-      model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32)
+      model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
       y_ref_1, y_ref_2 = model.predict([x1, x2])
 
       fd, fname = tempfile.mkstemp('.h5')
@@ -553,6 +553,37 @@ class ModelSubclassingTest(test.TestCase):
           len(model.non_trainable_weights), 4)
       self.assertEqual(len(model.trainable_weights), 12)
 
+  @test_util.run_in_graph_and_eager_modes()
+  def test_support_for_manual_training_arg(self):
+    # In most cases, the `training` argument is left unspecified, in which
+    # case it defaults to value corresponding to the Model method being used
+    # (fit -> True, predict -> False, etc).
+    # If the user writes their model `call` method to take
+    # an explicit `training` argument, we must check that the correct value
+    # is being passed to the model for each method call.
+
+    class DPNet(keras.Model):
+
+      def __init__(self):
+        super(DPNet, self).__init__()
+        self.dp = keras.layers.Dropout(0.5)
+        self.dense = keras.layers.Dense(1,
+                                        use_bias=False,
+                                        kernel_initializer='ones')
+
+      def call(self, inputs, training=False):
+        x = self.dp(inputs, training=training)
+        return self.dense(x)
+
+    with self.test_session():
+      model = DPNet()
+      x = np.ones((10, 10))
+      y = model.predict(x)
+      self.assertEqual(np.sum(y), np.sum(x))
+      model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+      loss = model.train_on_batch(x, y)
+      self.assertGreater(loss, 0.1)
+
 
 if __name__ == '__main__':
   test.main()