From: SsnL Date: Wed, 19 Dec 2018 20:26:44 +0000 (-0800) Subject: Refactor dataloader.py (#15331) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2166 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9217bde807115bf8e161dc54faeed0851a247780;p=platform%2Fupstream%2Fpytorch.git Refactor dataloader.py (#15331) Summary: Same as #14668, and was approved there. ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you! Below is the original description at #14668: As I am working on tasks in https://github.com/pytorch/pytorch/issues/13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse. So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this. No functionality is changed, except that I added `torch._six.queue`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15331 Reviewed By: yf225 Differential Revision: D13503120 Pulled By: ailzhang fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e --- diff --git a/test/test_dataloader.py b/test/test_dataloader.py index e1ca129..9db11b0 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -12,9 +12,9 @@ import unittest import subprocess import itertools from torch import multiprocessing as mp -from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset +from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset +from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split -from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests) @@ -788,16 +788,16 @@ class TestDataLoader(TestCase): # Should be a no-op arr = np.array(['a', 'b', 'c']) - default_collate(arr) + _utils.collate.default_collate(arr) arr = np.array([[['a', 'b', 'c']]]) - self.assertRaises(TypeError, lambda: default_collate(arr)) + self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) arr = np.array([object(), object(), object()]) - self.assertRaises(TypeError, lambda: default_collate(arr)) + self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) arr = np.array([[[object(), object(), object()]]]) - self.assertRaises(TypeError, lambda: default_collate(arr)) + self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) class StringDataset(Dataset): diff --git a/torch/_six.py b/torch/_six.py index 924e641..3a4c2ad 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -51,6 +51,12 @@ else: FileNotFoundError = FileNotFoundError +if PY2: + import Queue as queue +else: + import queue + + def with_metaclass(meta, *bases): """Create a base class with a metaclass.""" # This requires a bit of explanation: the basic idea is to make a dummy diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp index 0b23b4d..beb7bf9 100644 --- a/torch/csrc/DataLoader.cpp +++ b/torch/csrc/DataLoader.cpp @@ -1,11 +1,10 @@ #include -// In cases like DataLoader, if a worker process dies due to bus error/segfault -// or just hang, the main process will hang waiting for data. This is difficult -// to avoid on PyTorch side as it can be caused by limited shm, or other -// libraries users call in the workers. The following methods is an effort to do -// our best to provide some error message to users when such unfortunate events -// happen. +// Together with `torch/utils/data/_utils/signal_handling.py`, the following +// is an effort to do our best to provide some error message to users when a +// worker dies due to error / critical signals. +// +// See NOTE [ Signal handling in multiprocessing data loading ] for more details. // TODO: The following don't work on Windows. Specifically, sigaction, waitid // calls, and SIGCHLD handler. Currently, dummy implementations are provided @@ -128,7 +127,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { // workers, and trigger this again. pid_set->clear(); throw std::runtime_error(oss.str()); - } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal + } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " << "by signal: " << strsignal(infop.si_status) << ". "; @@ -145,18 +144,18 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple // of pids we are interested in. -static PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) { +static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) { HANDLE_TH_ERRORS if (PyTuple_GET_SIZE(args) != 2) { - throw TypeError("_update_worker_pids expectes exactly 2 arguments."); + throw TypeError("_set_worker_pids expects exactly 2 arguments."); } int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); if (worker_pids.find(key) != worker_pids.end()) { - throw ValueError("_update_worker_pids should be called only once for each _DataLoaderIter."); + throw ValueError("_set_worker_pids should be called only once for each _DataLoaderIter."); } PyObject *child_pids = PyTuple_GET_ITEM(args, 1); if (!PyTuple_Check(child_pids)) { - throw TypeError("_update_worker_pids expects a tuple for child_pids, but got %s.", + throw TypeError("_set_worker_pids expects a tuple for child_pids, but got %s.", Py_TYPE(child_pids)->tp_name); } @@ -164,7 +163,7 @@ static PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *args) { auto size = PyTuple_GET_SIZE(child_pids); for (int idx = 0; idx < size; idx++) { PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); - pids_set.insert((pid_t) THPUtils_unpackLong(obj)); + pids_set.insert(static_cast(THPUtils_unpackLong(obj))); } worker_pids[key] = pids_set; @@ -196,7 +195,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *_ Py_RETURN_NONE; } -static PyObject *THPModule_updateWorkerPIDs(PyObject *module, PyObject *_ignored) { +static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *_ignored) { Py_RETURN_NONE; } @@ -212,7 +211,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *_ig PyMethodDef DataLoaderMethods[] = { {"_set_worker_signal_handlers", (PyCFunction)THPModule_setWorkerSignalHandlers, METH_NOARGS, nullptr}, - {"_update_worker_pids", (PyCFunction)THPModule_updateWorkerPIDs, METH_VARARGS, nullptr}, + {"_set_worker_pids", (PyCFunction)THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, {"_remove_worker_pids", (PyCFunction)THPModule_removeWorkerPIDs, METH_O, nullptr}, {"_error_if_any_worker_fails", (PyCFunction)THPModule_errorIfAnyWorkerFails, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr} diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 05c94d1..ee58707 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -1,4 +1,3 @@ - from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler from .distributed import DistributedSampler from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py new file mode 100644 index 0000000..05b2b65 --- /dev/null +++ b/torch/utils/data/_utils/__init__.py @@ -0,0 +1,61 @@ +r"""Utility classes & functions for data loading. Code in this folder is mostly +used by ../dataloder.py. + +A lot of multiprocessing is used in data loading, which only supports running +functions defined in global environment (py2 can't serialize static methods). +Therefore, for code tidiness we put these functions into different files in this +folder. +""" + +import sys +import traceback +import atexit + + +IS_WINDOWS = sys.platform == "win32" + + +# NOTE [ Python Traceback Reference Cycle Problem ] +# +# When using sys.exc_info(), it is important to **not** store the exc_info[2], +# which is the traceback, because otherwise you will run into the traceback +# reference cycle problem, i.e., the traceback holding reference to the frame, +# and the frame (which holds reference to all the object in its temporary scope) +# holding reference the traceback. + + +class ExceptionWrapper(object): + r"""Wraps an exception plus traceback to communicate across threads""" + def __init__(self, exc_info): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + + +python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" + + +def _set_python_exit_flag(): + global python_exit_status + python_exit_status = True + +atexit.register(_set_python_exit_flag) + + +from . import worker, signal_handling, pin_memory, collate diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py new file mode 100644 index 0000000..e0823b2 --- /dev/null +++ b/torch/utils/data/_utils/collate.py @@ -0,0 +1,68 @@ +r""""Contains definitions of the methods used by the _DataLoaderIter workers to +collate samples fetched from dataset into Tensor(s). + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +import re +from torch._six import container_abcs, string_classes, int_classes + +_use_shared_memory = False +r"""Whether to use shared memory in default_collate""" + +np_str_obj_array_pattern = re.compile(r'[SaUO]') + +error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}" + +numpy_type_map = { + 'float64': torch.DoubleTensor, + 'float32': torch.FloatTensor, + 'float16': torch.HalfTensor, + 'int64': torch.LongTensor, + 'int32': torch.IntTensor, + 'int16': torch.ShortTensor, + 'int8': torch.CharTensor, + 'uint8': torch.ByteTensor, +} + + +def default_collate(batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + + elem_type = type(batch[0]) + if isinstance(batch[0], torch.Tensor): + out = None + if _use_shared_memory: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = batch[0].storage()._new_shared(numel) + out = batch[0].new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + if elem_type.__name__ == 'ndarray': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(error_msg_fmt.format(elem.dtype)) + + return torch.stack([torch.from_numpy(b) for b in batch], 0) + if elem.shape == (): # scalars + py_type = float if elem.dtype.name.startswith('float') else int + return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) + elif isinstance(batch[0], int_classes): + return torch.LongTensor(batch) + elif isinstance(batch[0], float): + return torch.DoubleTensor(batch) + elif isinstance(batch[0], string_classes): + return batch + elif isinstance(batch[0], container_abcs.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], container_abcs.Sequence): + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError((error_msg_fmt.format(type(batch[0])))) diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py new file mode 100644 index 0000000..8403b42 --- /dev/null +++ b/torch/utils/data/_utils/pin_memory.py @@ -0,0 +1,57 @@ +r""""Contains definitions of the methods used by the _DataLoaderIter to put +fetched tensors into pinned memory. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +from torch._six import queue, container_abcs, string_classes +from . import collate, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper + + +def _pin_memory_loop(in_queue, out_queue, device_id, done_event): + torch.cuda.set_device(device_id) + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while True: + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + except Exception: + if done_event.is_set(): + # Weird things can happen when shutting down, e.g., fd being + # closed when tensors are shared via fds. + break + raise + if r is None: + assert done_event.is_set() + return + elif done_event.is_set(): + # Haven't seen the final signal yet. Keep getting until None. + continue + elif isinstance(r[1], ExceptionWrapper): + out_queue.put(r) + else: + idx, batch = r + try: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + out_queue.put((idx, batch)) + + +def pin_memory_batch(batch): + if isinstance(batch, torch.Tensor): + return batch.pin_memory() + elif isinstance(batch, string_classes): + return batch + elif isinstance(batch, container_abcs.Mapping): + return {k: pin_memory_batch(sample) for k, sample in batch.items()} + elif isinstance(batch, container_abcs.Sequence): + return [pin_memory_batch(sample) for sample in batch] + else: + return batch diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py new file mode 100644 index 0000000..9364733 --- /dev/null +++ b/torch/utils/data/_utils/signal_handling.py @@ -0,0 +1,70 @@ +r""""Signal handling for multiprocessing data loading. + +NOTE [ Signal handling in multiprocessing data loading ] + +In cases like DataLoader, if a worker process dies due to bus error/segfault +or just hang, the main process will hang waiting for data. This is difficult +to avoid on PyTorch side as it can be caused by limited shm, or other +libraries users call in the workers. In this file and `DataLoader.cpp`, we make +our best effort to provide some error message to users when such unfortunate +events happen. + +When a _DataLoaderIter starts worker processes, their pids are registered in a +defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ] +via `_set_worker_pids`. + +When an error happens in a worker process, the main process received a SIGCHLD, +and Python will eventually call the handler registered below +(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` +call checks all registered worker pids and raise proper error message to +prevent main process from hanging waiting for data from worker. + +Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, +`_set_worker_signal_handlers` is called to register critical signal handlers +(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error +message to stderr before triggering the default handler. So a message will also +be printed from the worker process when it is killed by such signals. + +See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of +this signal handling design and other mechanism we implement to make our +multiprocessing data loading robust to errors. +""" + +import signal +import threading +import torch +from torch._C import _set_worker_pids, _remove_worker_pids, \ + _error_if_any_worker_fails, _set_worker_signal_handlers +from . import IS_WINDOWS + + +_SIGCHLD_handler_set = False +r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if IS_WINDOWS: + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + # This doesn't catch default handler, but SIGCHLD default handler is a + # no-op. + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py new file mode 100644 index 0000000..2258c6f --- /dev/null +++ b/torch/utils/data/_utils/worker.py @@ -0,0 +1,109 @@ +r""""Contains definitions of the methods used by the _DataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import torch +import random +import sys +import os +from torch._six import queue +from . import collate, signal_handling, MP_STATUS_CHECK_INTERVAL, \ + ExceptionWrapper, IS_WINDOWS + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import DWORD, BOOL, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + + self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + return not self.manager_dead +else: + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + collate._use_shared_memory = True + + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.set_num_threads(1) + random.seed(seed) + torch.manual_seed(seed) + + data_queue.cancel_join_thread() + + if init_fn is not None: + init_fn(worker_id) + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if r is None: + # Received the final signal + assert done_event.is_set() + return + elif done_event.is_set(): + # Done event is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, batch_indices = r + try: + samples = collate_fn([dataset[i] for i in batch_indices]) + except Exception: + # It is important that we don't store exc_info in a variable, + # see NOTE [ Python Traceback Reference Cycle Problem ] + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + data_queue.put((idx, samples)) + del samples + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 122272b..a8684e9 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1,300 +1,130 @@ -import random +r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes. + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + import torch import torch.multiprocessing as multiprocessing -from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ - _remove_worker_pids, _error_if_any_worker_fails from . import SequentialSampler, RandomSampler, BatchSampler -import signal -import functools -from torch._six import container_abcs -import re -import sys +from . import _utils import threading -import traceback -import os -import time -import atexit -from torch._six import string_classes, int_classes, FileNotFoundError - -IS_WINDOWS = sys.platform == "win32" -if IS_WINDOWS: - import ctypes - from ctypes.wintypes import DWORD, BOOL, HANDLE - -if sys.version_info[0] == 2: - import Queue as queue -else: - import queue - - -# NOTE [ Python Traceback Reference Cycle Problem ] -# -# When using sys.exc_info(), it is important to **not** store the exc_info[2], -# which is the traceback, because otherwise you will run into the traceback -# reference cycle problem, i.e., the traceback holding reference to the frame, -# and the frame (which holds reference to all the object in its temporary scope) -# holding reference the traceback. - - -class ExceptionWrapper(object): - r"""Wraps an exception plus traceback to communicate across threads""" - def __init__(self, exc_info): - # It is important that we don't store exc_info, see - # NOTE [ Python Traceback Reference Cycle Problem ] - self.exc_type = exc_info[0] - self.exc_msg = "".join(traceback.format_exception(*exc_info)) - - -_use_shared_memory = False -r"""Whether to use shared memory in default_collate""" - -MP_STATUS_CHECK_INTERVAL = 5.0 -r"""Interval (in seconds) to check status of processes to avoid hanging in - multiprocessing data loading. This is mainly used in getting data from - another process, in which case we need to periodically check whether the - sender is alive to prevent hanging.""" - -if IS_WINDOWS: - # On Windows, the parent ID of the worker process remains unchanged when the manager process - # is gone, and the only way to check it through OS is to let the worker have a process handle - # of the manager and ask if the process status has changed. - class ManagerWatchdog(object): - def __init__(self): - self.manager_pid = os.getppid() - - self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) - self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) - self.kernel32.OpenProcess.restype = HANDLE - self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) - self.kernel32.WaitForSingleObject.restype = DWORD - - # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx - SYNCHRONIZE = 0x00100000 - self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) - - if not self.manager_handle: - raise ctypes.WinError(ctypes.get_last_error()) - - self.manager_dead = False - - def is_alive(self): - if not self.manager_dead: - # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx - self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 - return not self.manager_dead -else: - class ManagerWatchdog(object): - def __init__(self): - self.manager_pid = os.getppid() - self.manager_dead = False - - def is_alive(self): - if not self.manager_dead: - self.manager_dead = os.getppid() != self.manager_pid - return not self.manager_dead - - -def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): - # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the - # logic of this function. - - try: - global _use_shared_memory - _use_shared_memory = True - - # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal - # module's handlers are executed after Python returns from C low-level - # handlers, likely when the same fatal signal had already happened. - # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 - _set_worker_signal_handlers() - - torch.set_num_threads(1) - random.seed(seed) - torch.manual_seed(seed) - - data_queue.cancel_join_thread() - - if init_fn is not None: - init_fn(worker_id) - - watchdog = ManagerWatchdog() - - while watchdog.is_alive(): - try: - r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue - if r is None: - # Received the final signal - assert done_event.is_set() - return - elif done_event.is_set(): - # Done event is set. But I haven't received the final signal - # (None) yet. I will keep continuing until get it, and skip the - # processing steps. - continue - idx, batch_indices = r - try: - samples = collate_fn([dataset[i] for i in batch_indices]) - except Exception: - # It is important that we don't store exc_info in a variable, - # see NOTE [ Python Traceback Reference Cycle Problem ] - data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) - else: - data_queue.put((idx, samples)) - del samples - except KeyboardInterrupt: - # Main process will raise KeyboardInterrupt anyways. - pass - - -def _pin_memory_loop(in_queue, out_queue, device_id, done_event): - torch.cuda.set_device(device_id) - - # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the - # logic of this function. - while True: - try: - r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue - except Exception: - if done_event.is_set(): - # Weird things can happen when shutting down, e.g., fd being - # closed when tensors are shared via fds. - break - raise - if r is None: - assert done_event.is_set() - return - elif done_event.is_set(): - # Haven't seen the final signal yet. Keep getting until None. - continue - elif isinstance(r[1], ExceptionWrapper): - out_queue.put(r) - else: - idx, batch = r - try: - batch = pin_memory_batch(batch) - except Exception: - out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) - else: - out_queue.put((idx, batch)) - -numpy_type_map = { - 'float64': torch.DoubleTensor, - 'float32': torch.FloatTensor, - 'float16': torch.HalfTensor, - 'int64': torch.LongTensor, - 'int32': torch.IntTensor, - 'int16': torch.ShortTensor, - 'int8': torch.CharTensor, - 'uint8': torch.ByteTensor, -} - - -def default_collate(batch): - r"""Puts each data field into a tensor with outer dimension batch size""" - - error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" - elem_type = type(batch[0]) - if isinstance(batch[0], torch.Tensor): - out = None - if _use_shared_memory: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum([x.numel() for x in batch]) - storage = batch[0].storage()._new_shared(numel) - out = batch[0].new(storage) - return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ - and elem_type.__name__ != 'string_': - elem = batch[0] - if elem_type.__name__ == 'ndarray': - # array of string classes and object - if re.search('[SaUO]', elem.dtype.str) is not None: - raise TypeError(error_msg.format(elem.dtype)) - - return torch.stack([torch.from_numpy(b) for b in batch], 0) - if elem.shape == (): # scalars - py_type = float if elem.dtype.name.startswith('float') else int - return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) - elif isinstance(batch[0], int_classes): - return torch.LongTensor(batch) - elif isinstance(batch[0], float): - return torch.DoubleTensor(batch) - elif isinstance(batch[0], string_classes): - return batch - elif isinstance(batch[0], container_abcs.Mapping): - return {key: default_collate([d[key] for d in batch]) for key in batch[0]} - elif isinstance(batch[0], container_abcs.Sequence): - transposed = zip(*batch) - return [default_collate(samples) for samples in transposed] +from torch._six import queue - raise TypeError((error_msg.format(type(batch[0])))) +# This function used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import torch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate = _utils.collate.default_collate -def pin_memory_batch(batch): - if isinstance(batch, torch.Tensor): - return batch.pin_memory() - elif isinstance(batch, string_classes): - return batch - elif isinstance(batch, container_abcs.Mapping): - return {k: pin_memory_batch(sample) for k, sample in batch.items()} - elif isinstance(batch, container_abcs.Sequence): - return [pin_memory_batch(sample) for sample in batch] - else: - return batch +class DataLoader(object): + r""" + Data loader. Combines a dataset and a sampler, and provides + single- or multi-process iterators over the dataset. -_SIGCHLD_handler_set = False -r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one -handler needs to be set for all DataLoaders in a process.""" - - -def _set_SIGCHLD_handler(): - # Windows doesn't support SIGCHLD handler - if sys.platform == 'win32': - return - # can't set signal in child threads - if not isinstance(threading.current_thread(), threading._MainThread): - return - global _SIGCHLD_handler_set - if _SIGCHLD_handler_set: - return - previous_handler = signal.getsignal(signal.SIGCHLD) - if not callable(previous_handler): - # This doesn't catch default handler, but SIGCHLD default handler is a - # no-op. - previous_handler = None - - def handler(signum, frame): - # This following call uses `waitid` with WNOHANG from C side. Therefore, - # Python can still get and update the process status successfully. - _error_if_any_worker_fails() - if previous_handler is not None: - previous_handler(signum, frame) - - signal.signal(signal.SIGCHLD, handler) - _SIGCHLD_handler_set = True - - -_python_exit_status = False -r"""Whether Python is shutting down. This flag is guaranteed to be set before -the Python core library resources are freed, but Python may already be exiting -for some time when this is set. - -Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar -hook in Python 3.7 multiprocessing library: -https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 -""" + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, ``shuffle`` must be False. + batch_sampler (Sampler, optional): like sampler, but returns a batch of + indices at a time. Mutually exclusive with :attr:`batch_size`, + :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (callable, optional): merges a list of samples to form a mini-batch. + pin_memory (bool, optional): If ``True``, the data loader will copy tensors + into CUDA pinned memory before returning them. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + + .. note:: By default, each worker will have its PyTorch seed set to + ``base_seed + worker_id``, where ``base_seed`` is a long generated + by main process using its RNG. However, seeds for other libraies + may be duplicated upon initializing workers (w.g., NumPy), causing + each worker to return identical random numbers. (See + :ref:`dataloader-workers-random-seed` section in FAQ.) You may + use :func:`torch.initial_seed()` to access the PyTorch seed for + each worker in :attr:`worker_init_fn`, and use it to set other + seeds before data loading. + + .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an + unpicklable object, e.g., a lambda function. + """ + + __initialized = False + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, collate_fn=default_collate, + pin_memory=False, drop_last=False, timeout=0, + worker_init_fn=None): + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.collate_fn = collate_fn + self.pin_memory = pin_memory + self.drop_last = drop_last + self.timeout = timeout + self.worker_init_fn = worker_init_fn + + if timeout < 0: + raise ValueError('timeout option should be non-negative') + + if batch_sampler is not None: + if batch_size > 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler option is mutually exclusive ' + 'with batch_size, shuffle, sampler, and ' + 'drop_last') + self.batch_size = None + self.drop_last = None + if sampler is not None and shuffle: + raise ValueError('sampler option is mutually exclusive with ' + 'shuffle') -def _set_python_exit_flag(): - global _python_exit_status - _python_exit_status = True + if self.num_workers < 0: + raise ValueError('num_workers option cannot be negative; ' + 'use num_workers=0 to disable multiprocessing.') -atexit.register(_set_python_exit_flag) + if batch_sampler is None: + if sampler is None: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.sampler = sampler + self.batch_sampler = batch_sampler + self.__initialized = True + + def __setattr__(self, attr, val): + if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): + raise ValueError('{} attribute should not be set after {} is ' + 'initialized'.format(attr, self.__class__.__name__)) + + super(DataLoader, self).__setattr__(attr, val) + + def __iter__(self): + return _DataLoaderIter(self) + + def __len__(self): + return len(self.batch_sampler) class _DataLoaderIter(object): @@ -354,12 +184,13 @@ class _DataLoaderIter(object): # Therefore, in this case, we actually need to prevent `__del__` from # being executed, and rely on the automatic termination of daemonic # children. Thus, we register an `atexit` hook that sets a global flag - # `_python_exit_status`. Since `atexit` hooks are executed in reverse - # order of registration, we are guaranteed that this flag is set before - # library resources we use are freed. (Hooks freeing those resources - # are registered at importing the Python core libraries at the top of - # this file.) So in `__del__`, we check if `_python_exit_status` is set - # or `None` (freed), and perform no-op if so. + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed. (Hooks freeing those + # resources are registered at importing the Python core libraries at + # the top of this file.) So in `__del__`, we check if + # `_utils.python_exit_status` is set or `None` (freed), and perform + # no-op if so. # # Another problem with `__del__` is also related to the library cleanup # calls. When a process ends, it shuts the all its daemonic children @@ -419,8 +250,8 @@ class _DataLoaderIter(object): # For `.get()` calls where the sender(s) is not the workers, we # guard them with timeouts, and check the status of the sender # when timeout happens: - # + in the workers, the `ManagerWatchdog` class checks the main - # process status. + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. # + if `pin_memory=True`, when getting from `pin_memory_thread`, # check `pin_memory_thread` status periodically until `.get()` # returns or see that `pin_memory_thread` died. @@ -545,7 +376,7 @@ class _DataLoaderIter(object): index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( - target=_worker_loop, + target=_utils.worker._worker_loop, args=(self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, base_seed + i, @@ -564,7 +395,7 @@ class _DataLoaderIter(object): if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( - target=_pin_memory_loop, + target=_utils.pin_memory._pin_memory_loop, args=(self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event)) pin_memory_thread.daemon = True @@ -575,8 +406,8 @@ class _DataLoaderIter(object): else: self.data_queue = self.worker_result_queue - _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) - _set_SIGCHLD_handler() + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers)) + _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop @@ -598,7 +429,7 @@ class _DataLoaderIter(object): elif self.pin_memory: while self.pin_memory_thread.is_alive(): try: - return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue else: @@ -614,7 +445,7 @@ class _DataLoaderIter(object): indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: - batch = pin_memory_batch(batch) + batch = _utils.pin_memory.pin_memory_batch(batch) return batch # check if the next sample has already been generated @@ -654,7 +485,7 @@ class _DataLoaderIter(object): def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() - if isinstance(batch, ExceptionWrapper): + if isinstance(batch, _utils.ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch @@ -669,7 +500,8 @@ class _DataLoaderIter(object): def _shutdown_workers(self): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. - if _python_exit_status is True or _python_exit_status is None: + python_exit_status = _utils.python_exit_status + if python_exit_status is True or python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. @@ -679,7 +511,7 @@ class _DataLoaderIter(object): # Removes pids from the C side data structure first so worker # termination afterwards won't trigger false positive error report. if self.worker_pids_set: - _remove_worker_pids(id(self)) + _utils.signal_handling._remove_worker_pids(id(self)) self.worker_pids_set = False self.done_event.set() @@ -715,108 +547,3 @@ class _DataLoaderIter(object): def __del__(self): if self.num_workers > 0: self._shutdown_workers() - - -class DataLoader(object): - r""" - Data loader. Combines a dataset and a sampler, and provides - single- or multi-process iterators over the dataset. - - Arguments: - dataset (Dataset): dataset from which to load the data. - batch_size (int, optional): how many samples per batch to load - (default: ``1``). - shuffle (bool, optional): set to ``True`` to have the data reshuffled - at every epoch (default: ``False``). - sampler (Sampler, optional): defines the strategy to draw samples from - the dataset. If specified, ``shuffle`` must be False. - batch_sampler (Sampler, optional): like sampler, but returns a batch of - indices at a time. Mutually exclusive with :attr:`batch_size`, - :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. - num_workers (int, optional): how many subprocesses to use for data - loading. 0 means that the data will be loaded in the main process. - (default: ``0``) - collate_fn (callable, optional): merges a list of samples to form a mini-batch. - pin_memory (bool, optional): If ``True``, the data loader will copy tensors - into CUDA pinned memory before returning them. - drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, - if the dataset size is not divisible by the batch size. If ``False`` and - the size of dataset is not divisible by the batch size, then the last batch - will be smaller. (default: ``False``) - timeout (numeric, optional): if positive, the timeout value for collecting a batch - from workers. Should always be non-negative. (default: ``0``) - worker_init_fn (callable, optional): If not ``None``, this will be called on each - worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as - input, after seeding and before data loading. (default: ``None``) - - .. note:: By default, each worker will have its PyTorch seed set to - ``base_seed + worker_id``, where ``base_seed`` is a long generated - by main process using its RNG. However, seeds for other libraies - may be duplicated upon initializing workers (w.g., NumPy), causing - each worker to return identical random numbers. (See - :ref:`dataloader-workers-random-seed` section in FAQ.) You may - use :func:`torch.initial_seed()` to access the PyTorch seed for - each worker in :attr:`worker_init_fn`, and use it to set other - seeds before data loading. - - .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an - unpicklable object, e.g., a lambda function. - """ - - __initialized = False - - def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, - num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): - self.dataset = dataset - self.batch_size = batch_size - self.num_workers = num_workers - self.collate_fn = collate_fn - self.pin_memory = pin_memory - self.drop_last = drop_last - self.timeout = timeout - self.worker_init_fn = worker_init_fn - - if timeout < 0: - raise ValueError('timeout option should be non-negative') - - if batch_sampler is not None: - if batch_size > 1 or shuffle or sampler is not None or drop_last: - raise ValueError('batch_sampler option is mutually exclusive ' - 'with batch_size, shuffle, sampler, and ' - 'drop_last') - self.batch_size = None - self.drop_last = None - - if sampler is not None and shuffle: - raise ValueError('sampler option is mutually exclusive with ' - 'shuffle') - - if self.num_workers < 0: - raise ValueError('num_workers option cannot be negative; ' - 'use num_workers=0 to disable multiprocessing.') - - if batch_sampler is None: - if sampler is None: - if shuffle: - sampler = RandomSampler(dataset) - else: - sampler = SequentialSampler(dataset) - batch_sampler = BatchSampler(sampler, batch_size, drop_last) - - self.sampler = sampler - self.batch_sampler = batch_sampler - self.__initialized = True - - def __setattr__(self, attr, val): - if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): - raise ValueError('{} attribute should not be set after {} is ' - 'initialized'.format(attr, self.__class__.__name__)) - - super(DataLoader, self).__setattr__(attr, val) - - def __iter__(self): - return _DataLoaderIter(self) - - def __len__(self): - return len(self.batch_sampler)