Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64903
Fix the accuracy regression caused by https://github.com/pytorch/pytorch/pull/63895.
Test Plan:
buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_periodic_model_averager
buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity
Reviewed By: rohan-varma
Differential Revision:
D30894688
fbshipit-source-id:
fe00b8b23b860d9f806f87c1b6caba1d0b807485
offset = 0
for p in params_it2:
- with torch.no_grad():
- p.set_(flat_params[offset : offset + p.numel()].view_as(p).type_as(p)) # type: ignore[call-overload]
+ p.data = flat_params[offset : offset + p.numel()].view_as(p)
offset += p.numel()