Refactor dataloader.py (#15331)
authorSsnL <tongzhou.wang.1994@gmail.com>
Wed, 19 Dec 2018 20:26:44 +0000 (12:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 20:36:03 +0000 (12:36 -0800)
Summary:
Same as #14668, and was approved there.

ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you!

Below is the original description at #14668:

As I am working on tasks in https://github.com/pytorch/pytorch/issues/13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15331

Reviewed By: yf225

Differential Revision: D13503120

Pulled By: ailzhang

fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e

test/test_dataloader.py
torch/_six.py
torch/csrc/DataLoader.cpp
torch/utils/data/__init__.py
torch/utils/data/_utils/__init__.py [new file with mode: 0644]
torch/utils/data/_utils/collate.py [new file with mode: 0644]
torch/utils/data/_utils/pin_memory.py [new file with mode: 0644]
torch/utils/data/_utils/signal_handling.py [new file with mode: 0644]
torch/utils/data/_utils/worker.py [new file with mode: 0644]
torch/utils/data/dataloader.py

index e1ca129..9db11b0 100644 (file)
@@ -12,9 +12,9 @@ import unittest
 import subprocess
 import itertools
 from torch import multiprocessing as mp
-from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
+from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
+from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
 from torch.utils.data.dataset import random_split
-from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
 from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN,
                           skipIfRocm, load_tests)
 
@@ -788,16 +788,16 @@ class TestDataLoader(TestCase):
 
         # Should be a no-op
         arr = np.array(['a', 'b', 'c'])
-        default_collate(arr)
+        _utils.collate.default_collate(arr)
 
         arr = np.array([[['a', 'b', 'c']]])
-        self.assertRaises(TypeError, lambda: default_collate(arr))
+        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
 
         arr = np.array([object(), object(), object()])
-        self.assertRaises(TypeError, lambda: default_collate(arr))
+        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
 
         arr = np.array([[[object(), object(), object()]]])
-        self.assertRaises(TypeError, lambda: default_collate(arr))
+        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
 
 
 class StringDataset(Dataset):
index 924e641..3a4c2ad 100644 (file)
@@ -51,6 +51,12 @@ else:
     FileNotFoundError = FileNotFoundError
 
 
+if PY2:
+    import Queue as queue
+else:
+    import queue
+
+
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass."""
     # This requires a bit of explanation: the basic idea is to make a dummy
index 0b23b4d..beb7bf9 100644 (file)
@@ -1,11 +1,10 @@
 #include <torch/csrc/DataLoader.h>
 
-// In cases like DataLoader, if a worker process dies due to bus error/segfault
-// or just hang, the main process will hang waiting for data. This is difficult
-// to avoid on PyTorch side as it can be caused by limited shm, or other
-// libraries users call in the workers. The following methods is an effort to do
-// our best to provide some error message to users when such unfortunate events
-// happen.
+// Together with `torch/utils/data/_utils/signal_handling.py`, the following
+// is an effort to do our best to provide some error message to users when a
+// worker dies due to error / critical signals.
+//
+// See NOTE [ Signal handling in multiprocessing data loading ] for more details.
 
 // TODO: The following don't work on Windows. Specifically, sigaction, waitid
 // calls, and SIGCHLD handler. Currently, dummy implementations are provided
@@ -128,7 +127,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
         // workers, and trigger this again.
         pid_set->clear();
         throw std::runtime_error(oss.str());
-      }  else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) {  // killed by signal
+      } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) {  // killed by signal
         std::ostringstream oss;
         oss << "DataLoader worker (pid " << worker_pid << ") is killed "
             << "by signal: " << strsignal(infop.si_status) << ". ";
@@ -145,18 +144,18 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
 
 // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple
 // of pids we are interested in.
-static PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) {
+static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) {
   HANDLE_TH_ERRORS
   if (PyTuple_GET_SIZE(args) != 2) {
-    throw TypeError("_update_worker_pids expectes exactly 2 arguments.");
+    throw TypeError("_set_worker_pids expects exactly 2 arguments.");
   }
   int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
   if (worker_pids.find(key) != worker_pids.end()) {
-    throw ValueError("_update_worker_pids should be called only once for each _DataLoaderIter.");
+    throw ValueError("_set_worker_pids should be called only once for each _DataLoaderIter.");
   }
   PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
   if (!PyTuple_Check(child_pids)) {
-    throw TypeError("_update_worker_pids expects a tuple for child_pids, but got %s.",
+    throw TypeError("_set_worker_pids expects a tuple for child_pids, but got %s.",
         Py_TYPE(child_pids)->tp_name);
   }
 
@@ -164,7 +163,7 @@ static PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) {
   auto size = PyTuple_GET_SIZE(child_pids);
   for (int idx = 0; idx < size; idx++) {
     PyObject* obj = PyTuple_GET_ITEM(child_pids, idx);
-    pids_set.insert((pid_t) THPUtils_unpackLong(obj));
+    pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj)));
   }
 
   worker_pids[key] = pids_set;
@@ -196,7 +195,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_
   Py_RETURN_NONE;
 }
 
-static PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) {
+static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *_ignored) {
   Py_RETURN_NONE;
 }
 
@@ -212,7 +211,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *_ig
 
 PyMethodDef DataLoaderMethods[] = {
   {"_set_worker_signal_handlers",  (PyCFunction)THPModule_setWorkerSignalHandlers,  METH_NOARGS,   nullptr},
-  {"_update_worker_pids",          (PyCFunction)THPModule_updateWorkerPIDs,         METH_VARARGS,  nullptr},
+  {"_set_worker_pids",             (PyCFunction)THPModule_setWorkerPIDs,            METH_VARARGS,  nullptr},
   {"_remove_worker_pids",          (PyCFunction)THPModule_removeWorkerPIDs,         METH_O,        nullptr},
   {"_error_if_any_worker_fails",   (PyCFunction)THPModule_errorIfAnyWorkerFails,    METH_NOARGS,   nullptr},
   {nullptr, nullptr, 0, nullptr}
index 05c94d1..ee58707 100644 (file)
@@ -1,4 +1,3 @@
-
 from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
 from .distributed import DistributedSampler
 from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split
diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py
new file mode 100644 (file)
index 0000000..05b2b65
--- /dev/null
@@ -0,0 +1,61 @@
+r"""Utility classes & functions for data loading. Code in this folder is mostly
+used by ../dataloder.py.
+
+A lot of multiprocessing is used in data loading, which only supports running
+functions defined in global environment (py2 can't serialize static methods).
+Therefore, for code tidiness we put these functions into different files in this
+folder.
+"""
+
+import sys
+import traceback
+import atexit
+
+
+IS_WINDOWS = sys.platform == "win32"
+
+
+# NOTE [ Python Traceback Reference Cycle Problem ]
+#
+# When using sys.exc_info(), it is important to **not** store the exc_info[2],
+# which is the traceback, because otherwise you will run into the traceback
+# reference cycle problem, i.e., the traceback holding reference to the frame,
+# and the frame (which holds reference to all the object in its temporary scope)
+# holding reference the traceback.
+
+
+class ExceptionWrapper(object):
+    r"""Wraps an exception plus traceback to communicate across threads"""
+    def __init__(self, exc_info):
+        # It is important that we don't store exc_info, see
+        # NOTE [ Python Traceback Reference Cycle Problem ]
+        self.exc_type = exc_info[0]
+        self.exc_msg = "".join(traceback.format_exception(*exc_info))
+
+
+MP_STATUS_CHECK_INTERVAL = 5.0
+r"""Interval (in seconds) to check status of processes to avoid hanging in
+    multiprocessing data loading. This is mainly used in getting data from
+    another process, in which case we need to periodically check whether the
+    sender is alive to prevent hanging."""
+
+
+python_exit_status = False
+r"""Whether Python is shutting down. This flag is guaranteed to be set before
+the Python core library resources are freed, but Python may already be exiting
+for some time when this is set.
+
+Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
+hook in Python 3.7 multiprocessing library:
+https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
+"""
+
+
+def _set_python_exit_flag():
+    global python_exit_status
+    python_exit_status = True
+
+atexit.register(_set_python_exit_flag)
+
+
+from . import worker, signal_handling, pin_memory, collate
diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py
new file mode 100644 (file)
index 0000000..e0823b2
--- /dev/null
@@ -0,0 +1,68 @@
+r""""Contains definitions of the methods used by the _DataLoaderIter workers to
+collate samples fetched from dataset into Tensor(s).
+
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import torch
+import re
+from torch._six import container_abcs, string_classes, int_classes
+
+_use_shared_memory = False
+r"""Whether to use shared memory in default_collate"""
+
+np_str_obj_array_pattern = re.compile(r'[SaUO]')
+
+error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
+
+numpy_type_map = {
+    'float64': torch.DoubleTensor,
+    'float32': torch.FloatTensor,
+    'float16': torch.HalfTensor,
+    'int64': torch.LongTensor,
+    'int32': torch.IntTensor,
+    'int16': torch.ShortTensor,
+    'int8': torch.CharTensor,
+    'uint8': torch.ByteTensor,
+}
+
+
+def default_collate(batch):
+    r"""Puts each data field into a tensor with outer dimension batch size"""
+
+    elem_type = type(batch[0])
+    if isinstance(batch[0], torch.Tensor):
+        out = None
+        if _use_shared_memory:
+            # If we're in a background process, concatenate directly into a
+            # shared memory tensor to avoid an extra copy
+            numel = sum([x.numel() for x in batch])
+            storage = batch[0].storage()._new_shared(numel)
+            out = batch[0].new(storage)
+        return torch.stack(batch, 0, out=out)
+    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+            and elem_type.__name__ != 'string_':
+        elem = batch[0]
+        if elem_type.__name__ == 'ndarray':
+            # array of string classes and object
+            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+                raise TypeError(error_msg_fmt.format(elem.dtype))
+
+            return torch.stack([torch.from_numpy(b) for b in batch], 0)
+        if elem.shape == ():  # scalars
+            py_type = float if elem.dtype.name.startswith('float') else int
+            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
+    elif isinstance(batch[0], int_classes):
+        return torch.LongTensor(batch)
+    elif isinstance(batch[0], float):
+        return torch.DoubleTensor(batch)
+    elif isinstance(batch[0], string_classes):
+        return batch
+    elif isinstance(batch[0], container_abcs.Mapping):
+        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
+    elif isinstance(batch[0], container_abcs.Sequence):
+        transposed = zip(*batch)
+        return [default_collate(samples) for samples in transposed]
+
+    raise TypeError((error_msg_fmt.format(type(batch[0]))))
diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py
new file mode 100644 (file)
index 0000000..8403b42
--- /dev/null
@@ -0,0 +1,57 @@
+r""""Contains definitions of the methods used by the _DataLoaderIter to put
+fetched tensors into pinned memory.
+
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import torch
+from torch._six import queue, container_abcs, string_classes
+from . import collate, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper
+
+
+def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
+    torch.cuda.set_device(device_id)
+
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+    while True:
+        try:
+            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+        except queue.Empty:
+            continue
+        except Exception:
+            if done_event.is_set():
+                # Weird things can happen when shutting down, e.g., fd being
+                # closed when tensors are shared via fds.
+                break
+            raise
+        if r is None:
+            assert done_event.is_set()
+            return
+        elif done_event.is_set():
+            # Haven't seen the final signal yet. Keep getting until None.
+            continue
+        elif isinstance(r[1], ExceptionWrapper):
+            out_queue.put(r)
+        else:
+            idx, batch = r
+            try:
+                batch = pin_memory_batch(batch)
+            except Exception:
+                out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+            else:
+                out_queue.put((idx, batch))
+
+
+def pin_memory_batch(batch):
+    if isinstance(batch, torch.Tensor):
+        return batch.pin_memory()
+    elif isinstance(batch, string_classes):
+        return batch
+    elif isinstance(batch, container_abcs.Mapping):
+        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
+    elif isinstance(batch, container_abcs.Sequence):
+        return [pin_memory_batch(sample) for sample in batch]
+    else:
+        return batch
diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py
new file mode 100644 (file)
index 0000000..9364733
--- /dev/null
@@ -0,0 +1,70 @@
+r""""Signal handling for multiprocessing data loading.
+
+NOTE [ Signal handling in multiprocessing data loading ]
+
+In cases like DataLoader, if a worker process dies due to bus error/segfault
+or just hang, the main process will hang waiting for data. This is difficult
+to avoid on PyTorch side as it can be caused by limited shm, or other
+libraries users call in the workers. In this file and `DataLoader.cpp`, we make
+our best effort to provide some error message to users when such unfortunate
+events happen.
+
+When a _DataLoaderIter starts worker processes, their pids are registered in a
+defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ]
+via `_set_worker_pids`.
+
+When an error happens in a worker process, the main process received a SIGCHLD,
+and Python will eventually call the handler registered below
+(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
+call checks all registered worker pids and raise proper error message to
+prevent main process from hanging waiting for data from worker.
+
+Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
+`_set_worker_signal_handlers` is called to register critical signal handlers
+(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
+message to stderr before triggering the default handler. So a message will also
+be printed from the worker process when it is killed by such signals.
+
+See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
+this signal handling design and other mechanism we implement to make our
+multiprocessing data loading robust to errors.
+"""
+
+import signal
+import threading
+import torch
+from torch._C import _set_worker_pids, _remove_worker_pids, \
+    _error_if_any_worker_fails, _set_worker_signal_handlers
+from . import IS_WINDOWS
+
+
+_SIGCHLD_handler_set = False
+r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
+handler needs to be set for all DataLoaders in a process."""
+
+
+def _set_SIGCHLD_handler():
+    # Windows doesn't support SIGCHLD handler
+    if IS_WINDOWS:
+        return
+    # can't set signal in child threads
+    if not isinstance(threading.current_thread(), threading._MainThread):
+        return
+    global _SIGCHLD_handler_set
+    if _SIGCHLD_handler_set:
+        return
+    previous_handler = signal.getsignal(signal.SIGCHLD)
+    if not callable(previous_handler):
+        # This doesn't catch default handler, but SIGCHLD default handler is a
+        # no-op.
+        previous_handler = None
+
+    def handler(signum, frame):
+        # This following call uses `waitid` with WNOHANG from C side. Therefore,
+        # Python can still get and update the process status successfully.
+        _error_if_any_worker_fails()
+        if previous_handler is not None:
+            previous_handler(signum, frame)
+
+    signal.signal(signal.SIGCHLD, handler)
+    _SIGCHLD_handler_set = True
diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py
new file mode 100644 (file)
index 0000000..2258c6f
--- /dev/null
@@ -0,0 +1,109 @@
+r""""Contains definitions of the methods used by the _DataLoaderIter workers.
+
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import torch
+import random
+import sys
+import os
+from torch._six import queue
+from . import collate, signal_handling, MP_STATUS_CHECK_INTERVAL, \
+    ExceptionWrapper, IS_WINDOWS
+
+if IS_WINDOWS:
+    import ctypes
+    from ctypes.wintypes import DWORD, BOOL, HANDLE
+
+    # On Windows, the parent ID of the worker process remains unchanged when the manager process
+    # is gone, and the only way to check it through OS is to let the worker have a process handle
+    # of the manager and ask if the process status has changed.
+    class ManagerWatchdog(object):
+        def __init__(self):
+            self.manager_pid = os.getppid()
+
+            self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
+            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
+            self.kernel32.OpenProcess.restype = HANDLE
+            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
+            self.kernel32.WaitForSingleObject.restype = DWORD
+
+            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
+            SYNCHRONIZE = 0x00100000
+            self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
+
+            if not self.manager_handle:
+                raise ctypes.WinError(ctypes.get_last_error())
+
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+                self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+            return not self.manager_dead
+else:
+    class ManagerWatchdog(object):
+        def __init__(self):
+            self.manager_pid = os.getppid()
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                self.manager_dead = os.getppid() != self.manager_pid
+            return not self.manager_dead
+
+
+def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+
+    try:
+        collate._use_shared_memory = True
+
+        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+        # module's handlers are executed after Python returns from C low-level
+        # handlers, likely when the same fatal signal had already happened
+        # again.
+        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
+        signal_handling._set_worker_signal_handlers()
+
+        torch.set_num_threads(1)
+        random.seed(seed)
+        torch.manual_seed(seed)
+
+        data_queue.cancel_join_thread()
+
+        if init_fn is not None:
+            init_fn(worker_id)
+
+        watchdog = ManagerWatchdog()
+
+        while watchdog.is_alive():
+            try:
+                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+            except queue.Empty:
+                continue
+            if r is None:
+                # Received the final signal
+                assert done_event.is_set()
+                return
+            elif done_event.is_set():
+                # Done event is set. But I haven't received the final signal
+                # (None) yet. I will keep continuing until get it, and skip the
+                # processing steps.
+                continue
+            idx, batch_indices = r
+            try:
+                samples = collate_fn([dataset[i] for i in batch_indices])
+            except Exception:
+                # It is important that we don't store exc_info in a variable,
+                # see NOTE [ Python Traceback Reference Cycle Problem ]
+                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+            else:
+                data_queue.put((idx, samples))
+                del samples
+    except KeyboardInterrupt:
+        # Main process will raise KeyboardInterrupt anyways.
+        pass
index 122272b..a8684e9 100644 (file)
-import random
+r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
+
+To support these two classes, in `./_utils` we define many utility methods and
+functions to be run in multiprocessing. E.g., the data loading worker loop is
+in `./_utils/worker.py`.
+"""
+
 import torch
 import torch.multiprocessing as multiprocessing
-from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
-    _remove_worker_pids, _error_if_any_worker_fails
 from . import SequentialSampler, RandomSampler, BatchSampler
-import signal
-import functools
-from torch._six import container_abcs
-import re
-import sys
+from . import _utils
 import threading
-import traceback
-import os
-import time
-import atexit
-from torch._six import string_classes, int_classes, FileNotFoundError
-
-IS_WINDOWS = sys.platform == "win32"
-if IS_WINDOWS:
-    import ctypes
-    from ctypes.wintypes import DWORD, BOOL, HANDLE
-
-if sys.version_info[0] == 2:
-    import Queue as queue
-else:
-    import queue
-
-
-# NOTE [ Python Traceback Reference Cycle Problem ]
-#
-# When using sys.exc_info(), it is important to **not** store the exc_info[2],
-# which is the traceback, because otherwise you will run into the traceback
-# reference cycle problem, i.e., the traceback holding reference to the frame,
-# and the frame (which holds reference to all the object in its temporary scope)
-# holding reference the traceback.
-
-
-class ExceptionWrapper(object):
-    r"""Wraps an exception plus traceback to communicate across threads"""
-    def __init__(self, exc_info):
-        # It is important that we don't store exc_info, see
-        # NOTE [ Python Traceback Reference Cycle Problem ]
-        self.exc_type = exc_info[0]
-        self.exc_msg = "".join(traceback.format_exception(*exc_info))
-
-
-_use_shared_memory = False
-r"""Whether to use shared memory in default_collate"""
-
-MP_STATUS_CHECK_INTERVAL = 5.0
-r"""Interval (in seconds) to check status of processes to avoid hanging in
-    multiprocessing data loading. This is mainly used in getting data from
-    another process, in which case we need to periodically check whether the
-    sender is alive to prevent hanging."""
-
-if IS_WINDOWS:
-    # On Windows, the parent ID of the worker process remains unchanged when the manager process
-    # is gone, and the only way to check it through OS is to let the worker have a process handle
-    # of the manager and ask if the process status has changed.
-    class ManagerWatchdog(object):
-        def __init__(self):
-            self.manager_pid = os.getppid()
-
-            self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
-            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
-            self.kernel32.OpenProcess.restype = HANDLE
-            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
-            self.kernel32.WaitForSingleObject.restype = DWORD
-
-            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
-            SYNCHRONIZE = 0x00100000
-            self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
-
-            if not self.manager_handle:
-                raise ctypes.WinError(ctypes.get_last_error())
-
-            self.manager_dead = False
-
-        def is_alive(self):
-            if not self.manager_dead:
-                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
-                self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
-            return not self.manager_dead
-else:
-    class ManagerWatchdog(object):
-        def __init__(self):
-            self.manager_pid = os.getppid()
-            self.manager_dead = False
-
-        def is_alive(self):
-            if not self.manager_dead:
-                self.manager_dead = os.getppid() != self.manager_pid
-            return not self.manager_dead
-
-
-def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
-    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
-    # logic of this function.
-
-    try:
-        global _use_shared_memory
-        _use_shared_memory = True
-
-        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
-        # module's handlers are executed after Python returns from C low-level
-        # handlers, likely when the same fatal signal had already happened.
-        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
-        _set_worker_signal_handlers()
-
-        torch.set_num_threads(1)
-        random.seed(seed)
-        torch.manual_seed(seed)
-
-        data_queue.cancel_join_thread()
-
-        if init_fn is not None:
-            init_fn(worker_id)
-
-        watchdog = ManagerWatchdog()
-
-        while watchdog.is_alive():
-            try:
-                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
-            except queue.Empty:
-                continue
-            if r is None:
-                # Received the final signal
-                assert done_event.is_set()
-                return
-            elif done_event.is_set():
-                # Done event is set. But I haven't received the final signal
-                # (None) yet. I will keep continuing until get it, and skip the
-                # processing steps.
-                continue
-            idx, batch_indices = r
-            try:
-                samples = collate_fn([dataset[i] for i in batch_indices])
-            except Exception:
-                # It is important that we don't store exc_info in a variable,
-                # see NOTE [ Python Traceback Reference Cycle Problem ]
-                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
-            else:
-                data_queue.put((idx, samples))
-                del samples
-    except KeyboardInterrupt:
-        # Main process will raise KeyboardInterrupt anyways.
-        pass
-
-
-def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
-    torch.cuda.set_device(device_id)
-
-    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
-    # logic of this function.
-    while True:
-        try:
-            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
-        except queue.Empty:
-            continue
-        except Exception:
-            if done_event.is_set():
-                # Weird things can happen when shutting down, e.g., fd being
-                # closed when tensors are shared via fds.
-                break
-            raise
-        if r is None:
-            assert done_event.is_set()
-            return
-        elif done_event.is_set():
-            # Haven't seen the final signal yet. Keep getting until None.
-            continue
-        elif isinstance(r[1], ExceptionWrapper):
-            out_queue.put(r)
-        else:
-            idx, batch = r
-            try:
-                batch = pin_memory_batch(batch)
-            except Exception:
-                out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
-            else:
-                out_queue.put((idx, batch))
-
-numpy_type_map = {
-    'float64': torch.DoubleTensor,
-    'float32': torch.FloatTensor,
-    'float16': torch.HalfTensor,
-    'int64': torch.LongTensor,
-    'int32': torch.IntTensor,
-    'int16': torch.ShortTensor,
-    'int8': torch.CharTensor,
-    'uint8': torch.ByteTensor,
-}
-
-
-def default_collate(batch):
-    r"""Puts each data field into a tensor with outer dimension batch size"""
-
-    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
-    elem_type = type(batch[0])
-    if isinstance(batch[0], torch.Tensor):
-        out = None
-        if _use_shared_memory:
-            # If we're in a background process, concatenate directly into a
-            # shared memory tensor to avoid an extra copy
-            numel = sum([x.numel() for x in batch])
-            storage = batch[0].storage()._new_shared(numel)
-            out = batch[0].new(storage)
-        return torch.stack(batch, 0, out=out)
-    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
-            and elem_type.__name__ != 'string_':
-        elem = batch[0]
-        if elem_type.__name__ == 'ndarray':
-            # array of string classes and object
-            if re.search('[SaUO]', elem.dtype.str) is not None:
-                raise TypeError(error_msg.format(elem.dtype))
-
-            return torch.stack([torch.from_numpy(b) for b in batch], 0)
-        if elem.shape == ():  # scalars
-            py_type = float if elem.dtype.name.startswith('float') else int
-            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
-    elif isinstance(batch[0], int_classes):
-        return torch.LongTensor(batch)
-    elif isinstance(batch[0], float):
-        return torch.DoubleTensor(batch)
-    elif isinstance(batch[0], string_classes):
-        return batch
-    elif isinstance(batch[0], container_abcs.Mapping):
-        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
-    elif isinstance(batch[0], container_abcs.Sequence):
-        transposed = zip(*batch)
-        return [default_collate(samples) for samples in transposed]
+from torch._six import queue
 
-    raise TypeError((error_msg.format(type(batch[0]))))
 
+# This function used to be defined in this file. However, it was moved to
+# _utils/collate.py. Although it is rather hard to access this from user land
+# (one has to explicitly directly `import torch.utils.data.dataloader`), there
+# probably is user code out there using it. This aliasing maintains BC in this
+# aspect.
+default_collate = _utils.collate.default_collate
 
-def pin_memory_batch(batch):
-    if isinstance(batch, torch.Tensor):
-        return batch.pin_memory()
-    elif isinstance(batch, string_classes):
-        return batch
-    elif isinstance(batch, container_abcs.Mapping):
-        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
-    elif isinstance(batch, container_abcs.Sequence):
-        return [pin_memory_batch(sample) for sample in batch]
-    else:
-        return batch
 
+class DataLoader(object):
+    r"""
+    Data loader. Combines a dataset and a sampler, and provides
+    single- or multi-process iterators over the dataset.
 
-_SIGCHLD_handler_set = False
-r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
-handler needs to be set for all DataLoaders in a process."""
-
-
-def _set_SIGCHLD_handler():
-    # Windows doesn't support SIGCHLD handler
-    if sys.platform == 'win32':
-        return
-    # can't set signal in child threads
-    if not isinstance(threading.current_thread(), threading._MainThread):
-        return
-    global _SIGCHLD_handler_set
-    if _SIGCHLD_handler_set:
-        return
-    previous_handler = signal.getsignal(signal.SIGCHLD)
-    if not callable(previous_handler):
-        # This doesn't catch default handler, but SIGCHLD default handler is a
-        # no-op.
-        previous_handler = None
-
-    def handler(signum, frame):
-        # This following call uses `waitid` with WNOHANG from C side. Therefore,
-        # Python can still get and update the process status successfully.
-        _error_if_any_worker_fails()
-        if previous_handler is not None:
-            previous_handler(signum, frame)
-
-    signal.signal(signal.SIGCHLD, handler)
-    _SIGCHLD_handler_set = True
-
-
-_python_exit_status = False
-r"""Whether Python is shutting down. This flag is guaranteed to be set before
-the Python core library resources are freed, but Python may already be exiting
-for some time when this is set.
-
-Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
-hook in Python 3.7 multiprocessing library:
-https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
-"""
+    Arguments:
+        dataset (Dataset): dataset from which to load the data.
+        batch_size (int, optional): how many samples per batch to load
+            (default: ``1``).
+        shuffle (bool, optional): set to ``True`` to have the data reshuffled
+            at every epoch (default: ``False``).
+        sampler (Sampler, optional): defines the strategy to draw samples from
+            the dataset. If specified, ``shuffle`` must be False.
+        batch_sampler (Sampler, optional): like sampler, but returns a batch of
+            indices at a time. Mutually exclusive with :attr:`batch_size`,
+            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
+        num_workers (int, optional): how many subprocesses to use for data
+            loading. 0 means that the data will be loaded in the main process.
+            (default: ``0``)
+        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
+        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
+            into CUDA pinned memory before returning them.
+        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+            if the dataset size is not divisible by the batch size. If ``False`` and
+            the size of dataset is not divisible by the batch size, then the last batch
+            will be smaller. (default: ``False``)
+        timeout (numeric, optional): if positive, the timeout value for collecting a batch
+            from workers. Should always be non-negative. (default: ``0``)
+        worker_init_fn (callable, optional): If not ``None``, this will be called on each
+            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+            input, after seeding and before data loading. (default: ``None``)
+
+    .. note:: By default, each worker will have its PyTorch seed set to
+              ``base_seed + worker_id``, where ``base_seed`` is a long generated
+              by main process using its RNG. However, seeds for other libraies
+              may be duplicated upon initializing workers (w.g., NumPy), causing
+              each worker to return identical random numbers. (See
+              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
+              use :func:`torch.initial_seed()` to access the PyTorch seed for
+              each worker in :attr:`worker_init_fn`, and use it to set other
+              seeds before data loading.
+
+    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
+                 unpicklable object, e.g., a lambda function.
+    """
+
+    __initialized = False
+
+    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
+                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
+                 pin_memory=False, drop_last=False, timeout=0,
+                 worker_init_fn=None):
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.collate_fn = collate_fn
+        self.pin_memory = pin_memory
+        self.drop_last = drop_last
+        self.timeout = timeout
+        self.worker_init_fn = worker_init_fn
+
+        if timeout < 0:
+            raise ValueError('timeout option should be non-negative')
+
+        if batch_sampler is not None:
+            if batch_size > 1 or shuffle or sampler is not None or drop_last:
+                raise ValueError('batch_sampler option is mutually exclusive '
+                                 'with batch_size, shuffle, sampler, and '
+                                 'drop_last')
+            self.batch_size = None
+            self.drop_last = None
 
+        if sampler is not None and shuffle:
+            raise ValueError('sampler option is mutually exclusive with '
+                             'shuffle')
 
-def _set_python_exit_flag():
-    global _python_exit_status
-    _python_exit_status = True
+        if self.num_workers < 0:
+            raise ValueError('num_workers option cannot be negative; '
+                             'use num_workers=0 to disable multiprocessing.')
 
-atexit.register(_set_python_exit_flag)
+        if batch_sampler is None:
+            if sampler is None:
+                if shuffle:
+                    sampler = RandomSampler(dataset)
+                else:
+                    sampler = SequentialSampler(dataset)
+            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+        self.sampler = sampler
+        self.batch_sampler = batch_sampler
+        self.__initialized = True
+
+    def __setattr__(self, attr, val):
+        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+            raise ValueError('{} attribute should not be set after {} is '
+                             'initialized'.format(attr, self.__class__.__name__))
+
+        super(DataLoader, self).__setattr__(attr, val)
+
+    def __iter__(self):
+        return _DataLoaderIter(self)
+
+    def __len__(self):
+        return len(self.batch_sampler)
 
 
 class _DataLoaderIter(object):
@@ -354,12 +184,13 @@ class _DataLoaderIter(object):
     #      Therefore, in this case, we actually need to prevent `__del__` from
     #      being executed, and rely on the automatic termination of daemonic
     #      children. Thus, we register an `atexit` hook that sets a global flag
-    #      `_python_exit_status`. Since `atexit` hooks are executed in reverse
-    #      order of registration, we are guaranteed that this flag is set before
-    #      library resources we use are freed. (Hooks freeing those resources
-    #      are registered at importing the Python core libraries at the top of
-    #      this file.) So in `__del__`, we check if `_python_exit_status` is set
-    #      or `None` (freed), and perform no-op if so.
+    #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the
+    #      reverse order of registration, we are guaranteed that this flag is
+    #      set before library resources we use are freed. (Hooks freeing those
+    #      resources are registered at importing the Python core libraries at
+    #      the top of this file.) So in `__del__`, we check if
+    #      `_utils.python_exit_status` is set or `None` (freed), and perform
+    #      no-op if so.
     #
     #      Another problem with `__del__` is also related to the library cleanup
     #      calls. When a process ends, it shuts the all its daemonic children
@@ -419,8 +250,8 @@ class _DataLoaderIter(object):
     #           For `.get()` calls where the sender(s) is not the workers, we
     #           guard them with timeouts, and check the status of the sender
     #           when timeout happens:
-    #             + in the workers, the `ManagerWatchdog` class checks the main
-    #               process status.
+    #             + in the workers, the `_utils.worker.ManagerWatchdog` class
+    #               checks the status of the main process.
     #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
     #               check `pin_memory_thread` status periodically until `.get()`
     #               returns or see that `pin_memory_thread` died.
@@ -545,7 +376,7 @@ class _DataLoaderIter(object):
                 index_queue = multiprocessing.Queue()
                 index_queue.cancel_join_thread()
                 w = multiprocessing.Process(
-                    target=_worker_loop,
+                    target=_utils.worker._worker_loop,
                     args=(self.dataset, index_queue,
                           self.worker_result_queue, self.done_event,
                           self.collate_fn, base_seed + i,
@@ -564,7 +395,7 @@ class _DataLoaderIter(object):
             if self.pin_memory:
                 self.data_queue = queue.Queue()
                 pin_memory_thread = threading.Thread(
-                    target=_pin_memory_loop,
+                    target=_utils.pin_memory._pin_memory_loop,
                     args=(self.worker_result_queue, self.data_queue,
                           torch.cuda.current_device(), self.done_event))
                 pin_memory_thread.daemon = True
@@ -575,8 +406,8 @@ class _DataLoaderIter(object):
             else:
                 self.data_queue = self.worker_result_queue
 
-            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
-            _set_SIGCHLD_handler()
+            _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
+            _utils.signal_handling._set_SIGCHLD_handler()
             self.worker_pids_set = True
 
             # prime the prefetch loop
@@ -598,7 +429,7 @@ class _DataLoaderIter(object):
         elif self.pin_memory:
             while self.pin_memory_thread.is_alive():
                 try:
-                    return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+                    return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
                 except queue.Empty:
                     continue
             else:
@@ -614,7 +445,7 @@ class _DataLoaderIter(object):
             indices = next(self.sample_iter)  # may raise StopIteration
             batch = self.collate_fn([self.dataset[i] for i in indices])
             if self.pin_memory:
-                batch = pin_memory_batch(batch)
+                batch = _utils.pin_memory.pin_memory_batch(batch)
             return batch
 
         # check if the next sample has already been generated
@@ -654,7 +485,7 @@ class _DataLoaderIter(object):
     def _process_next_batch(self, batch):
         self.rcvd_idx += 1
         self._put_indices()
-        if isinstance(batch, ExceptionWrapper):
+        if isinstance(batch, _utils.ExceptionWrapper):
             raise batch.exc_type(batch.exc_msg)
         return batch
 
@@ -669,7 +500,8 @@ class _DataLoaderIter(object):
     def _shutdown_workers(self):
         # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
         # the logic of this function.
-        if _python_exit_status is True or _python_exit_status is None:
+        python_exit_status = _utils.python_exit_status
+        if python_exit_status is True or python_exit_status is None:
             # See (2) of the note. If Python is shutting down, do no-op.
             return
         # Normal exit when last reference is gone / iterator is depleted.
@@ -679,7 +511,7 @@ class _DataLoaderIter(object):
             # Removes pids from the C side data structure first so worker
             # termination afterwards won't trigger false positive error report.
             if self.worker_pids_set:
-                _remove_worker_pids(id(self))
+                _utils.signal_handling._remove_worker_pids(id(self))
                 self.worker_pids_set = False
 
             self.done_event.set()
@@ -715,108 +547,3 @@ class _DataLoaderIter(object):
     def __del__(self):
         if self.num_workers > 0:
             self._shutdown_workers()
-
-
-class DataLoader(object):
-    r"""
-    Data loader. Combines a dataset and a sampler, and provides
-    single- or multi-process iterators over the dataset.
-
-    Arguments:
-        dataset (Dataset): dataset from which to load the data.
-        batch_size (int, optional): how many samples per batch to load
-            (default: ``1``).
-        shuffle (bool, optional): set to ``True`` to have the data reshuffled
-            at every epoch (default: ``False``).
-        sampler (Sampler, optional): defines the strategy to draw samples from
-            the dataset. If specified, ``shuffle`` must be False.
-        batch_sampler (Sampler, optional): like sampler, but returns a batch of
-            indices at a time. Mutually exclusive with :attr:`batch_size`,
-            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
-        num_workers (int, optional): how many subprocesses to use for data
-            loading. 0 means that the data will be loaded in the main process.
-            (default: ``0``)
-        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
-        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
-            into CUDA pinned memory before returning them.
-        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
-            if the dataset size is not divisible by the batch size. If ``False`` and
-            the size of dataset is not divisible by the batch size, then the last batch
-            will be smaller. (default: ``False``)
-        timeout (numeric, optional): if positive, the timeout value for collecting a batch
-            from workers. Should always be non-negative. (default: ``0``)
-        worker_init_fn (callable, optional): If not ``None``, this will be called on each
-            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
-            input, after seeding and before data loading. (default: ``None``)
-
-    .. note:: By default, each worker will have its PyTorch seed set to
-              ``base_seed + worker_id``, where ``base_seed`` is a long generated
-              by main process using its RNG. However, seeds for other libraies
-              may be duplicated upon initializing workers (w.g., NumPy), causing
-              each worker to return identical random numbers. (See
-              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
-              use :func:`torch.initial_seed()` to access the PyTorch seed for
-              each worker in :attr:`worker_init_fn`, and use it to set other
-              seeds before data loading.
-
-    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
-                 unpicklable object, e.g., a lambda function.
-    """
-
-    __initialized = False
-
-    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
-                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
-                 timeout=0, worker_init_fn=None):
-        self.dataset = dataset
-        self.batch_size = batch_size
-        self.num_workers = num_workers
-        self.collate_fn = collate_fn
-        self.pin_memory = pin_memory
-        self.drop_last = drop_last
-        self.timeout = timeout
-        self.worker_init_fn = worker_init_fn
-
-        if timeout < 0:
-            raise ValueError('timeout option should be non-negative')
-
-        if batch_sampler is not None:
-            if batch_size > 1 or shuffle or sampler is not None or drop_last:
-                raise ValueError('batch_sampler option is mutually exclusive '
-                                 'with batch_size, shuffle, sampler, and '
-                                 'drop_last')
-            self.batch_size = None
-            self.drop_last = None
-
-        if sampler is not None and shuffle:
-            raise ValueError('sampler option is mutually exclusive with '
-                             'shuffle')
-
-        if self.num_workers < 0:
-            raise ValueError('num_workers option cannot be negative; '
-                             'use num_workers=0 to disable multiprocessing.')
-
-        if batch_sampler is None:
-            if sampler is None:
-                if shuffle:
-                    sampler = RandomSampler(dataset)
-                else:
-                    sampler = SequentialSampler(dataset)
-            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
-
-        self.sampler = sampler
-        self.batch_sampler = batch_sampler
-        self.__initialized = True
-
-    def __setattr__(self, attr, val):
-        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
-            raise ValueError('{} attribute should not be set after {} is '
-                             'initialized'.format(attr, self.__class__.__name__))
-
-        super(DataLoader, self).__setattr__(attr, val)
-
-    def __iter__(self):
-        return _DataLoaderIter(self)
-
-    def __len__(self):
-        return len(self.batch_sampler)