From: Yi Wang Date: Tue, 14 Sep 2021 23:35:32 +0000 (-0700) Subject: [Model Averaging] Simplify PostLocalSGD Optimizer API (#64885) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~210 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3d312b3b8ee90f8b289c7d5601a13d0521b46b7e;p=platform%2Fupstream%2Fpytorch.git [Model Averaging] Simplify PostLocalSGD Optimizer API (#64885) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64885 1) The constructor accepts a local optimizer instance instead of the inputs of local optimizer constructor and the class type. 2) The parameters are read from local optimizer's `param_groups` instead of a separate input. Proposal: https://github.com/pytorch/pytorch/issues/59699 ghstack-source-id: 137865867 Test Plan: buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity Reviewed By: rohan-varma Differential Revision: D30888794 fbshipit-source-id: 21261b480f6bbb9b2333426020e3f350da3f73c2 --- diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index 8a15c03..2d1c861 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -1,5 +1,3 @@ -from typing import Any, Iterator, Type - import torch import torch.distributed.algorithms.model_averaging.averagers as averagers @@ -11,11 +9,8 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. Args: - params: All the parameters. - optimizer_class: The class of the local optimizer. + optim: The local optimizer. averager: A model averager instance to run post-localSGD algorithm. - **defaults: A dict containing default values of optimization options, - which are forwarded to the local optimizer. Example:: @@ -37,11 +32,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( - >>> model.parameters(), - >>> optimizer_class=torch.optim.SGD, - >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100), - >>> lr=0.01 + >>> optim=local_optim, + >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. @@ -59,13 +53,10 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): def __init__( self, - params: Iterator[torch.nn.Parameter], - optimizer_class: Type[torch.optim.Optimizer], - averager: averagers.ModelAverager, - **defaults: Any, + optim: torch.optim.Optimizer, + averager: averagers.ModelAverager ): - self.params = list(params) - self.optim = optimizer_class(iter(self.params), **defaults) + self.optim = optim self.param_groups = self.optim.param_groups self.averager = averager @@ -87,7 +78,11 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): Performs a single optimization step (parameter update). """ self.optim.step() - self.averager.average_parameters(iter(self.params)) + for param_group in self.param_groups: + for params in param_group["params"]: + if params.grad is None: + continue + self.averager.average_parameters(iter(params)) def zero_grad(self): self.optim.zero_grad() diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 613e23e..5f26bf4 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4625,12 +4625,10 @@ class DistributedTest: gradient_as_bucket_view=grad_is_view, ) post_localSGD_opt = post_localSGD_optimizer.PostLocalSGDOptimizer( - params=post_localSGD_net.parameters(), - optimizer_class=torch.optim.SGD, + optim=torch.optim.SGD(post_localSGD_net.parameters(), lr=learning_rate), averager=averagers.PeriodicModelAverager( period=period, warmup_steps=warmup_steps - ), - lr=learning_rate, + ) ) input = torch.randn(dist.get_world_size() * 2, 2).cuda()