[Model Averaging] Simplify PostLocalSGD Optimizer API (#64885)
authorYi Wang <wayi@fb.com>
Tue, 14 Sep 2021 23:35:32 +0000 (16:35 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 23:37:14 +0000 (16:37 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64885

1) The constructor accepts a local optimizer instance instead of the inputs of local optimizer constructor and the class type.
2) The parameters are read from local optimizer's `param_groups` instead of a separate input.

Proposal: https://github.com/pytorch/pytorch/issues/59699
ghstack-source-id: 137865867

Test Plan: buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity

Reviewed By: rohan-varma

Differential Revision: D30888794

fbshipit-source-id: 21261b480f6bbb9b2333426020e3f350da3f73c2

torch/distributed/optim/post_localSGD_optimizer.py
torch/testing/_internal/distributed/distributed_test.py

index 8a15c03..2d1c861 100644 (file)
@@ -1,5 +1,3 @@
-from typing import Any, Iterator, Type
-
 import torch
 import torch.distributed.algorithms.model_averaging.averagers as averagers
 
@@ -11,11 +9,8 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
     After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
 
     Args:
-        params: All the parameters.
-        optimizer_class: The class of the local optimizer.
+        optim: The local optimizer.
         averager: A model averager instance to run post-localSGD algorithm.
-        **defaults: A dict containing default values of optimization options,
-            which are forwarded to the local optimizer.
 
     Example::
 
@@ -37,11 +32,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
         >>>  # Create a post-localSGD optimizer that wraps a local optimizer.
         >>>  # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
         >>>  # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+        >>>  local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
         >>>  opt = PostLocalSGDOptimizer(
-        >>>      model.parameters(),
-        >>>      optimizer_class=torch.optim.SGD,
-        >>>      averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100),
-        >>>      lr=0.01
+        >>>      optim=local_optim,
+        >>>      averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
         >>>  )
         >>>
         >>>  # In the first 100 steps, DDP runs global gradient averaging at every step.
@@ -59,13 +53,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
 
     def __init__(
         self,
-        params: Iterator[torch.nn.Parameter],
-        optimizer_class: Type[torch.optim.Optimizer],
-        averager: averagers.ModelAverager,
-        **defaults: Any,
+        optim: torch.optim.Optimizer,
+        averager: averagers.ModelAverager
     ):
-        self.params = list(params)
-        self.optim = optimizer_class(iter(self.params), **defaults)
+        self.optim = optim
         self.param_groups = self.optim.param_groups
         self.averager = averager
 
@@ -87,7 +78,11 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
         Performs a single optimization step (parameter update).
         """
         self.optim.step()
-        self.averager.average_parameters(iter(self.params))
+        for param_group in self.param_groups:
+            for params in param_group["params"]:
+                if params.grad is None:
+                    continue
+                self.averager.average_parameters(iter(params))
 
     def zero_grad(self):
         self.optim.zero_grad()
index 613e23e..5f26bf4 100644 (file)
@@ -4625,12 +4625,10 @@ class DistributedTest:
                 gradient_as_bucket_view=grad_is_view,
             )
             post_localSGD_opt = post_localSGD_optimizer.PostLocalSGDOptimizer(
-                params=post_localSGD_net.parameters(),
-                optimizer_class=torch.optim.SGD,
+                optim=torch.optim.SGD(post_localSGD_net.parameters(), lr=learning_rate),
                 averager=averagers.PeriodicModelAverager(
                     period=period, warmup_steps=warmup_steps
-                ),
-                lr=learning_rate,
+                )
             )
 
             input = torch.randn(dist.get_world_size() * 2, 2).cuda()