From 2d3cf98b49541b0ba4d9db3b2f86cbbf3c5c71a4 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Tue, 4 Dec 2018 19:20:08 -0800 Subject: [PATCH] Making dist.get_default_group private for PT1 release (#14767) Summary: When I wrote the frontend API, it is designed on not letting users use the default_group directly on any functions. It should really be private. All collectives are supposed to either use group.WORLD, or anything that comes out of new_group. That was the initial design. We need to make a TODO on removing group.WORLD one day. It exists for backward compatibility reasons and adds lots of complexity. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14767 Reviewed By: pietern Differential Revision: D13330655 Pulled By: teng-li fbshipit-source-id: ace107e1c3a9b3910a300b22815a9e8096fafb1c --- docs/source/distributed.rst | 2 -- test/test_distributed.py | 5 ----- torch/distributed/distributed_c10d.py | 6 +++--- torch/nn/parallel/distributed.py | 10 +++++++--- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 838a1a2..d8c39e1 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -162,8 +162,6 @@ joined. .. autofunction:: is_initialized -.. autofunction:: get_default_group - .. autofunction:: is_mpi_available .. autofunction:: is_nccl_available diff --git a/test/test_distributed.py b/test/test_distributed.py index fa4d679..8df327b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -265,11 +265,6 @@ class _DistTestBase(object): self._barrier() - # GET default group - def test_get_default_group(self): - default_grp = dist.get_default_group() - self.assertNotEqual(default_grp, None) - def test_get_backend(self): if dist.get_world_size() > 2: group = [1, 2] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 0caffca..da91cbb 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -129,7 +129,7 @@ def _rank_not_in_group(group): Helper that checks if the current process's rank is not in a given group """ - default_backend, _ = _pg_map[get_default_group()] + default_backend, _ = _pg_map[_get_default_group()] if default_backend != Backend.MPI: return group == GroupMember.NON_GROUP_MEMBER else: @@ -249,7 +249,7 @@ def is_initialized(): return _default_pg is not None -def get_default_group(): +def _get_default_group(): """ Getting the default process group created by init_process_group @@ -447,7 +447,7 @@ def destroy_process_group(group=group.WORLD): global _default_pg global _default_pg_init_method - default_backend, _ = _pg_map[get_default_group()] + default_backend, _ = _pg_map[_get_default_group()] if (default_backend != Backend.MPI and group == GroupMember.NON_GROUP_MEMBER): return diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 2664fbc..f1a1593 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -6,6 +6,9 @@ from torch.cuda.comm import broadcast_coalesced from torch.cuda import nccl import torch.distributed as dist +if dist.is_available(): + from torch.distributed.distributed_c10d import _get_default_group + from ..modules import Module from .replicate import replicate from .scatter_gather import scatter_kwargs, gather @@ -186,7 +189,7 @@ class DistributedDataParallel(Module): output_device = device_ids[0] if process_group is None: - self.process_group = dist.get_default_group() + self.process_group = _get_default_group() else: self.process_group = process_group @@ -308,14 +311,15 @@ class DistributedDataParallel(Module): def __setstate__(self, state): # If serializable, then the process group should be the default one - self.process_group = dist.get_default_group() + self.process_group = _get_default_group() + self.check_previous_reduction = False super(DistributedDataParallel, self).__setstate__(state) self._ddp_init_helper() def _check_default_group(self): pickle_not_supported = False try: - if self.process_group != dist.get_default_group(): + if self.process_group != _get_default_group(): pickle_not_supported = True except RuntimeError: pickle_not_supported = True -- 2.7.4