[Model Averaging] Add a few member methods of PostLocalSGDOptimizer (#63340)
authorYi Wang <wayi@fb.com>
Mon, 16 Aug 2021 23:33:21 +0000 (16:33 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 23:39:01 +0000 (16:39 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63340

Some methods are needed such as accessing optimizer states. These are necessary for integration with PyTorch Lightning.

Proposal: https://github.com/pytorch/pytorch/issues/59699
ghstack-source-id: 135912246

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_hook_parity_post_localSGD

Reviewed By: rohan-varma

Differential Revision: D30328794

fbshipit-source-id: e585b874313bd266fdc7c79936e2af98700c7bad

torch/distributed/optim/post_localSGD_optimizer.py

index 61064dd..8a15c03 100644 (file)
@@ -65,22 +65,32 @@ class PostLocalSGDOptimizer(torch.optim.Optimizer):
         **defaults: Any,
     ):
         self.params = list(params)
-        self.local_optimizer = optimizer_class(iter(self.params), **defaults)
-        self.param_groups = self.local_optimizer.param_groups
+        self.optim = optimizer_class(iter(self.params), **defaults)
+        self.param_groups = self.optim.param_groups
         self.averager = averager
 
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __repr__(self):
+        return self.optim.__repr__()
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self.optim.load_state_dict(state_dict)
+
     def step(self):
         r"""
         Performs a single optimization step (parameter update).
         """
-        self.local_optimizer.step()
+        self.optim.step()
         self.averager.average_parameters(iter(self.params))
 
     def zero_grad(self):
-        self.local_optimizer.zero_grad()
-
-    def state_dict(self):
-        raise NotImplementedError
+        self.optim.zero_grad()
 
-    def load_state_dict(self, state_dict):
-        raise NotImplementedError
+    def add_param_group(self, param_group):
+        self.optim.add_param_group(param_group)