for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
+ def test_averaged_model_exponential_use_state_dict(self):
+ # Test AveragedModel with EMA as avg_fn and use_state_dict as True.
+ dnn = torch.nn.Sequential(
+ torch.nn.Conv2d(1, 5, kernel_size=3),
+ torch.nn.BatchNorm2d(5, momentum=0.3),
+ torch.nn.Linear(5, 10)
+ )
+ alpha = 0.9
+
+ def avg_fn(p_avg, p, n_avg):
+ return alpha * p_avg + (1 - alpha) * p
+ averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
+ averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
+ if param.size() != torch.Size([])]
+ n_updates = 10
+ for i in range(n_updates):
+ updated_averaged_params = []
+ for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
+ if p.size() == torch.Size([]):
+ continue
+ p.detach().add_(torch.randn_like(p))
+ if i == 0:
+ updated_averaged_params.append(p.clone())
+ else:
+ updated_averaged_params.append((p_avg * alpha +
+ p * (1 - alpha)).clone())
+ averaged_dnn.update_parameters(dnn)
+ averaged_params = updated_averaged_params
+
+ for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
+ self.assertEqual(p_avg, p_swa)
+
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
preactivation_sum = torch.zeros(dnn.n_features)
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
+ mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
+ (default: ``'parameters'``)
Example:
>>> loader, optimizer, model, loss_fn = ...
Generalizes Well:
https://arxiv.org/abs/2001.02312
"""
- def __init__(self, model, device=None, avg_fn=None):
+ def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
+ modes = ['parameters', 'state_dict']
+ if mode not in modes:
+ raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
+ self.use_state_dict = mode == 'state_dict'
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
- for p_swa, p_model in zip(self.parameters(), model.parameters()):
+ self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
+ model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
+ for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0: