From 9d95d485679392774532d4c79a73b9c11b665e1b Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Wed, 25 Aug 2021 22:56:33 -0700 Subject: [PATCH] (torch.distributed) Add torch.distributed.is_torchelastic_launched() util method + make init_method=tcp:// compatible with torchelastic (#63910) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63910 Addresses the current issue that `init_method=tcp://` is not compatible with `torch.distributed.run` and `torch.distributed.launch`. When running with a training script that initializes the process group with `init_method=tcp://localhost:$port` as such: ``` $ python -u -m torch.distributed.run --max_restarts 0 --nproc_per_node 1 --nnodes 1 --master_addr $(hostname) --master_port 6000 ~/tmp/test.py ``` An `Address in use` error is raised since the training script tries to create a TCPStore on port 6000, which is already taken since the elastic agent is already running a TCPStore on that port. For details see: https://github.com/pytorch/pytorch/issues/63874. This change does a couple of things: 1. Adds `is_torchelastic_launched()` check function that users can use in the training scripts to see whether the script is launched via torchelastic. 1. Update the `torch.distributed` docs page to include the new `is_torchelastic_launched()` function. 1. Makes `init_method=tcp://` torchelastic compatible by modifying `_tcp_rendezvous_handler` in `torch.distributed.rendezvous` (this is NOT the elastic rendezvous, it is the old rendezvous module which is slotted for deprecation in future releases) to check `is_torchelastic_launched()` AND `torchelastic_use_agent_store()` and if so, only create TCPStore clients (no daemons, not even for rank 0). 1. Adds a bunch of unittests to cover the different code paths NOTE: the issue mentions that we should fail-fast with an assertion on `init_method!=env://` when `is_torchelastic_launched()` is `True`. There are three registered init_methods in pytorch: env://, tcp://, file://. Since this diff makes tcp:// compatible with torchelastic and I've validated that file is compatible with torchelastic. There is no need to add assertions. I did update the docs to point out that env:// is the RECOMMENDED init_method. We should probably deprecate the other init_methods in the future but this is out of scope for this issue. Test Plan: Unittests. Reviewed By: cbalioglu Differential Revision: D30529984 fbshipit-source-id: 267aea6d4dad73eb14a2680ac921f210ff547cc5 --- docs/source/distributed.rst | 2 + .../launcher/bin/test_script_init_method.py | 76 +++++++++++++ .../bin/test_script_is_torchelastic_launched.py | 42 ++++++++ test/distributed/launcher/run_test.py | 117 +++++++++++++++++++++ test/distributed/test_launcher.py | 6 +- torch/_C/_distributed_c10d.pyi | 3 +- torch/distributed/distributed_c10d.py | 46 +++++--- torch/distributed/launch.py | 10 +- torch/distributed/rendezvous.py | 93 ++++++++++------ torch/distributed/run.py | 1 + 10 files changed, 342 insertions(+), 54 deletions(-) create mode 100755 test/distributed/launcher/bin/test_script_init_method.py create mode 100755 test/distributed/launcher/bin/test_script_is_torchelastic_launched.py diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 0f4e051..c5cd727 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -180,6 +180,8 @@ joined. .. autofunction:: is_nccl_available +.. autofunction:: is_torchelastic_launched + -------------------------------------------------------------------------------- Currently three initialization methods are supported: diff --git a/test/distributed/launcher/bin/test_script_init_method.py b/test/distributed/launcher/bin/test_script_init_method.py new file mode 100755 index 0000000..299839c --- /dev/null +++ b/test/distributed/launcher/bin/test_script_init_method.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +def parse_args(): + parser = argparse.ArgumentParser(description="test script") + + parser.add_argument( + "--init_method", + type=str, + required=True, + help="init_method to pass to `dist.init_process_group()` (e.g. env://)", + ) + parser.add_argument( + "--world_size", + type=int, + default=os.getenv("WORLD_SIZE", -1), + help="world_size to pass to `dist.init_process_group()`", + ) + parser.add_argument( + "--rank", + type=int, + default=os.getenv("RANK", -1), + help="rank to pass to `dist.init_process_group()`", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + dist.init_process_group( + backend="gloo", + init_method=args.init_method, + world_size=args.world_size, + rank=args.rank, + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + # one hot (by rank) tensor of size world_size + # example: + # rank 0, world_size 4 => [1, 0, 0, 0] + # rank 1, world_size 4 => [0, 1, 0, 0] + # ... + t = F.one_hot(torch.tensor(rank), num_classes=world_size) + + # after all_reduce t = tensor.ones(size=world_size) + dist.all_reduce(t) + + # adding all elements in t should equal world_size + derived_world_size = torch.sum(t).item() + if derived_world_size != world_size: + raise RuntimeError( + f"Wrong world size derived. Expected: {world_size}, Got: {derived_world_size}" + ) + + print("Done") + + +if __name__ == "__main__": + main() diff --git a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py new file mode 100755 index 0000000..fa9729c --- /dev/null +++ b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This is a test script that launches as part of the test cases in +run_test.py, to validate the correctness of +the method ``torch.distributed.is_torchelastic_launched()``. To do so, +we run this script with and without torchelastic and validate that the +boolean value written to the out_file is indeed what we expect (e.g. +should be False when not launched with torchelastic, True when launched with) +The script itself is not a test case hence no assertions are made in this script. + +see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched() + - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched() +""" +import argparse + +import torch.distributed as dist + + +def parse_args(): + parser = argparse.ArgumentParser(description="test script") + parser.add_argument( + "--out_file", + help="file to write indicating whether this script was launched with torchelastic", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + with open(args.out_file, "w") as out: + out.write(f"{dist.is_torchelastic_launched()}") + + +if __name__ == "__main__": + main() diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/run_test.py index 079fea7..4ed824c 100644 --- a/test/distributed/launcher/run_test.py +++ b/test/distributed/launcher/run_test.py @@ -7,8 +7,10 @@ # LICENSE file in the root directory of this source tree. import multiprocessing as mp import os +import runpy import shutil import subprocess +import sys import tempfile import unittest import uuid @@ -21,6 +23,7 @@ from torch.distributed.elastic.agent.server.api import RunResult, WorkerState from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer from torch.distributed.elastic.utils import get_socket_with_port +from torch.distributed.elastic.utils.distributed import get_free_port from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, sandcastle_skip_if, @@ -475,3 +478,117 @@ class ElasticLaunchTest(unittest.TestCase): param_mock.return_value = rdzv_handler_mock launch.main(args) rdzv_handler_mock.shutdown.assert_called_once() + + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_is_torchelastic_launched(self): + # launch test script with torchelastic and validate that + # torch.distributed.is_torchelastic_launched() returns True + + out_file = f"{os.path.join(self.test_dir, 'out')}" + + launch.main( + [ + "--run_path", + "--nnodes=1", + "--nproc_per_node=1", + "--monitor_interval=1", + path("bin/test_script_is_torchelastic_launched.py"), + f"--out_file={out_file}", + ] + ) + + with open(out_file, "r") as fp: + is_torchelastic_launched = fp.readline() + self.assertEqual("True", is_torchelastic_launched) + + def test_is_not_torchelastic_launched(self): + # launch test script without torchelastic and validate that + # torch.distributed.is_torchelastic_launched() returns False + + out_file = f"{os.path.join(self.test_dir, 'out')}" + + # need to run the script with runpy in the same interpreter + # as the test because otherwise (depending on the environment) + # it will not find torch as a dependency + with patch.object( + sys, + "argv", + [ + path("bin/test_script_is_torchelastic_launched.py"), + f"--out_file={out_file}", + ], + ): + runpy.run_path(sys.argv[0], run_name="__main__") + with open(out_file, "r") as fp: + is_torchelastic_launched = fp.readline() + self.assertEqual("False", is_torchelastic_launched) + + def test_init_method_tcp(self): + port = get_free_port() + with patch.object( + sys, + "argv", + [ + path("bin/test_script_init_method.py"), + f"--init_method=tcp://localhost:{port}", + "--rank=0", + "--world_size=1", + ], + ): + runpy.run_path(sys.argv[0], run_name="__main__") + # nothing to validate, just make sure it runs + + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_init_method_tcp_with_torchelastic(self): + port = get_free_port() + launch.main( + [ + "--run_path", + "--nnodes=1", + "--nproc_per_node=4", + "--master_addr=localhost", + f"--master_port={port}", + "--monitor_interval=1", + path("bin/test_script_init_method.py"), + f"--init_method=tcp://localhost:{port}", + ] + ) + # nothing to validate, just make sure it runs + + def test_init_method_env(self): + port = get_free_port() + with patch.dict( + os.environ, + { + "RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(port), + }, + ), patch.object( + sys, + "argv", + [ + path("bin/test_script_init_method.py"), + "--init_method=env://", + ], + ): + runpy.run_path(sys.argv[0], run_name="__main__") + # nothing to validate, just make sure it runs + + @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_init_method_env_with_torchelastic(self): + port = get_free_port() + launch.main( + [ + "--run_path", + "--nnodes=1", + "--nproc_per_node=4", + "--master_addr=localhost", + f"--master_port={port}", + "--monitor_interval=1", + path("bin/test_script_init_method.py"), + "--init_method=env://", + ] + ) + # nothing to validate, just make sure it runs diff --git a/test/distributed/test_launcher.py b/test/distributed/test_launcher.py index 4565a26..422c88b 100644 --- a/test/distributed/test_launcher.py +++ b/test/distributed/test_launcher.py @@ -20,10 +20,14 @@ from torch.testing._internal.common_utils import ( def path(script): return os.path.join(os.path.dirname(__file__), script) + if TEST_WITH_DEV_DBG_ASAN: - print("Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr) + print( + "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr + ) sys.exit(0) + class TestDistributedLaunch(TestCase): def test_launch_user_script(self): nnodes = 1 diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index cfa9c7c..50e7602 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -134,7 +134,8 @@ class TCPStore(Store): world_size: int = ..., is_master: bool = ..., timeout: timedelta = ..., - wait_for_workers: bool = ... + wait_for_workers: bool = ..., + multi_tenant: bool = ... ): ... class PrefixStore(Store): diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 1b1244d..fac096e 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1,6 +1,7 @@ import contextlib import io import logging +import os import pickle import time import warnings @@ -9,28 +10,31 @@ from typing import Dict, Optional, Tuple, Union import torch from torch._C._distributed_c10d import ( - AllreduceOptions, AllreduceCoalescedOptions, + AllreduceOptions, AllToAllOptions, BarrierOptions, BroadcastOptions, GatherOptions, PrefixStore, ProcessGroup, - ReduceOptions, ReduceOp, + ReduceOptions, ReduceScatterOptions, ScatterOptions, Store, + _DistributedDebugLevel, + _get_debug_mode, ) -from torch._C._distributed_c10d import _get_debug_mode, _DistributedDebugLevel from torch._six import string_classes +from .constants import default_pg_timeout +from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 + + # This module is wildcard imported from torch.distributed. # TODO: specify __all__ -from .constants import default_pg_timeout -from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401 _MPI_AVAILABLE = True _NCCL_AVAILABLE = True @@ -244,7 +248,9 @@ def _store_based_barrier(rank, store, timeout): ) ) - logger.info(f"Rank {rank}: Completed store-based barrier for key:{store_key} with {world_size} nodes.") + logger.info( + f"Rank {rank}: Completed store-based barrier for key:{store_key} with {world_size} nodes." + ) def _rank_not_in_group(group: ProcessGroup): @@ -384,6 +390,18 @@ def is_initialized(): return GroupMember.WORLD is not None +def is_torchelastic_launched(): + """ + Checks whether this process was launched with ``torch.distributed.elastic`` + (aka torchelastic). The existence of ``TORCHELASTIC_RUN_ID`` environment + variable is used as a proxy to determine whether the current process + was launched with torchelastic. This is a reasonable proxy since + ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a + non-null value indicating the job id for peer discovery purposes.. + """ + return os.getenv("TORCHELASTIC_RUN_ID") is not None + + def _get_default_group(): """ Getting the default process group created by init_process_group @@ -1778,8 +1796,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None): is_nccl_backend = group_backend == Backend.NCCL current_device = None if device is not None: - if is_nccl_backend and device.type != 'cuda': - raise ValueError('device type must be cuda for nccl backend') + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") current_device = device else: current_device = torch.device("cpu") @@ -2229,7 +2247,9 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): if _rank_not_in_group(group): return - scatter_list = [t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list] + scatter_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list + ] tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) my_rank = get_rank() @@ -3026,9 +3046,7 @@ def new_subgroups( if rank in ranks_in_subgroup: cur_subgroup = subgroup logger.info( - "Rank {} is assigned to subgroup {}".format( - rank, ranks_in_subgroup - ) + "Rank {} is assigned to subgroup {}".format(rank, ranks_in_subgroup) ) return cur_subgroup, subgroups @@ -3139,8 +3157,6 @@ def new_subgroups_by_enumeration( rank_to_ranks_dict[rank] = ranks if my_rank == rank: cur_subgroup = subgroup - logging.info( - "Rank {} is assigned to subgroup {}".format(rank, ranks) - ) + logging.info("Rank {} is assigned to subgroup {}".format(rank, ranks)) return cur_subgroup, subgroups diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index 5fcb3eb..4f29edd 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -97,9 +97,9 @@ or >>> # your code to run 3. In your training program, you are supposed to call the following function -at the beginning to start the distributed backend. You need to make sure that -the init_method uses ``env://``, which is the only supported ``init_method`` -by this module. +at the beginning to start the distributed backend. It is strongly recommended +that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, +but ``env://`` is the one that is officially supported by this module. :: @@ -147,6 +147,7 @@ import warnings from torch.distributed.run import get_args_parser, run + logger = logging.getLogger(__name__) @@ -181,7 +182,8 @@ def main(args=None): "If your script expects `--local_rank` argument to be set, please\n" "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" - "further instructions\n", FutureWarning + "further instructions\n", + FutureWarning, ) args = parse_args(args) launch(args) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 6a5b680..6e430e2 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -1,17 +1,22 @@ try: from urllib.parse import urlparse, urlunparse except ImportError: - raise ImportError("urllib cannot be found, urlparse from python2 is no longer supported.") + raise ImportError( + "urllib cannot be found, urlparse from python2 is no longer supported." + ) -import torch._six as six import numbers import os import sys from datetime import timedelta -from typing import Optional, Dict, Union -from torch.distributed import FileStore, TCPStore, PrefixStore +from typing import Dict, Optional, Union + +import torch._six as six +from torch.distributed import FileStore, PrefixStore, Store, TCPStore + from .constants import default_pg_timeout + _rendezvous_handlers = {} @@ -73,7 +78,9 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): query_dict["world_size"] = world_size result = result._replace( - query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()])) + query="{}".format( + "&".join(["{}={}".format(k, v) for k, v in query_dict.items()]) + ) ) url = urlunparse(result) @@ -92,8 +99,9 @@ def _file_rendezvous_handler(url: str, **kwargs): result = urlparse(url) path = result.path - if sys.platform == 'win32': + if sys.platform == "win32": import urllib.request + full_path = result.netloc + result.path path = urllib.request.url2pathname(full_path) if path: @@ -119,7 +127,41 @@ def _file_rendezvous_handler(url: str, **kwargs): raise RuntimeError("Unable to perform rerendezvous using file:// method") -def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): +def _torchelastic_use_agent_store() -> bool: + return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) + + +def _create_c10d_store(hostname, port, rank, world_size, timeout) -> Store: + """ + Smartly creates a c10d Store object on ``rank`` based on whether + we need to re-use agent store. The TCPStore server is assumed to be hosted + on ``hostname:port``. + + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that + the agent leader (node rank 0) hosts the TCPStore server (for which the + endpoint is specified by the given ``hostname:port``). Hence + ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). + + If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host + the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname + and port are correctly passed via ``hostname`` and ``port``. All + non-zero ranks will create and return a TCPStore client. + """ + + if _torchelastic_use_agent_store(): + attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] + tcp_store = TCPStore(hostname, port, world_size, False, timeout) + return PrefixStore(f"/worker/attempt_{attempt}", tcp_store) + else: + start_daemon = rank == 0 + return TCPStore( + hostname, port, world_size, start_daemon, timeout, multi_tenant=True + ) + + +def _tcp_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): def _error(msg): return _rendezvous_error("tcp:// rendezvous: " + msg) @@ -136,18 +178,19 @@ def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, * rank = int(query["rank"]) world_size = int(query["world_size"]) - start_daemon = rank == 0 assert result.hostname is not None - store = TCPStore( # type: ignore[call-arg] - result.hostname, result.port, world_size, start_daemon, timeout, multi_tenant=True - ) + + store = _create_c10d_store(result.hostname, result.port, rank, world_size, timeout) + yield (store, rank, world_size) # If this configuration is invalidated, there is nothing we can do about it - raise RuntimeError("Unable to perform rerendezvous using tcp:// method") + raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") -def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs): +def _env_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): def _error(msg): return _rendezvous_error("env:// rendezvous: " + msg) @@ -183,29 +226,13 @@ def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, * master_addr = _get_env_or_raise("MASTER_ADDR") master_port = int(_get_env_or_raise("MASTER_PORT")) + store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout) - use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) - - if use_torchelastic_store == str(True): - attempt = os.environ["TORCHELASTIC_RESTART_COUNT"] - worker_process_prefix = f"/worker/attempt_{attempt}" - # When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed - # to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread - # on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False - tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout) - # Each if-else condition returns due to: https://github.com/python/mypy/issues/1191 - yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size) - else: - # Start the TCP store daemon on the rank 0 - start_daemon = rank == 0 - store = TCPStore( # type: ignore[call-arg] - master_addr, master_port, world_size, start_daemon, timeout, multi_tenant=True - ) - # Each if-else condition returns due to: https://github.com/python/mypy/issues/1191 - yield (store, rank, world_size) + yield (store, rank, world_size) # If this configuration is invalidated, there is nothing we can do about it - raise RuntimeError("Unable to perform rerendezvous using env:// method") + raise RuntimeError("Unable to perform re-rendezvous using env:// method") + register_rendezvous_handler("tcp", _tcp_rendezvous_handler) register_rendezvous_handler("env", _env_rendezvous_handler) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index f21fc4e..d4428a0 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -321,6 +321,7 @@ from torch.distributed.elastic.utils import macros from torch.distributed.elastic.utils.logging import get_logger from torch.distributed.launcher.api import LaunchConfig, elastic_launch + log = get_logger() -- 2.7.4