From 0f62af4ab1883cd59f4dfd71945c9755dc138644 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Wed, 28 Nov 2018 11:32:47 -0800 Subject: [PATCH] Add timeout kwarg to init_process_group (#14435) Summary: This applies to the gloo backend only. Timeout support for the NCCL and MPI backends is tracked in issues #14371 and #14372 respectively. When creating a new process group (either the global one or any subgroup created through `new_group`) you can specify a timeout keyword argument (of type datetime.timedelta). This timeout applies to all collective operations executed against that process group, such that any operation taking longer than the timeout will throw a runtime error. Using a different, better catchable error type is tracked in #14433. This fixes #14376. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14435 Differential Revision: D13234317 Pulled By: pietern fbshipit-source-id: 973993b67994dc64861c0977cbb6f051ec9d87f6 --- test/test_distributed.py | 67 +++++++++++++++++++++++++++++++---- torch/distributed/distributed_c10d.py | 30 ++++++++++++---- 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index 2f6c2ad..8e8fbda 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -8,6 +8,7 @@ import time import tempfile import unittest from contextlib import contextmanager +from datetime import timedelta from functools import reduce, wraps import torch @@ -162,7 +163,9 @@ class Barrier(object): os.unlink(os.path.join(barrier_dir, f_name)) @classmethod - def sync(cls, timeout=5): + def sync(cls, wait_for=None, timeout=5): + if wait_for is None: + wait_for = dist.get_world_size() cls.barrier_id += 1 barrier_dir = os.path.join(TEMP_DIR, "barrier") pid = str(os.getpid()) @@ -180,7 +183,7 @@ class Barrier(object): data = f.read() if int(data) >= cls.barrier_id: arrived += 1 - if arrived == dist.get_world_size(): + if arrived == wait_for: break if time.time() - start_time > timeout: @@ -192,18 +195,18 @@ class _DistTestBase(object): def _barrier(self, *args, **kwargs): Barrier.sync(*args, **kwargs) - def _init_group_test(self): + def _init_group_test(self, **kwargs): group = [1, 2] - group_id = dist.new_group(group) + group_id = dist.new_group(group, **kwargs) rank = dist.get_rank() if rank not in group: return ([], None, rank) return (group, group_id, rank) - def _init_full_group_test(self): + def _init_full_group_test(self, **kwargs): group = [i for i in range(0, dist.get_world_size())] - group_id = dist.new_group() + group_id = dist.new_group(**kwargs) rank = dist.get_rank() return (group, group_id, rank) @@ -331,6 +334,58 @@ class _DistTestBase(object): self.assertEqual(dist.get_world_size(group_id), dist.get_world_size()) self.assertEqual(dist.get_rank(group_id), dist.get_rank()) + def _test_barrier_timeout(self, group_id, timeout): + local_rank = dist.get_rank(group_id) + + # Only execute barrier on rank == 0, causing it to timeout + if local_rank == 0: + expected_time = time.time() + timeout.total_seconds() + with self.assertRaisesRegex(RuntimeError, " (Timed out|closed) "): + dist.barrier(group_id) + self.assertGreaterEqual(time.time(), expected_time) + else: + time.sleep(timeout.total_seconds()) + + @unittest.skipIf(BACKEND != "gloo", "Only gloo backend supports timeouts") + @unittest.skipIf( + not INIT_METHOD.startswith("file://"), + "Requires file:// initialization method. " + + "Both tcp:// and env:// rely on the TCP store for which " + "reinitialization has proven racy." + ) + def test_barrier_timeout_global(self): + dist.destroy_process_group() + + # Explicitly pass world size to the barrier because we've + # just destroyed any state in torch.distributed. + self._barrier(wait_for=int(WORLD_SIZE)) + + # Reinitialize global process group + timeout = timedelta(seconds=0.2) + dist.init_process_group( + init_method=INIT_METHOD, + backend=BACKEND, + world_size=int(WORLD_SIZE), + rank=self.rank, + timeout=timeout, + ) + self._test_barrier_timeout(dist.group.WORLD, timeout) + + @skip_if_small_worldsize + @unittest.skipIf(BACKEND != "gloo", "Only gloo backend supports timeouts") + def test_barrier_timeout_group(self): + timeout = timedelta(seconds=0.2) + _, group_id, _ = self._init_group_test(timeout=timeout) + if group_id is not None: + self._test_barrier_timeout(group_id, timeout) + + @unittest.skipIf(BACKEND != "gloo", "Only gloo backend supports timeouts") + def test_barrier_timeout_full_group(self): + timeout = timedelta(seconds=0.2) + _, group_id, _ = self._init_full_group_test(timeout=timeout) + if group_id is not None: + self._test_barrier_timeout(group_id, timeout) + # SEND RECV @unittest.skipIf(BACKEND == "nccl", "Nccl does not support send/recv") def test_send_recv(self): diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index e8a8b0f..e2a1fc2 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -286,6 +286,7 @@ def get_backend(group=group.WORLD): def init_process_group(backend, init_method="env://", + timeout=_default_pg_timeout, **kwargs): """ Initializes the default distributed process group, and this will also @@ -302,6 +303,9 @@ def init_process_group(backend, world_size (int, optional): Number of processes participating in the job. rank (int, optional): Rank of the current process. + timeout (timedelta, optional): Timeout for operations executed against + the process group. Default value equals 30 minutes. + This is only applicable for the ``gloo`` backend. group_name (str, optional, deprecated): Group name. To enable ``backend == Backend.MPI``, PyTorch needs to built from source @@ -314,6 +318,10 @@ def init_process_group(backend, global _default_pg global _default_pg_init_method + if not isinstance(timeout, timedelta): + raise RuntimeError("Expected timeout argument to be of type" + "datetime.timedelta") + if _default_pg is not None: raise RuntimeError("trying to initialize the default process group " "twice!") @@ -348,7 +356,7 @@ def init_process_group(backend, store, rank, world_size, - timeout=_default_pg_timeout) + timeout=timeout) _pg_map[_default_pg] = (Backend.GLOO, store) _pg_names[_default_pg] = group_name elif backend == Backend.NCCL: @@ -367,7 +375,8 @@ def _new_process_group_helper(world_size, rank, group_ranks, in_group=True, - group_name=""): + group_name="", + timeout=_default_pg_timeout): """ Create a new distributed process group. And the new process group can be used to perform collective operations. @@ -385,6 +394,10 @@ def _new_process_group_helper(world_size, raise RuntimeError("The specified group name has already been " "created, please use a different group name") + if not isinstance(timeout, timedelta): + raise RuntimeError("Expected timeout argument to be of type" + "datetime.timedelta") + default_backend, default_store = _pg_map[_default_pg] if default_backend == Backend.MPI: @@ -402,7 +415,7 @@ def _new_process_group_helper(world_size, store, rank, world_size, - timeout=_default_pg_timeout) + timeout=timeout) _pg_map[pg] = (Backend.GLOO, store) _pg_names[pg] = group_name elif default_backend == Backend.NCCL: @@ -1162,7 +1175,7 @@ def barrier(group=group.WORLD, work.wait() -def new_group(ranks=None): +def new_group(ranks=None, timeout=_default_pg_timeout): """ Creates a new distributed group. @@ -1173,6 +1186,9 @@ def new_group(ranks=None): Arguments: ranks (list[int]): List of ranks of group members. + timeout (timedelta, optional): Timeout for operations executed against + the process group. Default value equals 30 minutes. + This is only applicable for the ``gloo`` backend. Returns: A handle of distributed group that can be given to collective calls. @@ -1214,7 +1230,8 @@ def new_group(ranks=None): pg = _new_process_group_helper(group_world_size, group_rank, input_ranks, - in_group) + in_group, + timeout=timeout) else: # Release ranks not in the group if global_rank not in ranks: @@ -1223,7 +1240,8 @@ def new_group(ranks=None): if default_backend != Backend.MPI: pg = _new_process_group_helper(group_world_size, group_rank, - input_ranks) + input_ranks, + timeout=timeout) # Create the global rank to group rank mapping _pg_group_ranks[pg] = {} -- 2.7.4