-from typing import Any, Iterator, Type
-
import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers
After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
Args:
- params: All the parameters.
- optimizer_class: The class of the local optimizer.
+ optim: The local optimizer.
averager: A model averager instance to run post-localSGD algorithm.
- **defaults: A dict containing default values of optimization options,
- which are forwarded to the local optimizer.
Example::
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
+ >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
- >>> model.parameters(),
- >>> optimizer_class=torch.optim.SGD,
- >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100),
- >>> lr=0.01
+ >>> optim=local_optim,
+ >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
def __init__(
self,
- params: Iterator[torch.nn.Parameter],
- optimizer_class: Type[torch.optim.Optimizer],
- averager: averagers.ModelAverager,
- **defaults: Any,
+ optim: torch.optim.Optimizer,
+ averager: averagers.ModelAverager
):
- self.params = list(params)
- self.optim = optimizer_class(iter(self.params), **defaults)
+ self.optim = optim
self.param_groups = self.optim.param_groups
self.averager = averager
Performs a single optimization step (parameter update).
"""
self.optim.step()
- self.averager.average_parameters(iter(self.params))
+ for param_group in self.param_groups:
+ for params in param_group["params"]:
+ if params.grad is None:
+ continue
+ self.averager.average_parameters(iter(params))
def zero_grad(self):
self.optim.zero_grad()
gradient_as_bucket_view=grad_is_view,
)
post_localSGD_opt = post_localSGD_optimizer.PostLocalSGDOptimizer(
- params=post_localSGD_net.parameters(),
- optimizer_class=torch.optim.SGD,
+ optim=torch.optim.SGD(post_localSGD_net.parameters(), lr=learning_rate),
averager=averagers.PeriodicModelAverager(
period=period, warmup_steps=warmup_steps
- ),
- lr=learning_rate,
+ )
)
input = torch.randn(dist.get_world_size() * 2, 2).cuda()