Allow override of backend in dist.new_group() (#18595)
authorPieter Noordhuis <pietern@fb.com>
Thu, 4 Apr 2019 21:14:50 +0000 (14:14 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 21:23:03 +0000 (14:23 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18595

There is no need to force the backend to be the same as the global
process group, as long as the backend is "nccl" or "gloo".

Reviewed By: mrshenli

Differential Revision: D14657204

fbshipit-source-id: 868817b9f219e3be8db0761a487f0027ed46663b

test/test_distributed.py
torch/distributed/distributed_c10d.py

index ffcdacb..7143f37 100644 (file)
@@ -141,6 +141,58 @@ def skip_if_small_worldsize(func):
     return wrapper
 
 
+def require_backend(backends):
+    if BACKEND not in backends:
+        return unittest.skip("Test requires backend to be one of %s" % backends)
+    return lambda func: func
+
+
+def require_backends_available(backends):
+    def check(backend):
+        if backend == dist.Backend.GLOO:
+            return dist.is_gloo_available()
+        if backend == dist.Backend.NCCL:
+            return dist.is_nccl_available()
+        if backend == dist.Backend.MPI:
+            return dist.is_mpi_available()
+        return False
+    backends = map(lambda b: dist.Backend(b), backends)
+    if not all(map(check, backends)):
+        return unittest.skip(
+            "Test requires backends to be available %s" % backends)
+    return lambda func: func
+
+
+def require_world_size(world_size):
+    if int(os.environ["WORLD_SIZE"]) < world_size:
+        return unittest.skip("Test requires world size of %d" % world_size)
+    return lambda func: func
+
+
+def require_num_gpus(n):
+    """
+    Require environment to have access to at least `n` GPUs.
+    Test is skipped otherwise.
+
+    Note: this check cannot run in the parent process, because calling
+    `torch.cuda.is_initialized()` will cause lazy initialization of a
+    CUDA runtime API context, and CUDA doesn't support forking.
+    """
+    def decorator(func):
+        func.skip_if_no_gpu = True
+
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if not torch.cuda.is_available():
+                sys.exit(SKIP_IF_NO_CUDA_EXIT_CODE)
+            if torch.cuda.device_count() < n:
+                sys.exit(SKIP_IF_NO_GPU_EXIT_CODE)
+            return func(*args, **kwargs)
+        return wrapper
+
+    return decorator
+
+
 def apply_hack_for_nccl():
     # This is a hack for a known NCCL issue using multiprocess
     # in conjunction with multiple threads to manage different GPUs which
@@ -398,6 +450,49 @@ class _DistTestBase(object):
         if group_id is not None:
             self._test_barrier_timeout(group_id, timeout)
 
+    # This test helper can only be used when using the Gloo or NCCL backend
+    # **and** both the Gloo and NCCL backends are available.
+    # See the @skip annotations below.
+    def _test_group_override_backend(self, initializer):
+        if BACKEND == "gloo":
+            new_backend = "nccl"
+        if BACKEND == "nccl":
+            new_backend = "gloo"
+
+        group, group_id, rank = initializer(backend=new_backend)
+        if group_id is None:
+            return
+
+        if new_backend == "gloo":
+            self.assertTrue(isinstance(group_id, dist.ProcessGroupGloo))
+        if new_backend == "nccl":
+            self.assertTrue(isinstance(group_id, dist.ProcessGroupNCCL))
+
+        self.assertEqual(rank, group[dist.get_rank(group_id)])
+        self.assertEqual(len(group), dist.get_world_size(group_id))
+
+        # Pin device (so we avoid NCCL race conditions/deadlocks).
+        group_rank = dist.get_rank(group_id)
+        torch.cuda.set_device(group_rank)
+
+        # Run broadcast of CUDA tensor (so it works for both Gloo and NCCL).
+        tensor = _build_tensor(2, value=group_rank).cuda()
+        dist.broadcast(tensor, src=group[0], group=group_id)
+        self.assertEqual(_build_tensor(2, value=0), tensor.to("cpu"))
+
+    @require_backend({"gloo", "nccl"})
+    @require_backends_available({"gloo", "nccl"})
+    @require_world_size(3)
+    @require_num_gpus(2)
+    def test_backend_group(self):
+        self._test_group_override_backend(self._init_group_test)
+
+    @require_backend({"gloo", "nccl"})
+    @require_backends_available({"gloo", "nccl"})
+    @require_num_gpus(3)
+    def test_backend_full_group(self):
+        self._test_group_override_backend(self._init_full_group_test)
+
     # SEND RECV
     @unittest.skipIf(BACKEND == "nccl", "Nccl does not support send/recv")
     def test_send_recv(self):
@@ -1456,7 +1551,8 @@ if BACKEND == "gloo" or BACKEND == "nccl":
             for attr in dir(cls):
                 if attr.startswith("test"):
                     fn = getattr(cls, attr)
-                    setattr(cls, attr, cls.manager_join(fn))
+                    if not getattr(fn, "__unittest_skip__", False):
+                        setattr(cls, attr, cls.manager_join(fn))
 
         def setUp(self):
             super(TestDistBackend, self).setUp()
index ffa989b..ddd3554 100644 (file)
@@ -11,11 +11,11 @@ from . import BroadcastOptions, AllreduceOptions, ReduceOptions, \
     ScatterOptions, GatherOptions
 from . import ReduceOp
 from . import PrefixStore
-from . import ProcessGroupGloo
 
 
 _MPI_AVAILABLE = True
 _NCCL_AVAILABLE = True
+_GLOO_AVAILABLE = True
 
 
 try:
@@ -28,6 +28,11 @@ try:
 except ImportError:
     _NCCL_AVAILABLE = False
 
+try:
+    from. import ProcessGroupGloo
+except ImportError:
+    _GLOO_AVAILABLE = False
+
 
 class Backend(object):
     """
@@ -230,7 +235,7 @@ def _check_tensor_list(param, param_name):
 
 def is_mpi_available():
     """
-    Checks if MPI is available
+    Checks if the MPI backend is available.
 
     """
     return _MPI_AVAILABLE
@@ -238,12 +243,20 @@ def is_mpi_available():
 
 def is_nccl_available():
     """
-    Checks if NCCL is available
+    Checks if the NCCL backend is available.
 
     """
     return _NCCL_AVAILABLE
 
 
+def is_gloo_available():
+    """
+    Checks if the Gloo backend is available.
+
+    """
+    return _GLOO_AVAILABLE
+
+
 def is_initialized():
     """
     Checking if the default process group has been initialized
@@ -390,7 +403,8 @@ def _new_process_group_helper(world_size,
                               group_ranks,
                               in_group,
                               group_name,
-                              timeout=_default_pg_timeout):
+                              timeout=_default_pg_timeout,
+                              backend=None):
     """
     Create a new distributed process group. And the new process group can be
     used to perform collective operations.
@@ -413,8 +427,12 @@ def _new_process_group_helper(world_size,
                            "datetime.timedelta")
 
     default_backend, default_store = _pg_map[_default_pg]
+    if backend is None:
+        backend = default_backend
+    else:
+        backend = Backend(backend)
 
-    if default_backend == Backend.MPI:
+    if backend == Backend.MPI:
         if not is_mpi_available():
             raise RuntimeError("Distributed package doesn't have MPI built in")
         pg = ProcessGroupMPI(group_ranks)
@@ -424,7 +442,7 @@ def _new_process_group_helper(world_size,
         # Create the prefix store
         store = PrefixStore(group_name, default_store)
 
-        if default_backend == Backend.GLOO:
+        if backend == Backend.GLOO:
             pg = ProcessGroupGloo(
                 store,
                 rank,
@@ -432,7 +450,7 @@ def _new_process_group_helper(world_size,
                 timeout=timeout)
             _pg_map[pg] = (Backend.GLOO, store)
             _pg_names[pg] = group_name
-        elif default_backend == Backend.NCCL:
+        elif backend == Backend.NCCL:
             if not is_nccl_available():
                 raise RuntimeError("Distributed package doesn't have NCCL "
                                    "built in")
@@ -1197,7 +1215,7 @@ def barrier(group=group.WORLD,
         work.wait()
 
 
-def new_group(ranks=None, timeout=_default_pg_timeout):
+def new_group(ranks=None, timeout=_default_pg_timeout, backend=None):
     """
     Creates a new distributed group.
 
@@ -1211,6 +1229,12 @@ def new_group(ranks=None, timeout=_default_pg_timeout):
         timeout (timedelta, optional): Timeout for operations executed against
             the process group. Default value equals 30 minutes.
             This is only applicable for the ``gloo`` backend.
+        backend (str or Backend, optional): The backend to use. Depending on
+            build-time configurations, valid values are ``gloo`` and ``nccl``.
+            By default uses the same backend as the global group. This field
+            should be given as a lowercase string (e.g., ``"gloo"``), which can
+            also be accessed via :class:`Backend` attributes (e.g.,
+            ``Backend.GLOO``).
 
     Returns:
         A handle of distributed group that can be given to collective calls.
@@ -1270,12 +1294,15 @@ def new_group(ranks=None, timeout=_default_pg_timeout):
             return GroupMember.NON_GROUP_MEMBER
 
         if default_backend != Backend.MPI:
+            if backend is None:
+                backend = default_backend
             pg = _new_process_group_helper(group_world_size,
                                            group_rank,
                                            input_ranks,
                                            True,
                                            group_name,
-                                           timeout=timeout)
+                                           timeout=timeout,
+                                           backend=backend)
 
     # Create the global rank to group rank mapping
     _pg_group_ranks[pg] = {}