From: Kutta Srinivasan Date: Thu, 18 Apr 2019 16:31:03 +0000 (-0700) Subject: Cleanup init_process_group (#19033) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~161 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b7323a94ad02f4de5cdf66cb82284ba43d27e517;p=platform%2Fupstream%2Fpytorch.git Cleanup init_process_group (#19033) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19033 torch.distributed.init_process_group() has had many parameters added, but the contract isn't clear. Adding documentation, asserts, and explicit args should make this clearer to callers and more strictly enforced. Reviewed By: mrshenli Differential Revision: D14813070 fbshipit-source-id: 80e4e7123087745bed436eb390887db9d1876042 --- diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d08e1f9..5f01181 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -306,12 +306,23 @@ def get_backend(group=group.WORLD): def init_process_group(backend, - init_method="env://", + init_method=None, timeout=_default_pg_timeout, - **kwargs): + world_size=-1, + rank=-1, + store=None, + group_name=''): """ Initializes the default distributed process group, and this will also - initialize the distributed package + initialize the distributed package. + + There are 2 main ways to initialize a process group: + 1. Specify ``store``, ``rank``, and ``world_size`` explicitly. + 2. Specify ``init_method`` (a URL string) which indicates where/how + to discover peers. Optionally specify ``rank`` and ``world_size``, + or encode all required parameters in the URL and omit them. + If neither is specified, ``init_method`` is assumed to be "env://". + Arguments: backend (str or Backend): The backend to use. Depending on @@ -323,12 +334,16 @@ def init_process_group(backend, must have exclusive access to every GPU it uses, as sharing GPUs between processes can result in deadlocks. init_method (str, optional): URL specifying how to initialize the - process group. + process group. Default is "env://" if no + ``init_method`` or ``store`` is specified. + Mutually exclusive with ``store``. world_size (int, optional): Number of processes participating in - the job. + the job. Required if ``store`` is specified. rank (int, optional): Rank of the current process. - store(Store, optional): Rendevous key/value store as an alternative - to other init methods. + Required if ``store`` is specified. + store(Store, optional): Key/value store accessible to all workers, used + to exchange connection/address information. + Mutually exclusive with ``init_method``. timeout (timedelta, optional): Timeout for operations executed against the process group. Default value equals 30 minutes. This is only applicable for the ``gloo`` backend. @@ -351,15 +366,14 @@ def init_process_group(backend, raise RuntimeError("trying to initialize the default process group " "twice!") - world_size = kwargs.pop('world_size', -1) - group_name = kwargs.pop('group_name', '') - rank = kwargs.pop('rank', -1) - store = kwargs.pop('store', None) + assert (store is None) or (init_method is None), \ + "Cannot specify both init_method and store." + if store is not None: - assert world_size > 0, 'world_size needs to be positive' - assert rank >= 0, 'rank needs to be non-negative' - assert len(kwargs) == 0, \ - "got unexpected keyword arguments: %s" % ",".join(kwargs.keys()) + assert world_size > 0, 'world_size must be positive if using store' + assert rank >= 0, 'rank must be non-negative if using store' + elif init_method is None: + init_method = "env://" backend = Backend(backend) @@ -374,15 +388,15 @@ def init_process_group(backend, timeout=timeout) else: # backward compatible API - url = init_method - if world_size != -1 and rank != -1: - url += "?rank={}&world_size={}".format(rank, world_size) - elif rank != -1: - url += "?rank={}".format(rank) - elif world_size != -1: - url += "?world_size={}".format(world_size) - if store is None: + url = init_method + if world_size != -1 and rank != -1: + url += "?rank={}&world_size={}".format(rank, world_size) + elif rank != -1: + url += "?rank={}".format(rank) + elif world_size != -1: + url += "?world_size={}".format(world_size) + store, rank, world_size = next(rendezvous(url)) store.set_timeout(timeout)