Cleanup init_process_group (#19033)
authorKutta Srinivasan <kutta@fb.com>
Thu, 18 Apr 2019 16:31:03 +0000 (09:31 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 16:37:38 +0000 (09:37 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19033

torch.distributed.init_process_group() has had many parameters added, but the contract isn't clear. Adding documentation, asserts, and explicit args should make this clearer to callers and more strictly enforced.

Reviewed By: mrshenli

Differential Revision: D14813070

fbshipit-source-id: 80e4e7123087745bed436eb390887db9d1876042

torch/distributed/distributed_c10d.py

index d08e1f9..5f01181 100644 (file)
@@ -306,12 +306,23 @@ def get_backend(group=group.WORLD):
 
 
 def init_process_group(backend,
-                       init_method="env://",
+                       init_method=None,
                        timeout=_default_pg_timeout,
-                       **kwargs):
+                       world_size=-1,
+                       rank=-1,
+                       store=None,
+                       group_name=''):
     """
     Initializes the default distributed process group, and this will also
-    initialize the distributed package
+    initialize the distributed package.
+
+    There are 2 main ways to initialize a process group:
+        1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
+        2. Specify ``init_method`` (a URL string) which indicates where/how
+           to discover peers. Optionally specify ``rank`` and ``world_size``,
+           or encode all required parameters in the URL and omit them.
+        If neither is specified, ``init_method`` is assumed to be "env://".
+
 
     Arguments:
         backend (str or Backend): The backend to use. Depending on
@@ -323,12 +334,16 @@ def init_process_group(backend,
             must have exclusive access to every GPU it uses, as sharing GPUs
             between processes can result in deadlocks.
         init_method (str, optional): URL specifying how to initialize the
-                                     process group.
+                                     process group. Default is "env://" if no
+                                     ``init_method`` or ``store`` is specified.
+                                     Mutually exclusive with ``store``.
         world_size (int, optional): Number of processes participating in
-                                    the job.
+                                    the job. Required if ``store`` is specified.
         rank (int, optional): Rank of the current process.
-        store(Store, optional): Rendevous key/value store as an alternative
-                                to other init methods.
+                              Required if ``store`` is specified.
+        store(Store, optional): Key/value store accessible to all workers, used
+                                to exchange connection/address information.
+                                Mutually exclusive with ``init_method``.
         timeout (timedelta, optional): Timeout for operations executed against
             the process group. Default value equals 30 minutes.
             This is only applicable for the ``gloo`` backend.
@@ -351,15 +366,14 @@ def init_process_group(backend,
         raise RuntimeError("trying to initialize the default process group "
                            "twice!")
 
-    world_size = kwargs.pop('world_size', -1)
-    group_name = kwargs.pop('group_name', '')
-    rank = kwargs.pop('rank', -1)
-    store = kwargs.pop('store', None)
+    assert (store is None) or (init_method is None), \
+        "Cannot specify both init_method and store."
+
     if store is not None:
-        assert world_size > 0, 'world_size needs to be positive'
-        assert rank >= 0, 'rank needs to be non-negative'
-    assert len(kwargs) == 0, \
-        "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
+        assert world_size > 0, 'world_size must be positive if using store'
+        assert rank >= 0, 'rank must be non-negative if using store'
+    elif init_method is None:
+        init_method = "env://"
 
     backend = Backend(backend)
 
@@ -374,15 +388,15 @@ def init_process_group(backend,
             timeout=timeout)
     else:
         # backward compatible API
-        url = init_method
-        if world_size != -1 and rank != -1:
-            url += "?rank={}&world_size={}".format(rank, world_size)
-        elif rank != -1:
-            url += "?rank={}".format(rank)
-        elif world_size != -1:
-            url += "?world_size={}".format(world_size)
-
         if store is None:
+            url = init_method
+            if world_size != -1 and rank != -1:
+                url += "?rank={}&world_size={}".format(rank, world_size)
+            elif rank != -1:
+                url += "?rank={}".format(rank)
+            elif world_size != -1:
+                url += "?world_size={}".format(world_size)
+
             store, rank, world_size = next(rendezvous(url))
             store.set_timeout(timeout)