# The following modules_params and modules_buffers are used for
# param/buffer sync in _sync_params.
self.modules_params = [list(self._get_parameters(self.module))]
- self._assign_modules_buffers()
-
- return parameters, expect_sparse_gradient
-
- def _assign_modules_buffers(self):
- """
- Assigns module buffers to self.modules_buffers which are then used to
- broadcast across ranks when broadcast_buffers=True. Note that this
- must be called every time buffers need to be synced because buffers can
- be reassigned by user module,
- see https://github.com/pytorch/pytorch/issues/63916.
- """
# Collect buffers for modules, filtering out buffers that should be ignored.
named_module_buffers = [
[
for module_buffers in named_module_buffers
]
+ return parameters, expect_sparse_gradient
def _build_param_to_name_mapping(self, parameters):
param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
else:
# The process with rank 0 is considered the authoritative copy.
authoritative_rank = 0
- # Update self.modules_buffers incase any buffers were
- # reassigned.
- self._assign_modules_buffers()
self._distributed_broadcast_coalesced(
self.modules_buffers[0],
self.broadcast_bucket_size,
)
def test_ddp_new_tensor_in_fwd_static_graph(self):
return self._test_ddp_new_tensor_in_fwd(static_graph=True)
-
- @skip_if_lt_x_gpu(2)
- @sandcastle_skip_if(
- BACKEND != "nccl" and BACKEND != "gloo",
- "Only Nccl & Gloo backend support DistributedDataParallel",
- )
- def test_ddp_broadcast_buffer(self):
- rank = self.rank
- torch.cuda.set_device(rank)
- torch.manual_seed(rank)
- torch.cuda.manual_seed(rank)
-
- class NetWithBuffers(nn.Module):
- def __init__(self):
- super().__init__()
- self.a = nn.Linear(10, 10, bias=False)
- self.b = nn.Linear(10, 1, bias=False)
- self.register_buffer('buffer', torch.randn(1, 2))
-
- def forward(self, x):
- return self.b(self.a(x))
-
- model = NetWithBuffers().cuda(rank)
- model_ddp = torch.nn.parallel.DistributedDataParallel(
- model,
- device_ids=[self.rank],
- )
- inp = torch.randn(2, 10, device=rank)
- for i in range(6):
- if rank == 0:
- model_ddp.module.buffer = model_ddp.module.buffer + 1
- loss = model_ddp(inp).sum()
- loss.backward()
- # Ensure all buffers are synchronized.
- bufs = [None for _ in range(dist.get_world_size())]
- dist.all_gather_object(bufs, model_ddp.module.buffer)
- bufs = [b.cpu() for b in bufs]
- rank_0_buf = bufs[0]
- for buf in bufs[1:]:
- self.assertEqual(rank_0_buf, buf)