[Reland] [Model Averaging] Simplify PostLocalSGD Optimizer API (#65197)
authorYi Wang <wayi@fb.com>
Fri, 17 Sep 2021 17:00:13 +0000 (10:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 17 Sep 2021 17:31:58 +0000 (10:31 -0700)
commitc1415a0a72a09e16a93bb1900eb1a8541bb448d6
tree0d230433eb53f6f7909787b381369aa4c4ba8e04
parent752a8202303089386a3973dee753dc78d2a657e2
[Reland] [Model Averaging] Simplify PostLocalSGD Optimizer API (#65197)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65197

1. The constructor accepts a local optimizer instance instead of the inputs of local optimizer constructor and the class type.
2. The parameters are read from local optimizer's param_groups instead of a separate input.

Proposal: https://github.com/pytorch/pytorch/issues/59699
ghstack-source-id: 138307226

Test Plan: buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity

Reviewed By: rohan-varma

Differential Revision: D31007439

fbshipit-source-id: bbb0526e6763ef76775b85088571506b3942c722
torch/distributed/algorithms/model_averaging/utils.py
torch/distributed/optim/post_localSGD_optimizer.py
torch/testing/_internal/distributed/distributed_test.py