import unittest
import subprocess
import itertools
+import warnings
from torch import multiprocessing as mp
from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN,
skipIfRocm, load_tests)
+try:
+ import psutil
+ HAS_PSUTIL = True
+except ImportError:
+ HAS_PSUTIL = False
+ warnings.warn(
+ "psutil not found. Some crucial data loader tests relying on it (e.g., "
+ "TestDataLoader.test_proper_exit) will not run.")
+
+
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
mp = mp.get_context(method='spawn')
-JOIN_TIMEOUT = 17.0 if IS_WINDOWS or IS_PPC else 8.5
+JOIN_TIMEOUT = 17.0 if (IS_WINDOWS or IS_PPC) else 11.0
class TestDatasetRandomSplit(TestCase):
# See TestDataLoader.test_proper_exit for usage
def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
- worker_pids, setup_event):
+ loader_setup_event, tester_setup_event):
num_workers = 2 if use_workers else 0
if exit_method == 'worker_error' or exit_method == 'worker_kill':
assert use_workers is True
- ds = TestProperExitDataset(10, setup_event if exit_method == 'worker_error' else None)
+ if exit_method == 'worker_error':
+ worker_error_event = mp.Event()
+ else:
+ worker_error_event = None
- loader = DataLoader(ds, batch_size=2, shuffle=False,
+ ds = TestProperExitDataset(12, worker_error_event)
+
+ loader = DataLoader(ds, batch_size=1, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory)
- error_it = 4
- assert len(loader) > error_it
+ error_it = 2
+
+ if use_workers:
+ # 2 is the magical per-worker prefetch number...
+ # FIXME: change this after the number becomes configurable.
+ assert len(loader) > (error_it + 2 + 1) * num_workers
it = iter(loader)
if use_workers:
- for i, w in enumerate(it.workers):
- worker_pids[i] = w.pid
+ workers = it.workers
def kill_pid(pid):
- if IS_WINDOWS:
- os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
- else:
- os.kill(os.getpid(), signal.SIGKILL)
+ psutil_p = psutil.Process(pid)
+ psutil_p.kill()
+ psutil_p.wait(JOIN_TIMEOUT)
+ assert not psutil_p.is_running()
for i, _ in enumerate(it):
if i == 0:
if not hold_iter_reference:
del it
- setup_event.set()
+ loader_setup_event.set()
+ tester_setup_event.wait()
+ # ensure that the workers are still alive
+ if use_workers:
+ for w in workers:
+ assert w.is_alive()
+ if worker_error_event is not None:
+ worker_error_event.set()
+
if i == error_it:
- if exit_method == 'main_error':
- raise RuntimeError('Error')
- elif exit_method == 'main_kill':
+ if exit_method == 'loader_error':
+ raise RuntimeError('Loader error')
+ elif exit_method == 'loader_kill':
kill_pid(os.getpid())
elif exit_method == 'worker_kill':
- kill_pid(worker_pids[0])
+ kill_pid(workers[0].pid)
if not hold_iter_reference:
# Tries to trigger the __del__ clean-up rather than the automatic
pin_memory_thread.join(JOIN_TIMEOUT)
self.assertFalse(pin_memory_thread.is_alive())
- @staticmethod
- def _is_process_alive(pid, pname):
- # There is a chance of a terminated child process's pid being reused by a new unrelated process,
- # but since we are looping this check very frequently, we will know that the child process dies
- # before the new unrelated process starts.
- if IS_WINDOWS:
- command = 'tasklist | find "{}" /i'.format(pid)
- else:
- command = 'ps -p {} -o comm='.format(pid)
- p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
- (output, err) = p.communicate()
- p_status = p.wait()
- output = output.decode('utf-8')
- return pname in output
-
@skipIfRocm
+ @unittest.skipIf(not HAS_PSUTIL, "psutil not found")
def test_proper_exit(self):
(r'''There might be ConnectionResetError or leaked semaphore warning '''
r'''(due to dirty process exit), but they are all safe to ignore''')
# TODO: test the case where the pin_memory_thread triggers an
# error/fatal signal. I haven't found out how to properly do that.
- # Array to store the worker pids.
- worker_pids = mp.Array('i', [-1 for _ in range(10)])
-
- def wait_pids(pids, timeout):
- r"""Wait for all process specified in pids to exit in given timeout."""
- exit_status = [False for _ in pids]
- start_time = time.time()
- pname = 'python'
- while True:
- for i in range(len(pids)):
- pid = pids[i]
- if not exit_status[i]:
- if not TestDataLoader._is_process_alive(pid, pname):
- exit_status[i] = True
- if all(exit_status):
- break
- else:
- if time.time() - start_time > timeout:
- break
- time.sleep(0.5)
- return exit_status
-
for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
# `hold_iter_reference` specifies whether we hold a reference to the
# iterator. This is interesting because Python3 error traces holds a
# - `None` means that no error happens.
# In all cases, all processes should end properly.
if use_workers:
- exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error']
+ exit_methods = [None, 'loader_error', 'loader_kill', 'worker_kill', 'worker_error']
else:
- exit_methods = [None, 'main_error', 'main_kill']
+ exit_methods = [None, 'loader_error', 'loader_kill']
for exit_method in exit_methods:
- # clear pids array first
- for i in range(len(worker_pids)):
- worker_pids[i] = -1
+ desc = []
+ desc.append('use_workers={}'.format(use_workers))
+ desc.append('pin_memory={}'.format(pin_memory))
+ desc.append('hold_iter_reference={}'.format(hold_iter_reference))
+ desc.append('exit_method={}'.format(exit_method))
+ desc = 'test_proper_exit with ' + ', '.join(desc)
# Event that the loader process uses to signal testing process
# that various things are setup, including that the worker pids
# are specified in `worker_pids` array.
- setup_event = mp.Event()
-
- p = ErrorTrackingProcess(target=_test_proper_exit,
- args=(use_workers, pin_memory, exit_method,
- hold_iter_reference, worker_pids, setup_event))
- p.start()
+ loader_setup_event = mp.Event()
+
+ # Event that this process has finished setting up, and the
+ # loader process can now proceed to trigger error events or
+ # finish normally.
+ tester_setup_event = mp.Event()
+
+ loader_p = ErrorTrackingProcess(target=_test_proper_exit,
+ args=(use_workers, pin_memory, exit_method,
+ hold_iter_reference, loader_setup_event,
+ tester_setup_event))
+ loader_p.start()
+
+ # Wait for loader process to set everything up, e.g., starting
+ # workers.
+ loader_setup_event.wait(timeout=JOIN_TIMEOUT)
+ if not loader_setup_event.is_set():
+ fail_msg = desc + ': loader process failed to setup with given time'
+ if loader_p.exception is not None:
+ self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
+ elif not loader_p.is_alive():
+ self.fail(fail_msg + ', and exited with code {} but no exception'.format(loader_p.exitcode))
+ else:
+ self.fail(fail_msg + ', and is still alive.')
- # Wait for loader process to set everything up, i.e., filling
- # worker pids in `worker_pids`.
- setup_event.wait(timeout=JOIN_TIMEOUT)
- self.assertTrue(setup_event.is_set(), 'loader process setup timed out')
+ worker_psutil_p = psutil.Process(loader_p.pid).children()
- pids = [pid for pid in worker_pids if pid > 0]
+ tester_setup_event.set()
try:
- exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
- if not all(exit_status):
- self.fail('subprocess (pid(s) {}) not terminated'.format(
- ', '.join(p for p, exited in zip(pids, exit_status) if not exited)))
- p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
- self.assertFalse(p.is_alive(), 'loader process not terminated')
+ loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
+ self.assertFalse(loader_p.is_alive(), desc + ': loader process not terminated')
+ _, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
+ if len(alive) > 0:
+ self.fail(desc + ': worker process (pid(s) {}) not terminated'.format(
+ ', '.join(str(p.pid) for p in alive)))
if exit_method is None:
- self.assertEqual(p.exitcode, 0)
+ self.assertEqual(loader_p.exitcode, 0)
else:
- self.assertNotEqual(p.exitcode, 0)
+ self.assertNotEqual(loader_p.exitcode, 0)
+ if exit_method == 'loader_error':
+ self.assertIsInstance(loader_p.exception, RuntimeError, desc)
+ self.assertIn('Loader error', str(loader_p.exception), desc)
+ elif exit_method == 'worker_kill':
+ self.assertIsInstance(loader_p.exception, RuntimeError, desc)
+ self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc)
+ elif exit_method == 'worker_error':
+ self.assertIsInstance(loader_p.exception, RuntimeError, desc)
+ self.assertIn('Worker error', str(loader_p.exception), desc)
finally:
- p.terminate()
+ loader_p.terminate()
def test_len(self):
def check_len(dl, expected):
# happen when data in queue is corrupted (e.g., due to
# `cancel_join_thread` or unexpected exit).
#
- # For child exit, we register SIGCHLD handler on main process,
- # which checks if any of the workers fail in the (Python) handler.
- # See DataLoader.cpp.
+ # For child exit on Windows platform, we set a timeout whenever
+ # we get from `data_queue`, and check the workers' status on each
+ # timeout and error.
+ # See `_DataLoaderiter._get_batch()` and
+ # `_DataLoaderiter._try_get_batch()` for details
+ #
+ # For child exit on non-Windows platforms, we register a SIGCHLD
+ # handler (which is supported on Windows) on main process, which
+ # checks if any of the workers fail in the (Python) handler. This
+ # is more efficient and faster in detecting worker failures,
+ # compared to the above strategy applied for Windows.
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
#
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
def __len__(self):
return len(self.batch_sampler)
+ def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
+ # Tries to fetch data from `data_queue` for a given timeout. This can
+ # also be used as inner loop of fetching without timeout, with the
+ # sender status as the loop condition.
+ #
+ # This raises a RuntimeError if any worker died expectedly. This error
+ # comes from a SIGCHLD handler in `_utils/signal_handling.py` for
+ # non-Windows platforms, and comes from a manual check on errors and
+ # timeouts on Windows.
+ #
+ # Returns a 2-tuple:
+ # (bool: whether successfully get data, any: data if successful else None)
+ try:
+ data = self.data_queue.get(timeout=timeout)
+ return (True, data)
+ except Exception as e:
+ if _utils.IS_WINDOWS:
+ # Windows doesn't have SIGCHLD handler, so at timeout and error,
+ # we need to manually check whether any worker has failed.
+ if not all(w.is_alive() for w in self.workers):
+ pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive())
+ raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
+ if isinstance(e, queue.Empty):
+ return (False, None)
+ raise
+
def _get_batch(self):
- # In the non-timeout case, worker exit is covered by SIGCHLD handler.
- # But if `pin_memory=True`, we still need account for the possibility
- # that `pin_memory_thread` dies.
+ # Fetches data from `self.data_queue`.
+ #
+ # Worker exit is covered by the SIGCHLD handler in
+ # _utils/signal_handling.py for non-Windows platforms. For Windows, we
+ # must check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
+ # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)`
+ # in a loop. On Windows, The `self._try_get_batch` will check workers'
+ # status on errors and timeouts.
+ #
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
+ # died at timeouts.
if self.timeout > 0:
- try:
- return self.data_queue.get(timeout=self.timeout)
- except queue.Empty:
+ success, data = self._try_get_batch(self.timeout)
+ if success:
+ return data
+ else:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
elif self.pin_memory:
while self.pin_memory_thread.is_alive():
- try:
- return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
- except queue.Empty:
- continue
+ success, data = self._try_get_batch()
+ if success:
+ return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self.data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
- return self.data_queue.get()
+ if _utils.IS_WINDOWS:
+ # Windows doesn't have SIGCHLD handler and relies on the check
+ # in `self._try_get_batch()` to detect worker failures, so we
+ # need to do a while loop here.
+ while True:
+ success, data = self._try_get_batch()
+ if success:
+ return data
+ else:
+ return self.data_queue.get()
def __next__(self):
if self.num_workers == 0: # same-process loading