Revert D30745921: [DDP] Fix when buffers are reassigned in module
authorHoward Huang <howardhuang@fb.com>
Thu, 9 Sep 2021 15:20:40 +0000 (08:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 15:23:16 +0000 (08:23 -0700)
Test Plan: revert-hammer

Differential Revision:
D30745921 (https://github.com/pytorch/pytorch/commit/d59ecc02df70bad2273858c2fad2b4993133a3d3)

Original commit changeset: 25eb1edbf445

fbshipit-source-id: 343ead86bf1e2d0b2d4124be331ea2fa437303ad

torch/nn/parallel/distributed.py
torch/testing/_internal/distributed/distributed_test.py

index a1e24b1..60d2143 100644 (file)
@@ -737,18 +737,6 @@ class DistributedDataParallel(Module, Joinable):
         # The following modules_params and modules_buffers are used for
         # param/buffer sync in _sync_params.
         self.modules_params = [list(self._get_parameters(self.module))]
-        self._assign_modules_buffers()
-
-        return parameters, expect_sparse_gradient
-
-    def _assign_modules_buffers(self):
-        """
-        Assigns module buffers to self.modules_buffers which are then used to
-        broadcast across ranks when broadcast_buffers=True. Note that this
-        must be called every time buffers need to be synced because buffers can
-        be reassigned by user module,
-        see https://github.com/pytorch/pytorch/issues/63916.
-        """
         # Collect buffers for modules, filtering out buffers that should be ignored.
         named_module_buffers = [
             [
@@ -765,6 +753,7 @@ class DistributedDataParallel(Module, Joinable):
             for module_buffers in named_module_buffers
         ]
 
+        return parameters, expect_sparse_gradient
 
     def _build_param_to_name_mapping(self, parameters):
         param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
@@ -1382,9 +1371,6 @@ class DistributedDataParallel(Module, Joinable):
                 else:
                     # The process with rank 0 is considered the authoritative copy.
                     authoritative_rank = 0
-                # Update self.modules_buffers incase any buffers were
-                # reassigned.
-                self._assign_modules_buffers()
                 self._distributed_broadcast_coalesced(
                     self.modules_buffers[0],
                     self.broadcast_bucket_size,
index e40d6e7..613e23e 100644 (file)
@@ -7924,43 +7924,3 @@ class DistributedTest:
         )
         def test_ddp_new_tensor_in_fwd_static_graph(self):
             return self._test_ddp_new_tensor_in_fwd(static_graph=True)
-
-        @skip_if_lt_x_gpu(2)
-        @sandcastle_skip_if(
-            BACKEND != "nccl" and BACKEND != "gloo",
-            "Only Nccl & Gloo backend support DistributedDataParallel",
-        )
-        def test_ddp_broadcast_buffer(self):
-            rank = self.rank
-            torch.cuda.set_device(rank)
-            torch.manual_seed(rank)
-            torch.cuda.manual_seed(rank)
-
-            class NetWithBuffers(nn.Module):
-                def __init__(self):
-                    super().__init__()
-                    self.a = nn.Linear(10, 10, bias=False)
-                    self.b = nn.Linear(10, 1, bias=False)
-                    self.register_buffer('buffer', torch.randn(1, 2))
-
-                def forward(self, x):
-                    return self.b(self.a(x))
-
-            model = NetWithBuffers().cuda(rank)
-            model_ddp = torch.nn.parallel.DistributedDataParallel(
-                model,
-                device_ids=[self.rank],
-            )
-            inp = torch.randn(2, 10, device=rank)
-            for i in range(6):
-                if rank == 0:
-                    model_ddp.module.buffer = model_ddp.module.buffer + 1
-                loss = model_ddp(inp).sum()
-                loss.backward()
-                # Ensure all buffers are synchronized.
-                bufs = [None for _ in range(dist.get_world_size())]
-                dist.all_gather_object(bufs, model_ddp.module.buffer)
-                bufs = [b.cpu() for b in bufs]
-                rank_0_buf = bufs[0]
-                for buf in bufs[1:]:
-                    self.assertEqual(rank_0_buf, buf)