# The following modules_params and modules_buffers are used for
# param/buffer sync in _sync_params.
self.modules_params = [list(self._get_parameters(self.module))]
+ self._assign_modules_buffers()
+
+ return parameters, expect_sparse_gradient
+
+ def _assign_modules_buffers(self):
+ """
+ Assigns module buffers to self.modules_buffers which are then used to
+ broadcast across ranks when broadcast_buffers=True. Note that this
+ must be called every time buffers need to be synced because buffers can
+ be reassigned by user module,
+ see https://github.com/pytorch/pytorch/issues/63916.
+ """
# Collect buffers for modules, filtering out buffers that should be ignored.
named_module_buffers = [
[
for module_buffers in named_module_buffers
]
- return parameters, expect_sparse_gradient
def _build_param_to_name_mapping(self, parameters):
param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
else:
# The process with rank 0 is considered the authoritative copy.
authoritative_rank = 0
+ # Update self.modules_buffers incase any buffers were
+ # reassigned.
+ self._assign_modules_buffers()
self._distributed_broadcast_coalesced(
self.modules_buffers[0],
self.broadcast_bucket_size,
)
def test_ddp_new_tensor_in_fwd_static_graph(self):
return self._test_ddp_new_tensor_in_fwd(static_graph=True)
+
+ @skip_if_lt_x_gpu(2)
+ @sandcastle_skip_if(
+ BACKEND != "nccl" and BACKEND != "gloo",
+ "Only Nccl & Gloo backend support DistributedDataParallel",
+ )
+ def test_ddp_broadcast_buffer(self):
+ rank = self.rank
+ torch.cuda.set_device(rank)
+ torch.manual_seed(rank)
+ torch.cuda.manual_seed(rank)
+
+ class NetWithBuffers(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.a = nn.Linear(10, 10, bias=False)
+ self.b = nn.Linear(10, 1, bias=False)
+ self.register_buffer('buffer', torch.randn(1, 2))
+
+ def forward(self, x):
+ return self.b(self.a(x))
+
+ model = NetWithBuffers().cuda(rank)
+ model_ddp = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[self.rank],
+ )
+ inp = torch.randn(2, 10, device=rank)
+ for i in range(6):
+ if rank == 0:
+ model_ddp.module.buffer = model_ddp.module.buffer + 1
+ loss = model_ddp(inp).sum()
+ loss.backward()
+ # Ensure all buffers are synchronized.
+ bufs = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(bufs, model_ddp.module.buffer)
+ bufs = [b.cpu() for b in bufs]
+ rank_0_buf = bufs[0]
+ for buf in bufs[1:]:
+ self.assertEqual(rank_0_buf, buf)