Add the ability to specify an explicit `training` argument when calling a Model ...
authorFrancois Chollet <fchollet@google.com>
Wed, 21 Feb 2018 23:12:16 +0000 (15:12 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 23:16:16 +0000 (15:16 -0800)
PiperOrigin-RevId: 186526925

tensorflow/python/keras/_impl/keras/engine/topology.py
tensorflow/python/keras/_impl/keras/engine/topology_test.py
tensorflow/python/keras/_impl/keras/models.py
tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt

index dbf9652..f562a19 100644 (file)
@@ -1260,7 +1260,7 @@ class Network(Layer):
       return specs[0]
     return specs
 
-  def call(self, inputs, mask=None):
+  def call(self, inputs, training=None, mask=None):
     """Call the model on new inputs.
 
     In this case `call` just reapplies
@@ -1269,6 +1269,8 @@ class Network(Layer):
 
     Arguments:
         inputs: A tensor or list of tensors.
+        training: Boolean or boolean scalar tensor, indicating whether to run
+          the `Network` in training mode or inference mode.
         mask: A mask or list of masks. A mask can be
             either a tensor or None (no mask).
 
@@ -1291,7 +1293,9 @@ class Network(Layer):
         # Cache hit.
         return self._output_tensor_cache[cache_key]
     # Actually apply the network graph to the new inputs.
-    outputs, _ = self._run_internal_graph(inputs, masks)
+    outputs, _ = self._run_internal_graph(inputs,
+                                          training=training,
+                                          mask=masks)
     return outputs
 
   def compute_output_shape(self, input_shape):
@@ -1393,7 +1397,7 @@ class Network(Layer):
     else:
       return tensor_shape.TensorShape(output_shapes)
 
-  def _run_internal_graph(self, inputs, masks=None):
+  def _run_internal_graph(self, inputs, training=None, mask=None):
     """Computes output tensors for new inputs.
 
     # Note:
@@ -1402,7 +1406,8 @@ class Network(Layer):
 
     Arguments:
         inputs: List of tensors
-        masks: List of masks (tensors or None).
+        training: Boolean learning phase.
+        mask: List of masks (tensors or None).
 
     Returns:
         Three lists: output_tensors, output_masks, output_shapes
@@ -1414,8 +1419,10 @@ class Network(Layer):
     # the future and 2) Keras is a major user of Network.  If you don't
     # use masking, it does not interfere with regular behavior at all and you
     # can ignore it.
-    if masks is None:
+    if mask is None:
       masks = [None for _ in range(len(inputs))]
+    else:
+      masks = mask
 
     # Dictionary mapping reference tensors to tuples
     # (computed tensor, compute mask)
@@ -1454,8 +1461,9 @@ class Network(Layer):
               computed_tensor, computed_mask = computed_data[0]
               # Ensure mask propagation if applicable.
               if 'mask' in tf_inspect.getargspec(layer.call).args:
-                if 'mask' not in kwargs:
-                  kwargs['mask'] = computed_mask
+                kwargs.setdefault('mask', computed_mask)
+              if 'training' in tf_inspect.getargspec(layer.call).args:
+                kwargs.setdefault('training', training)
 
               output_tensors = nest.flatten(
                   layer.call(computed_tensor, **kwargs))
@@ -1470,8 +1478,10 @@ class Network(Layer):
               computed_tensors = [x[0] for x in computed_data]
               computed_masks = [x[1] for x in computed_data]
               if 'mask' in tf_inspect.getargspec(layer.call).args:
-                if 'mask' not in kwargs:
-                  kwargs['mask'] = computed_masks
+                kwargs.setdefault('mask', computed_masks)
+              if 'training' in tf_inspect.getargspec(layer.call).args:
+                kwargs.setdefault('training', training)
+
               output_tensors = nest.flatten(
                   layer.call(computed_tensors, **kwargs))
               if hasattr(layer, 'compute_mask'):
index ba4d427..139621d 100644 (file)
@@ -852,6 +852,28 @@ class TopologyConstructionTest(test.TestCase):
       output_val_2 = m2.predict(x_val)
       self.assertAllClose(output_val, output_val_2, atol=1e-6)
 
+  def test_explicit_training_argument(self):
+    with self.test_session():
+      a = keras.layers.Input(shape=(2,))
+      b = keras.layers.Dropout(0.5)(a)
+      base_model = keras.models.Model(a, b)
+
+      a = keras.layers.Input(shape=(2,))
+      b = base_model(a, training=False)
+      model = keras.models.Model(a, b)
+
+      x = np.ones((100, 2))
+      y = np.ones((100, 2))
+      model.compile(optimizer='sgd', loss='mse')
+      loss = model.train_on_batch(x, y)
+      self.assertEqual(loss, 0)  # In inference mode, output is equal to input.
+
+      a = keras.layers.Input(shape=(2,))
+      b = base_model(a, training=True)
+      model = keras.models.Model(a, b)
+      preds = model.predict(x)
+      self.assertEqual(np.min(preds), 0.)  # At least one unit was dropped.
+
 
 class TestSaving(test.TestCase):
 
index 05912b2..8000eaa 100644 (file)
@@ -572,10 +572,10 @@ class Sequential(Model):
       self.build()
     return self.model.get_layer(name, index)
 
-  def call(self, inputs, mask=None):
+  def call(self, inputs, **kwargs):
     if not self.built:
       self.build()
-    return self.model.call(inputs, mask)
+    return self.model.call(inputs, **kwargs)
 
   def build(self, input_shape=None):
     if not self.inputs or not self.outputs:
index 5fb6fa3..04724e3 100644 (file)
@@ -139,7 +139,7 @@ tf_class {
   }
   member_method {
     name: "call"
-    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "compile"
index 16f1afb..c94bd2f 100644 (file)
@@ -152,7 +152,7 @@ tf_class {
   }
   member_method {
     name: "call"
-    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
   }
   member_method {
     name: "compile"
index 4260da3..88eb237 100644 (file)
@@ -139,7 +139,7 @@ tf_class {
   }
   member_method {
     name: "call"
-    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "compile"
index 02ddb37..34f10f0 100644 (file)
@@ -152,7 +152,7 @@ tf_class {
   }
   member_method {
     name: "call"
-    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
   }
   member_method {
     name: "compile"