}
EXCLUDE_SCRIPT_MODULES = {
- 'test_nn_BatchNorm1d_not_tracking_stats',
- 'test_nn_BatchNorm2d_not_tracking_stats',
- 'test_nn_BatchNorm3d_not_tracking_stats',
'test_nn_AdaptiveAvgPool2d_tuple_none',
'test_nn_AdaptiveAvgPool3d_tuple_none',
'test_nn_AdaptiveMaxPool2d_tuple_none',
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
- self.num_batches_tracked += 1
- if self.momentum is None: # use cumulative moving average
- exponential_average_factor = 1.0 / float(self.num_batches_tracked)
- else: # use exponential moving average
- exponential_average_factor = self.momentum
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,