return specs[0]
return specs
- def call(self, inputs, mask=None):
+ def call(self, inputs, training=None, mask=None):
"""Call the model on new inputs.
In this case `call` just reapplies
Arguments:
inputs: A tensor or list of tensors.
+ training: Boolean or boolean scalar tensor, indicating whether to run
+ the `Network` in training mode or inference mode.
mask: A mask or list of masks. A mask can be
either a tensor or None (no mask).
# Cache hit.
return self._output_tensor_cache[cache_key]
# Actually apply the network graph to the new inputs.
- outputs, _ = self._run_internal_graph(inputs, masks)
+ outputs, _ = self._run_internal_graph(inputs,
+ training=training,
+ mask=masks)
return outputs
def compute_output_shape(self, input_shape):
else:
return tensor_shape.TensorShape(output_shapes)
- def _run_internal_graph(self, inputs, masks=None):
+ def _run_internal_graph(self, inputs, training=None, mask=None):
"""Computes output tensors for new inputs.
# Note:
Arguments:
inputs: List of tensors
- masks: List of masks (tensors or None).
+ training: Boolean learning phase.
+ mask: List of masks (tensors or None).
Returns:
Three lists: output_tensors, output_masks, output_shapes
# the future and 2) Keras is a major user of Network. If you don't
# use masking, it does not interfere with regular behavior at all and you
# can ignore it.
- if masks is None:
+ if mask is None:
masks = [None for _ in range(len(inputs))]
+ else:
+ masks = mask
# Dictionary mapping reference tensors to tuples
# (computed tensor, compute mask)
computed_tensor, computed_mask = computed_data[0]
# Ensure mask propagation if applicable.
if 'mask' in tf_inspect.getargspec(layer.call).args:
- if 'mask' not in kwargs:
- kwargs['mask'] = computed_mask
+ kwargs.setdefault('mask', computed_mask)
+ if 'training' in tf_inspect.getargspec(layer.call).args:
+ kwargs.setdefault('training', training)
output_tensors = nest.flatten(
layer.call(computed_tensor, **kwargs))
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
if 'mask' in tf_inspect.getargspec(layer.call).args:
- if 'mask' not in kwargs:
- kwargs['mask'] = computed_masks
+ kwargs.setdefault('mask', computed_masks)
+ if 'training' in tf_inspect.getargspec(layer.call).args:
+ kwargs.setdefault('training', training)
+
output_tensors = nest.flatten(
layer.call(computed_tensors, **kwargs))
if hasattr(layer, 'compute_mask'):
output_val_2 = m2.predict(x_val)
self.assertAllClose(output_val, output_val_2, atol=1e-6)
+ def test_explicit_training_argument(self):
+ with self.test_session():
+ a = keras.layers.Input(shape=(2,))
+ b = keras.layers.Dropout(0.5)(a)
+ base_model = keras.models.Model(a, b)
+
+ a = keras.layers.Input(shape=(2,))
+ b = base_model(a, training=False)
+ model = keras.models.Model(a, b)
+
+ x = np.ones((100, 2))
+ y = np.ones((100, 2))
+ model.compile(optimizer='sgd', loss='mse')
+ loss = model.train_on_batch(x, y)
+ self.assertEqual(loss, 0) # In inference mode, output is equal to input.
+
+ a = keras.layers.Input(shape=(2,))
+ b = base_model(a, training=True)
+ model = keras.models.Model(a, b)
+ preds = model.predict(x)
+ self.assertEqual(np.min(preds), 0.) # At least one unit was dropped.
+
class TestSaving(test.TestCase):