From 8800a8b4281ca80f35dc338bccb40a678710509d Mon Sep 17 00:00:00 2001 From: Alban Desmaison Date: Thu, 16 Sep 2021 06:36:29 -0700 Subject: [PATCH] Revert D30888794: [Model Averaging] Simplify PostLocalSGD Optimizer API Test Plan: revert-hammer Differential Revision: D30888794 (https://github.com/pytorch/pytorch/commit/3d312b3b8ee90f8b289c7d5601a13d0521b46b7e) Original commit changeset: 21261b480f6b fbshipit-source-id: 87abb7e8cd9ecaac909ec6c3ee053fa7c4ae1975 --- torch/distributed/optim/post_localSGD_optimizer.py | 29 +++++++++++++--------- .../_internal/distributed/distributed_test.py | 6 +++-- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index 2d1c861..8a15c03 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -1,3 +1,5 @@ +from typing import Any, Iterator, Type + import torch import torch.distributed.algorithms.model_averaging.averagers as averagers @@ -9,8 +11,11 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): After the warm-up stage, it averages parameters periodically afer the local optimizer is applied. Args: - optim: The local optimizer. + params: All the parameters. + optimizer_class: The class of the local optimizer. averager: A model averager instance to run post-localSGD algorithm. + **defaults: A dict containing default values of optimization options, + which are forwarded to the local optimizer. Example:: @@ -32,10 +37,11 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): >>> # Create a post-localSGD optimizer that wraps a local optimizer. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. - >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( - >>> optim=local_optim, - >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> model.parameters(), + >>> optimizer_class=torch.optim.SGD, + >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100), + >>> lr=0.01 >>> ) >>> >>> # In the first 100 steps, DDP runs global gradient averaging at every step. @@ -53,10 +59,13 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): def __init__( self, - optim: torch.optim.Optimizer, - averager: averagers.ModelAverager + params: Iterator[torch.nn.Parameter], + optimizer_class: Type[torch.optim.Optimizer], + averager: averagers.ModelAverager, + **defaults: Any, ): - self.optim = optim + self.params = list(params) + self.optim = optimizer_class(iter(self.params), **defaults) self.param_groups = self.optim.param_groups self.averager = averager @@ -78,11 +87,7 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): Performs a single optimization step (parameter update). """ self.optim.step() - for param_group in self.param_groups: - for params in param_group["params"]: - if params.grad is None: - continue - self.averager.average_parameters(iter(params)) + self.averager.average_parameters(iter(self.params)) def zero_grad(self): self.optim.zero_grad() diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5f26bf4..613e23e 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -4625,10 +4625,12 @@ class DistributedTest: gradient_as_bucket_view=grad_is_view, ) post_localSGD_opt = post_localSGD_optimizer.PostLocalSGDOptimizer( - optim=torch.optim.SGD(post_localSGD_net.parameters(), lr=learning_rate), + params=post_localSGD_net.parameters(), + optimizer_class=torch.optim.SGD, averager=averagers.PeriodicModelAverager( period=period, warmup_steps=warmup_steps - ) + ), + lr=learning_rate, ) input = torch.randn(dist.get_world_size() * 2, 2).cuda() -- 2.7.4