IterableDataset,
Subset,
TensorDataset,
+ communication,
_utils
)
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
load_tests, TEST_WITH_TSAN, IS_SANDCASTLE)
+
try:
import psutil
HAS_PSUTIL = True
# Should be used as worker_init_fn with TestWorkerInfoDataset.
# See _test_get_worker_info below for usage.
-def test_worker_info_init_fn(worker_id):
+def _test_worker_info_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
assert worker_id == worker_info.id, "worker_init_fn and worker_info should have consistent id"
assert worker_id < worker_info.num_workers, "worker_init_fn and worker_info should have valid id"
dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers,
- worker_init_fn=test_worker_info_init_fn)
+ worker_init_fn=_test_worker_info_init_fn)
it = iter(dataloader)
data = []
for d in it:
data = torch.cat(data, 0)
for d in data:
# each `d` is a [worker_id, worker_pid] pair, which is set in
- # test_worker_info_init_fn
+ # _test_worker_info_init_fn
assert d[1] == worker_pids[d[0]]
# get_worker_info returns None in main proc after data loading
assert torch.utils.data.get_worker_info() is None
class TestDataLoader2(TestCase):
@skipIfNoDill
def test_basics(self):
- dp = IterableWrapper(list(range(10)))
+ # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order
+ # of traversing workers
+ dp = IterableWrapper(list(range(1000)))
dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
- self.assertEquals(list(dl), list(dl2))
+ dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2, parallelism_mode='thread')
+ self.assertEqual(list(dl), list(dl2))
+ self.assertEqual(list(dl), list(dl2_threading))
+
+
+
+@unittest.skipIf(
+ TEST_WITH_TSAN,
+ "Fails with TSAN with the following error: starting new threads after multi-threaded "
+ "fork is not supported. Dying (set die_after_fork=0 to override)")
+class TestDataLoader2_EventLoop(TestCase):
+ @skipIfNoDill
+ def test_basic_threading(self):
+ def clean_me(process, req_queue, res_queue):
+ req_queue.put(communication.messages.TerminateRequest())
+ _ = res_queue.get()
+ process.join()
+
+ it = list(range(100))
+ numbers_dp = IterableWrapper(it)
+ (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp)
+
+ process.start()
+ local_datapipe = communication.iter.QueueWrapper(
+ communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
+
+ actual = list(local_datapipe)
+ clean_me(process, req_queue, res_queue)
+ self.assertEqual(list(range(100)), actual)
class StringDataset(Dataset):
def __init__(self):
runtime_validation_disabled,
)
from torch.utils.data.dataloader_experimental import DataLoader2
-
+from torch.utils.data import communication
__all__ = ['BatchSampler',
'ChainDataset',
'WeightedRandomSampler',
'_DatasetKind',
'argument_validation',
+ 'communication',
'functional_datapipe',
'get_worker_info',
'guaranteed_datapipes_determinism',
--- /dev/null
+from . import eventloop
+from . import iter
+from . import messages
+from . import protocol
+from . import queue
--- /dev/null
+import torch
+import threading
+import pickle
+
+from torch.utils.data import IterDataPipe, communication
+
+
+def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue):
+ if isinstance(source_datapipe, IterDataPipe):
+ pipe_type = communication.iter
+ protocol_type = communication.protocol.IterDataPipeQueueProtocolServer
+ else:
+ raise Exception('Only supports IterDataPipe, got', source_datapipe)
+ # pipe_type = communication.map
+ # protocol_type = communication.protocol.MapDataPipeQueueProtocolServer
+
+ torch.set_num_threads(1)
+ for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True):
+ pass
+
+
+def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe):
+ req_queue = multiprocessing_ctx.Queue()
+ res_queue = multiprocessing_ctx.Queue()
+ process = multiprocessing_ctx.Process(
+ target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue))
+ return process, req_queue, res_queue
+
+
+def SpawnThreadForDataPipeline(datapipe):
+ req_queue = communication.queue.ThreadingQueue()
+ res_queue = communication.queue.ThreadingQueue()
+
+ try:
+ new_datapipe = pickle.loads(pickle.dumps(datapipe))
+ except Exception as e:
+ raise Exception('Unable to pickle DataPipe to make thread local copy', e)
+
+ process = threading.Thread(target=DataPipeToQueuesLoop, args=(
+ new_datapipe, req_queue, res_queue), daemon=True)
+ return process, req_queue, res_queue, new_datapipe
--- /dev/null
+import time
+import types
+
+from torch.utils.data import IterDataPipe, communication
+
+DEFAULT_NON_BLOCKING_SLEEP = 0.001
+
+
+def default_not_available_hook():
+ time.sleep(DEFAULT_NON_BLOCKING_SLEEP)
+
+
+class NotAvailable(Exception):
+ pass
+
+
+class InvalidStateResetRequired(Exception):
+ """
+ Returned by DataPipe when it is expecting to get reset request,
+ for example RouterDataPipe expecting all workers to request reset'
+ """
+ pass
+
+
+class NonBlocking(IterDataPipe):
+ not_available_hook = default_not_available_hook
+
+ def __iter__(self):
+ self.reset_iterator()
+ return self
+
+ def __next__(self):
+ while True:
+ try:
+ return self.nonblocking_next()
+ except StopIteration:
+ raise StopIteration
+ except NotAvailable:
+ if NonBlocking.not_available_hook is not None:
+ NonBlocking.not_available_hook()
+
+ def nonblocking_next(self):
+ raise NotImplementedError(
+ "nonblocking_next is not implemented for %s" % self.__class__)
+
+ def reset_iterator(self):
+ raise NotImplementedError(
+ "reset_iterator is not implemented for %s" % self.__class__)
+
+ @staticmethod
+ def register_not_available_hook(hook_function):
+ NonBlocking.not_available_hook = hook_function
+
+
+def EnsureNonBlockingDataPipe(validated_datapipe):
+ if not isinstance(validated_datapipe, IterDataPipe):
+ raise Exception('Not Iterable DataPipe ' +
+ str(validated_datapipe.__class__))
+ if isinstance(validated_datapipe, NonBlocking):
+ return validated_datapipe
+ if not hasattr(validated_datapipe, '_as_iterator'):
+ validated_datapipe._as_iterator = None # type: ignore[attr-defined]
+ if not hasattr(validated_datapipe, 'nonblocking_next'):
+ def nonblocking_next(self):
+ if self._as_iterator is None:
+ self._as_iterator = iter(self)
+ return next(self._as_iterator)
+ validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined]
+ nonblocking_next, validated_datapipe)
+ if not hasattr(validated_datapipe, 'reset_iterator'):
+ def reset_iterator(self):
+ self._as_iterator = None
+ validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined]
+ reset_iterator, validated_datapipe)
+ return validated_datapipe
+
+
+def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False):
+ """
+ Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue
+ If raise_stop is true, raises exception when StopIteration received from the source_datapipe
+ """
+ if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer):
+ raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol)
+ source_datapipe = EnsureNonBlockingDataPipe(source_datapipe)
+ forever = True
+ while forever:
+
+ try:
+ # Non-blocking call is Extremely slow here for python.mp, need to figureout good workaround
+ request = protocol.get_new_request(block=blocking_request_get)
+ except communication.protocol.EmptyQueue:
+ yield True
+ continue
+
+ if isinstance(request, communication.messages.ResetIteratorRequest):
+ source_datapipe.reset_iterator()
+ protocol.response_reset()
+
+ elif isinstance(request, communication.messages.TerminateRequest):
+ forever = False
+ protocol.response_terminate()
+
+ elif isinstance(request, communication.messages.GetNextRequest):
+ while forever:
+ try:
+ value = source_datapipe.nonblocking_next()
+ except NotAvailable:
+ yield True
+ continue
+ except StopIteration:
+ protocol.response_stop()
+ if full_stop:
+ forever = False
+ else:
+ yield True
+ break
+ except InvalidStateResetRequired:
+ protocol.response_invalid()
+ if full_stop:
+ forever = False
+ else:
+ yield True
+ break
+ protocol.response_next(value)
+ yield True # Returns control
+ break
+ else:
+ raise Exception('Unrecognized type of request received', request)
+
+
+class QueueWrapper(NonBlocking):
+ """
+ Creates iter.DataPipe which reads data from the DataLoader.Queue
+ """
+
+ def __init__(self, protocol, response_wait_time=0.00001):
+ if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient):
+ raise Exception('Got', protocol)
+
+ self.protocol = protocol
+ self.counter = 0
+ self._stop_iteration = False
+ self._response_wait_time = response_wait_time
+
+ def reset_iterator(self):
+ self._stop_iteration = False
+ self.counter = 0
+ self.protocol.request_reset()
+ while True:
+ try:
+ self.protocol.get_response_reset()
+ break
+ except communication.protocol.EmptyQueue:
+ if NonBlocking.not_available_hook is not None:
+ NonBlocking.not_available_hook()
+
+ def nonblocking_next(self):
+ if self._stop_iteration:
+ raise Exception(
+ '`next` or `nonblocking_next` called after receiving StopIteration')
+ if self.protocol.can_take_request():
+ self.protocol.request_next()
+ try:
+ response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time)
+ except communication.protocol.EmptyQueue:
+ raise NotAvailable
+ if isinstance(response, communication.messages.StopIterationResponse):
+ self._stop_iteration = True
+ raise StopIteration
+ if isinstance(response, communication.messages.InvalidStateResponse):
+ raise NotAvailable
+ return response.value
--- /dev/null
+class DataLoaderQueueMessage(object):
+ pass
+
+
+class Request(DataLoaderQueueMessage):
+ pass
+
+
+class Response(DataLoaderQueueMessage):
+ pass
+
+
+class ResetIteratorRequest(Request):
+ pass
+
+
+class ResetIteratorResponse(Response):
+ pass
+
+
+class TerminateRequest(Request):
+ pass
+
+
+class TerminateResponse(Response):
+ pass
+
+
+class LenRequest(Request):
+ pass
+
+
+class LenResponse(Response):
+ __slots__ = ('len')
+
+ def __init__(self, len):
+ self.len = len
+
+
+class GetItemRequest(Request):
+ __slots__ = ('key')
+
+ def __init__(self, key):
+ self.key = key
+
+
+class GetItemResponse(Response):
+ __slots__ = ('key', 'value')
+
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+
+
+class GetNextRequest(Request):
+ pass
+
+
+class GetNextResponse(Response):
+ __slots__ = ('value')
+
+ def __init__(self, value):
+ self.value = value
+
+
+class StopIterationResponse(Response):
+ pass
+
+
+class InvalidStateResponse(Response):
+ """
+ Returned by DataPipe when it is expecting to get reset request,
+ for example RouterDataPipe expecting all workers to request reset'
+ """
+ pass
--- /dev/null
+from torch.utils.data import communication
+
+
+class Protocol(object):
+ __slots__ = ('request_queue', 'response_queue')
+
+ def __init__(self, request_queue, response_queue):
+ self.request_queue = request_queue
+ self.response_queue = response_queue
+
+
+class ProtocolClient(Protocol):
+ """
+ ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue.
+ """
+ _req_sent = None
+
+ def __init__(self, request_queue, response_queue):
+ self.request_queue = request_queue
+ self.response_queue = response_queue
+ self._req_sent = None
+
+ def can_take_request(self):
+ return self._req_sent is None
+
+ def waiting_for_response(self):
+ return self._req_sent is not None
+
+ def request_sent(self, request=True):
+ if not self.can_take_request():
+ raise Exception('Protocol only supports one request in the Queue')
+ self._req_sent = request
+
+ def request_served(self, result=None):
+ if not self.waiting_for_response():
+ raise Exception(
+ 'Expected no peding requests, but something got served', result)
+ self._req_sent = None
+
+
+class ProtocolServer(Protocol):
+ """
+ ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe.
+ """
+ _req_received = None
+
+ def __init__(self, request_queue, response_queue):
+ self.request_queue = request_queue
+ self.response_queue = response_queue
+ self._req_received = None
+
+ def have_pending_request(self):
+ return self._req_received is not None
+
+ def get_new_request(self, block=False):
+ if self.have_pending_request():
+ raise Exception(
+ 'Trying to get next request, while having one unserved')
+ try:
+ response = self.request_queue.get(block=block)
+ except Exception as e: # TODO: Catch only timeout exceptions
+ raise EmptyQueue('queue is empty')
+ self._req_received = response
+ return response
+
+ # TODO: Validate supported requests
+
+ def response_reset(self):
+ if not self.have_pending_request():
+ raise Exception("Attempting to reply with pending request")
+ if not isinstance(self._req_received, communication.messages.ResetIteratorRequest):
+ raise Exception(
+ "Replaying with reset status to other type of message")
+ self.response_queue.put(communication.messages.ResetIteratorResponse())
+ self._req_received = None
+
+ def response_next(self, value):
+ if not self.have_pending_request():
+ raise Exception("Attempting to reply with pending request")
+ self.response_queue.put(communication.messages.GetNextResponse(value))
+ self._req_received = None
+
+ def response_stop(self):
+ if not self.have_pending_request():
+ raise Exception("Attempting to reply with pending request")
+ self.response_queue.put(communication.messages.StopIterationResponse())
+ self._req_received = None
+
+ def response_invalid(self):
+ if not self.have_pending_request():
+ raise Exception("Attempting to reply with pending request")
+ self.response_queue.put(communication.messages.InvalidStateResponse())
+ self._req_received = None
+
+ def response_terminate(self):
+ if not self.have_pending_request():
+ raise Exception("Attempting to reply with pending request")
+ if not isinstance(self._req_received, communication.messages.TerminateRequest):
+ raise Exception(
+ "Replaying with terminate status to other type of message")
+ self.response_queue.put(communication.messages.TerminateResponse())
+ self._req_received = None
+
+
+class MapDataPipeQueueProtocolClient(ProtocolClient):
+ pass
+
+
+class MapDataPipeQueueProtocolServer(ProtocolServer):
+ pass
+
+
+class EmptyQueue(Exception):
+ pass
+
+
+class IterDataPipeQueueProtocolServer(ProtocolServer):
+ pass
+
+
+class IterDataPipeQueueProtocolClient(ProtocolClient):
+ def request_reset(self):
+ if not self.can_take_request():
+ raise Exception(
+ 'Can not reset while we are still waiting response for previous request')
+ request = communication.messages.ResetIteratorRequest()
+ self.request_queue.put(request)
+ self.request_sent(request)
+
+ def request_next(self):
+ if not self.can_take_request():
+ raise Exception(
+ 'Can not request next item while we are still waiting response for previous request')
+ request = communication.messages.GetNextRequest()
+ self.request_queue.put(request)
+ self.request_sent(request)
+
+ def get_response_reset(self, block=False):
+ try:
+ response = self.response_queue.get(block=block)
+ except Exception as e: # TODO: Catch only timeout exceptions
+ raise EmptyQueue('queue is empty')
+ self.request_served(response)
+
+ if not isinstance(response, communication.messages.ResetIteratorResponse):
+ raise Exception('Invalid response received')
+
+ def get_response_next(self, block=False, timeout=None):
+ if not self.waiting_for_response():
+ raise Exception(
+ 'Can not expect any response without submitted request')
+ try:
+ response = self.response_queue.get(block=block, timeout=timeout)
+ except Exception as e: # TODO: Catch only timeout exceptions
+ raise EmptyQueue('queue is empty')
+ self.request_served(response)
+
+ # TODO(VitalyFedyunin): Add possible response types validation here
+ return response
--- /dev/null
+import threading
+import time
+
+class LocalQueue():
+ ops = 0
+ stored = 0
+ uid = 0
+ empty = 0
+
+ def __init__(self, name='unnamed'):
+ self.items = []
+ self.name = name
+ self.uid = LocalQueue.uid
+ LocalQueue.uid += 1
+
+ def put(self, item, block=True):
+ LocalQueue.ops += 1
+ LocalQueue.stored += 1
+ self.items.append(item)
+
+ def get(self, block=True, timeout=0):
+ # TODO(VitalyFedyunin): Add support of block and timeout arguments
+ LocalQueue.ops += 1
+ if not len(self.items):
+ LocalQueue.empty += 1
+ raise Exception('LocalQueue is empty')
+ LocalQueue.stored -= 1
+ return self.items.pop()
+
+
+class ThreadingQueue():
+ def __init__(self, name='unnamed'):
+ self.lock = threading.Lock()
+ self.items = []
+ self.name = name
+
+ def put(self, item, block=True):
+ with self.lock:
+ self.items.append(item)
+
+ def get(self, block=True, timeout=0):
+ # TODO(VitalyFedyunin): Add support of block and timeout arguments
+ while True:
+ with self.lock:
+ if len(self.items) > 0:
+ return self.items.pop()
+ if not block:
+ raise Exception("Not available")
+ # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue
+ time.sleep(0.000001)
import functools
+import time
+
+from typing import Any, List
import torch.utils.data.backward_compatibility
-from torch.utils.data import DataLoader, IterDataPipe
+
+import torch.utils.data.sharding
+from torch.utils.data import DataLoader, IterDataPipe, communication
from torch.utils.data.datapipes.iter import IterableWrapper
+class _ThreadingDataLoader2:
+
+ def __init__(self, datapipe, num_workers=0, collate_fn=None):
+ self.threads = []
+ self.datapipes = []
+ self.collate_fn = collate_fn
+ for worker_id in range(num_workers):
+ (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe)
+ torch.utils.data.sharding.apply_sharding(thread_localdatapipe, num_workers, worker_id)
+ thread.start()
+ self.threads.append((thread, req_queue, res_queue))
+ local_datapipe = communication.iter.QueueWrapper(
+ communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
+ self.datapipes.append(local_datapipe)
+
+ def __iter__(self):
+ not_available = False
+ forever = True
+ exclude_datapipes: List[Any] = []
+ while len(exclude_datapipes) < len(self.datapipes):
+ for dp in self.datapipes:
+ if dp not in exclude_datapipes:
+ try:
+ value = dp.nonblocking_next()
+ yield value
+ except StopIteration:
+ exclude_datapipes.append(dp)
+ except communication.iter.NotAvailable:
+ not_available = True
+ if not_available:
+ time.sleep(0.001)
+
+ def __del__(self):
+ self._cleanup_all_threads()
+
+ def _cleanup_all_threads(self):
+ def clean_me(thread, req_queue, res_queue):
+ req_queue.put(communication.messages.TerminateRequest())
+ _ = res_queue.get()
+ thread.join()
+
+ for thread, req_queue, res_queue in self.threads:
+ clean_me(thread, req_queue, res_queue)
+
+
class DataLoader2:
def __new__(cls,
dataset,
*,
prefetch_factor=2,
persistent_workers=False,
- batch_outside_worker=False):
+ batch_outside_worker=False,
+ parallelism_mode='mp'):
if isinstance(dataset, IterDataPipe):
- datapipe = dataset
+ data_loader: Any = None
if batch_sampler is not None:
raise Exception(
- 'batch_sampler is not yet supported for DataPipes')
+ 'batch_sampler is not yet supported by DataPipes')
if sampler is not None:
raise Exception(
- 'sampler is not yet supported for DataPipes')
+ 'sampler is not yet supported by DataPipes')
+ datapipe = dataset
if shuffle:
datapipe = datapipe.shuffle()
if batch_outside_worker and pin_memory:
datapipe = datapipe.batch(batch_size, drop_last=drop_last)
if collate_fn is None:
collate_fn = torch.utils.data._utils.collate.default_collate
+ if parallelism_mode == 'mp' or num_workers == 0:
+ def sharding_worker_init_fn(worker_init_fn, worker_id):
+ if worker_init_fn is not None:
+ worker_init_fn(worker_id)
+ torch.utils.data.backward_compatibility.worker_init_fn(
+ worker_id)
- def sharding_worker_init_fn(worker_init_fn, worker_id):
- if worker_init_fn is not None:
- worker_init_fn(worker_id)
- torch.utils.data.backward_compatibility.worker_init_fn(
- worker_id)
-
- my_worker_init_fn = functools.partial(
- sharding_worker_init_fn, worker_init_fn)
-
- data_loader = DataLoader(datapipe,
- batch_size=None, # Replaced by .batch DataPipe
- shuffle=False, # Replaced by .shuffle DataPipe
- sampler=None,
- batch_sampler=None,
- num_workers=num_workers,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=False, # Replaced by .batch DataPipe
- timeout=timeout,
- worker_init_fn=my_worker_init_fn,
- prefetch_factor=prefetch_factor,
- persistent_workers=persistent_workers)
+ my_worker_init_fn = functools.partial(
+ sharding_worker_init_fn, worker_init_fn)
+ data_loader = DataLoader(datapipe,
+ batch_size=None, # Replaced by .batch DataPipe
+ shuffle=False, # Replaced by .shuffle DataPipe
+ sampler=None,
+ batch_sampler=None,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=False, # Replaced by .batch DataPipe
+ timeout=timeout,
+ worker_init_fn=my_worker_init_fn,
+ prefetch_factor=prefetch_factor,
+ persistent_workers=persistent_workers)
+ elif parallelism_mode == 'thread':
+ if collate_fn is not None and not batch_outside_worker:
+ datapipe = datapipe.map(collate_fn)
+ if pin_memory:
+ raise Exception(
+ 'pin_memory is not yet supported by DataPipes with Threading')
+ if worker_init_fn is not None:
+ raise Exception(
+ 'worker_init_fn is not yet supported by DataPipes with Threading')
+ data_loader = _ThreadingDataLoader2(datapipe,
+ num_workers=num_workers,
+ collate_fn=collate_fn)
+ else:
+ raise Exception('Unsupported parallelism mode', parallelism_mode)
if not batch_outside_worker:
return data_loader
else:
datapipe = IterableWrapper(data_loader).batch(
batch_size, drop_last=drop_last).map(collate_fn)
return datapipe
-
else:
+ if parallelism_mode != 'thread':
+ raise Exception(
+ 'thread parallelism mode is not supported for old DataSets')
+
return DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,