From d59ecc02df70bad2273858c2fad2b4993133a3d3 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 8 Sep 2021 19:13:33 -0700 Subject: [PATCH] [DDP] Fix when buffers are reassigned in module (#64472) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64472 Sometimes, user module can reassign tensor buffer, as in: ``` self.buffer = torch.randn(1, 2) # in init self.buffer += 1 # in forward ``` in this case, `self.modules_buffers` will become outdated and we should repopulate self.modules_buffers if we need to sync module buffers. See https://github.com/pytorch/pytorch/issues/63916 for full description of the issue. ghstack-source-id: 137526309 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D30745921 fbshipit-source-id: 25eb1edbf445703a481802e07f3058d38ea6fc64 --- torch/nn/parallel/distributed.py | 16 ++++++++- .../_internal/distributed/distributed_test.py | 40 ++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 60d2143..a1e24b1 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -737,6 +737,18 @@ class DistributedDataParallel(Module, Joinable): # The following modules_params and modules_buffers are used for # param/buffer sync in _sync_params. self.modules_params = [list(self._get_parameters(self.module))] + self._assign_modules_buffers() + + return parameters, expect_sparse_gradient + + def _assign_modules_buffers(self): + """ + Assigns module buffers to self.modules_buffers which are then used to + broadcast across ranks when broadcast_buffers=True. Note that this + must be called every time buffers need to be synced because buffers can + be reassigned by user module, + see https://github.com/pytorch/pytorch/issues/63916. + """ # Collect buffers for modules, filtering out buffers that should be ignored. named_module_buffers = [ [ @@ -753,7 +765,6 @@ class DistributedDataParallel(Module, Joinable): for module_buffers in named_module_buffers ] - return parameters, expect_sparse_gradient def _build_param_to_name_mapping(self, parameters): param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))} @@ -1371,6 +1382,9 @@ class DistributedDataParallel(Module, Joinable): else: # The process with rank 0 is considered the authoritative copy. authoritative_rank = 0 + # Update self.modules_buffers incase any buffers were + # reassigned. + self._assign_modules_buffers() self._distributed_broadcast_coalesced( self.modules_buffers[0], self.broadcast_bucket_size, diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 613e23e..e40d6e7 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7924,3 +7924,43 @@ class DistributedTest: ) def test_ddp_new_tensor_in_fwd_static_graph(self): return self._test_ddp_new_tensor_in_fwd(static_graph=True) + + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if( + BACKEND != "nccl" and BACKEND != "gloo", + "Only Nccl & Gloo backend support DistributedDataParallel", + ) + def test_ddp_broadcast_buffer(self): + rank = self.rank + torch.cuda.set_device(rank) + torch.manual_seed(rank) + torch.cuda.manual_seed(rank) + + class NetWithBuffers(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Linear(10, 10, bias=False) + self.b = nn.Linear(10, 1, bias=False) + self.register_buffer('buffer', torch.randn(1, 2)) + + def forward(self, x): + return self.b(self.a(x)) + + model = NetWithBuffers().cuda(rank) + model_ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + inp = torch.randn(2, 10, device=rank) + for i in range(6): + if rank == 0: + model_ddp.module.buffer = model_ddp.module.buffer + 1 + loss = model_ddp(inp).sum() + loss.backward() + # Ensure all buffers are synchronized. + bufs = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(bufs, model_ddp.module.buffer) + bufs = [b.cpu() for b in bufs] + rank_0_buf = bufs[0] + for buf in bufs[1:]: + self.assertEqual(rank_0_buf, buf) -- 2.7.4