Added option to update parameters using state_dict in AveragedModel (#65495) (#65755)
authorPrabhat Roy <prabhatroy@fb.com>
Wed, 6 Oct 2021 18:13:31 +0000 (19:13 +0100)
committerGitHub <noreply@github.com>
Wed, 6 Oct 2021 18:13:31 +0000 (11:13 -0700)
* Added option to update parameters using state_dict in AveragedModel (#65495)

Summary:
While implementing [EMA](https://github.com/pytorch/vision/pull/4381)(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](https://github.com/pytorch/vision/pull/4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: https://github.com/pytorch/vision/pull/4406#pullrequestreview-753734102

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b1fd543304e3be7bd223cac451cd093e16)

* Added validation of mode parameter in AveragedModel (#65921)

Summary:
Discussion: https://github.com/pytorch/pytorch/pull/65495#issuecomment-930460469

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
(cherry picked from commit c7748fc172553da66368fd0b7fea3fe5661e2dc1)

test/test_optim.py
torch/optim/swa_utils.py

index 2d88d6f..4db1a49 100644 (file)
@@ -2290,6 +2290,38 @@ class TestSWAUtils(TestCase):
         for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
             self.assertEqual(p_avg, p_swa)
 
+    def test_averaged_model_exponential_use_state_dict(self):
+        # Test AveragedModel with EMA as avg_fn and use_state_dict as True.
+        dnn = torch.nn.Sequential(
+            torch.nn.Conv2d(1, 5, kernel_size=3),
+            torch.nn.BatchNorm2d(5, momentum=0.3),
+            torch.nn.Linear(5, 10)
+        )
+        alpha = 0.9
+
+        def avg_fn(p_avg, p, n_avg):
+            return alpha * p_avg + (1 - alpha) * p
+        averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, mode='state_dict')
+        averaged_params = [torch.zeros_like(param) for param in dnn.state_dict().values()
+                           if param.size() != torch.Size([])]
+        n_updates = 10
+        for i in range(n_updates):
+            updated_averaged_params = []
+            for p, p_avg in zip(dnn.state_dict().values(), averaged_params):
+                if p.size() == torch.Size([]):
+                    continue
+                p.detach().add_(torch.randn_like(p))
+                if i == 0:
+                    updated_averaged_params.append(p.clone())
+                else:
+                    updated_averaged_params.append((p_avg * alpha +
+                                                   p * (1 - alpha)).clone())
+            averaged_dnn.update_parameters(dnn)
+            averaged_params = updated_averaged_params
+
+        for p_avg, p_swa in zip(averaged_params, averaged_dnn.module.state_dict().values()):
+            self.assertEqual(p_avg, p_swa)
+
     def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
 
         preactivation_sum = torch.zeros(dnn.n_features)
index a143ffd..e87f10e 100644 (file)
@@ -26,6 +26,8 @@ class AveragedModel(Module):
             :class:`AveragedModel` parameter, the current value of :attr:`model`
             parameter and the number of models already averaged; if None,
             equally weighted average is used (default: None)
+        mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
+            (default: ``'parameters'``)
 
     Example:
         >>> loader, optimizer, model, loss_fn = ...
@@ -84,7 +86,7 @@ class AveragedModel(Module):
         Generalizes Well:
         https://arxiv.org/abs/2001.02312
     """
-    def __init__(self, model, device=None, avg_fn=None):
+    def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
         super(AveragedModel, self).__init__()
         self.module = deepcopy(model)
         if device is not None:
@@ -96,12 +98,18 @@ class AveragedModel(Module):
                 return averaged_model_parameter + \
                     (model_parameter - averaged_model_parameter) / (num_averaged + 1)
         self.avg_fn = avg_fn
+        modes = ['parameters', 'state_dict']
+        if mode not in modes:
+            raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
+        self.use_state_dict = mode == 'state_dict'
 
     def forward(self, *args, **kwargs):
         return self.module(*args, **kwargs)
 
     def update_parameters(self, model):
-        for p_swa, p_model in zip(self.parameters(), model.parameters()):
+        self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
+        model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
+        for p_swa, p_model in zip(self_param, model_param):
             device = p_swa.device
             p_model_ = p_model.detach().to(device)
             if self.n_averaged == 0: