From f1aaf8afcd7cc9ed8745aba48d091d0638d4afe3 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 9 Sep 2021 08:20:40 -0700 Subject: [PATCH] Revert D30745960: [DDP] Remove SPMD from self.modules_buffers Test Plan: revert-hammer Differential Revision: D30745960 (https://github.com/pytorch/pytorch/commit/15532595209d2daf34d35e10f8d3d3b64966aea2) Original commit changeset: 66a8f9847e9f fbshipit-source-id: d3f3fb813c45ac1b0ff15c6154b2e99e5dbab433 --- torch/nn/parallel/distributed.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 76260aa..734d42c 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -748,15 +748,21 @@ class DistributedDataParallel(Module, Joinable): """ # Collect buffers for modules, filtering out buffers that should be ignored. named_module_buffers = [ - (buffer, buffer_name) - for buffer_name, buffer in self.module.named_buffers() + [ + (buffer, buffer_name) + for buffer_name, buffer in self.module.named_buffers() + ] ] self.modules_buffers = [ - buffer - for (buffer, buffer_name) in named_module_buffers - if buffer_name not in self.parameters_to_ignore + [ + buffer + for (buffer, buffer_name) in module_buffers + if buffer_name not in self.parameters_to_ignore + ] + for module_buffers in named_module_buffers ] + def _build_param_to_name_mapping(self, parameters): param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))} param_set = set(parameters[0]) @@ -1033,7 +1039,7 @@ class DistributedDataParallel(Module, Joinable): if self.will_sync_module_buffers(): authoritative_rank = self._find_common_rank(self._distributed_rank, False) self._distributed_broadcast_coalesced( - self.modules_buffers, self.broadcast_bucket_size, authoritative_rank + self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank ) # When running in join model, agrees upon a common rank and broadcast model @@ -1339,7 +1345,7 @@ class DistributedDataParallel(Module, Joinable): return ( self.require_forward_param_sync and self.broadcast_buffers - and len(self.modules_buffers) > 0 + and len(self.modules_buffers[0]) > 0 ) def _find_common_rank(self, input_rank, rank_cond): @@ -1377,7 +1383,7 @@ class DistributedDataParallel(Module, Joinable): # reassigned. self._assign_modules_buffers() self._distributed_broadcast_coalesced( - self.modules_buffers, + self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank, ) -- 2.7.4