[DDP] Remove SPMD from self.modules_buffers (#64474)
authorRohan Varma <rvarm1@fb.com>
Thu, 9 Sep 2021 02:13:33 +0000 (19:13 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 02:16:15 +0000 (19:16 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64474

No need for a nested list here.
ghstack-source-id: 137526312

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D30745960

fbshipit-source-id: 66a8f9847e9fe1e02c51b79647e93bf7665cf4d9

torch/nn/parallel/distributed.py

index 734d42c..76260aa 100644 (file)
@@ -748,21 +748,15 @@ class DistributedDataParallel(Module, Joinable):
         """
         # Collect buffers for modules, filtering out buffers that should be ignored.
         named_module_buffers = [
-            [
-                (buffer, buffer_name)
-                for buffer_name, buffer in self.module.named_buffers()
-            ]
+            (buffer, buffer_name)
+            for buffer_name, buffer in self.module.named_buffers()
         ]
         self.modules_buffers = [
-            [
-                buffer
-                for (buffer, buffer_name) in module_buffers
-                if buffer_name not in self.parameters_to_ignore
-            ]
-            for module_buffers in named_module_buffers
+            buffer
+            for (buffer, buffer_name) in named_module_buffers
+            if buffer_name not in self.parameters_to_ignore
         ]
 
-
     def _build_param_to_name_mapping(self, parameters):
         param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
         param_set = set(parameters[0])
@@ -1039,7 +1033,7 @@ class DistributedDataParallel(Module, Joinable):
         if self.will_sync_module_buffers():
             authoritative_rank = self._find_common_rank(self._distributed_rank, False)
             self._distributed_broadcast_coalesced(
-                self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank
+                self.modules_buffers, self.broadcast_bucket_size, authoritative_rank
             )
 
     # When running in join model, agrees upon a common rank and broadcast model
@@ -1345,7 +1339,7 @@ class DistributedDataParallel(Module, Joinable):
         return (
             self.require_forward_param_sync
             and self.broadcast_buffers
-            and len(self.modules_buffers[0]) > 0
+            and len(self.modules_buffers) > 0
         )
 
     def _find_common_rank(self, input_rank, rank_cond):
@@ -1383,7 +1377,7 @@ class DistributedDataParallel(Module, Joinable):
                 # reassigned.
                 self._assign_modules_buffers()
                 self._distributed_broadcast_coalesced(
-                    self.modules_buffers[0],
+                    self.modules_buffers,
                     self.broadcast_bucket_size,
                     authoritative_rank,
                 )