From: Yi Wang Date: Mon, 16 Aug 2021 23:33:21 +0000 (-0700) Subject: [Model Averaging] Add a few member methods of PostLocalSGDOptimizer (#63340) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~981 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=068d6fec5cb36cecd757f769c2b245a2c77555af;p=platform%2Fupstream%2Fpytorch.git [Model Averaging] Add a few member methods of PostLocalSGDOptimizer (#63340) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63340 Some methods are needed such as accessing optimizer states. These are necessary for integration with PyTorch Lightning. Proposal: https://github.com/pytorch/pytorch/issues/59699 ghstack-source-id: 135912246 Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_hook_parity_post_localSGD Reviewed By: rohan-varma Differential Revision: D30328794 fbshipit-source-id: e585b874313bd266fdc7c79936e2af98700c7bad --- diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index 61064dd..8a15c03 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -65,22 +65,32 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer): **defaults: Any, ): self.params = list(params) - self.local_optimizer = optimizer_class(iter(self.params), **defaults) - self.param_groups = self.local_optimizer.param_groups + self.optim = optimizer_class(iter(self.params), **defaults) + self.param_groups = self.optim.param_groups self.averager = averager + @property + def state(self): + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + return self.optim.state_dict() + + def load_state_dict(self, state_dict): + self.optim.load_state_dict(state_dict) + def step(self): r""" Performs a single optimization step (parameter update). """ - self.local_optimizer.step() + self.optim.step() self.averager.average_parameters(iter(self.params)) def zero_grad(self): - self.local_optimizer.zero_grad() - - def state_dict(self): - raise NotImplementedError + self.optim.zero_grad() - def load_state_dict(self, state_dict): - raise NotImplementedError + def add_param_group(self, param_group): + self.optim.add_param_group(param_group)