When not necessary, avoid the creation of a `placeholder_with_default` in BN (not...
authorFrancois Chollet <fchollet@google.com>
Wed, 11 Apr 2018 20:46:03 +0000 (13:46 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 11 Apr 2018 20:48:27 +0000 (13:48 -0700)
PiperOrigin-RevId: 192502020

tensorflow/python/keras/_impl/keras/layers/normalization.py

index b73025a..5462a95 100644 (file)
@@ -489,6 +489,7 @@ class BatchNormalization(Layer):
     return (r, d, new_mean, new_variance)
 
   def call(self, inputs, training=None):
+    original_training_value = training
     if training is None:
       training = K.learning_phase()
 
@@ -512,7 +513,7 @@ class BatchNormalization(Layer):
         # Currently never reaches here since fused_batch_norm does not support
         # virtual batching
         outputs = undo_virtual_batching(outputs)
-      if not context.executing_eagerly() and training is K.learning_phase():
+      if not context.executing_eagerly() and original_training_value is None:
         outputs._uses_learning_phase = True  # pylint: disable=protected-access
       return outputs
 
@@ -628,7 +629,7 @@ class BatchNormalization(Layer):
 
     if self.virtual_batch_size is not None:
       outputs = undo_virtual_batching(outputs)
-    if not context.executing_eagerly() and training is K.learning_phase():
+    if not context.executing_eagerly() and original_training_value is None:
       outputs._uses_learning_phase = True  # pylint: disable=protected-access
     return outputs