From 38eb1beff5bbaed0d3cc8ad59039b50f850f7245 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Tue, 4 Dec 2018 20:23:25 -0800 Subject: [PATCH] Revert D13289919: [pytorch][PR] [DataLoader] Refactor dataloader.py Differential Revision: D13289919 Original commit changeset: d701bc7bb48f fbshipit-source-id: c350c491fefa98a0a7c0cf22cb832e78aeb15c3d --- test/test_dataloader.py | 12 +- torch/_six.py | 6 - torch/csrc/DataLoader.cpp | 13 +- torch/utils/data/__init__.py | 1 + torch/utils/data/_utils/__init__.py | 61 ---- torch/utils/data/_utils/collate.py | 68 ---- torch/utils/data/_utils/pin_memory.py | 57 --- torch/utils/data/_utils/signal_handling.py | 70 ---- torch/utils/data/_utils/worker.py | 108 ------ torch/utils/data/dataloader.py | 541 ++++++++++++++++++++++------- 10 files changed, 421 insertions(+), 516 deletions(-) delete mode 100644 torch/utils/data/_utils/__init__.py delete mode 100644 torch/utils/data/_utils/collate.py delete mode 100644 torch/utils/data/_utils/pin_memory.py delete mode 100644 torch/utils/data/_utils/signal_handling.py delete mode 100644 torch/utils/data/_utils/worker.py diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 9db11b0..e1ca129 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -12,9 +12,9 @@ import unittest import subprocess import itertools from torch import multiprocessing as mp -from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset -from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL +from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset from torch.utils.data.dataset import random_split +from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests) @@ -788,16 +788,16 @@ class TestDataLoader(TestCase): # Should be a no-op arr = np.array(['a', 'b', 'c']) - _utils.collate.default_collate(arr) + default_collate(arr) arr = np.array([[['a', 'b', 'c']]]) - self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) + self.assertRaises(TypeError, lambda: default_collate(arr)) arr = np.array([object(), object(), object()]) - self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) + self.assertRaises(TypeError, lambda: default_collate(arr)) arr = np.array([[[object(), object(), object()]]]) - self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) + self.assertRaises(TypeError, lambda: default_collate(arr)) class StringDataset(Dataset): diff --git a/torch/_six.py b/torch/_six.py index 3a4c2ad..924e641 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -51,12 +51,6 @@ else: FileNotFoundError = FileNotFoundError -if PY2: - import Queue as queue -else: - import queue - - def with_metaclass(meta, *bases): """Create a base class with a metaclass.""" # This requires a bit of explanation: the basic idea is to make a dummy diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp index d75bcf6..c5cdf64 100644 --- a/torch/csrc/DataLoader.cpp +++ b/torch/csrc/DataLoader.cpp @@ -1,10 +1,11 @@ #include "DataLoader.h" -// Together with `torch/utils/data/_utils/signal_handling.py`, the following -// is an effort to do our best to provide some error message to users when a -// worker dies due to error / critical signals. -// -// See NOTE [ Signal handling in multiprocessing data loading ] for more details. +// In cases like DataLoader, if a worker process dies due to bus error/segfault +// or just hang, the main process will hang waiting for data. This is difficult +// to avoid on PyTorch side as it can be caused by limited shm, or other +// libraries users call in the workers. The following methods is an effort to do +// our best to provide some error message to users when such unfortunate events +// happen. // TODO: The following don't work on Windows. Specifically, sigaction, waitid // calls, and SIGCHLD handler. Currently, dummy implementations are provided @@ -127,7 +128,7 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { // workers, and trigger this again. pid_set->clear(); throw std::runtime_error(oss.str()); - } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal + } else if (infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " << "by signal: " << strsignal(infop.si_status) << ". "; diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index ee58707..05c94d1 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -1,3 +1,4 @@ + from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler from .distributed import DistributedSampler from .dataset import Dataset, TensorDataset, ConcatDataset, Subset, random_split diff --git a/torch/utils/data/_utils/__init__.py b/torch/utils/data/_utils/__init__.py deleted file mode 100644 index 05b2b65..0000000 --- a/torch/utils/data/_utils/__init__.py +++ /dev/null @@ -1,61 +0,0 @@ -r"""Utility classes & functions for data loading. Code in this folder is mostly -used by ../dataloder.py. - -A lot of multiprocessing is used in data loading, which only supports running -functions defined in global environment (py2 can't serialize static methods). -Therefore, for code tidiness we put these functions into different files in this -folder. -""" - -import sys -import traceback -import atexit - - -IS_WINDOWS = sys.platform == "win32" - - -# NOTE [ Python Traceback Reference Cycle Problem ] -# -# When using sys.exc_info(), it is important to **not** store the exc_info[2], -# which is the traceback, because otherwise you will run into the traceback -# reference cycle problem, i.e., the traceback holding reference to the frame, -# and the frame (which holds reference to all the object in its temporary scope) -# holding reference the traceback. - - -class ExceptionWrapper(object): - r"""Wraps an exception plus traceback to communicate across threads""" - def __init__(self, exc_info): - # It is important that we don't store exc_info, see - # NOTE [ Python Traceback Reference Cycle Problem ] - self.exc_type = exc_info[0] - self.exc_msg = "".join(traceback.format_exception(*exc_info)) - - -MP_STATUS_CHECK_INTERVAL = 5.0 -r"""Interval (in seconds) to check status of processes to avoid hanging in - multiprocessing data loading. This is mainly used in getting data from - another process, in which case we need to periodically check whether the - sender is alive to prevent hanging.""" - - -python_exit_status = False -r"""Whether Python is shutting down. This flag is guaranteed to be set before -the Python core library resources are freed, but Python may already be exiting -for some time when this is set. - -Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar -hook in Python 3.7 multiprocessing library: -https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 -""" - - -def _set_python_exit_flag(): - global python_exit_status - python_exit_status = True - -atexit.register(_set_python_exit_flag) - - -from . import worker, signal_handling, pin_memory, collate diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py deleted file mode 100644 index e0823b2..0000000 --- a/torch/utils/data/_utils/collate.py +++ /dev/null @@ -1,68 +0,0 @@ -r""""Contains definitions of the methods used by the _DataLoaderIter workers to -collate samples fetched from dataset into Tensor(s). - -These **needs** to be in global scope since Py2 doesn't support serializing -static methods. -""" - -import torch -import re -from torch._six import container_abcs, string_classes, int_classes - -_use_shared_memory = False -r"""Whether to use shared memory in default_collate""" - -np_str_obj_array_pattern = re.compile(r'[SaUO]') - -error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}" - -numpy_type_map = { - 'float64': torch.DoubleTensor, - 'float32': torch.FloatTensor, - 'float16': torch.HalfTensor, - 'int64': torch.LongTensor, - 'int32': torch.IntTensor, - 'int16': torch.ShortTensor, - 'int8': torch.CharTensor, - 'uint8': torch.ByteTensor, -} - - -def default_collate(batch): - r"""Puts each data field into a tensor with outer dimension batch size""" - - elem_type = type(batch[0]) - if isinstance(batch[0], torch.Tensor): - out = None - if _use_shared_memory: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum([x.numel() for x in batch]) - storage = batch[0].storage()._new_shared(numel) - out = batch[0].new(storage) - return torch.stack(batch, 0, out=out) - elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ - and elem_type.__name__ != 'string_': - elem = batch[0] - if elem_type.__name__ == 'ndarray': - # array of string classes and object - if np_str_obj_array_pattern.search(elem.dtype.str) is not None: - raise TypeError(error_msg_fmt.format(elem.dtype)) - - return torch.stack([torch.from_numpy(b) for b in batch], 0) - if elem.shape == (): # scalars - py_type = float if elem.dtype.name.startswith('float') else int - return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) - elif isinstance(batch[0], int_classes): - return torch.LongTensor(batch) - elif isinstance(batch[0], float): - return torch.DoubleTensor(batch) - elif isinstance(batch[0], string_classes): - return batch - elif isinstance(batch[0], container_abcs.Mapping): - return {key: default_collate([d[key] for d in batch]) for key in batch[0]} - elif isinstance(batch[0], container_abcs.Sequence): - transposed = zip(*batch) - return [default_collate(samples) for samples in transposed] - - raise TypeError((error_msg_fmt.format(type(batch[0])))) diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py deleted file mode 100644 index 8403b42..0000000 --- a/torch/utils/data/_utils/pin_memory.py +++ /dev/null @@ -1,57 +0,0 @@ -r""""Contains definitions of the methods used by the _DataLoaderIter to put -fetched tensors into pinned memory. - -These **needs** to be in global scope since Py2 doesn't support serializing -static methods. -""" - -import torch -from torch._six import queue, container_abcs, string_classes -from . import collate, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper - - -def _pin_memory_loop(in_queue, out_queue, device_id, done_event): - torch.cuda.set_device(device_id) - - # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the - # logic of this function. - while True: - try: - r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue - except Exception: - if done_event.is_set(): - # Weird things can happen when shutting down, e.g., fd being - # closed when tensors are shared via fds. - break - raise - if r is None: - assert done_event.is_set() - return - elif done_event.is_set(): - # Haven't seen the final signal yet. Keep getting until None. - continue - elif isinstance(r[1], ExceptionWrapper): - out_queue.put(r) - else: - idx, batch = r - try: - batch = pin_memory_batch(batch) - except Exception: - out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) - else: - out_queue.put((idx, batch)) - - -def pin_memory_batch(batch): - if isinstance(batch, torch.Tensor): - return batch.pin_memory() - elif isinstance(batch, string_classes): - return batch - elif isinstance(batch, container_abcs.Mapping): - return {k: pin_memory_batch(sample) for k, sample in batch.items()} - elif isinstance(batch, container_abcs.Sequence): - return [pin_memory_batch(sample) for sample in batch] - else: - return batch diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py deleted file mode 100644 index 74bb2ba..0000000 --- a/torch/utils/data/_utils/signal_handling.py +++ /dev/null @@ -1,70 +0,0 @@ -r""""Signal handling for multiprocessing data loading. - -NOTE [ Signal handling in multiprocessing data loading ] - -In cases like DataLoader, if a worker process dies due to bus error/segfault -or just hang, the main process will hang waiting for data. This is difficult -to avoid on PyTorch side as it can be caused by limited shm, or other -libraries users call in the workers. In this file and `DataLoader.cpp`, we make -our best effort to provide some error message to users when such unfortunate -events happen. - -When a _DataLoaderIter starts worker processes, their pids are registered in a -defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ] -via `_update_worker_pids`. - -When an error happens in a worker process, the main process received a SIGCHLD, -and Python will eventually call the handler registered below -(in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails` -call checks all registered worker pids and raise proper error message to -prevent main process from hanging waiting for data from worker. - -Additionally, at the beginning of each worker's `_utils.worker._worker_loop`, -`_set_worker_signal_handlers` is called to register critical signal handlers -(e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error -message to stderr before triggering the default handler. So a message will also -be printed from the worker process when it is killed by such signals. - -See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of -this signal handling design and other mechanism we implement to make our -multiprocessing data loading robust to errors. -""" - -import signal -import threading -import torch -from torch._C import _update_worker_pids, _remove_worker_pids, \ - _error_if_any_worker_fails, _set_worker_signal_handlers -from . import IS_WINDOWS - - -_SIGCHLD_handler_set = False -r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one -handler needs to be set for all DataLoaders in a process.""" - - -def _set_SIGCHLD_handler(): - # Windows doesn't support SIGCHLD handler - if IS_WINDOWS: - return - # can't set signal in child threads - if not isinstance(threading.current_thread(), threading._MainThread): - return - global _SIGCHLD_handler_set - if _SIGCHLD_handler_set: - return - previous_handler = signal.getsignal(signal.SIGCHLD) - if not callable(previous_handler): - # This doesn't catch default handler, but SIGCHLD default handler is a - # no-op. - previous_handler = None - - def handler(signum, frame): - # This following call uses `waitid` with WNOHANG from C side. Therefore, - # Python can still get and update the process status successfully. - _error_if_any_worker_fails() - if previous_handler is not None: - previous_handler(signum, frame) - - signal.signal(signal.SIGCHLD, handler) - _SIGCHLD_handler_set = True diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py deleted file mode 100644 index 2cddf3a..0000000 --- a/torch/utils/data/_utils/worker.py +++ /dev/null @@ -1,108 +0,0 @@ -r""""Contains definitions of the methods used by the _DataLoaderIter workers. - -These **needs** to be in global scope since Py2 doesn't support serializing -static methods. -""" - -import torch -import random -import sys -import os -from torch._six import queue -from . import collate, signal_handling, MP_STATUS_CHECK_INTERVAL, \ - ExceptionWrapper, IS_WINDOWS - -if IS_WINDOWS: - import ctypes - from ctypes.wintypes import DWORD, BOOL, HANDLE - - # On Windows, the parent ID of the worker process remains unchanged when the manager process - # is gone, and the only way to check it through OS is to let the worker have a process handle - # of the manager and ask if the process status has changed. - class ManagerWatchdog(object): - def __init__(self): - self.manager_pid = os.getppid() - - self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) - self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) - self.kernel32.OpenProcess.restype = HANDLE - self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) - self.kernel32.WaitForSingleObject.restype = DWORD - - # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx - SYNCHRONIZE = 0x00100000 - self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) - - if not self.manager_handle: - raise ctypes.WinError(ctypes.get_last_error()) - - self.manager_dead = False - - def is_alive(self): - if not self.manager_dead: - # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx - self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 - return not self.manager_dead -else: - class ManagerWatchdog(object): - def __init__(self): - self.manager_pid = os.getppid() - self.manager_dead = False - - def is_alive(self): - if not self.manager_dead: - self.manager_dead = os.getppid() != self.manager_pid - return not self.manager_dead - - -def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): - # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the - # logic of this function. - - try: - collate._use_shared_memory = True - - # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal - # module's handlers are executed after Python returns from C low-level - # handlers, likely when the same fatal signal happened again already. - # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 - signal_handling._set_worker_signal_handlers() - - torch.set_num_threads(1) - random.seed(seed) - torch.manual_seed(seed) - - data_queue.cancel_join_thread() - - if init_fn is not None: - init_fn(worker_id) - - watchdog = ManagerWatchdog() - - while watchdog.is_alive(): - try: - r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue - if r is None: - # Received the final signal - assert done_event.is_set() - return - elif done_event.is_set(): - # Done event is set. But I haven't received the final signal - # (None) yet. I will keep continuing until get it, and skip the - # processing steps. - continue - idx, batch_indices = r - try: - samples = collate_fn([dataset[i] for i in batch_indices]) - except Exception: - # It is important that we don't store exc_info in a variable, - # see NOTE [ Python Traceback Reference Cycle Problem ] - data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) - else: - data_queue.put((idx, samples)) - del samples - except KeyboardInterrupt: - # Main process will raise KeyboardInterrupt anyways. - pass diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 6d5cb81..c1ee0eb 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1,130 +1,300 @@ -r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes. - -To support these two classes, in `./_utils` we define many utility methods and -functions to be run in multiprocessing. E.g., the data loading worker loop is -in `./_utils/worker.py`. -""" - +import random import torch import torch.multiprocessing as multiprocessing +from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ + _remove_worker_pids, _error_if_any_worker_fails from . import SequentialSampler, RandomSampler, BatchSampler -from . import _utils +import signal +import functools +from torch._six import container_abcs +import re +import sys import threading -from torch._six import queue - - -# This function used to be defined in this file. However, it was moved to -# _utils/collate.py. Although it is rather hard to access this from user land -# (one has to explicitly directly `import torch.utils.data.dataloader`), there -# probably is user code out there using it. This aliasing maintains BC in this -# aspect. -default_collate = _utils.collate.default_collate - - -class DataLoader(object): - r""" - Data loader. Combines a dataset and a sampler, and provides - single- or multi-process iterators over the dataset. - - Arguments: - dataset (Dataset): dataset from which to load the data. - batch_size (int, optional): how many samples per batch to load - (default: ``1``). - shuffle (bool, optional): set to ``True`` to have the data reshuffled - at every epoch (default: ``False``). - sampler (Sampler, optional): defines the strategy to draw samples from - the dataset. If specified, ``shuffle`` must be False. - batch_sampler (Sampler, optional): like sampler, but returns a batch of - indices at a time. Mutually exclusive with :attr:`batch_size`, - :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. - num_workers (int, optional): how many subprocesses to use for data - loading. 0 means that the data will be loaded in the main process. - (default: ``0``) - collate_fn (callable, optional): merges a list of samples to form a mini-batch. - pin_memory (bool, optional): If ``True``, the data loader will copy tensors - into CUDA pinned memory before returning them. - drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, - if the dataset size is not divisible by the batch size. If ``False`` and - the size of dataset is not divisible by the batch size, then the last batch - will be smaller. (default: ``False``) - timeout (numeric, optional): if positive, the timeout value for collecting a batch - from workers. Should always be non-negative. (default: ``0``) - worker_init_fn (callable, optional): If not ``None``, this will be called on each - worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as - input, after seeding and before data loading. (default: ``None``) - - .. note:: By default, each worker will have its PyTorch seed set to - ``base_seed + worker_id``, where ``base_seed`` is a long generated - by main process using its RNG. However, seeds for other libraies - may be duplicated upon initializing workers (w.g., NumPy), causing - each worker to return identical random numbers. (See - :ref:`dataloader-workers-random-seed` section in FAQ.) You may - use :func:`torch.initial_seed()` to access the PyTorch seed for - each worker in :attr:`worker_init_fn`, and use it to set other - seeds before data loading. - - .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an - unpicklable object, e.g., a lambda function. - """ - - __initialized = False +import traceback +import os +import time +import atexit +from torch._six import string_classes, int_classes, FileNotFoundError + +IS_WINDOWS = sys.platform == "win32" +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import DWORD, BOOL, HANDLE + +if sys.version_info[0] == 2: + import Queue as queue +else: + import queue + + +# NOTE [ Python Traceback Reference Cycle Problem ] +# +# When using sys.exc_info(), it is important to **not** store the exc_info[2], +# which is the traceback, because otherwise you will run into the traceback +# reference cycle problem, i.e., the traceback holding reference to the frame, +# and the frame (which holds reference to all the object in its temporary scope) +# holding reference the traceback. + + +class ExceptionWrapper(object): + r"""Wraps an exception plus traceback to communicate across threads""" + def __init__(self, exc_info): + # It is important that we don't store exc_info, see + # NOTE [ Python Traceback Reference Cycle Problem ] + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + +_use_shared_memory = False +r"""Whether to use shared memory in default_collate""" + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + +if IS_WINDOWS: + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + + self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + return not self.manager_dead +else: + class ManagerWatchdog(object): + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + global _use_shared_memory + _use_shared_memory = True + + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal happened again already. + # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 + _set_worker_signal_handlers() + + torch.set_num_threads(1) + random.seed(seed) + torch.manual_seed(seed) + + data_queue.cancel_join_thread() + + if init_fn is not None: + init_fn(worker_id) + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if r is None: + # Received the final signal + assert done_event.is_set() + return + elif done_event.is_set(): + # Done event is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, batch_indices = r + try: + samples = collate_fn([dataset[i] for i in batch_indices]) + except Exception: + # It is important that we don't store exc_info in a variable, + # see NOTE [ Python Traceback Reference Cycle Problem ] + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + data_queue.put((idx, samples)) + del samples + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + + +def _pin_memory_loop(in_queue, out_queue, device_id, done_event): + torch.cuda.set_device(device_id) + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while True: + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + except Exception: + if done_event.is_set(): + # Weird things can happen when shutting down, e.g., fd being + # closed when tensors are shared via fds. + break + raise + if r is None: + assert done_event.is_set() + return + elif done_event.is_set(): + # Haven't seen the final signal yet. Keep getting until None. + continue + elif isinstance(r[1], ExceptionWrapper): + out_queue.put(r) + else: + idx, batch = r + try: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + out_queue.put((idx, batch)) + +numpy_type_map = { + 'float64': torch.DoubleTensor, + 'float32': torch.FloatTensor, + 'float16': torch.HalfTensor, + 'int64': torch.LongTensor, + 'int32': torch.IntTensor, + 'int16': torch.ShortTensor, + 'int8': torch.CharTensor, + 'uint8': torch.ByteTensor, +} + + +def default_collate(batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + + error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" + elem_type = type(batch[0]) + if isinstance(batch[0], torch.Tensor): + out = None + if _use_shared_memory: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = batch[0].storage()._new_shared(numel) + out = batch[0].new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + if elem_type.__name__ == 'ndarray': + # array of string classes and object + if re.search('[SaUO]', elem.dtype.str) is not None: + raise TypeError(error_msg.format(elem.dtype)) + + return torch.stack([torch.from_numpy(b) for b in batch], 0) + if elem.shape == (): # scalars + py_type = float if elem.dtype.name.startswith('float') else int + return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) + elif isinstance(batch[0], int_classes): + return torch.LongTensor(batch) + elif isinstance(batch[0], float): + return torch.DoubleTensor(batch) + elif isinstance(batch[0], string_classes): + return batch + elif isinstance(batch[0], container_abcs.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], container_abcs.Sequence): + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] - def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, collate_fn=default_collate, - pin_memory=False, drop_last=False, timeout=0, - worker_init_fn=None): - self.dataset = dataset - self.batch_size = batch_size - self.num_workers = num_workers - self.collate_fn = collate_fn - self.pin_memory = pin_memory - self.drop_last = drop_last - self.timeout = timeout - self.worker_init_fn = worker_init_fn + raise TypeError((error_msg.format(type(batch[0])))) - if timeout < 0: - raise ValueError('timeout option should be non-negative') - if batch_sampler is not None: - if batch_size > 1 or shuffle or sampler is not None or drop_last: - raise ValueError('batch_sampler option is mutually exclusive ' - 'with batch_size, shuffle, sampler, and ' - 'drop_last') - self.batch_size = None - self.drop_last = None +def pin_memory_batch(batch): + if isinstance(batch, torch.Tensor): + return batch.pin_memory() + elif isinstance(batch, string_classes): + return batch + elif isinstance(batch, container_abcs.Mapping): + return {k: pin_memory_batch(sample) for k, sample in batch.items()} + elif isinstance(batch, container_abcs.Sequence): + return [pin_memory_batch(sample) for sample in batch] + else: + return batch - if sampler is not None and shuffle: - raise ValueError('sampler option is mutually exclusive with ' - 'shuffle') - if self.num_workers < 0: - raise ValueError('num_workers option cannot be negative; ' - 'use num_workers=0 to disable multiprocessing.') +_SIGCHLD_handler_set = False +r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if sys.platform == 'win32': + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + # This doesn't catch default handler, but SIGCHLD default handler is a + # no-op. + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True + + +_python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" - if batch_sampler is None: - if sampler is None: - if shuffle: - sampler = RandomSampler(dataset) - else: - sampler = SequentialSampler(dataset) - batch_sampler = BatchSampler(sampler, batch_size, drop_last) - self.sampler = sampler - self.batch_sampler = batch_sampler - self.__initialized = True +def _set_python_exit_flag(): + global _python_exit_status + _python_exit_status = True - def __setattr__(self, attr, val): - if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): - raise ValueError('{} attribute should not be set after {} is ' - 'initialized'.format(attr, self.__class__.__name__)) - - super(DataLoader, self).__setattr__(attr, val) - - def __iter__(self): - return _DataLoaderIter(self) - - def __len__(self): - return len(self.batch_sampler) +atexit.register(_set_python_exit_flag) class _DataLoaderIter(object): @@ -184,13 +354,12 @@ class _DataLoaderIter(object): # Therefore, in this case, we actually need to prevent `__del__` from # being executed, and rely on the automatic termination of daemonic # children. Thus, we register an `atexit` hook that sets a global flag - # `_utils.python_exit_status`. Since `atexit` hooks are executed in the - # reverse order of registration, we are guaranteed that this flag is - # set before library resources we use are freed. (Hooks freeing those - # resources are registered at importing the Python core libraries at - # the top of this file.) So in `__del__`, we check if - # `_utils.python_exit_status` is set or `None` (freed), and perform - # no-op if so. + # `_python_exit_status`. Since `atexit` hooks are executed in reverse + # order of registration, we are guaranteed that this flag is set before + # library resources we use are freed. (Hooks freeing those resources + # are registered at importing the Python core libraries at the top of + # this file.) So in `__del__`, we check if `_python_exit_status` is set + # or `None` (freed), and perform no-op if so. # # Another problem with `__del__` is also related to the library cleanup # calls. When a process ends, it shuts the all its daemonic children @@ -250,8 +419,8 @@ class _DataLoaderIter(object): # For `.get()` calls where the sender(s) is not the workers, we # guard them with timeouts, and check the status of the sender # when timeout happens: - # + in the workers, the `_utils.worker.ManagerWatchdog` class - # checks the status of the main process. + # + in the workers, the `ManagerWatchdog` class checks the main + # process status. # + if `pin_memory=True`, when getting from `pin_memory_thread`, # check `pin_memory_thread` status periodically until `.get()` # returns or see that `pin_memory_thread` died. @@ -376,7 +545,7 @@ class _DataLoaderIter(object): index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( - target=_utils.worker._worker_loop, + target=_worker_loop, args=(self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, base_seed + i, @@ -395,7 +564,7 @@ class _DataLoaderIter(object): if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( - target=_utils.pin_memory._pin_memory_loop, + target=_pin_memory_loop, args=(self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event)) pin_memory_thread.daemon = True @@ -406,8 +575,8 @@ class _DataLoaderIter(object): else: self.data_queue = self.worker_result_queue - _utils.signal_handling._update_worker_pids(id(self), tuple(w.pid for w in self.workers)) - _utils.signal_handling._set_SIGCHLD_handler() + _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) + _set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop @@ -429,7 +598,7 @@ class _DataLoaderIter(object): elif self.pin_memory: while self.pin_memory_thread.is_alive(): try: - return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue else: @@ -445,7 +614,7 @@ class _DataLoaderIter(object): indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: - batch = _utils.pin_memory.pin_memory_batch(batch) + batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated @@ -485,7 +654,7 @@ class _DataLoaderIter(object): def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() - if isinstance(batch, _utils.ExceptionWrapper): + if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch @@ -500,8 +669,7 @@ class _DataLoaderIter(object): def _shutdown_workers(self): # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. - python_exit_status = _utils.python_exit_status - if python_exit_status is True or python_exit_status is None: + if _python_exit_status is True or _python_exit_status is None: # See (2) of the note. If Python is shutting down, do no-op. return # Normal exit when last reference is gone / iterator is depleted. @@ -511,7 +679,7 @@ class _DataLoaderIter(object): # Removes pids from the C side data structure first so worker # termination afterwards won't trigger false positive error report. if self.worker_pids_set: - _utils.signal_handling._remove_worker_pids(id(self)) + _remove_worker_pids(id(self)) self.worker_pids_set = False self.done_event.set() @@ -547,3 +715,108 @@ class _DataLoaderIter(object): def __del__(self): if self.num_workers > 0: self._shutdown_workers() + + +class DataLoader(object): + r""" + Data loader. Combines a dataset and a sampler, and provides + single- or multi-process iterators over the dataset. + + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, ``shuffle`` must be False. + batch_sampler (Sampler, optional): like sampler, but returns a batch of + indices at a time. Mutually exclusive with :attr:`batch_size`, + :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (callable, optional): merges a list of samples to form a mini-batch. + pin_memory (bool, optional): If ``True``, the data loader will copy tensors + into CUDA pinned memory before returning them. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + + .. note:: By default, each worker will have its PyTorch seed set to + ``base_seed + worker_id``, where ``base_seed`` is a long generated + by main process using its RNG. However, seeds for other libraies + may be duplicated upon initializing workers (w.g., NumPy), causing + each worker to return identical random numbers. (See + :ref:`dataloader-workers-random-seed` section in FAQ.) You may + use :func:`torch.initial_seed()` to access the PyTorch seed for + each worker in :attr:`worker_init_fn`, and use it to set other + seeds before data loading. + + .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an + unpicklable object, e.g., a lambda function. + """ + + __initialized = False + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, + num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.collate_fn = collate_fn + self.pin_memory = pin_memory + self.drop_last = drop_last + self.timeout = timeout + self.worker_init_fn = worker_init_fn + + if timeout < 0: + raise ValueError('timeout option should be non-negative') + + if batch_sampler is not None: + if batch_size > 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler option is mutually exclusive ' + 'with batch_size, shuffle, sampler, and ' + 'drop_last') + self.batch_size = None + self.drop_last = None + + if sampler is not None and shuffle: + raise ValueError('sampler option is mutually exclusive with ' + 'shuffle') + + if self.num_workers < 0: + raise ValueError('num_workers option cannot be negative; ' + 'use num_workers=0 to disable multiprocessing.') + + if batch_sampler is None: + if sampler is None: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.sampler = sampler + self.batch_sampler = batch_sampler + self.__initialized = True + + def __setattr__(self, attr, val): + if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): + raise ValueError('{} attribute should not be set after {} is ' + 'initialized'.format(attr, self.__class__.__name__)) + + super(DataLoader, self).__setattr__(attr, val) + + def __iter__(self): + return _DataLoaderIter(self) + + def __len__(self): + return len(self.batch_sampler) -- 2.7.4