self._barrier()
- # GET default group
- def test_get_default_group(self):
- default_grp = dist.get_default_group()
- self.assertNotEqual(default_grp, None)
-
def test_get_backend(self):
if dist.get_world_size() > 2:
group = [1, 2]
Helper that checks if the current process's rank is not in a given group
"""
- default_backend, _ = _pg_map[get_default_group()]
+ default_backend, _ = _pg_map[_get_default_group()]
if default_backend != Backend.MPI:
return group == GroupMember.NON_GROUP_MEMBER
else:
return _default_pg is not None
-def get_default_group():
+def _get_default_group():
"""
Getting the default process group created by init_process_group
global _default_pg
global _default_pg_init_method
- default_backend, _ = _pg_map[get_default_group()]
+ default_backend, _ = _pg_map[_get_default_group()]
if (default_backend != Backend.MPI and
group == GroupMember.NON_GROUP_MEMBER):
return
from torch.cuda import nccl
import torch.distributed as dist
+if dist.is_available():
+ from torch.distributed.distributed_c10d import _get_default_group
+
from ..modules import Module
from .replicate import replicate
from .scatter_gather import scatter_kwargs, gather
output_device = device_ids[0]
if process_group is None:
- self.process_group = dist.get_default_group()
+ self.process_group = _get_default_group()
else:
self.process_group = process_group
def __setstate__(self, state):
# If serializable, then the process group should be the default one
- self.process_group = dist.get_default_group()
+ self.process_group = _get_default_group()
+ self.check_previous_reduction = False
super(DistributedDataParallel, self).__setstate__(state)
self._ddp_init_helper()
def _check_default_group(self):
pickle_not_supported = False
try:
- if self.process_group != dist.get_default_group():
+ if self.process_group != _get_default_group():
pickle_not_supported = True
except RuntimeError:
pickle_not_supported = True