Fix TestDataLoader.test_proper_exit (#15665)
authorSsnL <tongzhou.wang.1994@gmail.com>
Thu, 10 Jan 2019 16:44:32 +0000 (08:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 16:47:27 +0000 (08:47 -0800)
Summary:
Currently, in `test_proper_exit`,
1. we do not kill the correct input `pid` in the `kill_pid` function
https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L325-L329
2. the Windows command that detects process status doesn't actually work
https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L641-L646
3. `worker_error` and `worker_kill` cases (sometimes?) are not tested because the workers may exit naturally due to the pre-fetching mechanism and a too small `dataset size / batch size`.

In this PR, I, in separate commits:
1. Install `psutil` (a python package specifically built for process monitoring) on some CI builds. (Linux builds installation are done in https://github.com/pietern/pytorch-dockerfiles/pull/29 https://github.com/pietern/pytorch-dockerfiles/pull/30  https://github.com/pytorch/ossci-job-dsl/pull/36 and https://github.com/pytorch/pytorch/pull/15795).
2. Rewrite `test_proper_exit` with `psutil` so we

    1. do not rely on the hacky `is_process_alive` https://github.com/pytorch/pytorch/blob/fe15d6a2c231a7bc1b32781217ed336ccf9adff7/test/test_dataloader.py#L640-L653
   2. increase the #task per worker so `worker_error` and `worker_kill` properly trigger
   3. test error message content to ensure that the loader exits with correct message corresponding to each exiting scenario.

3. Fix Windows data loader not having any mechanism to detect worker failures.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15665

Differential Revision: D13615527

Pulled By: soumith

fbshipit-source-id: cfb2f67837d2d87928a53f00b4d20f09754b7949

.jenkins/pytorch/macos-test.sh
.jenkins/pytorch/test.sh
.jenkins/pytorch/win-test.sh
test/test_dataloader.py
torch/utils/data/dataloader.py

index 1ce42eb..bd4c73e 100755 (executable)
@@ -16,7 +16,7 @@ fi
 export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH"
 source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate
 conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja six
-pip install hypothesis
+pip install hypothesis librosa>=0.6.2 psutil
 if [ -z "${IN_CIRCLECI}" ]; then
   rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch*
 fi
index cdd2669..cb0e724 100755 (executable)
@@ -78,6 +78,10 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
   export PYTORCH_TEST_WITH_ROCM=1
   export LANG=C.UTF-8
   export LC_ALL=C.UTF-8
+
+  # ROCm CI is using Caffe2 docker images, which doesn't have several packages
+  # needed in testing. We install them here.
+  pip install -q psutil librosa>=0.6.2 --user
 fi
 
 if [[ "${JOB_BASE_NAME}" == *-NO_AVX-* ]]; then
index cde9983..c2f4c97 100755 (executable)
@@ -54,7 +54,7 @@ if NOT "%BUILD_ENVIRONMENT%"=="" (
     :: We have to pin Python version to 3.6.7, until mkl supports Python 3.7
     call conda install -y -q python=3.6.7 numpy mkl cffi pyyaml boto3
 )
-pip install ninja future hypothesis
+pip install ninja future hypothesis librosa>=0.6.2 psutil
 
 set WORKING_DIR=%CD%
 call "C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Auxiliary\\Build\\vcvarsall.bat" x86_amd64
@@ -115,7 +115,7 @@ EOL
 cat >ci_scripts/test_libtorch.bat <<EOL
 call ci_scripts/setup_pytorch_env.bat
 dir
-dir %CD%\\test 
+dir %CD%\\test
 dir %CD%\\test\\torch
 dir %CD%\\test\\torch\\lib
 cd %CD%\\test\\torch\\lib
index cdd8139..291b9c0 100644 (file)
@@ -11,6 +11,7 @@ import traceback
 import unittest
 import subprocess
 import itertools
+import warnings
 from torch import multiprocessing as mp
 from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
 from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
@@ -18,6 +19,16 @@ from torch.utils.data.dataset import random_split
 from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN,
                           skipIfRocm, load_tests)
 
+try:
+    import psutil
+    HAS_PSUTIL = True
+except ImportError:
+    HAS_PSUTIL = False
+    warnings.warn(
+        "psutil not found. Some crucial data loader tests relying on it (e.g., "
+        "TestDataLoader.test_proper_exit) will not run.")
+
+
 # load_tests from common_utils is used to automatically filter tests for
 # sharding on sandcastle. This line silences flake warnings
 load_tests = load_tests
@@ -34,7 +45,7 @@ if not NO_MULTIPROCESSING_SPAWN:
     mp = mp.get_context(method='spawn')
 
 
-JOIN_TIMEOUT = 17.0 if IS_WINDOWS or IS_PPC else 8.5
+JOIN_TIMEOUT = 17.0 if (IS_WINDOWS or IS_PPC) else 11.0
 
 
 class TestDatasetRandomSplit(TestCase):
@@ -304,42 +315,58 @@ class TestProperExitDataset(object):
 
 # See TestDataLoader.test_proper_exit for usage
 def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
-                      worker_pids, setup_event):
+                      loader_setup_event, tester_setup_event):
     num_workers = 2 if use_workers else 0
 
     if exit_method == 'worker_error' or exit_method == 'worker_kill':
         assert use_workers is True
 
-    ds = TestProperExitDataset(10, setup_event if exit_method == 'worker_error' else None)
+    if exit_method == 'worker_error':
+        worker_error_event = mp.Event()
+    else:
+        worker_error_event = None
 
-    loader = DataLoader(ds, batch_size=2, shuffle=False,
+    ds = TestProperExitDataset(12, worker_error_event)
+
+    loader = DataLoader(ds, batch_size=1, shuffle=False,
                         num_workers=num_workers, pin_memory=pin_memory)
-    error_it = 4
-    assert len(loader) > error_it
+    error_it = 2
+
+    if use_workers:
+        # 2 is the magical per-worker prefetch number...
+        # FIXME: change this after the number becomes configurable.
+        assert len(loader) > (error_it + 2 + 1) * num_workers
 
     it = iter(loader)
     if use_workers:
-        for i, w in enumerate(it.workers):
-            worker_pids[i] = w.pid
+        workers = it.workers
 
     def kill_pid(pid):
-        if IS_WINDOWS:
-            os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
-        else:
-            os.kill(os.getpid(), signal.SIGKILL)
+        psutil_p = psutil.Process(pid)
+        psutil_p.kill()
+        psutil_p.wait(JOIN_TIMEOUT)
+        assert not psutil_p.is_running()
 
     for i, _ in enumerate(it):
         if i == 0:
             if not hold_iter_reference:
                 del it
-            setup_event.set()
+            loader_setup_event.set()
+            tester_setup_event.wait()
+            # ensure that the workers are still alive
+            if use_workers:
+                for w in workers:
+                    assert w.is_alive()
+            if worker_error_event is not None:
+                worker_error_event.set()
+
         if i == error_it:
-            if exit_method == 'main_error':
-                raise RuntimeError('Error')
-            elif exit_method == 'main_kill':
+            if exit_method == 'loader_error':
+                raise RuntimeError('Loader error')
+            elif exit_method == 'loader_kill':
                 kill_pid(os.getpid())
             elif exit_method == 'worker_kill':
-                kill_pid(worker_pids[0])
+                kill_pid(workers[0].pid)
 
     if not hold_iter_reference:
         # Tries to trigger the __del__ clean-up rather than the automatic
@@ -637,22 +664,8 @@ class TestDataLoader(TestCase):
                 pin_memory_thread.join(JOIN_TIMEOUT)
                 self.assertFalse(pin_memory_thread.is_alive())
 
-    @staticmethod
-    def _is_process_alive(pid, pname):
-        # There is a chance of a terminated child process's pid being reused by a new unrelated process,
-        # but since we are looping this check very frequently, we will know that the child process dies
-        # before the new unrelated process starts.
-        if IS_WINDOWS:
-            command = 'tasklist | find "{}" /i'.format(pid)
-        else:
-            command = 'ps -p {} -o comm='.format(pid)
-        p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
-        (output, err) = p.communicate()
-        p_status = p.wait()
-        output = output.decode('utf-8')
-        return pname in output
-
     @skipIfRocm
+    @unittest.skipIf(not HAS_PSUTIL, "psutil not found")
     def test_proper_exit(self):
         (r'''There might be ConnectionResetError or leaked semaphore warning '''
          r'''(due to dirty process exit), but they are all safe to ignore''')
@@ -660,28 +673,6 @@ class TestDataLoader(TestCase):
         # TODO: test the case where the pin_memory_thread triggers an
         #       error/fatal signal. I haven't found out how to properly do that.
 
-        # Array to store the worker pids.
-        worker_pids = mp.Array('i', [-1 for _ in range(10)])
-
-        def wait_pids(pids, timeout):
-            r"""Wait for all process specified in pids to exit in given timeout."""
-            exit_status = [False for _ in pids]
-            start_time = time.time()
-            pname = 'python'
-            while True:
-                for i in range(len(pids)):
-                    pid = pids[i]
-                    if not exit_status[i]:
-                        if not TestDataLoader._is_process_alive(pid, pname):
-                            exit_status[i] = True
-                if all(exit_status):
-                    break
-                else:
-                    if time.time() - start_time > timeout:
-                        break
-                    time.sleep(0.5)
-            return exit_status
-
         for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
             # `hold_iter_reference` specifies whether we hold a reference to the
             # iterator. This is interesting because Python3 error traces holds a
@@ -700,46 +691,73 @@ class TestDataLoader(TestCase):
             #   - `None` means that no error happens.
             # In all cases, all processes should end properly.
             if use_workers:
-                exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error']
+                exit_methods = [None, 'loader_error', 'loader_kill', 'worker_kill', 'worker_error']
             else:
-                exit_methods = [None, 'main_error', 'main_kill']
+                exit_methods = [None, 'loader_error', 'loader_kill']
 
             for exit_method in exit_methods:
 
-                # clear pids array first
-                for i in range(len(worker_pids)):
-                    worker_pids[i] = -1
+                desc = []
+                desc.append('use_workers={}'.format(use_workers))
+                desc.append('pin_memory={}'.format(pin_memory))
+                desc.append('hold_iter_reference={}'.format(hold_iter_reference))
+                desc.append('exit_method={}'.format(exit_method))
+                desc = 'test_proper_exit with ' + ', '.join(desc)
 
                 # Event that the loader process uses to signal testing process
                 # that various things are setup, including that the worker pids
                 # are specified in `worker_pids` array.
-                setup_event = mp.Event()
-
-                p = ErrorTrackingProcess(target=_test_proper_exit,
-                                         args=(use_workers, pin_memory, exit_method,
-                                               hold_iter_reference, worker_pids, setup_event))
-                p.start()
+                loader_setup_event = mp.Event()
+
+                # Event that this process has finished setting up, and the
+                # loader process can now proceed to trigger error events or
+                # finish normally.
+                tester_setup_event = mp.Event()
+
+                loader_p = ErrorTrackingProcess(target=_test_proper_exit,
+                                                args=(use_workers, pin_memory, exit_method,
+                                                      hold_iter_reference, loader_setup_event,
+                                                      tester_setup_event))
+                loader_p.start()
+
+                # Wait for loader process to set everything up, e.g., starting
+                # workers.
+                loader_setup_event.wait(timeout=JOIN_TIMEOUT)
+                if not loader_setup_event.is_set():
+                    fail_msg = desc + ': loader process failed to setup with given time'
+                    if loader_p.exception is not None:
+                        self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
+                    elif not loader_p.is_alive():
+                        self.fail(fail_msg + ', and exited with code {} but no exception'.format(loader_p.exitcode))
+                    else:
+                        self.fail(fail_msg + ', and is still alive.')
 
-                # Wait for loader process to set everything up, i.e., filling
-                # worker pids in `worker_pids`.
-                setup_event.wait(timeout=JOIN_TIMEOUT)
-                self.assertTrue(setup_event.is_set(), 'loader process setup timed out')
+                worker_psutil_p = psutil.Process(loader_p.pid).children()
 
-                pids = [pid for pid in worker_pids if pid > 0]
+                tester_setup_event.set()
 
                 try:
-                    exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
-                    if not all(exit_status):
-                        self.fail('subprocess (pid(s) {}) not terminated'.format(
-                            ', '.join(p for p, exited in zip(pids, exit_status) if not exited)))
-                    p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
-                    self.assertFalse(p.is_alive(), 'loader process not terminated')
+                    loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
+                    self.assertFalse(loader_p.is_alive(), desc + ': loader process not terminated')
+                    _, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
+                    if len(alive) > 0:
+                        self.fail(desc + ': worker process (pid(s) {}) not terminated'.format(
+                            ', '.join(str(p.pid) for p in alive)))
                     if exit_method is None:
-                        self.assertEqual(p.exitcode, 0)
+                        self.assertEqual(loader_p.exitcode, 0)
                     else:
-                        self.assertNotEqual(p.exitcode, 0)
+                        self.assertNotEqual(loader_p.exitcode, 0)
+                        if exit_method == 'loader_error':
+                            self.assertIsInstance(loader_p.exception, RuntimeError, desc)
+                            self.assertIn('Loader error', str(loader_p.exception), desc)
+                        elif exit_method == 'worker_kill':
+                            self.assertIsInstance(loader_p.exception, RuntimeError, desc)
+                            self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc)
+                        elif exit_method == 'worker_error':
+                            self.assertIsInstance(loader_p.exception, RuntimeError, desc)
+                            self.assertIn('Worker error', str(loader_p.exception), desc)
                 finally:
-                    p.terminate()
+                    loader_p.terminate()
 
     def test_len(self):
         def check_len(dl, expected):
index a8684e9..997d0cd 100644 (file)
@@ -243,9 +243,18 @@ class _DataLoaderIter(object):
     #           happen when data in queue is corrupted (e.g., due to
     #           `cancel_join_thread` or unexpected exit).
     #
-    #           For child exit, we register SIGCHLD handler on main process,
-    #           which checks if any of the workers fail in the (Python) handler.
-    #           See DataLoader.cpp.
+    #           For child exit on Windows platform, we set a timeout whenever
+    #           we get from `data_queue`, and check the workers' status on each
+    #           timeout and error.
+    #           See `_DataLoaderiter._get_batch()` and
+    #           `_DataLoaderiter._try_get_batch()` for details
+    #
+    #           For child exit on non-Windows platforms, we register a SIGCHLD
+    #           handler (which is supported on Windows) on main process, which
+    #           checks if any of the workers fail in the (Python) handler. This
+    #           is more efficient and faster in detecting worker failures,
+    #           compared to the above strategy applied for Windows.
+    #           See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
     #
     #           For `.get()` calls where the sender(s) is not the workers, we
     #           guard them with timeouts, and check the status of the sender
@@ -417,28 +426,71 @@ class _DataLoaderIter(object):
     def __len__(self):
         return len(self.batch_sampler)
 
+    def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
+        # Tries to fetch data from `data_queue` for a given timeout. This can
+        # also be used as inner loop of fetching without timeout, with the
+        # sender status as the loop condition.
+        #
+        # This raises a RuntimeError if any worker died expectedly. This error
+        # comes from a SIGCHLD handler in `_utils/signal_handling.py` for
+        # non-Windows platforms, and comes from a manual check on errors and
+        # timeouts on Windows.
+        #
+        # Returns a 2-tuple:
+        #   (bool: whether successfully get data, any: data if successful else None)
+        try:
+            data = self.data_queue.get(timeout=timeout)
+            return (True, data)
+        except Exception as e:
+            if _utils.IS_WINDOWS:
+                # Windows doesn't have SIGCHLD handler, so at timeout and error,
+                # we need to manually check whether any worker has failed.
+                if not all(w.is_alive() for w in self.workers):
+                    pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive())
+                    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
+            if isinstance(e, queue.Empty):
+                return (False, None)
+            raise
+
     def _get_batch(self):
-        # In the non-timeout case, worker exit is covered by SIGCHLD handler.
-        # But if `pin_memory=True`, we still need account for the possibility
-        # that `pin_memory_thread` dies.
+        # Fetches data from `self.data_queue`.
+        #
+        # Worker exit is covered by the SIGCHLD handler in
+        # _utils/signal_handling.py for non-Windows platforms. For Windows, we
+        # must check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
+        # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)`
+        # in a loop. On Windows, The `self._try_get_batch` will check workers'
+        # status on errors and timeouts.
+        #
+        # If `pin_memory=True`, we also need check if `pin_memory_thread` had
+        # died at timeouts.
         if self.timeout > 0:
-            try:
-                return self.data_queue.get(timeout=self.timeout)
-            except queue.Empty:
+            success, data = self._try_get_batch(self.timeout)
+            if success:
+                return data
+            else:
                 raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
         elif self.pin_memory:
             while self.pin_memory_thread.is_alive():
-                try:
-                    return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
-                except queue.Empty:
-                    continue
+                success, data = self._try_get_batch()
+                if success:
+                    return data
             else:
                 # while condition is false, i.e., pin_memory_thread died.
                 raise RuntimeError('Pin memory thread exited unexpectedly')
             # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
             # need to call `.task_done()` because we don't use `.join()`.
         else:
-            return self.data_queue.get()
+            if _utils.IS_WINDOWS:
+                # Windows doesn't have SIGCHLD handler and relies on the check
+                # in `self._try_get_batch()` to detect worker failures, so we
+                # need to do a while loop here.
+                while True:
+                    success, data = self._try_get_batch()
+                    if success:
+                        return data
+            else:
+                return self.data_queue.get()
 
     def __next__(self):
         if self.num_workers == 0:  # same-process loading