-r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
-
-To support these two classes, in `./_utils` we define many utility methods and
-functions to be run in multiprocessing. E.g., the data loading worker loop is
-in `./_utils/worker.py`.
-"""
-
+import random
import torch
import torch.multiprocessing as multiprocessing
+from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
+ _remove_worker_pids, _error_if_any_worker_fails
from . import SequentialSampler, RandomSampler, BatchSampler
-from . import _utils
+import signal
+import functools
+from torch._six import container_abcs
+import re
+import sys
import threading
-from torch._six import queue
-
-
-# This function used to be defined in this file. However, it was moved to
-# _utils/collate.py. Although it is rather hard to access this from user land
-# (one has to explicitly directly `import torch.utils.data.dataloader`), there
-# probably is user code out there using it. This aliasing maintains BC in this
-# aspect.
-default_collate = _utils.collate.default_collate
-
-
-class DataLoader(object):
- r"""
- Data loader. Combines a dataset and a sampler, and provides
- single- or multi-process iterators over the dataset.
-
- Arguments:
- dataset (Dataset): dataset from which to load the data.
- batch_size (int, optional): how many samples per batch to load
- (default: ``1``).
- shuffle (bool, optional): set to ``True`` to have the data reshuffled
- at every epoch (default: ``False``).
- sampler (Sampler, optional): defines the strategy to draw samples from
- the dataset. If specified, ``shuffle`` must be False.
- batch_sampler (Sampler, optional): like sampler, but returns a batch of
- indices at a time. Mutually exclusive with :attr:`batch_size`,
- :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
- num_workers (int, optional): how many subprocesses to use for data
- loading. 0 means that the data will be loaded in the main process.
- (default: ``0``)
- collate_fn (callable, optional): merges a list of samples to form a mini-batch.
- pin_memory (bool, optional): If ``True``, the data loader will copy tensors
- into CUDA pinned memory before returning them.
- drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
- if the dataset size is not divisible by the batch size. If ``False`` and
- the size of dataset is not divisible by the batch size, then the last batch
- will be smaller. (default: ``False``)
- timeout (numeric, optional): if positive, the timeout value for collecting a batch
- from workers. Should always be non-negative. (default: ``0``)
- worker_init_fn (callable, optional): If not ``None``, this will be called on each
- worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
- input, after seeding and before data loading. (default: ``None``)
-
- .. note:: By default, each worker will have its PyTorch seed set to
- ``base_seed + worker_id``, where ``base_seed`` is a long generated
- by main process using its RNG. However, seeds for other libraies
- may be duplicated upon initializing workers (w.g., NumPy), causing
- each worker to return identical random numbers. (See
- :ref:`dataloader-workers-random-seed` section in FAQ.) You may
- use :func:`torch.initial_seed()` to access the PyTorch seed for
- each worker in :attr:`worker_init_fn`, and use it to set other
- seeds before data loading.
-
- .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
- unpicklable object, e.g., a lambda function.
- """
-
- __initialized = False
+import traceback
+import os
+import time
+import atexit
+from torch._six import string_classes, int_classes, FileNotFoundError
+
+IS_WINDOWS = sys.platform == "win32"
+if IS_WINDOWS:
+ import ctypes
+ from ctypes.wintypes import DWORD, BOOL, HANDLE
+
+if sys.version_info[0] == 2:
+ import Queue as queue
+else:
+ import queue
+
+
+# NOTE [ Python Traceback Reference Cycle Problem ]
+#
+# When using sys.exc_info(), it is important to **not** store the exc_info[2],
+# which is the traceback, because otherwise you will run into the traceback
+# reference cycle problem, i.e., the traceback holding reference to the frame,
+# and the frame (which holds reference to all the object in its temporary scope)
+# holding reference the traceback.
+
+
+class ExceptionWrapper(object):
+ r"""Wraps an exception plus traceback to communicate across threads"""
+ def __init__(self, exc_info):
+ # It is important that we don't store exc_info, see
+ # NOTE [ Python Traceback Reference Cycle Problem ]
+ self.exc_type = exc_info[0]
+ self.exc_msg = "".join(traceback.format_exception(*exc_info))
+
+
+_use_shared_memory = False
+r"""Whether to use shared memory in default_collate"""
+
+MP_STATUS_CHECK_INTERVAL = 5.0
+r"""Interval (in seconds) to check status of processes to avoid hanging in
+ multiprocessing data loading. This is mainly used in getting data from
+ another process, in which case we need to periodically check whether the
+ sender is alive to prevent hanging."""
+
+if IS_WINDOWS:
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
+ # of the manager and ask if the process status has changed.
+ class ManagerWatchdog(object):
+ def __init__(self):
+ self.manager_pid = os.getppid()
+
+ self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
+ self.kernel32.OpenProcess.restype = HANDLE
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
+ self.kernel32.WaitForSingleObject.restype = DWORD
+
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
+ SYNCHRONIZE = 0x00100000
+ self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
+
+ if not self.manager_handle:
+ raise ctypes.WinError(ctypes.get_last_error())
+
+ self.manager_dead = False
+
+ def is_alive(self):
+ if not self.manager_dead:
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+ self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+ return not self.manager_dead
+else:
+ class ManagerWatchdog(object):
+ def __init__(self):
+ self.manager_pid = os.getppid()
+ self.manager_dead = False
+
+ def is_alive(self):
+ if not self.manager_dead:
+ self.manager_dead = os.getppid() != self.manager_pid
+ return not self.manager_dead
+
+
+def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+ # logic of this function.
+
+ try:
+ global _use_shared_memory
+ _use_shared_memory = True
+
+ # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+ # module's handlers are executed after Python returns from C low-level
+ # handlers, likely when the same fatal signal happened again already.
+ # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
+ _set_worker_signal_handlers()
+
+ torch.set_num_threads(1)
+ random.seed(seed)
+ torch.manual_seed(seed)
+
+ data_queue.cancel_join_thread()
+
+ if init_fn is not None:
+ init_fn(worker_id)
+
+ watchdog = ManagerWatchdog()
+
+ while watchdog.is_alive():
+ try:
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ except queue.Empty:
+ continue
+ if r is None:
+ # Received the final signal
+ assert done_event.is_set()
+ return
+ elif done_event.is_set():
+ # Done event is set. But I haven't received the final signal
+ # (None) yet. I will keep continuing until get it, and skip the
+ # processing steps.
+ continue
+ idx, batch_indices = r
+ try:
+ samples = collate_fn([dataset[i] for i in batch_indices])
+ except Exception:
+ # It is important that we don't store exc_info in a variable,
+ # see NOTE [ Python Traceback Reference Cycle Problem ]
+ data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+ else:
+ data_queue.put((idx, samples))
+ del samples
+ except KeyboardInterrupt:
+ # Main process will raise KeyboardInterrupt anyways.
+ pass
+
+
+def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
+ torch.cuda.set_device(device_id)
+
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+ # logic of this function.
+ while True:
+ try:
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ except queue.Empty:
+ continue
+ except Exception:
+ if done_event.is_set():
+ # Weird things can happen when shutting down, e.g., fd being
+ # closed when tensors are shared via fds.
+ break
+ raise
+ if r is None:
+ assert done_event.is_set()
+ return
+ elif done_event.is_set():
+ # Haven't seen the final signal yet. Keep getting until None.
+ continue
+ elif isinstance(r[1], ExceptionWrapper):
+ out_queue.put(r)
+ else:
+ idx, batch = r
+ try:
+ batch = pin_memory_batch(batch)
+ except Exception:
+ out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
+ else:
+ out_queue.put((idx, batch))
+
+numpy_type_map = {
+ 'float64': torch.DoubleTensor,
+ 'float32': torch.FloatTensor,
+ 'float16': torch.HalfTensor,
+ 'int64': torch.LongTensor,
+ 'int32': torch.IntTensor,
+ 'int16': torch.ShortTensor,
+ 'int8': torch.CharTensor,
+ 'uint8': torch.ByteTensor,
+}
+
+
+def default_collate(batch):
+ r"""Puts each data field into a tensor with outer dimension batch size"""
+
+ error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
+ elem_type = type(batch[0])
+ if isinstance(batch[0], torch.Tensor):
+ out = None
+ if _use_shared_memory:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = batch[0].storage()._new_shared(numel)
+ out = batch[0].new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ elem = batch[0]
+ if elem_type.__name__ == 'ndarray':
+ # array of string classes and object
+ if re.search('[SaUO]', elem.dtype.str) is not None:
+ raise TypeError(error_msg.format(elem.dtype))
+
+ return torch.stack([torch.from_numpy(b) for b in batch], 0)
+ if elem.shape == (): # scalars
+ py_type = float if elem.dtype.name.startswith('float') else int
+ return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
+ elif isinstance(batch[0], int_classes):
+ return torch.LongTensor(batch)
+ elif isinstance(batch[0], float):
+ return torch.DoubleTensor(batch)
+ elif isinstance(batch[0], string_classes):
+ return batch
+ elif isinstance(batch[0], container_abcs.Mapping):
+ return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
+ elif isinstance(batch[0], container_abcs.Sequence):
+ transposed = zip(*batch)
+ return [default_collate(samples) for samples in transposed]
- def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
- batch_sampler=None, num_workers=0, collate_fn=default_collate,
- pin_memory=False, drop_last=False, timeout=0,
- worker_init_fn=None):
- self.dataset = dataset
- self.batch_size = batch_size
- self.num_workers = num_workers
- self.collate_fn = collate_fn
- self.pin_memory = pin_memory
- self.drop_last = drop_last
- self.timeout = timeout
- self.worker_init_fn = worker_init_fn
+ raise TypeError((error_msg.format(type(batch[0]))))
- if timeout < 0:
- raise ValueError('timeout option should be non-negative')
- if batch_sampler is not None:
- if batch_size > 1 or shuffle or sampler is not None or drop_last:
- raise ValueError('batch_sampler option is mutually exclusive '
- 'with batch_size, shuffle, sampler, and '
- 'drop_last')
- self.batch_size = None
- self.drop_last = None
+def pin_memory_batch(batch):
+ if isinstance(batch, torch.Tensor):
+ return batch.pin_memory()
+ elif isinstance(batch, string_classes):
+ return batch
+ elif isinstance(batch, container_abcs.Mapping):
+ return {k: pin_memory_batch(sample) for k, sample in batch.items()}
+ elif isinstance(batch, container_abcs.Sequence):
+ return [pin_memory_batch(sample) for sample in batch]
+ else:
+ return batch
- if sampler is not None and shuffle:
- raise ValueError('sampler option is mutually exclusive with '
- 'shuffle')
- if self.num_workers < 0:
- raise ValueError('num_workers option cannot be negative; '
- 'use num_workers=0 to disable multiprocessing.')
+_SIGCHLD_handler_set = False
+r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
+handler needs to be set for all DataLoaders in a process."""
+
+
+def _set_SIGCHLD_handler():
+ # Windows doesn't support SIGCHLD handler
+ if sys.platform == 'win32':
+ return
+ # can't set signal in child threads
+ if not isinstance(threading.current_thread(), threading._MainThread):
+ return
+ global _SIGCHLD_handler_set
+ if _SIGCHLD_handler_set:
+ return
+ previous_handler = signal.getsignal(signal.SIGCHLD)
+ if not callable(previous_handler):
+ # This doesn't catch default handler, but SIGCHLD default handler is a
+ # no-op.
+ previous_handler = None
+
+ def handler(signum, frame):
+ # This following call uses `waitid` with WNOHANG from C side. Therefore,
+ # Python can still get and update the process status successfully.
+ _error_if_any_worker_fails()
+ if previous_handler is not None:
+ previous_handler(signum, frame)
+
+ signal.signal(signal.SIGCHLD, handler)
+ _SIGCHLD_handler_set = True
+
+
+_python_exit_status = False
+r"""Whether Python is shutting down. This flag is guaranteed to be set before
+the Python core library resources are freed, but Python may already be exiting
+for some time when this is set.
+
+Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
+hook in Python 3.7 multiprocessing library:
+https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
+"""
- if batch_sampler is None:
- if sampler is None:
- if shuffle:
- sampler = RandomSampler(dataset)
- else:
- sampler = SequentialSampler(dataset)
- batch_sampler = BatchSampler(sampler, batch_size, drop_last)
- self.sampler = sampler
- self.batch_sampler = batch_sampler
- self.__initialized = True
+def _set_python_exit_flag():
+ global _python_exit_status
+ _python_exit_status = True
- def __setattr__(self, attr, val):
- if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
- raise ValueError('{} attribute should not be set after {} is '
- 'initialized'.format(attr, self.__class__.__name__))
-
- super(DataLoader, self).__setattr__(attr, val)
-
- def __iter__(self):
- return _DataLoaderIter(self)
-
- def __len__(self):
- return len(self.batch_sampler)
+atexit.register(_set_python_exit_flag)
class _DataLoaderIter(object):
# Therefore, in this case, we actually need to prevent `__del__` from
# being executed, and rely on the automatic termination of daemonic
# children. Thus, we register an `atexit` hook that sets a global flag
- # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
- # reverse order of registration, we are guaranteed that this flag is
- # set before library resources we use are freed. (Hooks freeing those
- # resources are registered at importing the Python core libraries at
- # the top of this file.) So in `__del__`, we check if
- # `_utils.python_exit_status` is set or `None` (freed), and perform
- # no-op if so.
+ # `_python_exit_status`. Since `atexit` hooks are executed in reverse
+ # order of registration, we are guaranteed that this flag is set before
+ # library resources we use are freed. (Hooks freeing those resources
+ # are registered at importing the Python core libraries at the top of
+ # this file.) So in `__del__`, we check if `_python_exit_status` is set
+ # or `None` (freed), and perform no-op if so.
#
# Another problem with `__del__` is also related to the library cleanup
# calls. When a process ends, it shuts the all its daemonic children
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
# when timeout happens:
- # + in the workers, the `_utils.worker.ManagerWatchdog` class
- # checks the status of the main process.
+ # + in the workers, the `ManagerWatchdog` class checks the main
+ # process status.
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
# check `pin_memory_thread` status periodically until `.get()`
# returns or see that `pin_memory_thread` died.
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
- target=_utils.worker._worker_loop,
+ target=_worker_loop,
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event,
self.collate_fn, base_seed + i,
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
- target=_utils.pin_memory._pin_memory_loop,
+ target=_pin_memory_loop,
args=(self.worker_result_queue, self.data_queue,
torch.cuda.current_device(), self.done_event))
pin_memory_thread.daemon = True
else:
self.data_queue = self.worker_result_queue
- _utils.signal_handling._update_worker_pids(id(self), tuple(w.pid for w in self.workers))
- _utils.signal_handling._set_SIGCHLD_handler()
+ _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
+ _set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
elif self.pin_memory:
while self.pin_memory_thread.is_alive():
try:
- return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
+ return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
else:
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
- batch = _utils.pin_memory.pin_memory_batch(batch)
+ batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
- if isinstance(batch, _utils.ExceptionWrapper):
+ if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def _shutdown_workers(self):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
- python_exit_status = _utils.python_exit_status
- if python_exit_status is True or python_exit_status is None:
+ if _python_exit_status is True or _python_exit_status is None:
# See (2) of the note. If Python is shutting down, do no-op.
return
# Normal exit when last reference is gone / iterator is depleted.
# Removes pids from the C side data structure first so worker
# termination afterwards won't trigger false positive error report.
if self.worker_pids_set:
- _utils.signal_handling._remove_worker_pids(id(self))
+ _remove_worker_pids(id(self))
self.worker_pids_set = False
self.done_event.set()
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()
+
+
+class DataLoader(object):
+ r"""
+ Data loader. Combines a dataset and a sampler, and provides
+ single- or multi-process iterators over the dataset.
+
+ Arguments:
+ dataset (Dataset): dataset from which to load the data.
+ batch_size (int, optional): how many samples per batch to load
+ (default: ``1``).
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: ``False``).
+ sampler (Sampler, optional): defines the strategy to draw samples from
+ the dataset. If specified, ``shuffle`` must be False.
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
+ indices at a time. Mutually exclusive with :attr:`batch_size`,
+ :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. 0 means that the data will be loaded in the main process.
+ (default: ``0``)
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
+ into CUDA pinned memory before returning them.
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+ if the dataset size is not divisible by the batch size. If ``False`` and
+ the size of dataset is not divisible by the batch size, then the last batch
+ will be smaller. (default: ``False``)
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
+ from workers. Should always be non-negative. (default: ``0``)
+ worker_init_fn (callable, optional): If not ``None``, this will be called on each
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+ input, after seeding and before data loading. (default: ``None``)
+
+ .. note:: By default, each worker will have its PyTorch seed set to
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
+ by main process using its RNG. However, seeds for other libraies
+ may be duplicated upon initializing workers (w.g., NumPy), causing
+ each worker to return identical random numbers. (See
+ :ref:`dataloader-workers-random-seed` section in FAQ.) You may
+ use :func:`torch.initial_seed()` to access the PyTorch seed for
+ each worker in :attr:`worker_init_fn`, and use it to set other
+ seeds before data loading.
+
+ .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
+ unpicklable object, e.g., a lambda function.
+ """
+
+ __initialized = False
+
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
+ num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
+ timeout=0, worker_init_fn=None):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.collate_fn = collate_fn
+ self.pin_memory = pin_memory
+ self.drop_last = drop_last
+ self.timeout = timeout
+ self.worker_init_fn = worker_init_fn
+
+ if timeout < 0:
+ raise ValueError('timeout option should be non-negative')
+
+ if batch_sampler is not None:
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
+ raise ValueError('batch_sampler option is mutually exclusive '
+ 'with batch_size, shuffle, sampler, and '
+ 'drop_last')
+ self.batch_size = None
+ self.drop_last = None
+
+ if sampler is not None and shuffle:
+ raise ValueError('sampler option is mutually exclusive with '
+ 'shuffle')
+
+ if self.num_workers < 0:
+ raise ValueError('num_workers option cannot be negative; '
+ 'use num_workers=0 to disable multiprocessing.')
+
+ if batch_sampler is None:
+ if sampler is None:
+ if shuffle:
+ sampler = RandomSampler(dataset)
+ else:
+ sampler = SequentialSampler(dataset)
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+ self.sampler = sampler
+ self.batch_sampler = batch_sampler
+ self.__initialized = True
+
+ def __setattr__(self, attr, val):
+ if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+ raise ValueError('{} attribute should not be set after {} is '
+ 'initialized'.format(attr, self.__class__.__name__))
+
+ super(DataLoader, self).__setattr__(attr, val)
+
+ def __iter__(self):
+ return _DataLoaderIter(self)
+
+ def __len__(self):
+ return len(self.batch_sampler)