From: SsnL Date: Thu, 10 Jan 2019 16:44:32 +0000 (-0800) Subject: Fix TestDataLoader.test_proper_exit (#15665) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1927 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9b5ec2a076982c57033f2e345cee3051b55de996;p=platform%2Fupstream%2Fpytorch.git Fix TestDataLoader.test_proper_exit (#15665) Summary: Currently, in `test_proper_exit`, 1. we do not kill the correct input `pid` in the `kill_pid` function https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L325-L329 2. the Windows command that detects process status doesn't actually work https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L641-L646 3. `worker_error` and `worker_kill` cases (sometimes?) are not tested because the workers may exit naturally due to the pre-fetching mechanism and a too small `dataset size / batch size`. In this PR, I, in separate commits: 1. Install `psutil` (a python package specifically built for process monitoring) on some CI builds. (Linux builds installation are done in https://github.com/pietern/pytorch-dockerfiles/pull/29 https://github.com/pietern/pytorch-dockerfiles/pull/30 https://github.com/pytorch/ossci-job-dsl/pull/36 and https://github.com/pytorch/pytorch/pull/15795). 2. Rewrite `test_proper_exit` with `psutil` so we 1. do not rely on the hacky `is_process_alive` https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L640-L653 2. increase the #task per worker so `worker_error` and `worker_kill` properly trigger 3. test error message content to ensure that the loader exits with correct message corresponding to each exiting scenario. 3. Fix Windows data loader not having any mechanism to detect worker failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15665 Differential Revision: D13615527 Pulled By: soumith fbshipit-source-id: cfb2f67837d2d87928a53f00b4d20f09754b7949 --- diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 1ce42eb..bd4c73e 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -16,7 +16,7 @@ fi export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH" source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja six -pip install hypothesis +pip install hypothesis librosa>=0.6.2 psutil if [ -z "${IN_CIRCLECI}" ]; then rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch* fi diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index cdd2669..cb0e724 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -78,6 +78,10 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then export PYTORCH_TEST_WITH_ROCM=1 export LANG=C.UTF-8 export LC_ALL=C.UTF-8 + + # ROCm CI is using Caffe2 docker images, which doesn't have several packages + # needed in testing. We install them here. + pip install -q psutil librosa>=0.6.2 --user fi if [[ "${JOB_BASE_NAME}" == *-NO_AVX-* ]]; then diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index cde9983..c2f4c97 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -54,7 +54,7 @@ if NOT "%BUILD_ENVIRONMENT%"=="" ( :: We have to pin Python version to 3.6.7, until mkl supports Python 3.7 call conda install -y -q python=3.6.7 numpy mkl cffi pyyaml boto3 ) -pip install ninja future hypothesis +pip install ninja future hypothesis librosa>=0.6.2 psutil set WORKING_DIR=%CD% call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x86_amd64 @@ -115,7 +115,7 @@ EOL cat >ci_scripts/test_libtorch.bat < error_it + error_it = 2 + + if use_workers: + # 2 is the magical per-worker prefetch number... + # FIXME: change this after the number becomes configurable. + assert len(loader) > (error_it + 2 + 1) * num_workers it = iter(loader) if use_workers: - for i, w in enumerate(it.workers): - worker_pids[i] = w.pid + workers = it.workers def kill_pid(pid): - if IS_WINDOWS: - os.system('taskkill /PID ' + str(os.getpid()) + ' /F') - else: - os.kill(os.getpid(), signal.SIGKILL) + psutil_p = psutil.Process(pid) + psutil_p.kill() + psutil_p.wait(JOIN_TIMEOUT) + assert not psutil_p.is_running() for i, _ in enumerate(it): if i == 0: if not hold_iter_reference: del it - setup_event.set() + loader_setup_event.set() + tester_setup_event.wait() + # ensure that the workers are still alive + if use_workers: + for w in workers: + assert w.is_alive() + if worker_error_event is not None: + worker_error_event.set() + if i == error_it: - if exit_method == 'main_error': - raise RuntimeError('Error') - elif exit_method == 'main_kill': + if exit_method == 'loader_error': + raise RuntimeError('Loader error') + elif exit_method == 'loader_kill': kill_pid(os.getpid()) elif exit_method == 'worker_kill': - kill_pid(worker_pids[0]) + kill_pid(workers[0].pid) if not hold_iter_reference: # Tries to trigger the __del__ clean-up rather than the automatic @@ -637,22 +664,8 @@ class TestDataLoader(TestCase): pin_memory_thread.join(JOIN_TIMEOUT) self.assertFalse(pin_memory_thread.is_alive()) - @staticmethod - def _is_process_alive(pid, pname): - # There is a chance of a terminated child process's pid being reused by a new unrelated process, - # but since we are looping this check very frequently, we will know that the child process dies - # before the new unrelated process starts. - if IS_WINDOWS: - command = 'tasklist | find "{}" /i'.format(pid) - else: - command = 'ps -p {} -o comm='.format(pid) - p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) - (output, err) = p.communicate() - p_status = p.wait() - output = output.decode('utf-8') - return pname in output - @skipIfRocm + @unittest.skipIf(not HAS_PSUTIL, "psutil not found") def test_proper_exit(self): (r'''There might be ConnectionResetError or leaked semaphore warning ''' r'''(due to dirty process exit), but they are all safe to ignore''') @@ -660,28 +673,6 @@ class TestDataLoader(TestCase): # TODO: test the case where the pin_memory_thread triggers an # error/fatal signal. I haven't found out how to properly do that. - # Array to store the worker pids. - worker_pids = mp.Array('i', [-1 for _ in range(10)]) - - def wait_pids(pids, timeout): - r"""Wait for all process specified in pids to exit in given timeout.""" - exit_status = [False for _ in pids] - start_time = time.time() - pname = 'python' - while True: - for i in range(len(pids)): - pid = pids[i] - if not exit_status[i]: - if not TestDataLoader._is_process_alive(pid, pname): - exit_status[i] = True - if all(exit_status): - break - else: - if time.time() - start_time > timeout: - break - time.sleep(0.5) - return exit_status - for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3): # `hold_iter_reference` specifies whether we hold a reference to the # iterator. This is interesting because Python3 error traces holds a @@ -700,46 +691,73 @@ class TestDataLoader(TestCase): # - `None` means that no error happens. # In all cases, all processes should end properly. if use_workers: - exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error'] + exit_methods = [None, 'loader_error', 'loader_kill', 'worker_kill', 'worker_error'] else: - exit_methods = [None, 'main_error', 'main_kill'] + exit_methods = [None, 'loader_error', 'loader_kill'] for exit_method in exit_methods: - # clear pids array first - for i in range(len(worker_pids)): - worker_pids[i] = -1 + desc = [] + desc.append('use_workers={}'.format(use_workers)) + desc.append('pin_memory={}'.format(pin_memory)) + desc.append('hold_iter_reference={}'.format(hold_iter_reference)) + desc.append('exit_method={}'.format(exit_method)) + desc = 'test_proper_exit with ' + ', '.join(desc) # Event that the loader process uses to signal testing process # that various things are setup, including that the worker pids # are specified in `worker_pids` array. - setup_event = mp.Event() - - p = ErrorTrackingProcess(target=_test_proper_exit, - args=(use_workers, pin_memory, exit_method, - hold_iter_reference, worker_pids, setup_event)) - p.start() + loader_setup_event = mp.Event() + + # Event that this process has finished setting up, and the + # loader process can now proceed to trigger error events or + # finish normally. + tester_setup_event = mp.Event() + + loader_p = ErrorTrackingProcess(target=_test_proper_exit, + args=(use_workers, pin_memory, exit_method, + hold_iter_reference, loader_setup_event, + tester_setup_event)) + loader_p.start() + + # Wait for loader process to set everything up, e.g., starting + # workers. + loader_setup_event.wait(timeout=JOIN_TIMEOUT) + if not loader_setup_event.is_set(): + fail_msg = desc + ': loader process failed to setup with given time' + if loader_p.exception is not None: + self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception)) + elif not loader_p.is_alive(): + self.fail(fail_msg + ', and exited with code {} but no exception'.format(loader_p.exitcode)) + else: + self.fail(fail_msg + ', and is still alive.') - # Wait for loader process to set everything up, i.e., filling - # worker pids in `worker_pids`. - setup_event.wait(timeout=JOIN_TIMEOUT) - self.assertTrue(setup_event.is_set(), 'loader process setup timed out') + worker_psutil_p = psutil.Process(loader_p.pid).children() - pids = [pid for pid in worker_pids if pid > 0] + tester_setup_event.set() try: - exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT)) - if not all(exit_status): - self.fail('subprocess (pid(s) {}) not terminated'.format( - ', '.join(p for p, exited in zip(pids, exit_status) if not exited))) - p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) - self.assertFalse(p.is_alive(), 'loader process not terminated') + loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) + self.assertFalse(loader_p.is_alive(), desc + ': loader process not terminated') + _, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT)) + if len(alive) > 0: + self.fail(desc + ': worker process (pid(s) {}) not terminated'.format( + ', '.join(str(p.pid) for p in alive))) if exit_method is None: - self.assertEqual(p.exitcode, 0) + self.assertEqual(loader_p.exitcode, 0) else: - self.assertNotEqual(p.exitcode, 0) + self.assertNotEqual(loader_p.exitcode, 0) + if exit_method == 'loader_error': + self.assertIsInstance(loader_p.exception, RuntimeError, desc) + self.assertIn('Loader error', str(loader_p.exception), desc) + elif exit_method == 'worker_kill': + self.assertIsInstance(loader_p.exception, RuntimeError, desc) + self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc) + elif exit_method == 'worker_error': + self.assertIsInstance(loader_p.exception, RuntimeError, desc) + self.assertIn('Worker error', str(loader_p.exception), desc) finally: - p.terminate() + loader_p.terminate() def test_len(self): def check_len(dl, expected): diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index a8684e9..997d0cd 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -243,9 +243,18 @@ class _DataLoaderIter(object): # happen when data in queue is corrupted (e.g., due to # `cancel_join_thread` or unexpected exit). # - # For child exit, we register SIGCHLD handler on main process, - # which checks if any of the workers fail in the (Python) handler. - # See DataLoader.cpp. + # For child exit on Windows platform, we set a timeout whenever + # we get from `data_queue`, and check the workers' status on each + # timeout and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_batch()` for details + # + # For child exit on non-Windows platforms, we register a SIGCHLD + # handler (which is supported on Windows) on main process, which + # checks if any of the workers fail in the (Python) handler. This + # is more efficient and faster in detecting worker failures, + # compared to the above strategy applied for Windows. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. # # For `.get()` calls where the sender(s) is not the workers, we # guard them with timeouts, and check the status of the sender @@ -417,28 +426,71 @@ class _DataLoaderIter(object): def __len__(self): return len(self.batch_sampler) + def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `data_queue` for a given timeout. This can + # also be used as inner loop of fetching without timeout, with the + # sender status as the loop condition. + # + # This raises a RuntimeError if any worker died expectedly. This error + # comes from a SIGCHLD handler in `_utils/signal_handling.py` for + # non-Windows platforms, and comes from a manual check on errors and + # timeouts on Windows. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self.data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + if _utils.IS_WINDOWS: + # Windows doesn't have SIGCHLD handler, so at timeout and error, + # we need to manually check whether any worker has failed. + if not all(w.is_alive() for w in self.workers): + pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive()) + raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) + if isinstance(e, queue.Empty): + return (False, None) + raise + def _get_batch(self): - # In the non-timeout case, worker exit is covered by SIGCHLD handler. - # But if `pin_memory=True`, we still need account for the possibility - # that `pin_memory_thread` dies. + # Fetches data from `self.data_queue`. + # + # Worker exit is covered by the SIGCHLD handler in + # _utils/signal_handling.py for non-Windows platforms. For Windows, we + # must check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. On Windows, The `self._try_get_batch` will check workers' + # status on errors and timeouts. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. if self.timeout > 0: - try: - return self.data_queue.get(timeout=self.timeout) - except queue.Empty: + success, data = self._try_get_batch(self.timeout) + if success: + return data + else: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) elif self.pin_memory: while self.pin_memory_thread.is_alive(): - try: - return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue + success, data = self._try_get_batch() + if success: + return data else: # while condition is false, i.e., pin_memory_thread died. raise RuntimeError('Pin memory thread exited unexpectedly') # In this case, `self.data_queue` is a `queue.Queue`,. But we don't # need to call `.task_done()` because we don't use `.join()`. else: - return self.data_queue.get() + if _utils.IS_WINDOWS: + # Windows doesn't have SIGCHLD handler and relies on the check + # in `self._try_get_batch()` to detect worker failures, so we + # need to do a while loop here. + while True: + success, data = self._try_get_batch() + if success: + return data + else: + return self.data_queue.get() def __next__(self): if self.num_workers == 0: # same-process loading