[DDP] Fix when buffers are reassigned in module (#64472)
authorRohan Varma <rvarm1@fb.com>
Thu, 9 Sep 2021 02:13:33 +0000 (19:13 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 02:14:55 +0000 (19:14 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64472

Sometimes, user module can reassign tensor buffer, as in:

```
self.buffer = torch.randn(1, 2) # in init
self.buffer += 1 # in forward
```

in this case, `self.modules_buffers` will become outdated and we should
repopulate self.modules_buffers if we need to sync module buffers.

See https://github.com/pytorch/pytorch/issues/63916 for full description of the
issue.
ghstack-source-id: 137526309

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D30745921

fbshipit-source-id: 25eb1edbf445703a481802e07f3058d38ea6fc64

torch/nn/parallel/distributed.py
torch/testing/_internal/distributed/distributed_test.py

index 60d2143..a1e24b1 100644 (file)
@@ -737,6 +737,18 @@ class DistributedDataParallel(Module, Joinable):
         # The following modules_params and modules_buffers are used for
         # param/buffer sync in _sync_params.
         self.modules_params = [list(self._get_parameters(self.module))]
+        self._assign_modules_buffers()
+
+        return parameters, expect_sparse_gradient
+
+    def _assign_modules_buffers(self):
+        """
+        Assigns module buffers to self.modules_buffers which are then used to
+        broadcast across ranks when broadcast_buffers=True. Note that this
+        must be called every time buffers need to be synced because buffers can
+        be reassigned by user module,
+        see https://github.com/pytorch/pytorch/issues/63916.
+        """
         # Collect buffers for modules, filtering out buffers that should be ignored.
         named_module_buffers = [
             [
@@ -753,7 +765,6 @@ class DistributedDataParallel(Module, Joinable):
             for module_buffers in named_module_buffers
         ]
 
-        return parameters, expect_sparse_gradient
 
     def _build_param_to_name_mapping(self, parameters):
         param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
@@ -1371,6 +1382,9 @@ class DistributedDataParallel(Module, Joinable):
                 else:
                     # The process with rank 0 is considered the authoritative copy.
                     authoritative_rank = 0
+                # Update self.modules_buffers incase any buffers were
+                # reassigned.
+                self._assign_modules_buffers()
                 self._distributed_broadcast_coalesced(
                     self.modules_buffers[0],
                     self.broadcast_bucket_size,
index 613e23e..e40d6e7 100644 (file)
@@ -7924,3 +7924,43 @@ class DistributedTest:
         )
         def test_ddp_new_tensor_in_fwd_static_graph(self):
             return self._test_ddp_new_tensor_in_fwd(static_graph=True)
+
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(
+            BACKEND != "nccl" and BACKEND != "gloo",
+            "Only Nccl & Gloo backend support DistributedDataParallel",
+        )
+        def test_ddp_broadcast_buffer(self):
+            rank = self.rank
+            torch.cuda.set_device(rank)
+            torch.manual_seed(rank)
+            torch.cuda.manual_seed(rank)
+
+            class NetWithBuffers(nn.Module):
+                def __init__(self):
+                    super().__init__()
+                    self.a = nn.Linear(10, 10, bias=False)
+                    self.b = nn.Linear(10, 1, bias=False)
+                    self.register_buffer('buffer', torch.randn(1, 2))
+
+                def forward(self, x):
+                    return self.b(self.a(x))
+
+            model = NetWithBuffers().cuda(rank)
+            model_ddp = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[self.rank],
+            )
+            inp = torch.randn(2, 10, device=rank)
+            for i in range(6):
+                if rank == 0:
+                    model_ddp.module.buffer = model_ddp.module.buffer + 1
+                loss = model_ddp(inp).sum()
+                loss.backward()
+                # Ensure all buffers are synchronized.
+                bufs = [None for _ in range(dist.get_world_size())]
+                dist.all_gather_object(bufs, model_ddp.module.buffer)
+                bufs = [b.cpu() for b in bufs]
+                rank_0_buf = bufs[0]
+                for buf in bufs[1:]:
+                    self.assertEqual(rank_0_buf, buf)