Fix momentum setting in BatchNorm forward pass. (#18764)
authorSpandan Tiwari <sptiwari@microsoft.com>
Mon, 8 Apr 2019 23:21:30 +0000 (16:21 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 8 Apr 2019 23:30:00 +0000 (16:30 -0700)
Summary:
This is a fix for issue https://github.com/pytorch/pytorch/issues/18525. The issue is related not only to ONNX export, but can manifest in other scenarios.
An existing test point in test/onnx/test_operators.py has been updated to cover this scenario as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18764

Reviewed By: zrphercule

Differential Revision: D14735166

Pulled By: houseroad

fbshipit-source-id: 5a737c648f64355929ff31eb12bd4869e744768d

test/onnx/expect/TestOperators.test_batchnorm.expect
test/onnx/expect/TestOperators.test_batchnorm_1d.expect
test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect
test/onnx/test_operators.py
torch/nn/modules/batchnorm.py

index e9edb45..44226c5 100644 (file)
@@ -17,7 +17,7 @@ graph {
     }
     attribute {
       name: "momentum"
-      f: 1
+      f: 0.9
       type: FLOAT
     }
   }
index f3dac32..a8097f1 100644 (file)
@@ -27,7 +27,7 @@ graph {
     }
     attribute {
       name: "momentum"
-      f: 1
+      f: 0.9
       type: FLOAT
     }
   }
index 6e7b9e7..9f7765a 100644 (file)
@@ -43,7 +43,7 @@ graph {
     }
     attribute {
       name: "momentum"
-      f: 1
+      f: 0.7
       type: FLOAT
     }
   }
index 77764f5..0b1fe97 100644 (file)
@@ -480,7 +480,7 @@ class TestOperators(TestCase):
 
     def test_batchnorm_noaffine(self):
         x = torch.randn(128, 128, 1, 1, requires_grad=True)
-        self.assertONNX(nn.BatchNorm2d(128, affine=False), x)
+        self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x)
 
     def test_embedding_bags(self):
         emb_bag = nn.EmbeddingBag(10, 8)
index 2d2034b..cf360bf 100644 (file)
@@ -60,7 +60,13 @@ class _BatchNorm(Module):
     def forward(self, input):
         self._check_input_dim(input)
 
-        exponential_average_factor = 0.0
+        # exponential_average_factor is self.momentum set to
+        # (when it is available) only so that if gets updated
+        # in ONNX graph when this node is exported to ONNX.
+        if self.momentum is None:
+            exponential_average_factor = 0.0
+        else:
+            exponential_average_factor = self.momentum
 
         if self.training and self.track_running_stats:
             # TODO: if statement only here to tell the jit to skip emitting this when it is None