-import random
+r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
+
+To support these two classes, in `./_utils` we define many utility methods and
+functions to be run in multiprocessing. E.g., the data loading worker loop is
+in `./_utils/worker.py`.
+"""
+
import torch
import torch.multiprocessing as multiprocessing
-from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
- _remove_worker_pids, _error_if_any_worker_fails
from . import SequentialSampler, RandomSampler, BatchSampler
-import signal
-import functools
-from torch._six import container_abcs
-import re
-import sys
+from . import _utils
import threading
-import traceback
-import os
-import time
-import atexit
-from torch._six import string_classes, int_classes, FileNotFoundError
-
-IS_WINDOWS = sys.platform == "win32"
-if IS_WINDOWS:
- import ctypes
- from ctypes.wintypes import DWORD, BOOL, HANDLE
-
-if sys.version_info[0] == 2:
- import Queue as queue
-else:
- import queue
-
-
-# NOTE [ Python Traceback Reference Cycle Problem ]
-#
-# When using sys.exc_info(), it is important to **not** store the exc_info[2],
-# which is the traceback, because otherwise you will run into the traceback
-# reference cycle problem, i.e., the traceback holding reference to the frame,
-# and the frame (which holds reference to all the object in its temporary scope)
-# holding reference the traceback.
-
-
-class ExceptionWrapper(object):
- r"""Wraps an exception plus traceback to communicate across threads"""
- def __init__(self, exc_info):
- # It is important that we don't store exc_info, see
- # NOTE [ Python Traceback Reference Cycle Problem ]
- self.exc_type = exc_info[0]
- self.exc_msg = "".join(traceback.format_exception(*exc_info))
-
-
-_use_shared_memory = False
-r"""Whether to use shared memory in default_collate"""
-
-MP_STATUS_CHECK_INTERVAL = 5.0
-r"""Interval (in seconds) to check status of processes to avoid hanging in
- multiprocessing data loading. This is mainly used in getting data from
- another process, in which case we need to periodically check whether the
- sender is alive to prevent hanging."""
-
-if IS_WINDOWS:
- # On Windows, the parent ID of the worker process remains unchanged when the manager process
- # is gone, and the only way to check it through OS is to let the worker have a process handle
- # of the manager and ask if the process status has changed.
- class ManagerWatchdog(object):
- def __init__(self):
- self.manager_pid = os.getppid()
-
- self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
- self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
- self.kernel32.OpenProcess.restype = HANDLE
- self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
- self.kernel32.WaitForSingleObject.restype = DWORD
-
- # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
- SYNCHRONIZE = 0x00100000
- self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
-
- if not self.manager_handle:
- raise ctypes.WinError(ctypes.get_last_error())
-
- self.manager_dead = False
-
- def is_alive(self):
- if not self.manager_dead:
- # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
- self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
- return not self.manager_dead
-else:
- class ManagerWatchdog(object):
- def __init__(self):
- self.manager_pid = os.getppid()
- self.manager_dead = False
-
- def is_alive(self):
- if not self.manager_dead:
- self.manager_dead = os.getppid() != self.manager_pid
- return not self.manager_dead
-
-
-def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
- # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
- # logic of this function.
-
- try:
- global _use_shared_memory
- _use_shared_memory = True
-
- # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
- # module's handlers are executed after Python returns from C low-level
- # handlers, likely when the same fatal signal had already happened.
- # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
- _set_worker_signal_handlers()
-
- torch.set_num_threads(1)
- random.seed(seed)
- torch.manual_seed(seed)
-
- data_queue.cancel_join_thread()
-
- if init_fn is not None:
- init_fn(worker_id)
-
- watchdog = ManagerWatchdog()
-
- while watchdog.is_alive():
- try:
- r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
- except queue.Empty:
- continue
- if r is None:
- # Received the final signal
- assert done_event.is_set()
- return
- elif done_event.is_set():
- # Done event is set. But I haven't received the final signal
- # (None) yet. I will keep continuing until get it, and skip the
- # processing steps.
- continue
- idx, batch_indices = r
- try:
- samples = collate_fn([dataset[i] for i in batch_indices])
- except Exception:
- # It is important that we don't store exc_info in a variable,
- # see NOTE [ Python Traceback Reference Cycle Problem ]
- data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
- else:
- data_queue.put((idx, samples))
- del samples
- except KeyboardInterrupt:
- # Main process will raise KeyboardInterrupt anyways.
- pass
-
-
-def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
- torch.cuda.set_device(device_id)
-
- # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
- # logic of this function.
- while True:
- try:
- r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
- except queue.Empty:
- continue
- except Exception:
- if done_event.is_set():
- # Weird things can happen when shutting down, e.g., fd being
- # closed when tensors are shared via fds.
- break
- raise
- if r is None:
- assert done_event.is_set()
- return
- elif done_event.is_set():
- # Haven't seen the final signal yet. Keep getting until None.
- continue
- elif isinstance(r[1], ExceptionWrapper):
- out_queue.put(r)
- else:
- idx, batch = r
- try:
- batch = pin_memory_batch(batch)
- except Exception:
- out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
- else:
- out_queue.put((idx, batch))
-
-numpy_type_map = {
- 'float64': torch.DoubleTensor,
- 'float32': torch.FloatTensor,
- 'float16': torch.HalfTensor,
- 'int64': torch.LongTensor,
- 'int32': torch.IntTensor,
- 'int16': torch.ShortTensor,
- 'int8': torch.CharTensor,
- 'uint8': torch.ByteTensor,
-}
-
-
-def default_collate(batch):
- r"""Puts each data field into a tensor with outer dimension batch size"""
-
- error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
- elem_type = type(batch[0])
- if isinstance(batch[0], torch.Tensor):
- out = None
- if _use_shared_memory:
- # If we're in a background process, concatenate directly into a
- # shared memory tensor to avoid an extra copy
- numel = sum([x.numel() for x in batch])
- storage = batch[0].storage()._new_shared(numel)
- out = batch[0].new(storage)
- return torch.stack(batch, 0, out=out)
- elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
- and elem_type.__name__ != 'string_':
- elem = batch[0]
- if elem_type.__name__ == 'ndarray':
- # array of string classes and object
- if re.search('[SaUO]', elem.dtype.str) is not None:
- raise TypeError(error_msg.format(elem.dtype))
-
- return torch.stack([torch.from_numpy(b) for b in batch], 0)
- if elem.shape == (): # scalars
- py_type = float if elem.dtype.name.startswith('float') else int
- return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
- elif isinstance(batch[0], int_classes):
- return torch.LongTensor(batch)
- elif isinstance(batch[0], float):
- return torch.DoubleTensor(batch)
- elif isinstance(batch[0], string_classes):
- return batch
- elif isinstance(batch[0], container_abcs.Mapping):
- return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
- elif isinstance(batch[0], container_abcs.Sequence):
- transposed = zip(*batch)
- return [default_collate(samples) for samples in transposed]
+from torch._six import queue
- raise TypeError((error_msg.format(type(batch[0]))))
+# This function used to be defined in this file. However, it was moved to
+# _utils/collate.py. Although it is rather hard to access this from user land
+# (one has to explicitly directly `import torch.utils.data.dataloader`), there
+# probably is user code out there using it. This aliasing maintains BC in this
+# aspect.
+default_collate = _utils.collate.default_collate
-def pin_memory_batch(batch):
- if isinstance(batch, torch.Tensor):
- return batch.pin_memory()
- elif isinstance(batch, string_classes):
- return batch
- elif isinstance(batch, container_abcs.Mapping):
- return {k: pin_memory_batch(sample) for k, sample in batch.items()}
- elif isinstance(batch, container_abcs.Sequence):
- return [pin_memory_batch(sample) for sample in batch]
- else:
- return batch
+class DataLoader(object):
+ r"""
+ Data loader. Combines a dataset and a sampler, and provides
+ single- or multi-process iterators over the dataset.
-_SIGCHLD_handler_set = False
-r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
-handler needs to be set for all DataLoaders in a process."""
-
-
-def _set_SIGCHLD_handler():
- # Windows doesn't support SIGCHLD handler
- if sys.platform == 'win32':
- return
- # can't set signal in child threads
- if not isinstance(threading.current_thread(), threading._MainThread):
- return
- global _SIGCHLD_handler_set
- if _SIGCHLD_handler_set:
- return
- previous_handler = signal.getsignal(signal.SIGCHLD)
- if not callable(previous_handler):
- # This doesn't catch default handler, but SIGCHLD default handler is a
- # no-op.
- previous_handler = None
-
- def handler(signum, frame):
- # This following call uses `waitid` with WNOHANG from C side. Therefore,
- # Python can still get and update the process status successfully.
- _error_if_any_worker_fails()
- if previous_handler is not None:
- previous_handler(signum, frame)
-
- signal.signal(signal.SIGCHLD, handler)
- _SIGCHLD_handler_set = True
-
-
-_python_exit_status = False
-r"""Whether Python is shutting down. This flag is guaranteed to be set before
-the Python core library resources are freed, but Python may already be exiting
-for some time when this is set.
-
-Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar
-hook in Python 3.7 multiprocessing library:
-https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327
-"""
+ Arguments:
+ dataset (Dataset): dataset from which to load the data.
+ batch_size (int, optional): how many samples per batch to load
+ (default: ``1``).
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: ``False``).
+ sampler (Sampler, optional): defines the strategy to draw samples from
+ the dataset. If specified, ``shuffle`` must be False.
+ batch_sampler (Sampler, optional): like sampler, but returns a batch of
+ indices at a time. Mutually exclusive with :attr:`batch_size`,
+ :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. 0 means that the data will be loaded in the main process.
+ (default: ``0``)
+ collate_fn (callable, optional): merges a list of samples to form a mini-batch.
+ pin_memory (bool, optional): If ``True``, the data loader will copy tensors
+ into CUDA pinned memory before returning them.
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+ if the dataset size is not divisible by the batch size. If ``False`` and
+ the size of dataset is not divisible by the batch size, then the last batch
+ will be smaller. (default: ``False``)
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
+ from workers. Should always be non-negative. (default: ``0``)
+ worker_init_fn (callable, optional): If not ``None``, this will be called on each
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+ input, after seeding and before data loading. (default: ``None``)
+
+ .. note:: By default, each worker will have its PyTorch seed set to
+ ``base_seed + worker_id``, where ``base_seed`` is a long generated
+ by main process using its RNG. However, seeds for other libraies
+ may be duplicated upon initializing workers (w.g., NumPy), causing
+ each worker to return identical random numbers. (See
+ :ref:`dataloader-workers-random-seed` section in FAQ.) You may
+ use :func:`torch.initial_seed()` to access the PyTorch seed for
+ each worker in :attr:`worker_init_fn`, and use it to set other
+ seeds before data loading.
+
+ .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
+ unpicklable object, e.g., a lambda function.
+ """
+
+ __initialized = False
+
+ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
+ batch_sampler=None, num_workers=0, collate_fn=default_collate,
+ pin_memory=False, drop_last=False, timeout=0,
+ worker_init_fn=None):
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.collate_fn = collate_fn
+ self.pin_memory = pin_memory
+ self.drop_last = drop_last
+ self.timeout = timeout
+ self.worker_init_fn = worker_init_fn
+
+ if timeout < 0:
+ raise ValueError('timeout option should be non-negative')
+
+ if batch_sampler is not None:
+ if batch_size > 1 or shuffle or sampler is not None or drop_last:
+ raise ValueError('batch_sampler option is mutually exclusive '
+ 'with batch_size, shuffle, sampler, and '
+ 'drop_last')
+ self.batch_size = None
+ self.drop_last = None
+ if sampler is not None and shuffle:
+ raise ValueError('sampler option is mutually exclusive with '
+ 'shuffle')
-def _set_python_exit_flag():
- global _python_exit_status
- _python_exit_status = True
+ if self.num_workers < 0:
+ raise ValueError('num_workers option cannot be negative; '
+ 'use num_workers=0 to disable multiprocessing.')
-atexit.register(_set_python_exit_flag)
+ if batch_sampler is None:
+ if sampler is None:
+ if shuffle:
+ sampler = RandomSampler(dataset)
+ else:
+ sampler = SequentialSampler(dataset)
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+ self.sampler = sampler
+ self.batch_sampler = batch_sampler
+ self.__initialized = True
+
+ def __setattr__(self, attr, val):
+ if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+ raise ValueError('{} attribute should not be set after {} is '
+ 'initialized'.format(attr, self.__class__.__name__))
+
+ super(DataLoader, self).__setattr__(attr, val)
+
+ def __iter__(self):
+ return _DataLoaderIter(self)
+
+ def __len__(self):
+ return len(self.batch_sampler)
class _DataLoaderIter(object):
# Therefore, in this case, we actually need to prevent `__del__` from
# being executed, and rely on the automatic termination of daemonic
# children. Thus, we register an `atexit` hook that sets a global flag
- # `_python_exit_status`. Since `atexit` hooks are executed in reverse
- # order of registration, we are guaranteed that this flag is set before
- # library resources we use are freed. (Hooks freeing those resources
- # are registered at importing the Python core libraries at the top of
- # this file.) So in `__del__`, we check if `_python_exit_status` is set
- # or `None` (freed), and perform no-op if so.
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
+ # reverse order of registration, we are guaranteed that this flag is
+ # set before library resources we use are freed. (Hooks freeing those
+ # resources are registered at importing the Python core libraries at
+ # the top of this file.) So in `__del__`, we check if
+ # `_utils.python_exit_status` is set or `None` (freed), and perform
+ # no-op if so.
#
# Another problem with `__del__` is also related to the library cleanup
# calls. When a process ends, it shuts the all its daemonic children
# For `.get()` calls where the sender(s) is not the workers, we
# guard them with timeouts, and check the status of the sender
# when timeout happens:
- # + in the workers, the `ManagerWatchdog` class checks the main
- # process status.
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
+ # checks the status of the main process.
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
# check `pin_memory_thread` status periodically until `.get()`
# returns or see that `pin_memory_thread` died.
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
- target=_worker_loop,
+ target=_utils.worker._worker_loop,
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event,
self.collate_fn, base_seed + i,
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
- target=_pin_memory_loop,
+ target=_utils.pin_memory._pin_memory_loop,
args=(self.worker_result_queue, self.data_queue,
torch.cuda.current_device(), self.done_event))
pin_memory_thread.daemon = True
else:
self.data_queue = self.worker_result_queue
- _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
- _set_SIGCHLD_handler()
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
+ _utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
elif self.pin_memory:
while self.pin_memory_thread.is_alive():
try:
- return self.data_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
else:
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
- batch = pin_memory_batch(batch)
+ batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
- if isinstance(batch, ExceptionWrapper):
+ if isinstance(batch, _utils.ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
def _shutdown_workers(self):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
# the logic of this function.
- if _python_exit_status is True or _python_exit_status is None:
+ python_exit_status = _utils.python_exit_status
+ if python_exit_status is True or python_exit_status is None:
# See (2) of the note. If Python is shutting down, do no-op.
return
# Normal exit when last reference is gone / iterator is depleted.
# Removes pids from the C side data structure first so worker
# termination afterwards won't trigger false positive error report.
if self.worker_pids_set:
- _remove_worker_pids(id(self))
+ _utils.signal_handling._remove_worker_pids(id(self))
self.worker_pids_set = False
self.done_event.set()
def __del__(self):
if self.num_workers > 0:
self._shutdown_workers()
-
-
-class DataLoader(object):
- r"""
- Data loader. Combines a dataset and a sampler, and provides
- single- or multi-process iterators over the dataset.
-
- Arguments:
- dataset (Dataset): dataset from which to load the data.
- batch_size (int, optional): how many samples per batch to load
- (default: ``1``).
- shuffle (bool, optional): set to ``True`` to have the data reshuffled
- at every epoch (default: ``False``).
- sampler (Sampler, optional): defines the strategy to draw samples from
- the dataset. If specified, ``shuffle`` must be False.
- batch_sampler (Sampler, optional): like sampler, but returns a batch of
- indices at a time. Mutually exclusive with :attr:`batch_size`,
- :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
- num_workers (int, optional): how many subprocesses to use for data
- loading. 0 means that the data will be loaded in the main process.
- (default: ``0``)
- collate_fn (callable, optional): merges a list of samples to form a mini-batch.
- pin_memory (bool, optional): If ``True``, the data loader will copy tensors
- into CUDA pinned memory before returning them.
- drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
- if the dataset size is not divisible by the batch size. If ``False`` and
- the size of dataset is not divisible by the batch size, then the last batch
- will be smaller. (default: ``False``)
- timeout (numeric, optional): if positive, the timeout value for collecting a batch
- from workers. Should always be non-negative. (default: ``0``)
- worker_init_fn (callable, optional): If not ``None``, this will be called on each
- worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
- input, after seeding and before data loading. (default: ``None``)
-
- .. note:: By default, each worker will have its PyTorch seed set to
- ``base_seed + worker_id``, where ``base_seed`` is a long generated
- by main process using its RNG. However, seeds for other libraies
- may be duplicated upon initializing workers (w.g., NumPy), causing
- each worker to return identical random numbers. (See
- :ref:`dataloader-workers-random-seed` section in FAQ.) You may
- use :func:`torch.initial_seed()` to access the PyTorch seed for
- each worker in :attr:`worker_init_fn`, and use it to set other
- seeds before data loading.
-
- .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
- unpicklable object, e.g., a lambda function.
- """
-
- __initialized = False
-
- def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
- num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
- timeout=0, worker_init_fn=None):
- self.dataset = dataset
- self.batch_size = batch_size
- self.num_workers = num_workers
- self.collate_fn = collate_fn
- self.pin_memory = pin_memory
- self.drop_last = drop_last
- self.timeout = timeout
- self.worker_init_fn = worker_init_fn
-
- if timeout < 0:
- raise ValueError('timeout option should be non-negative')
-
- if batch_sampler is not None:
- if batch_size > 1 or shuffle or sampler is not None or drop_last:
- raise ValueError('batch_sampler option is mutually exclusive '
- 'with batch_size, shuffle, sampler, and '
- 'drop_last')
- self.batch_size = None
- self.drop_last = None
-
- if sampler is not None and shuffle:
- raise ValueError('sampler option is mutually exclusive with '
- 'shuffle')
-
- if self.num_workers < 0:
- raise ValueError('num_workers option cannot be negative; '
- 'use num_workers=0 to disable multiprocessing.')
-
- if batch_sampler is None:
- if sampler is None:
- if shuffle:
- sampler = RandomSampler(dataset)
- else:
- sampler = SequentialSampler(dataset)
- batch_sampler = BatchSampler(sampler, batch_size, drop_last)
-
- self.sampler = sampler
- self.batch_sampler = batch_sampler
- self.__initialized = True
-
- def __setattr__(self, attr, val):
- if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
- raise ValueError('{} attribute should not be set after {} is '
- 'initialized'.format(attr, self.__class__.__name__))
-
- super(DataLoader, self).__setattr__(attr, val)
-
- def __iter__(self):
- return _DataLoaderIter(self)
-
- def __len__(self):
- return len(self.batch_sampler)