def test_batchnorm_noaffine(self):
x = torch.randn(128, 128, 1, 1, requires_grad=True)
- self.assertONNX(nn.BatchNorm2d(128, affine=False), x)
+ self.assertONNX(nn.BatchNorm2d(128, affine=False, momentum=0.3), x)
def test_embedding_bags(self):
emb_bag = nn.EmbeddingBag(10, 8)
def forward(self, input):
self._check_input_dim(input)
- exponential_average_factor = 0.0
+ # exponential_average_factor is self.momentum set to
+ # (when it is available) only so that if gets updated
+ # in ONNX graph when this node is exported to ONNX.
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None