Revert D13289919: [pytorch][PR] [DataLoader] Refactor dataloader.py
authorAiling Zhang <ailzhang@fb.com>
Wed, 5 Dec 2018 04:23:25 +0000 (20:23 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 04:25:16 +0000 (20:25 -0800)
Differential Revision:
D13289919

Original commit changeset: d701bc7bb48f

fbshipit-source-id: c350c491fefa98a0a7c0cf22cb832e78aeb15c3d

test/test_dataloader.py
torch/_six.py
torch/csrc/DataLoader.cpp
torch/utils/data/__init__.py
torch/utils/data/_utils/__init__.py [deleted file]
torch/utils/data/_utils/collate.py [deleted file]
torch/utils/data/_utils/pin_memory.py [deleted file]
torch/utils/data/_utils/signal_handling.py [deleted file]
torch/utils/data/_utils/worker.py [deleted file]
torch/utils/data/dataloader.py

index 9db11b0..e1ca129 100644 (file)
@@ -12,9 +12,9 @@ import unittest
 import subprocess
 import itertools
 from torch import multiprocessing as mp
-from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
-from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
+from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
 from torch.utils.data.dataset import random_split
+from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
 from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN,
                           skipIfRocm, load_tests)
 
@@ -788,16 +788,16 @@ class TestDataLoader(TestCase):
 
         # Should be a no-op
         arr = np.array(['a', 'b', 'c'])
-        _utils.collate.default_collate(arr)
+        default_collate(arr)
 
         arr = np.array([[['a', 'b', 'c']]])
-        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
+        self.assertRaises(TypeError, lambda: default_collate(arr))
 
         arr = np.array([object(), object(), object()])
-        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
+        self.assertRaises(TypeError, lambda: default_collate(arr))
 
         arr = np.array([[[object(), object(), object()]]])
-        self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
+        self.assertRaises(TypeError, lambda: default_collate(arr))
 
 
 class StringDataset(Dataset):
index 3a4c2ad..924e641 100644 (file)
@@ -51,12 +51,6 @@ else:
     FileNotFoundError = FileNotFoundError
 
 
-if PY2:
-    import Queue as queue
-else:
-    import queue
-
-
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass."""
     # This requires a bit of explanation: the basic idea is to make a dummy
index d75bcf6..c5cdf64 100644 (file)
@@ -1,10 +1,11 @@
 #include "DataLoader.h"
 
-// Together with `torch/utils/data/_utils/signal_handling.py`, the following
-// is an effort to do our best to provide some error message to users when a
-// worker dies due to error / critical signals.
-//
-// See NOTE [ Signal handling in multiprocessing data loading ] for more details.
+// In cases like DataLoader, if a worker process dies due to bus error/segfault
+// or just hang, the main process will hang waiting for data. This is difficult
+// to avoid on PyTorch side as it can be caused by limited shm, or other
+// libraries users call in the workers. The following methods is an effort to do
+// our best to provide some error message to users when such unfortunate events
+// happen.
 
 // TODO: The following don't work on Windows. Specifically, sigaction, waitid
 // calls, and SIGCHLD handler. Currently, dummy implementations are provided
@@ -127,7 +128,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) {
         // workers, and trigger this again.
         pid_set->clear();
         throw std::runtime_error(oss.str());
-      } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) {  // killed by signal
+      }  else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) {  // killed by signal
         std::ostringstream oss;
         oss << "DataLoader worker (pid " << worker_pid << ") is killed "
             << "by signal: " << strsignal(infop.si_status) << ". ";
index ee58707..05c94d1 100644 (file)
@@ -1,3 +1,4 @@
+
 from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
 from .distributed import DistributedSampler
 from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split
diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py
deleted file mode 100644 (file)
index 05b2b65..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-r"""Utility classes & functions for data loading. Code in this folder is mostly
-used by ../dataloder.py.
-
-A lot of multiprocessing is used in data loading, which only supports running
-functions defined in global environment (py2 can't serialize static methods).
-Therefore, for code tidiness we put these functions into different files in this
-folder.
-"""
-
-import sys
-import traceback
-import atexit
-
-
-IS_WINDOWS = sys.platform == "win32"
-
-
-# NOTE [ Python Traceback Reference Cycle Problem ]
-#
-# When using sys.exc_info(), it is important to **not** store the exc_info[2],
-# which is the traceback, because otherwise you will run into the traceback
-# reference cycle problem, i.e., the traceback holding reference to the frame,
-# and the frame (which holds reference to all the object in its temporary scope)
-# holding reference the traceback.
-
-
-class ExceptionWrapper(object):
-    r"""Wraps an exception plus traceback to communicate across threads"""
-    def __init__(self, exc_info):
-        # It is important that we don't store exc_info, see
-        # NOTE [ Python Traceback Reference Cycle Problem ]
-        self.exc_type = exc_info[0]
-        self.exc_msg = "".join(traceback.format_exception(*exc_info))
-
-
-MP_STATUS_CHECK_INTERVAL = 5.0
-r"""Interval (in seconds) to check status of processes to avoid hanging in
-    multiprocessing data loading. This is mainly used in getting data from
-    another process, in which case we need to periodically check whether the
-    sender is alive to prevent hanging."""
-
-
-python_exit_status = False
-r"""Whether Python is shutting down. This flag is guaranteed to be set before
-the Python core library resources are freed, but Python may already be exiting
-for some time when this is set.
-
-Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
-hook in Python 3.7 multiprocessing library:
-https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
-"""
-
-
-def _set_python_exit_flag():
-    global python_exit_status
-    python_exit_status = True
-
-atexit.register(_set_python_exit_flag)
-
-
-from . import worker, signal_handling, pin_memory, collate
diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py
deleted file mode 100644 (file)
index e0823b2..0000000
+++ /dev/null
@@ -1,68 +0,0 @@
-r""""Contains definitions of the methods used by the _DataLoaderIter workers to
-collate samples fetched from dataset into Tensor(s).
-
-These **needs** to be in global scope since Py2 doesn't support serializing
-static methods.
-"""
-
-import torch
-import re
-from torch._six import container_abcs, string_classes, int_classes
-
-_use_shared_memory = False
-r"""Whether to use shared memory in default_collate"""
-
-np_str_obj_array_pattern = re.compile(r'[SaUO]')
-
-error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
-
-numpy_type_map = {
-    'float64': torch.DoubleTensor,
-    'float32': torch.FloatTensor,
-    'float16': torch.HalfTensor,
-    'int64': torch.LongTensor,
-    'int32': torch.IntTensor,
-    'int16': torch.ShortTensor,
-    'int8': torch.CharTensor,
-    'uint8': torch.ByteTensor,
-}
-
-
-def default_collate(batch):
-    r"""Puts each data field into a tensor with outer dimension batch size"""
-
-    elem_type = type(batch[0])
-    if isinstance(batch[0], torch.Tensor):
-        out = None
-        if _use_shared_memory:
-            # If we're in a background process, concatenate directly into a
-            # shared memory tensor to avoid an extra copy
-            numel = sum([x.numel() for x in batch])
-            storage = batch[0].storage()._new_shared(numel)
-            out = batch[0].new(storage)
-        return torch.stack(batch, 0, out=out)
-    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
-            and elem_type.__name__ != 'string_':
-        elem = batch[0]
-        if elem_type.__name__ == 'ndarray':
-            # array of string classes and object
-            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
-                raise TypeError(error_msg_fmt.format(elem.dtype))
-
-            return torch.stack([torch.from_numpy(b) for b in batch], 0)
-        if elem.shape == ():  # scalars
-            py_type = float if elem.dtype.name.startswith('float') else int
-            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
-    elif isinstance(batch[0], int_classes):
-        return torch.LongTensor(batch)
-    elif isinstance(batch[0], float):
-        return torch.DoubleTensor(batch)
-    elif isinstance(batch[0], string_classes):
-        return batch
-    elif isinstance(batch[0], container_abcs.Mapping):
-        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
-    elif isinstance(batch[0], container_abcs.Sequence):
-        transposed = zip(*batch)
-        return [default_collate(samples) for samples in transposed]
-
-    raise TypeError((error_msg_fmt.format(type(batch[0]))))
diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py
deleted file mode 100644 (file)
index 8403b42..0000000
+++ /dev/null
@@ -1,57 +0,0 @@
-r""""Contains definitions of the methods used by the _DataLoaderIter to put
-fetched tensors into pinned memory.
-
-These **needs** to be in global scope since Py2 doesn't support serializing
-static methods.
-"""
-
-import torch
-from torch._six import queue, container_abcs, string_classes
-from . import collate, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper
-
-
-def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
-    torch.cuda.set_device(device_id)
-
-    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
-    # logic of this function.
-    while True:
-        try:
-            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
-        except queue.Empty:
-            continue
-        except Exception:
-            if done_event.is_set():
-                # Weird things can happen when shutting down, e.g., fd being
-                # closed when tensors are shared via fds.
-                break
-            raise
-        if r is None:
-            assert done_event.is_set()
-            return
-        elif done_event.is_set():
-            # Haven't seen the final signal yet. Keep getting until None.
-            continue
-        elif isinstance(r[1], ExceptionWrapper):
-            out_queue.put(r)
-        else:
-            idx, batch = r
-            try:
-                batch = pin_memory_batch(batch)
-            except Exception:
-                out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
-            else:
-                out_queue.put((idx, batch))
-
-
-def pin_memory_batch(batch):
-    if isinstance(batch, torch.Tensor):
-        return batch.pin_memory()
-    elif isinstance(batch, string_classes):
-        return batch
-    elif isinstance(batch, container_abcs.Mapping):
-        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
-    elif isinstance(batch, container_abcs.Sequence):
-        return [pin_memory_batch(sample) for sample in batch]
-    else:
-        return batch
diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py
deleted file mode 100644 (file)
index 74bb2ba..0000000
+++ /dev/null
@@ -1,70 +0,0 @@
-r""""Signal handling for multiprocessing data loading.
-
-NOTE [ Signal handling in multiprocessing data loading ]
-
-In cases like DataLoader, if a worker process dies due to bus error/segfault
-or just hang, the main process will hang waiting for data. This is difficult
-to avoid on PyTorch side as it can be caused by limited shm, or other
-libraries users call in the workers. In this file and `DataLoader.cpp`, we make
-our best effort to provide some error message to users when such unfortunate
-events happen.
-
-When a _DataLoaderIter starts worker processes, their pids are registered in a
-defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ]
-via `_update_worker_pids`.
-
-When an error happens in a worker process, the main process received a SIGCHLD,
-and Python will eventually call the handler registered below
-(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
-call checks all registered worker pids and raise proper error message to
-prevent main process from hanging waiting for data from worker.
-
-Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
-`_set_worker_signal_handlers` is called to register critical signal handlers
-(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
-message to stderr before triggering the default handler. So a message will also
-be printed from the worker process when it is killed by such signals.
-
-See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
-this signal handling design and other mechanism we implement to make our
-multiprocessing data loading robust to errors.
-"""
-
-import signal
-import threading
-import torch
-from torch._C import _update_worker_pids, _remove_worker_pids, \
-    _error_if_any_worker_fails, _set_worker_signal_handlers
-from . import IS_WINDOWS
-
-
-_SIGCHLD_handler_set = False
-r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
-handler needs to be set for all DataLoaders in a process."""
-
-
-def _set_SIGCHLD_handler():
-    # Windows doesn't support SIGCHLD handler
-    if IS_WINDOWS:
-        return
-    # can't set signal in child threads
-    if not isinstance(threading.current_thread(), threading._MainThread):
-        return
-    global _SIGCHLD_handler_set
-    if _SIGCHLD_handler_set:
-        return
-    previous_handler = signal.getsignal(signal.SIGCHLD)
-    if not callable(previous_handler):
-        # This doesn't catch default handler, but SIGCHLD default handler is a
-        # no-op.
-        previous_handler = None
-
-    def handler(signum, frame):
-        # This following call uses `waitid` with WNOHANG from C side. Therefore,
-        # Python can still get and update the process status successfully.
-        _error_if_any_worker_fails()
-        if previous_handler is not None:
-            previous_handler(signum, frame)
-
-    signal.signal(signal.SIGCHLD, handler)
-    _SIGCHLD_handler_set = True
diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py
deleted file mode 100644 (file)
index 2cddf3a..0000000
+++ /dev/null
@@ -1,108 +0,0 @@
-r""""Contains definitions of the methods used by the _DataLoaderIter workers.
-
-These **needs** to be in global scope since Py2 doesn't support serializing
-static methods.
-"""
-
-import torch
-import random
-import sys
-import os
-from torch._six import queue
-from . import collate, signal_handling, MP_STATUS_CHECK_INTERVAL, \
-    ExceptionWrapper, IS_WINDOWS
-
-if IS_WINDOWS:
-    import ctypes
-    from ctypes.wintypes import DWORD, BOOL, HANDLE
-
-    # On Windows, the parent ID of the worker process remains unchanged when the manager process
-    # is gone, and the only way to check it through OS is to let the worker have a process handle
-    # of the manager and ask if the process status has changed.
-    class ManagerWatchdog(object):
-        def __init__(self):
-            self.manager_pid = os.getppid()
-
-            self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
-            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
-            self.kernel32.OpenProcess.restype = HANDLE
-            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
-            self.kernel32.WaitForSingleObject.restype = DWORD
-
-            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
-            SYNCHRONIZE = 0x00100000
-            self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
-
-            if not self.manager_handle:
-                raise ctypes.WinError(ctypes.get_last_error())
-
-            self.manager_dead = False
-
-        def is_alive(self):
-            if not self.manager_dead:
-                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
-                self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
-            return not self.manager_dead
-else:
-    class ManagerWatchdog(object):
-        def __init__(self):
-            self.manager_pid = os.getppid()
-            self.manager_dead = False
-
-        def is_alive(self):
-            if not self.manager_dead:
-                self.manager_dead = os.getppid() != self.manager_pid
-            return not self.manager_dead
-
-
-def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
-    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
-    # logic of this function.
-
-    try:
-        collate._use_shared_memory = True
-
-        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
-        # module's handlers are executed after Python returns from C low-level
-        # handlers, likely when the same fatal signal happened again already.
-        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
-        signal_handling._set_worker_signal_handlers()
-
-        torch.set_num_threads(1)
-        random.seed(seed)
-        torch.manual_seed(seed)
-
-        data_queue.cancel_join_thread()
-
-        if init_fn is not None:
-            init_fn(worker_id)
-
-        watchdog = ManagerWatchdog()
-
-        while watchdog.is_alive():
-            try:
-                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
-            except queue.Empty:
-                continue
-            if r is None:
-                # Received the final signal
-                assert done_event.is_set()
-                return
-            elif done_event.is_set():
-                # Done event is set. But I haven't received the final signal
-                # (None) yet. I will keep continuing until get it, and skip the
-                # processing steps.
-                continue
-            idx, batch_indices = r
-            try:
-                samples = collate_fn([dataset[i] for i in batch_indices])
-            except Exception:
-                # It is important that we don't store exc_info in a variable,
-                # see NOTE [ Python Traceback Reference Cycle Problem ]
-                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
-            else:
-                data_queue.put((idx, samples))
-                del samples
-    except KeyboardInterrupt:
-        # Main process will raise KeyboardInterrupt anyways.
-        pass
index 6d5cb81..c1ee0eb 100644 (file)
-r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
-
-To support these two classes, in `./_utils` we define many utility methods and
-functions to be run in multiprocessing. E.g., the data loading worker loop is
-in `./_utils/worker.py`.
-"""
-
+import random
 import torch
 import torch.multiprocessing as multiprocessing
+from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
+    _remove_worker_pids, _error_if_any_worker_fails
 from . import SequentialSampler, RandomSampler, BatchSampler
-from . import _utils
+import signal
+import functools
+from torch._six import container_abcs
+import re
+import sys
 import threading
-from torch._six import queue
-
-
-# This function used to be defined in this file. However, it was moved to
-# _utils/collate.py. Although it is rather hard to access this from user land
-# (one has to explicitly directly `import torch.utils.data.dataloader`), there
-# probably is user code out there using it. This aliasing maintains BC in this
-# aspect.
-default_collate = _utils.collate.default_collate
-
-
-class DataLoader(object):
-    r"""
-    Data loader. Combines a dataset and a sampler, and provides
-    single- or multi-process iterators over the dataset.
-
-    Arguments:
-        dataset (Dataset): dataset from which to load the data.
-        batch_size (int, optional): how many samples per batch to load
-            (default: ``1``).
-        shuffle (bool, optional): set to ``True`` to have the data reshuffled
-            at every epoch (default: ``False``).
-        sampler (Sampler, optional): defines the strategy to draw samples from
-            the dataset. If specified, ``shuffle`` must be False.
-        batch_sampler (Sampler, optional): like sampler, but returns a batch of
-            indices at a time. Mutually exclusive with :attr:`batch_size`,
-            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
-        num_workers (int, optional): how many subprocesses to use for data
-            loading. 0 means that the data will be loaded in the main process.
-            (default: ``0``)
-        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
-        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
-            into CUDA pinned memory before returning them.
-        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
-            if the dataset size is not divisible by the batch size. If ``False`` and
-            the size of dataset is not divisible by the batch size, then the last batch
-            will be smaller. (default: ``False``)
-        timeout (numeric, optional): if positive, the timeout value for collecting a batch
-            from workers. Should always be non-negative. (default: ``0``)
-        worker_init_fn (callable, optional): If not ``None``, this will be called on each
-            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
-            input, after seeding and before data loading. (default: ``None``)
-
-    .. note:: By default, each worker will have its PyTorch seed set to
-              ``base_seed + worker_id``, where ``base_seed`` is a long generated
-              by main process using its RNG. However, seeds for other libraies
-              may be duplicated upon initializing workers (w.g., NumPy), causing
-              each worker to return identical random numbers. (See
-              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
-              use :func:`torch.initial_seed()` to access the PyTorch seed for
-              each worker in :attr:`worker_init_fn`, and use it to set other
-              seeds before data loading.
-
-    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
-                 unpicklable object, e.g., a lambda function.
-    """
-
-    __initialized = False
+import traceback
+import os
+import time
+import atexit
+from torch._six import string_classes, int_classes, FileNotFoundError
+
+IS_WINDOWS = sys.platform == "win32"
+if IS_WINDOWS:
+    import ctypes
+    from ctypes.wintypes import DWORD, BOOL, HANDLE
+
+if sys.version_info[0] == 2:
+    import Queue as queue
+else:
+    import queue
+
+
+# NOTE [ Python Traceback Reference Cycle Problem ]
+#
+# When using sys.exc_info(), it is important to **not** store the exc_info[2],
+# which is the traceback, because otherwise you will run into the traceback
+# reference cycle problem, i.e., the traceback holding reference to the frame,
+# and the frame (which holds reference to all the object in its temporary scope)
+# holding reference the traceback.
+
+
+class ExceptionWrapper(object):
+    r"""Wraps an exception plus traceback to communicate across threads"""
+    def __init__(self, exc_info):
+        # It is important that we don't store exc_info, see
+        # NOTE [ Python Traceback Reference Cycle Problem ]
+        self.exc_type = exc_info[0]
+        self.exc_msg = "".join(traceback.format_exception(*exc_info))
+
+
+_use_shared_memory = False
+r"""Whether to use shared memory in default_collate"""
+
+MP_STATUS_CHECK_INTERVAL = 5.0
+r"""Interval (in seconds) to check status of processes to avoid hanging in
+    multiprocessing data loading. This is mainly used in getting data from
+    another process, in which case we need to periodically check whether the
+    sender is alive to prevent hanging."""
+
+if IS_WINDOWS:
+    # On Windows, the parent ID of the worker process remains unchanged when the manager process
+    # is gone, and the only way to check it through OS is to let the worker have a process handle
+    # of the manager and ask if the process status has changed.
+    class ManagerWatchdog(object):
+        def __init__(self):
+            self.manager_pid = os.getppid()
+
+            self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
+            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
+            self.kernel32.OpenProcess.restype = HANDLE
+            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
+            self.kernel32.WaitForSingleObject.restype = DWORD
+
+            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
+            SYNCHRONIZE = 0x00100000
+            self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
+
+            if not self.manager_handle:
+                raise ctypes.WinError(ctypes.get_last_error())
+
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+                self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+            return not self.manager_dead
+else:
+    class ManagerWatchdog(object):
+        def __init__(self):
+            self.manager_pid = os.getppid()
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                self.manager_dead = os.getppid() != self.manager_pid
+            return not self.manager_dead
+
+
+def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+
+    try:
+        global _use_shared_memory
+        _use_shared_memory = True
+
+        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+        # module's handlers are executed after Python returns from C low-level
+        # handlers, likely when the same fatal signal happened again already.
+        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
+        _set_worker_signal_handlers()
+
+        torch.set_num_threads(1)
+        random.seed(seed)
+        torch.manual_seed(seed)
+
+        data_queue.cancel_join_thread()
+
+        if init_fn is not None:
+            init_fn(worker_id)
+
+        watchdog = ManagerWatchdog()
+
+        while watchdog.is_alive():
+            try:
+                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+            except queue.Empty:
+                continue
+            if r is None:
+                # Received the final signal
+                assert done_event.is_set()
+                return
+            elif done_event.is_set():
+                # Done event is set. But I haven't received the final signal
+                # (None) yet. I will keep continuing until get it, and skip the
+                # processing steps.
+                continue
+            idx, batch_indices = r
+            try:
+                samples = collate_fn([dataset[i] for i in batch_indices])
+            except Exception:
+                # It is important that we don't store exc_info in a variable,
+                # see NOTE [ Python Traceback Reference Cycle Problem ]
+                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+            else:
+                data_queue.put((idx, samples))
+                del samples
+    except KeyboardInterrupt:
+        # Main process will raise KeyboardInterrupt anyways.
+        pass
+
+
+def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
+    torch.cuda.set_device(device_id)
+
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+    while True:
+        try:
+            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+        except queue.Empty:
+            continue
+        except Exception:
+            if done_event.is_set():
+                # Weird things can happen when shutting down, e.g., fd being
+                # closed when tensors are shared via fds.
+                break
+            raise
+        if r is None:
+            assert done_event.is_set()
+            return
+        elif done_event.is_set():
+            # Haven't seen the final signal yet. Keep getting until None.
+            continue
+        elif isinstance(r[1], ExceptionWrapper):
+            out_queue.put(r)
+        else:
+            idx, batch = r
+            try:
+                batch = pin_memory_batch(batch)
+            except Exception:
+                out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+            else:
+                out_queue.put((idx, batch))
+
+numpy_type_map = {
+    'float64': torch.DoubleTensor,
+    'float32': torch.FloatTensor,
+    'float16': torch.HalfTensor,
+    'int64': torch.LongTensor,
+    'int32': torch.IntTensor,
+    'int16': torch.ShortTensor,
+    'int8': torch.CharTensor,
+    'uint8': torch.ByteTensor,
+}
+
+
+def default_collate(batch):
+    r"""Puts each data field into a tensor with outer dimension batch size"""
+
+    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
+    elem_type = type(batch[0])
+    if isinstance(batch[0], torch.Tensor):
+        out = None
+        if _use_shared_memory:
+            # If we're in a background process, concatenate directly into a
+            # shared memory tensor to avoid an extra copy
+            numel = sum([x.numel() for x in batch])
+            storage = batch[0].storage()._new_shared(numel)
+            out = batch[0].new(storage)
+        return torch.stack(batch, 0, out=out)
+    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+            and elem_type.__name__ != 'string_':
+        elem = batch[0]
+        if elem_type.__name__ == 'ndarray':
+            # array of string classes and object
+            if re.search('[SaUO]', elem.dtype.str) is not None:
+                raise TypeError(error_msg.format(elem.dtype))
+
+            return torch.stack([torch.from_numpy(b) for b in batch], 0)
+        if elem.shape == ():  # scalars
+            py_type = float if elem.dtype.name.startswith('float') else int
+            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
+    elif isinstance(batch[0], int_classes):
+        return torch.LongTensor(batch)
+    elif isinstance(batch[0], float):
+        return torch.DoubleTensor(batch)
+    elif isinstance(batch[0], string_classes):
+        return batch
+    elif isinstance(batch[0], container_abcs.Mapping):
+        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
+    elif isinstance(batch[0], container_abcs.Sequence):
+        transposed = zip(*batch)
+        return [default_collate(samples) for samples in transposed]
 
-    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
-                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
-                 pin_memory=False, drop_last=False, timeout=0,
-                 worker_init_fn=None):
-        self.dataset = dataset
-        self.batch_size = batch_size
-        self.num_workers = num_workers
-        self.collate_fn = collate_fn
-        self.pin_memory = pin_memory
-        self.drop_last = drop_last
-        self.timeout = timeout
-        self.worker_init_fn = worker_init_fn
+    raise TypeError((error_msg.format(type(batch[0]))))
 
-        if timeout < 0:
-            raise ValueError('timeout option should be non-negative')
 
-        if batch_sampler is not None:
-            if batch_size > 1 or shuffle or sampler is not None or drop_last:
-                raise ValueError('batch_sampler option is mutually exclusive '
-                                 'with batch_size, shuffle, sampler, and '
-                                 'drop_last')
-            self.batch_size = None
-            self.drop_last = None
+def pin_memory_batch(batch):
+    if isinstance(batch, torch.Tensor):
+        return batch.pin_memory()
+    elif isinstance(batch, string_classes):
+        return batch
+    elif isinstance(batch, container_abcs.Mapping):
+        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
+    elif isinstance(batch, container_abcs.Sequence):
+        return [pin_memory_batch(sample) for sample in batch]
+    else:
+        return batch
 
-        if sampler is not None and shuffle:
-            raise ValueError('sampler option is mutually exclusive with '
-                             'shuffle')
 
-        if self.num_workers < 0:
-            raise ValueError('num_workers option cannot be negative; '
-                             'use num_workers=0 to disable multiprocessing.')
+_SIGCHLD_handler_set = False
+r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
+handler needs to be set for all DataLoaders in a process."""
+
+
+def _set_SIGCHLD_handler():
+    # Windows doesn't support SIGCHLD handler
+    if sys.platform == 'win32':
+        return
+    # can't set signal in child threads
+    if not isinstance(threading.current_thread(), threading._MainThread):
+        return
+    global _SIGCHLD_handler_set
+    if _SIGCHLD_handler_set:
+        return
+    previous_handler = signal.getsignal(signal.SIGCHLD)
+    if not callable(previous_handler):
+        # This doesn't catch default handler, but SIGCHLD default handler is a
+        # no-op.
+        previous_handler = None
+
+    def handler(signum, frame):
+        # This following call uses `waitid` with WNOHANG from C side. Therefore,
+        # Python can still get and update the process status successfully.
+        _error_if_any_worker_fails()
+        if previous_handler is not None:
+            previous_handler(signum, frame)
+
+    signal.signal(signal.SIGCHLD, handler)
+    _SIGCHLD_handler_set = True
+
+
+_python_exit_status = False
+r"""Whether Python is shutting down. This flag is guaranteed to be set before
+the Python core library resources are freed, but Python may already be exiting
+for some time when this is set.
+
+Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
+hook in Python 3.7 multiprocessing library:
+https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
+"""
 
-        if batch_sampler is None:
-            if sampler is None:
-                if shuffle:
-                    sampler = RandomSampler(dataset)
-                else:
-                    sampler = SequentialSampler(dataset)
-            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
 
-        self.sampler = sampler
-        self.batch_sampler = batch_sampler
-        self.__initialized = True
+def _set_python_exit_flag():
+    global _python_exit_status
+    _python_exit_status = True
 
-    def __setattr__(self, attr, val):
-        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
-            raise ValueError('{} attribute should not be set after {} is '
-                             'initialized'.format(attr, self.__class__.__name__))
-
-        super(DataLoader, self).__setattr__(attr, val)
-
-    def __iter__(self):
-        return _DataLoaderIter(self)
-
-    def __len__(self):
-        return len(self.batch_sampler)
+atexit.register(_set_python_exit_flag)
 
 
 class _DataLoaderIter(object):
@@ -184,13 +354,12 @@ class _DataLoaderIter(object):
     #      Therefore, in this case, we actually need to prevent `__del__` from
     #      being executed, and rely on the automatic termination of daemonic
     #      children. Thus, we register an `atexit` hook that sets a global flag
-    #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the
-    #      reverse order of registration, we are guaranteed that this flag is
-    #      set before library resources we use are freed. (Hooks freeing those
-    #      resources are registered at importing the Python core libraries at
-    #      the top of this file.) So in `__del__`, we check if
-    #      `_utils.python_exit_status` is set or `None` (freed), and perform
-    #      no-op if so.
+    #      `_python_exit_status`. Since `atexit` hooks are executed in reverse
+    #      order of registration, we are guaranteed that this flag is set before
+    #      library resources we use are freed. (Hooks freeing those resources
+    #      are registered at importing the Python core libraries at the top of
+    #      this file.) So in `__del__`, we check if `_python_exit_status` is set
+    #      or `None` (freed), and perform no-op if so.
     #
     #      Another problem with `__del__` is also related to the library cleanup
     #      calls. When a process ends, it shuts the all its daemonic children
@@ -250,8 +419,8 @@ class _DataLoaderIter(object):
     #           For `.get()` calls where the sender(s) is not the workers, we
     #           guard them with timeouts, and check the status of the sender
     #           when timeout happens:
-    #             + in the workers, the `_utils.worker.ManagerWatchdog` class
-    #               checks the status of the main process.
+    #             + in the workers, the `ManagerWatchdog` class checks the main
+    #               process status.
     #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
     #               check `pin_memory_thread` status periodically until `.get()`
     #               returns or see that `pin_memory_thread` died.
@@ -376,7 +545,7 @@ class _DataLoaderIter(object):
                 index_queue = multiprocessing.Queue()
                 index_queue.cancel_join_thread()
                 w = multiprocessing.Process(
-                    target=_utils.worker._worker_loop,
+                    target=_worker_loop,
                     args=(self.dataset, index_queue,
                           self.worker_result_queue, self.done_event,
                           self.collate_fn, base_seed + i,
@@ -395,7 +564,7 @@ class _DataLoaderIter(object):
             if self.pin_memory:
                 self.data_queue = queue.Queue()
                 pin_memory_thread = threading.Thread(
-                    target=_utils.pin_memory._pin_memory_loop,
+                    target=_pin_memory_loop,
                     args=(self.worker_result_queue, self.data_queue,
                           torch.cuda.current_device(), self.done_event))
                 pin_memory_thread.daemon = True
@@ -406,8 +575,8 @@ class _DataLoaderIter(object):
             else:
                 self.data_queue = self.worker_result_queue
 
-            _utils.signal_handling._update_worker_pids(id(self), tuple(w.pid for w in self.workers))
-            _utils.signal_handling._set_SIGCHLD_handler()
+            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
+            _set_SIGCHLD_handler()
             self.worker_pids_set = True
 
             # prime the prefetch loop
@@ -429,7 +598,7 @@ class _DataLoaderIter(object):
         elif self.pin_memory:
             while self.pin_memory_thread.is_alive():
                 try:
-                    return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
+                    return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
                 except queue.Empty:
                     continue
             else:
@@ -445,7 +614,7 @@ class _DataLoaderIter(object):
             indices = next(self.sample_iter)  # may raise StopIteration
             batch = self.collate_fn([self.dataset[i] for i in indices])
             if self.pin_memory:
-                batch = _utils.pin_memory.pin_memory_batch(batch)
+                batch = pin_memory_batch(batch)
             return batch
 
         # check if the next sample has already been generated
@@ -485,7 +654,7 @@ class _DataLoaderIter(object):
     def _process_next_batch(self, batch):
         self.rcvd_idx += 1
         self._put_indices()
-        if isinstance(batch, _utils.ExceptionWrapper):
+        if isinstance(batch, ExceptionWrapper):
             raise batch.exc_type(batch.exc_msg)
         return batch
 
@@ -500,8 +669,7 @@ class _DataLoaderIter(object):
     def _shutdown_workers(self):
         # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
         # the logic of this function.
-        python_exit_status = _utils.python_exit_status
-        if python_exit_status is True or python_exit_status is None:
+        if _python_exit_status is True or _python_exit_status is None:
             # See (2) of the note. If Python is shutting down, do no-op.
             return
         # Normal exit when last reference is gone / iterator is depleted.
@@ -511,7 +679,7 @@ class _DataLoaderIter(object):
             # Removes pids from the C side data structure first so worker
             # termination afterwards won't trigger false positive error report.
             if self.worker_pids_set:
-                _utils.signal_handling._remove_worker_pids(id(self))
+                _remove_worker_pids(id(self))
                 self.worker_pids_set = False
 
             self.done_event.set()
@@ -547,3 +715,108 @@ class _DataLoaderIter(object):
     def __del__(self):
         if self.num_workers > 0:
             self._shutdown_workers()
+
+
+class DataLoader(object):
+    r"""
+    Data loader. Combines a dataset and a sampler, and provides
+    single- or multi-process iterators over the dataset.
+
+    Arguments:
+        dataset (Dataset): dataset from which to load the data.
+        batch_size (int, optional): how many samples per batch to load
+            (default: ``1``).
+        shuffle (bool, optional): set to ``True`` to have the data reshuffled
+            at every epoch (default: ``False``).
+        sampler (Sampler, optional): defines the strategy to draw samples from
+            the dataset. If specified, ``shuffle`` must be False.
+        batch_sampler (Sampler, optional): like sampler, but returns a batch of
+            indices at a time. Mutually exclusive with :attr:`batch_size`,
+            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
+        num_workers (int, optional): how many subprocesses to use for data
+            loading. 0 means that the data will be loaded in the main process.
+            (default: ``0``)
+        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
+        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
+            into CUDA pinned memory before returning them.
+        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+            if the dataset size is not divisible by the batch size. If ``False`` and
+            the size of dataset is not divisible by the batch size, then the last batch
+            will be smaller. (default: ``False``)
+        timeout (numeric, optional): if positive, the timeout value for collecting a batch
+            from workers. Should always be non-negative. (default: ``0``)
+        worker_init_fn (callable, optional): If not ``None``, this will be called on each
+            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+            input, after seeding and before data loading. (default: ``None``)
+
+    .. note:: By default, each worker will have its PyTorch seed set to
+              ``base_seed + worker_id``, where ``base_seed`` is a long generated
+              by main process using its RNG. However, seeds for other libraies
+              may be duplicated upon initializing workers (w.g., NumPy), causing
+              each worker to return identical random numbers. (See
+              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
+              use :func:`torch.initial_seed()` to access the PyTorch seed for
+              each worker in :attr:`worker_init_fn`, and use it to set other
+              seeds before data loading.
+
+    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
+                 unpicklable object, e.g., a lambda function.
+    """
+
+    __initialized = False
+
+    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
+                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
+                 timeout=0, worker_init_fn=None):
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.num_workers = num_workers
+        self.collate_fn = collate_fn
+        self.pin_memory = pin_memory
+        self.drop_last = drop_last
+        self.timeout = timeout
+        self.worker_init_fn = worker_init_fn
+
+        if timeout < 0:
+            raise ValueError('timeout option should be non-negative')
+
+        if batch_sampler is not None:
+            if batch_size > 1 or shuffle or sampler is not None or drop_last:
+                raise ValueError('batch_sampler option is mutually exclusive '
+                                 'with batch_size, shuffle, sampler, and '
+                                 'drop_last')
+            self.batch_size = None
+            self.drop_last = None
+
+        if sampler is not None and shuffle:
+            raise ValueError('sampler option is mutually exclusive with '
+                             'shuffle')
+
+        if self.num_workers < 0:
+            raise ValueError('num_workers option cannot be negative; '
+                             'use num_workers=0 to disable multiprocessing.')
+
+        if batch_sampler is None:
+            if sampler is None:
+                if shuffle:
+                    sampler = RandomSampler(dataset)
+                else:
+                    sampler = SequentialSampler(dataset)
+            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+        self.sampler = sampler
+        self.batch_sampler = batch_sampler
+        self.__initialized = True
+
+    def __setattr__(self, attr, val):
+        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+            raise ValueError('{} attribute should not be set after {} is '
+                             'initialized'.format(attr, self.__class__.__name__))
+
+        super(DataLoader, self).__setattr__(attr, val)
+
+    def __iter__(self):
+        return _DataLoaderIter(self)
+
+    def __len__(self):
+        return len(self.batch_sampler)