[Model Averaging] Allow subgroup to be None in PostLocalSGDState (#63277)
authorYi Wang <wayi@fb.com>
Mon, 16 Aug 2021 17:05:47 +0000 (10:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 17:07:41 +0000 (10:07 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63277

`PostLocalSGDState` requires a subgroup. To initialize this subgroup, a global process group must be initialized. However, this imposes a restriction that a hook state can only be provided after distributed environment initialization, which is not compatible with lightning DDP plugin setup where hook state should be provided before distributed environment initialization.

Proposal: https://github.com/pytorch/pytorch/issues/59699
ghstack-source-id: 135848575

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_hook_parity_post_localSGD

Reviewed By: cbalioglu

Differential Revision: D30325041

fbshipit-source-id: 7b870166d096d306c3f2f7c69816a705cec0bebd

torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
torch/testing/_internal/distributed/distributed_test.py

index d1669e6..927030e 100644 (file)
@@ -10,6 +10,9 @@ class PostLocalSGDState(object):
     r"""
     Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
     and all-reducing gradients locally using ``subgroup`` afterwards.
+
+    If ``process_group`` is ``None``, the global process group will be used.
+    If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
     """
 
     __slots__ = [
@@ -91,4 +94,6 @@ def post_localSGD_hook(
     # Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
     # From this moment, model averaging should run after the optimizer step,
     # to globally allreduce all the parameters.
+    if state.subgroup is None:
+        state.subgroup, _ = dist.new_subgroups()
     return default._allreduce_fut(state.subgroup, input_tensor)
index 12c88ab..d7bf0ca 100644 (file)
@@ -4173,6 +4173,17 @@ class DistributedTest:
                 state=state, hook=post_localSGD.post_localSGD_hook
             )
 
+            # When `subgroup` is None, it is equivalent to the subgroup on the each node.
+            # For this single-node test environment, the intra-node process group is equivalent to
+            # the global process group.
+            if self.world_size == dist.get_world_size():
+                state = post_localSGD.PostLocalSGDState(
+                    process_group=None, subgroup=None, start_localSGD_iter=10
+                )
+                self._test_ddp_hook_parity(
+                    state=state, hook=post_localSGD.post_localSGD_hook
+                )
+
             # Since we start local SGD later than the total number of 100 iterations,
             # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
             state = post_localSGD.PostLocalSGDState(