From: Vitaly Fedyunin Date: Mon, 30 Aug 2021 14:54:11 +0000 (-0700) Subject: [DataLoader2] Adding Messages, Protocols, Loop wrappers (#63882) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~603 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=82174330d0bae4e2356295e16e261052f1d0ff8c;p=platform%2Fupstream%2Fpytorch.git [DataLoader2] Adding Messages, Protocols, Loop wrappers (#63882) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63882 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30627452 Pulled By: VitalyFedyunin fbshipit-source-id: 561ea2df07f3572e04401171946154024126387b --- diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 6555463..c768246 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -22,6 +22,7 @@ from torch.utils.data import ( IterableDataset, Subset, TensorDataset, + communication, _utils ) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL @@ -32,6 +33,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMP IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, load_tests, TEST_WITH_TSAN, IS_SANDCASTLE) + try: import psutil HAS_PSUTIL = True @@ -730,7 +732,7 @@ class TestWorkerInfoDataset(SynchronizedDataset): # Should be used as worker_init_fn with TestWorkerInfoDataset. # See _test_get_worker_info below for usage. -def test_worker_info_init_fn(worker_id): +def _test_worker_info_init_fn(worker_id): worker_info = torch.utils.data.get_worker_info() assert worker_id == worker_info.id, "worker_init_fn and worker_info should have consistent id" assert worker_id < worker_info.num_workers, "worker_init_fn and worker_info should have valid id" @@ -760,7 +762,7 @@ def _test_get_worker_info(): dataset = TestWorkerInfoDataset(6, batch_size, num_workers) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, - worker_init_fn=test_worker_info_init_fn) + worker_init_fn=_test_worker_info_init_fn) it = iter(dataloader) data = [] for d in it: @@ -769,7 +771,7 @@ def _test_get_worker_info(): data = torch.cat(data, 0) for d in data: # each `d` is a [worker_id, worker_pid] pair, which is set in - # test_worker_info_init_fn + # _test_worker_info_init_fn assert d[1] == worker_pids[d[0]] # get_worker_info returns None in main proc after data loading assert torch.utils.data.get_worker_info() is None @@ -1963,11 +1965,41 @@ except RuntimeError as e: class TestDataLoader2(TestCase): @skipIfNoDill def test_basics(self): - dp = IterableWrapper(list(range(10))) + # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order + # of traversing workers + dp = IterableWrapper(list(range(1000))) dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2) dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2) - self.assertEquals(list(dl), list(dl2)) + dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2, parallelism_mode='thread') + self.assertEqual(list(dl), list(dl2)) + self.assertEqual(list(dl), list(dl2_threading)) + + + +@unittest.skipIf( + TEST_WITH_TSAN, + "Fails with TSAN with the following error: starting new threads after multi-threaded " + "fork is not supported. Dying (set die_after_fork=0 to override)") +class TestDataLoader2_EventLoop(TestCase): + @skipIfNoDill + def test_basic_threading(self): + def clean_me(process, req_queue, res_queue): + req_queue.put(communication.messages.TerminateRequest()) + _ = res_queue.get() + process.join() + + it = list(range(100)) + numbers_dp = IterableWrapper(it) + (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) + + process.start() + local_datapipe = communication.iter.QueueWrapper( + communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) + + actual = list(local_datapipe) + clean_me(process, req_queue, res_queue) + self.assertEqual(list(range(100)), actual) class StringDataset(Dataset): def __init__(self): diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 0af9e61..ac0c763 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -35,7 +35,7 @@ from torch.utils.data._decorator import ( runtime_validation_disabled, ) from torch.utils.data.dataloader_experimental import DataLoader2 - +from torch.utils.data import communication __all__ = ['BatchSampler', 'ChainDataset', @@ -56,6 +56,7 @@ __all__ = ['BatchSampler', 'WeightedRandomSampler', '_DatasetKind', 'argument_validation', + 'communication', 'functional_datapipe', 'get_worker_info', 'guaranteed_datapipes_determinism', diff --git a/torch/utils/data/communication/__init__.py b/torch/utils/data/communication/__init__.py new file mode 100644 index 0000000..88a395e --- /dev/null +++ b/torch/utils/data/communication/__init__.py @@ -0,0 +1,5 @@ +from . import eventloop +from . import iter +from . import messages +from . import protocol +from . import queue diff --git a/torch/utils/data/communication/eventloop.py b/torch/utils/data/communication/eventloop.py new file mode 100644 index 0000000..75c44c5 --- /dev/null +++ b/torch/utils/data/communication/eventloop.py @@ -0,0 +1,41 @@ +import torch +import threading +import pickle + +from torch.utils.data import IterDataPipe, communication + + +def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): + if isinstance(source_datapipe, IterDataPipe): + pipe_type = communication.iter + protocol_type = communication.protocol.IterDataPipeQueueProtocolServer + else: + raise Exception('Only supports IterDataPipe, got', source_datapipe) + # pipe_type = communication.map + # protocol_type = communication.protocol.MapDataPipeQueueProtocolServer + + torch.set_num_threads(1) + for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True): + pass + + +def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): + req_queue = multiprocessing_ctx.Queue() + res_queue = multiprocessing_ctx.Queue() + process = multiprocessing_ctx.Process( + target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) + return process, req_queue, res_queue + + +def SpawnThreadForDataPipeline(datapipe): + req_queue = communication.queue.ThreadingQueue() + res_queue = communication.queue.ThreadingQueue() + + try: + new_datapipe = pickle.loads(pickle.dumps(datapipe)) + except Exception as e: + raise Exception('Unable to pickle DataPipe to make thread local copy', e) + + process = threading.Thread(target=DataPipeToQueuesLoop, args=( + new_datapipe, req_queue, res_queue), daemon=True) + return process, req_queue, res_queue, new_datapipe diff --git a/torch/utils/data/communication/iter.py b/torch/utils/data/communication/iter.py new file mode 100644 index 0000000..594a466 --- /dev/null +++ b/torch/utils/data/communication/iter.py @@ -0,0 +1,173 @@ +import time +import types + +from torch.utils.data import IterDataPipe, communication + +DEFAULT_NON_BLOCKING_SLEEP = 0.001 + + +def default_not_available_hook(): + time.sleep(DEFAULT_NON_BLOCKING_SLEEP) + + +class NotAvailable(Exception): + pass + + +class InvalidStateResetRequired(Exception): + """ + Returned by DataPipe when it is expecting to get reset request, + for example RouterDataPipe expecting all workers to request reset' + """ + pass + + +class NonBlocking(IterDataPipe): + not_available_hook = default_not_available_hook + + def __iter__(self): + self.reset_iterator() + return self + + def __next__(self): + while True: + try: + return self.nonblocking_next() + except StopIteration: + raise StopIteration + except NotAvailable: + if NonBlocking.not_available_hook is not None: + NonBlocking.not_available_hook() + + def nonblocking_next(self): + raise NotImplementedError( + "nonblocking_next is not implemented for %s" % self.__class__) + + def reset_iterator(self): + raise NotImplementedError( + "reset_iterator is not implemented for %s" % self.__class__) + + @staticmethod + def register_not_available_hook(hook_function): + NonBlocking.not_available_hook = hook_function + + +def EnsureNonBlockingDataPipe(validated_datapipe): + if not isinstance(validated_datapipe, IterDataPipe): + raise Exception('Not Iterable DataPipe ' + + str(validated_datapipe.__class__)) + if isinstance(validated_datapipe, NonBlocking): + return validated_datapipe + if not hasattr(validated_datapipe, '_as_iterator'): + validated_datapipe._as_iterator = None # type: ignore[attr-defined] + if not hasattr(validated_datapipe, 'nonblocking_next'): + def nonblocking_next(self): + if self._as_iterator is None: + self._as_iterator = iter(self) + return next(self._as_iterator) + validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] + nonblocking_next, validated_datapipe) + if not hasattr(validated_datapipe, 'reset_iterator'): + def reset_iterator(self): + self._as_iterator = None + validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] + reset_iterator, validated_datapipe) + return validated_datapipe + + +def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): + """ + Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue + If raise_stop is true, raises exception when StopIteration received from the source_datapipe + """ + if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): + raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol) + source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) + forever = True + while forever: + + try: + # Non-blocking call is Extremely slow here for python.mp, need to figureout good workaround + request = protocol.get_new_request(block=blocking_request_get) + except communication.protocol.EmptyQueue: + yield True + continue + + if isinstance(request, communication.messages.ResetIteratorRequest): + source_datapipe.reset_iterator() + protocol.response_reset() + + elif isinstance(request, communication.messages.TerminateRequest): + forever = False + protocol.response_terminate() + + elif isinstance(request, communication.messages.GetNextRequest): + while forever: + try: + value = source_datapipe.nonblocking_next() + except NotAvailable: + yield True + continue + except StopIteration: + protocol.response_stop() + if full_stop: + forever = False + else: + yield True + break + except InvalidStateResetRequired: + protocol.response_invalid() + if full_stop: + forever = False + else: + yield True + break + protocol.response_next(value) + yield True # Returns control + break + else: + raise Exception('Unrecognized type of request received', request) + + +class QueueWrapper(NonBlocking): + """ + Creates iter.DataPipe which reads data from the DataLoader.Queue + """ + + def __init__(self, protocol, response_wait_time=0.00001): + if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): + raise Exception('Got', protocol) + + self.protocol = protocol + self.counter = 0 + self._stop_iteration = False + self._response_wait_time = response_wait_time + + def reset_iterator(self): + self._stop_iteration = False + self.counter = 0 + self.protocol.request_reset() + while True: + try: + self.protocol.get_response_reset() + break + except communication.protocol.EmptyQueue: + if NonBlocking.not_available_hook is not None: + NonBlocking.not_available_hook() + + def nonblocking_next(self): + if self._stop_iteration: + raise Exception( + '`next` or `nonblocking_next` called after receiving StopIteration') + if self.protocol.can_take_request(): + self.protocol.request_next() + try: + response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) + except communication.protocol.EmptyQueue: + raise NotAvailable + if isinstance(response, communication.messages.StopIterationResponse): + self._stop_iteration = True + raise StopIteration + if isinstance(response, communication.messages.InvalidStateResponse): + raise NotAvailable + return response.value diff --git a/torch/utils/data/communication/messages.py b/torch/utils/data/communication/messages.py new file mode 100644 index 0000000..449cf23 --- /dev/null +++ b/torch/utils/data/communication/messages.py @@ -0,0 +1,75 @@ +class DataLoaderQueueMessage(object): + pass + + +class Request(DataLoaderQueueMessage): + pass + + +class Response(DataLoaderQueueMessage): + pass + + +class ResetIteratorRequest(Request): + pass + + +class ResetIteratorResponse(Response): + pass + + +class TerminateRequest(Request): + pass + + +class TerminateResponse(Response): + pass + + +class LenRequest(Request): + pass + + +class LenResponse(Response): + __slots__ = ('len') + + def __init__(self, len): + self.len = len + + +class GetItemRequest(Request): + __slots__ = ('key') + + def __init__(self, key): + self.key = key + + +class GetItemResponse(Response): + __slots__ = ('key', 'value') + + def __init__(self, key, value): + self.key = key + self.value = value + + +class GetNextRequest(Request): + pass + + +class GetNextResponse(Response): + __slots__ = ('value') + + def __init__(self, value): + self.value = value + + +class StopIterationResponse(Response): + pass + + +class InvalidStateResponse(Response): + """ + Returned by DataPipe when it is expecting to get reset request, + for example RouterDataPipe expecting all workers to request reset' + """ + pass diff --git a/torch/utils/data/communication/protocol.py b/torch/utils/data/communication/protocol.py new file mode 100644 index 0000000..68ff335 --- /dev/null +++ b/torch/utils/data/communication/protocol.py @@ -0,0 +1,159 @@ +from torch.utils.data import communication + + +class Protocol(object): + __slots__ = ('request_queue', 'response_queue') + + def __init__(self, request_queue, response_queue): + self.request_queue = request_queue + self.response_queue = response_queue + + +class ProtocolClient(Protocol): + """ + ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue. + """ + _req_sent = None + + def __init__(self, request_queue, response_queue): + self.request_queue = request_queue + self.response_queue = response_queue + self._req_sent = None + + def can_take_request(self): + return self._req_sent is None + + def waiting_for_response(self): + return self._req_sent is not None + + def request_sent(self, request=True): + if not self.can_take_request(): + raise Exception('Protocol only supports one request in the Queue') + self._req_sent = request + + def request_served(self, result=None): + if not self.waiting_for_response(): + raise Exception( + 'Expected no peding requests, but something got served', result) + self._req_sent = None + + +class ProtocolServer(Protocol): + """ + ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. + """ + _req_received = None + + def __init__(self, request_queue, response_queue): + self.request_queue = request_queue + self.response_queue = response_queue + self._req_received = None + + def have_pending_request(self): + return self._req_received is not None + + def get_new_request(self, block=False): + if self.have_pending_request(): + raise Exception( + 'Trying to get next request, while having one unserved') + try: + response = self.request_queue.get(block=block) + except Exception as e: # TODO: Catch only timeout exceptions + raise EmptyQueue('queue is empty') + self._req_received = response + return response + + # TODO: Validate supported requests + + def response_reset(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + if not isinstance(self._req_received, communication.messages.ResetIteratorRequest): + raise Exception( + "Replaying with reset status to other type of message") + self.response_queue.put(communication.messages.ResetIteratorResponse()) + self._req_received = None + + def response_next(self, value): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.GetNextResponse(value)) + self._req_received = None + + def response_stop(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.StopIterationResponse()) + self._req_received = None + + def response_invalid(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + self.response_queue.put(communication.messages.InvalidStateResponse()) + self._req_received = None + + def response_terminate(self): + if not self.have_pending_request(): + raise Exception("Attempting to reply with pending request") + if not isinstance(self._req_received, communication.messages.TerminateRequest): + raise Exception( + "Replaying with terminate status to other type of message") + self.response_queue.put(communication.messages.TerminateResponse()) + self._req_received = None + + +class MapDataPipeQueueProtocolClient(ProtocolClient): + pass + + +class MapDataPipeQueueProtocolServer(ProtocolServer): + pass + + +class EmptyQueue(Exception): + pass + + +class IterDataPipeQueueProtocolServer(ProtocolServer): + pass + + +class IterDataPipeQueueProtocolClient(ProtocolClient): + def request_reset(self): + if not self.can_take_request(): + raise Exception( + 'Can not reset while we are still waiting response for previous request') + request = communication.messages.ResetIteratorRequest() + self.request_queue.put(request) + self.request_sent(request) + + def request_next(self): + if not self.can_take_request(): + raise Exception( + 'Can not request next item while we are still waiting response for previous request') + request = communication.messages.GetNextRequest() + self.request_queue.put(request) + self.request_sent(request) + + def get_response_reset(self, block=False): + try: + response = self.response_queue.get(block=block) + except Exception as e: # TODO: Catch only timeout exceptions + raise EmptyQueue('queue is empty') + self.request_served(response) + + if not isinstance(response, communication.messages.ResetIteratorResponse): + raise Exception('Invalid response received') + + def get_response_next(self, block=False, timeout=None): + if not self.waiting_for_response(): + raise Exception( + 'Can not expect any response without submitted request') + try: + response = self.response_queue.get(block=block, timeout=timeout) + except Exception as e: # TODO: Catch only timeout exceptions + raise EmptyQueue('queue is empty') + self.request_served(response) + + # TODO(VitalyFedyunin): Add possible response types validation here + return response diff --git a/torch/utils/data/communication/queue.py b/torch/utils/data/communication/queue.py new file mode 100644 index 0000000..7717697 --- /dev/null +++ b/torch/utils/data/communication/queue.py @@ -0,0 +1,50 @@ +import threading +import time + +class LocalQueue(): + ops = 0 + stored = 0 + uid = 0 + empty = 0 + + def __init__(self, name='unnamed'): + self.items = [] + self.name = name + self.uid = LocalQueue.uid + LocalQueue.uid += 1 + + def put(self, item, block=True): + LocalQueue.ops += 1 + LocalQueue.stored += 1 + self.items.append(item) + + def get(self, block=True, timeout=0): + # TODO(VitalyFedyunin): Add support of block and timeout arguments + LocalQueue.ops += 1 + if not len(self.items): + LocalQueue.empty += 1 + raise Exception('LocalQueue is empty') + LocalQueue.stored -= 1 + return self.items.pop() + + +class ThreadingQueue(): + def __init__(self, name='unnamed'): + self.lock = threading.Lock() + self.items = [] + self.name = name + + def put(self, item, block=True): + with self.lock: + self.items.append(item) + + def get(self, block=True, timeout=0): + # TODO(VitalyFedyunin): Add support of block and timeout arguments + while True: + with self.lock: + if len(self.items) > 0: + return self.items.pop() + if not block: + raise Exception("Not available") + # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue + time.sleep(0.000001) diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py index ea08529..a74c75c 100644 --- a/torch/utils/data/dataloader_experimental.py +++ b/torch/utils/data/dataloader_experimental.py @@ -1,10 +1,60 @@ import functools +import time + +from typing import Any, List import torch.utils.data.backward_compatibility -from torch.utils.data import DataLoader, IterDataPipe + +import torch.utils.data.sharding +from torch.utils.data import DataLoader, IterDataPipe, communication from torch.utils.data.datapipes.iter import IterableWrapper +class _ThreadingDataLoader2: + + def __init__(self, datapipe, num_workers=0, collate_fn=None): + self.threads = [] + self.datapipes = [] + self.collate_fn = collate_fn + for worker_id in range(num_workers): + (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) + torch.utils.data.sharding.apply_sharding(thread_localdatapipe, num_workers, worker_id) + thread.start() + self.threads.append((thread, req_queue, res_queue)) + local_datapipe = communication.iter.QueueWrapper( + communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) + self.datapipes.append(local_datapipe) + + def __iter__(self): + not_available = False + forever = True + exclude_datapipes: List[Any] = [] + while len(exclude_datapipes) < len(self.datapipes): + for dp in self.datapipes: + if dp not in exclude_datapipes: + try: + value = dp.nonblocking_next() + yield value + except StopIteration: + exclude_datapipes.append(dp) + except communication.iter.NotAvailable: + not_available = True + if not_available: + time.sleep(0.001) + + def __del__(self): + self._cleanup_all_threads() + + def _cleanup_all_threads(self): + def clean_me(thread, req_queue, res_queue): + req_queue.put(communication.messages.TerminateRequest()) + _ = res_queue.get() + thread.join() + + for thread, req_queue, res_queue in self.threads: + clean_me(thread, req_queue, res_queue) + + class DataLoader2: def __new__(cls, dataset, @@ -21,15 +71,17 @@ class DataLoader2: *, prefetch_factor=2, persistent_workers=False, - batch_outside_worker=False): + batch_outside_worker=False, + parallelism_mode='mp'): if isinstance(dataset, IterDataPipe): - datapipe = dataset + data_loader: Any = None if batch_sampler is not None: raise Exception( - 'batch_sampler is not yet supported for DataPipes') + 'batch_sampler is not yet supported by DataPipes') if sampler is not None: raise Exception( - 'sampler is not yet supported for DataPipes') + 'sampler is not yet supported by DataPipes') + datapipe = dataset if shuffle: datapipe = datapipe.shuffle() if batch_outside_worker and pin_memory: @@ -40,30 +92,43 @@ class DataLoader2: datapipe = datapipe.batch(batch_size, drop_last=drop_last) if collate_fn is None: collate_fn = torch.utils.data._utils.collate.default_collate + if parallelism_mode == 'mp' or num_workers == 0: + def sharding_worker_init_fn(worker_init_fn, worker_id): + if worker_init_fn is not None: + worker_init_fn(worker_id) + torch.utils.data.backward_compatibility.worker_init_fn( + worker_id) - def sharding_worker_init_fn(worker_init_fn, worker_id): - if worker_init_fn is not None: - worker_init_fn(worker_id) - torch.utils.data.backward_compatibility.worker_init_fn( - worker_id) - - my_worker_init_fn = functools.partial( - sharding_worker_init_fn, worker_init_fn) - - data_loader = DataLoader(datapipe, - batch_size=None, # Replaced by .batch DataPipe - shuffle=False, # Replaced by .shuffle DataPipe - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=False, # Replaced by .batch DataPipe - timeout=timeout, - worker_init_fn=my_worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) + my_worker_init_fn = functools.partial( + sharding_worker_init_fn, worker_init_fn) + data_loader = DataLoader(datapipe, + batch_size=None, # Replaced by .batch DataPipe + shuffle=False, # Replaced by .shuffle DataPipe + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, # Replaced by .batch DataPipe + timeout=timeout, + worker_init_fn=my_worker_init_fn, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + elif parallelism_mode == 'thread': + if collate_fn is not None and not batch_outside_worker: + datapipe = datapipe.map(collate_fn) + if pin_memory: + raise Exception( + 'pin_memory is not yet supported by DataPipes with Threading') + if worker_init_fn is not None: + raise Exception( + 'worker_init_fn is not yet supported by DataPipes with Threading') + data_loader = _ThreadingDataLoader2(datapipe, + num_workers=num_workers, + collate_fn=collate_fn) + else: + raise Exception('Unsupported parallelism mode', parallelism_mode) if not batch_outside_worker: return data_loader else: @@ -72,8 +137,11 @@ class DataLoader2: datapipe = IterableWrapper(data_loader).batch( batch_size, drop_last=drop_last).map(collate_fn) return datapipe - else: + if parallelism_mode != 'thread': + raise Exception( + 'thread parallelism mode is not supported for old DataSets') + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,