From: David Riazati Date: Tue, 4 Dec 2018 23:09:30 +0000 (-0800) Subject: BatchNorm support not tracking stats X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2477 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c3bfa0e52bff13984a20156b1c13358bcd9ac395;p=platform%2Fupstream%2Fpytorch.git BatchNorm support not tracking stats Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14764 Differential Revision: D13325800 Pulled By: driazati fbshipit-source-id: a3e4773dc31b83565e7a4de33614d6efd4a12de9 --- diff --git a/test/test_jit.py b/test/test_jit.py index f71ee37..063ac01 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9843,9 +9843,6 @@ EXCLUDE_PYTHON_PRINT = { } EXCLUDE_SCRIPT_MODULES = { - 'test_nn_BatchNorm1d_not_tracking_stats', - 'test_nn_BatchNorm2d_not_tracking_stats', - 'test_nn_BatchNorm3d_not_tracking_stats', 'test_nn_AdaptiveAvgPool2d_tuple_none', 'test_nn_AdaptiveAvgPool3d_tuple_none', 'test_nn_AdaptiveMaxPool2d_tuple_none', diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 8d75d83..9b09483 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -62,11 +62,13 @@ class _BatchNorm(Module): exponential_average_factor = 0.0 if self.training and self.track_running_stats: - self.num_batches_tracked += 1 - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias,