[Reland] Replacing the p.data acccess in utils with tensor.set_ . Passes both test_po...
authorAayush Prakash <aayushp@fb.com>
Wed, 25 Aug 2021 18:11:08 +0000 (11:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 18:12:55 +0000 (11:12 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63895

When updating the model parameter, updating `parameter.data` is no longer recommended, because this `data` field will be deprecated in the future.

The replacement is `tensor.set_`.
ghstack-source-id: 136593433

Test Plan:
buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_periodic_model_averager
buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn -- test_post_localSGD_optimizer_parity

Reviewed By: SciPioneer

Differential Revision: D30526178

fbshipit-source-id: a1ac0ec3665d8623edd5bf94f01c1132daff5c00

torch/distributed/algorithms/model_averaging/utils.py

index 44ee422..ce1fb65 100644 (file)
@@ -29,5 +29,6 @@ def average_parameters(
 
     offset = 0
     for p in params_it2:
-        p.data = flat_params[offset : offset + p.numel()].view_as(p)
+        with torch.no_grad():
+            p.set_(flat_params[offset : offset + p.numel()].view_as(p).type_as(p))  # type: ignore[call-overload]
         offset += p.numel()