From 45bd0f61818cd41f9f5548d3120640986a1b9165 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 17 Sep 2021 12:23:09 -0700 Subject: [PATCH] Back out "Revert D30745960: [DDP] Remove SPMD from self.modules_buffers" (#64778) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64778 Original commit changeset: d3f3fb813c45 ghstack-source-id: 138326910 Test Plan: ci Reviewed By: H-Huang Differential Revision: D30849443 fbshipit-source-id: 15dab8a959a29d2e2fefac6ad52b8d8168eacc02 --- torch/nn/parallel/distributed.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 734d42c..76260aa 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -748,21 +748,15 @@ class DistributedDataParallel(Module, Joinable): """ # Collect buffers for modules, filtering out buffers that should be ignored. named_module_buffers = [ - [ - (buffer, buffer_name) - for buffer_name, buffer in self.module.named_buffers() - ] + (buffer, buffer_name) + for buffer_name, buffer in self.module.named_buffers() ] self.modules_buffers = [ - [ - buffer - for (buffer, buffer_name) in module_buffers - if buffer_name not in self.parameters_to_ignore - ] - for module_buffers in named_module_buffers + buffer + for (buffer, buffer_name) in named_module_buffers + if buffer_name not in self.parameters_to_ignore ] - def _build_param_to_name_mapping(self, parameters): param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))} param_set = set(parameters[0]) @@ -1039,7 +1033,7 @@ class DistributedDataParallel(Module, Joinable): if self.will_sync_module_buffers(): authoritative_rank = self._find_common_rank(self._distributed_rank, False) self._distributed_broadcast_coalesced( - self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank + self.modules_buffers, self.broadcast_bucket_size, authoritative_rank ) # When running in join model, agrees upon a common rank and broadcast model @@ -1345,7 +1339,7 @@ class DistributedDataParallel(Module, Joinable): return ( self.require_forward_param_sync and self.broadcast_buffers - and len(self.modules_buffers[0]) > 0 + and len(self.modules_buffers) > 0 ) def _find_common_rank(self, input_rank, rank_cond): @@ -1383,7 +1377,7 @@ class DistributedDataParallel(Module, Joinable): # reassigned. self._assign_modules_buffers() self._distributed_broadcast_coalesced( - self.modules_buffers[0], + self.modules_buffers, self.broadcast_bucket_size, authoritative_rank, ) -- 2.7.4