BatchNorm support not tracking stats
authorDavid Riazati <davidriazati@fb.com>
Tue, 4 Dec 2018 23:09:30 +0000 (15:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 23:11:53 +0000 (15:11 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14764

Differential Revision: D13325800

Pulled By: driazati

fbshipit-source-id: a3e4773dc31b83565e7a4de33614d6efd4a12de9

test/test_jit.py
torch/nn/modules/batchnorm.py

index f71ee37..063ac01 100644 (file)
@@ -9843,9 +9843,6 @@ EXCLUDE_PYTHON_PRINT = {
 }
 
 EXCLUDE_SCRIPT_MODULES = {
-    'test_nn_BatchNorm1d_not_tracking_stats',
-    'test_nn_BatchNorm2d_not_tracking_stats',
-    'test_nn_BatchNorm3d_not_tracking_stats',
     'test_nn_AdaptiveAvgPool2d_tuple_none',
     'test_nn_AdaptiveAvgPool3d_tuple_none',
     'test_nn_AdaptiveMaxPool2d_tuple_none',
index 8d75d83..9b09483 100644 (file)
@@ -62,11 +62,13 @@ class _BatchNorm(Module):
         exponential_average_factor = 0.0
 
         if self.training and self.track_running_stats:
-            self.num_batches_tracked += 1
-            if self.momentum is None:  # use cumulative moving average
-                exponential_average_factor = 1.0 / float(self.num_batches_tracked)
-            else:  # use exponential moving average
-                exponential_average_factor = self.momentum
+            # TODO: if statement only here to tell the jit to skip emitting this when it is None
+            if self.num_batches_tracked is not None:
+                self.num_batches_tracked += 1
+                if self.momentum is None:  # use cumulative moving average
+                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
+                else:  # use exponential moving average
+                    exponential_average_factor = self.momentum
 
         return F.batch_norm(
             input, self.running_mean, self.running_var, self.weight, self.bias,