Revert D30745960: [DDP] Remove SPMD from self.modules_buffers
authorHoward Huang <howardhuang@fb.com>
Thu, 9 Sep 2021 15:20:40 +0000 (08:20 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 15:22:12 +0000 (08:22 -0700)
Test Plan: revert-hammer

Differential Revision:
D30745960 (https://github.com/pytorch/pytorch/commit/15532595209d2daf34d35e10f8d3d3b64966aea2)

Original commit changeset: 66a8f9847e9f

fbshipit-source-id: d3f3fb813c45ac1b0ff15c6154b2e99e5dbab433

torch/nn/parallel/distributed.py

index 76260aa..734d42c 100644 (file)
@@ -748,15 +748,21 @@ class DistributedDataParallel(Module, Joinable):
         """
         # Collect buffers for modules, filtering out buffers that should be ignored.
         named_module_buffers = [
-            (buffer, buffer_name)
-            for buffer_name, buffer in self.module.named_buffers()
+            [
+                (buffer, buffer_name)
+                for buffer_name, buffer in self.module.named_buffers()
+            ]
         ]
         self.modules_buffers = [
-            buffer
-            for (buffer, buffer_name) in named_module_buffers
-            if buffer_name not in self.parameters_to_ignore
+            [
+                buffer
+                for (buffer, buffer_name) in module_buffers
+                if buffer_name not in self.parameters_to_ignore
+            ]
+            for module_buffers in named_module_buffers
         ]
 
+
     def _build_param_to_name_mapping(self, parameters):
         param_to_param_index = {parameters[0][i]: i for i in range(len(parameters[0]))}
         param_set = set(parameters[0])
@@ -1033,7 +1039,7 @@ class DistributedDataParallel(Module, Joinable):
         if self.will_sync_module_buffers():
             authoritative_rank = self._find_common_rank(self._distributed_rank, False)
             self._distributed_broadcast_coalesced(
-                self.modules_buffers, self.broadcast_bucket_size, authoritative_rank
+                self.modules_buffers[0], self.broadcast_bucket_size, authoritative_rank
             )
 
     # When running in join model, agrees upon a common rank and broadcast model
@@ -1339,7 +1345,7 @@ class DistributedDataParallel(Module, Joinable):
         return (
             self.require_forward_param_sync
             and self.broadcast_buffers
-            and len(self.modules_buffers) > 0
+            and len(self.modules_buffers[0]) > 0
         )
 
     def _find_common_rank(self, input_rank, rank_cond):
@@ -1377,7 +1383,7 @@ class DistributedDataParallel(Module, Joinable):
                 # reassigned.
                 self._assign_modules_buffers()
                 self._distributed_broadcast_coalesced(
-                    self.modules_buffers,
+                    self.modules_buffers[0],
                     self.broadcast_bucket_size,
                     authoritative_rank,
                 )