return (r, d, new_mean, new_variance)
def call(self, inputs, training=None):
+ original_training_value = training
if training is None:
training = K.learning_phase()
# Currently never reaches here since fused_batch_norm does not support
# virtual batching
outputs = undo_virtual_batching(outputs)
- if not context.executing_eagerly() and training is K.learning_phase():
+ if not context.executing_eagerly() and original_training_value is None:
outputs._uses_learning_phase = True # pylint: disable=protected-access
return outputs
if self.virtual_batch_size is not None:
outputs = undo_virtual_batching(outputs)
- if not context.executing_eagerly() and training is K.learning_phase():
+ if not context.executing_eagerly() and original_training_value is None:
outputs._uses_learning_phase = True # pylint: disable=protected-access
return outputs