r"""
Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
and all-reducing gradients locally using ``subgroup`` afterwards.
+
+ If ``process_group`` is ``None``, the global process group will be used.
+ If ``subgroup`` is ``None``, the intra-node process group on each machine will be used.
"""
__slots__ = [
# Run allreduce using `subgroup` after the first `start_localSGD_iter` iterations.
# From this moment, model averaging should run after the optimizer step,
# to globally allreduce all the parameters.
+ if state.subgroup is None:
+ state.subgroup, _ = dist.new_subgroups()
return default._allreduce_fut(state.subgroup, input_tensor)
state=state, hook=post_localSGD.post_localSGD_hook
)
+ # When `subgroup` is None, it is equivalent to the subgroup on the each node.
+ # For this single-node test environment, the intra-node process group is equivalent to
+ # the global process group.
+ if self.world_size == dist.get_world_size():
+ state = post_localSGD.PostLocalSGDState(
+ process_group=None, subgroup=None, start_localSGD_iter=10
+ )
+ self._test_ddp_hook_parity(
+ state=state, hook=post_localSGD.post_localSGD_hook
+ )
+
# Since we start local SGD later than the total number of 100 iterations,
# no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
state = post_localSGD.PostLocalSGDState(