Making dist.get_default_group private for PT1 release (#14767)
authorTeng Li <tengli@fb.com>
Wed, 5 Dec 2018 03:20:08 +0000 (19:20 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 03:22:24 +0000 (19:22 -0800)
Summary:
When I wrote the frontend API, it is designed on not letting users use the default_group directly on any functions.  It should really be private.

All collectives are supposed to either use group.WORLD, or anything that comes out of new_group. That was the initial design.

We need to make a TODO on removing group.WORLD one day. It exists for backward compatibility reasons and adds lots of complexity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14767

Reviewed By: pietern

Differential Revision: D13330655

Pulled By: teng-li

fbshipit-source-id: ace107e1c3a9b3910a300b22815a9e8096fafb1c

docs/source/distributed.rst
test/test_distributed.py
torch/distributed/distributed_c10d.py
torch/nn/parallel/distributed.py

index 838a1a2..d8c39e1 100644 (file)
@@ -162,8 +162,6 @@ joined.
 
 .. autofunction:: is_initialized
 
-.. autofunction:: get_default_group
-
 .. autofunction:: is_mpi_available
 
 .. autofunction:: is_nccl_available
index fa4d679..8df327b 100644 (file)
@@ -265,11 +265,6 @@ class _DistTestBase(object):
 
         self._barrier()
 
-    # GET default group
-    def test_get_default_group(self):
-        default_grp = dist.get_default_group()
-        self.assertNotEqual(default_grp, None)
-
     def test_get_backend(self):
         if dist.get_world_size() > 2:
             group = [1, 2]
index 0caffca..da91cbb 100644 (file)
@@ -129,7 +129,7 @@ def _rank_not_in_group(group):
     Helper that checks if the current process's rank is not in a given group
 
     """
-    default_backend, _ = _pg_map[get_default_group()]
+    default_backend, _ = _pg_map[_get_default_group()]
     if default_backend != Backend.MPI:
         return group == GroupMember.NON_GROUP_MEMBER
     else:
@@ -249,7 +249,7 @@ def is_initialized():
     return _default_pg is not None
 
 
-def get_default_group():
+def _get_default_group():
     """
     Getting the default process group created by init_process_group
 
@@ -447,7 +447,7 @@ def destroy_process_group(group=group.WORLD):
     global _default_pg
     global _default_pg_init_method
 
-    default_backend, _ = _pg_map[get_default_group()]
+    default_backend, _ = _pg_map[_get_default_group()]
     if (default_backend != Backend.MPI and
             group == GroupMember.NON_GROUP_MEMBER):
         return
index 2664fbc..f1a1593 100644 (file)
@@ -6,6 +6,9 @@ from torch.cuda.comm import broadcast_coalesced
 from torch.cuda import nccl
 import torch.distributed as dist
 
+if dist.is_available():
+    from torch.distributed.distributed_c10d import _get_default_group
+
 from ..modules import Module
 from .replicate import replicate
 from .scatter_gather import scatter_kwargs, gather
@@ -186,7 +189,7 @@ class DistributedDataParallel(Module):
             output_device = device_ids[0]
 
         if process_group is None:
-            self.process_group = dist.get_default_group()
+            self.process_group = _get_default_group()
         else:
             self.process_group = process_group
 
@@ -308,14 +311,15 @@ class DistributedDataParallel(Module):
 
     def __setstate__(self, state):
         # If serializable, then the process group should be the default one
-        self.process_group = dist.get_default_group()
+        self.process_group = _get_default_group()
+        self.check_previous_reduction = False
         super(DistributedDataParallel, self).__setstate__(state)
         self._ddp_init_helper()
 
     def _check_default_group(self):
         pickle_not_supported = False
         try:
-            if self.process_group != dist.get_default_group():
+            if self.process_group != _get_default_group():
                 pickle_not_supported = True
         except RuntimeError:
             pickle_not_supported = True