From 49f52b6c074a049f7c301d4f4f2e78788b4cbbc1 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 8 Oct 2021 15:17:47 +0100 Subject: [PATCH] Revert "Added option to update parameters using state_dict in AveragedModel (#65495) (#65755)" (#66308) This reverts commit 5f1a434599b46afd99607839d15892e09269a1c4. --- test/test_optim.py | 32 -------------------------------- torch/optim/swa_utils.py | 12 ++---------- 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index 4db1a49..2d88d6f 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -2290,38 +2290,6 @@ class TestSWAUtils(TestCase): for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertEqual(p_avg, p_swa) - def test_averaged_model_exponential_use_state_dict(self): - # Test AveragedModel with EMA as avg_fn and use_state_dict as True. - dnn = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), - torch.nn.Linear(5, 10) - ) - alpha = 0.9 - - def avg_fn(p_avg, p, n_avg): - return alpha * p_avg + (1 - alpha) * p - averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict') - averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values() - if param.size() != torch.Size([])] - n_updates = 10 - for i in range(n_updates): - updated_averaged_params = [] - for p, p_avg in zip(dnn.state_dict().values(), averaged_params): - if p.size() == torch.Size([]): - continue - p.detach().add_(torch.randn_like(p)) - if i == 0: - updated_averaged_params.append(p.clone()) - else: - updated_averaged_params.append((p_avg * alpha + - p * (1 - alpha)).clone()) - averaged_dnn.update_parameters(dnn) - averaged_params = updated_averaged_params - - for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()): - self.assertEqual(p_avg, p_swa) - def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): preactivation_sum = torch.zeros(dnn.n_features) diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index e87f10e..a143ffd 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -26,8 +26,6 @@ class AveragedModel(Module): :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: None) - mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update - (default: ``'parameters'``) Example: >>> loader, optimizer, model, loss_fn = ... @@ -86,7 +84,7 @@ class AveragedModel(Module): Generalizes Well: https://arxiv.org/abs/2001.02312 """ - def __init__(self, model, device=None, avg_fn=None, mode='parameters'): + def __init__(self, model, device=None, avg_fn=None): super(AveragedModel, self).__init__() self.module = deepcopy(model) if device is not None: @@ -98,18 +96,12 @@ class AveragedModel(Module): return averaged_model_parameter + \ (model_parameter - averaged_model_parameter) / (num_averaged + 1) self.avg_fn = avg_fn - modes = ['parameters', 'state_dict'] - if mode not in modes: - raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.') - self.use_state_dict = mode == 'state_dict' def forward(self, *args, **kwargs): return self.module(*args, **kwargs) def update_parameters(self, model): - self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters() - model_param = model.state_dict().values() if self.use_state_dict else model.parameters() - for p_swa, p_model in zip(self_param, model_param): + for p_swa, p_model in zip(self.parameters(), model.parameters()): device = p_swa.device p_model_ = p_model.detach().to(device) if self.n_averaged == 0: -- 2.7.4