[DataLoader2] Adding Messages, Protocols, Loop wrappers (#63882)
authorVitaly Fedyunin <vitaly.fedyunin@gmail.com>
Mon, 30 Aug 2021 14:54:11 +0000 (07:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 14:57:20 +0000 (07:57 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63882

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30627452

Pulled By: VitalyFedyunin

fbshipit-source-id: 561ea2df07f3572e04401171946154024126387b

test/test_dataloader.py
torch/utils/data/__init__.py
torch/utils/data/communication/__init__.py [new file with mode: 0644]
torch/utils/data/communication/eventloop.py [new file with mode: 0644]
torch/utils/data/communication/iter.py [new file with mode: 0644]
torch/utils/data/communication/messages.py [new file with mode: 0644]
torch/utils/data/communication/protocol.py [new file with mode: 0644]
torch/utils/data/communication/queue.py [new file with mode: 0644]
torch/utils/data/dataloader_experimental.py

index 6555463..c768246 100644 (file)
@@ -22,6 +22,7 @@ from torch.utils.data import (
     IterableDataset,
     Subset,
     TensorDataset,
+    communication,
     _utils
 )
 from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
@@ -32,6 +33,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMP
                                                   IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
                                                   load_tests, TEST_WITH_TSAN, IS_SANDCASTLE)
 
+
 try:
     import psutil
     HAS_PSUTIL = True
@@ -730,7 +732,7 @@ class TestWorkerInfoDataset(SynchronizedDataset):
 
 # Should be used as worker_init_fn with TestWorkerInfoDataset.
 # See _test_get_worker_info below for usage.
-def test_worker_info_init_fn(worker_id):
+def _test_worker_info_init_fn(worker_id):
     worker_info = torch.utils.data.get_worker_info()
     assert worker_id == worker_info.id, "worker_init_fn and worker_info should have consistent id"
     assert worker_id < worker_info.num_workers, "worker_init_fn and worker_info should have valid id"
@@ -760,7 +762,7 @@ def _test_get_worker_info():
     dataset = TestWorkerInfoDataset(6, batch_size, num_workers)
     dataloader = DataLoader(dataset, batch_size=batch_size,
                             num_workers=num_workers,
-                            worker_init_fn=test_worker_info_init_fn)
+                            worker_init_fn=_test_worker_info_init_fn)
     it = iter(dataloader)
     data = []
     for d in it:
@@ -769,7 +771,7 @@ def _test_get_worker_info():
     data = torch.cat(data, 0)
     for d in data:
         # each `d` is a [worker_id, worker_pid] pair, which is set in
-        # test_worker_info_init_fn
+        # _test_worker_info_init_fn
         assert d[1] == worker_pids[d[0]]
     # get_worker_info returns None in main proc after data loading
     assert torch.utils.data.get_worker_info() is None
@@ -1963,11 +1965,41 @@ except RuntimeError as e:
 class TestDataLoader2(TestCase):
     @skipIfNoDill
     def test_basics(self):
-        dp = IterableWrapper(list(range(10)))
+        # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order
+        # of traversing workers
+        dp = IterableWrapper(list(range(1000)))
         dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
         dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
-        self.assertEquals(list(dl), list(dl2))
+        dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2, parallelism_mode='thread')
+        self.assertEqual(list(dl), list(dl2))
+        self.assertEqual(list(dl), list(dl2_threading))
+
+
+
+@unittest.skipIf(
+    TEST_WITH_TSAN,
+    "Fails with TSAN with the following error: starting new threads after multi-threaded "
+    "fork is not supported. Dying (set die_after_fork=0 to override)")
+class TestDataLoader2_EventLoop(TestCase):
+    @skipIfNoDill
+    def test_basic_threading(self):
+        def clean_me(process, req_queue, res_queue):
+            req_queue.put(communication.messages.TerminateRequest())
+            _ = res_queue.get()
+            process.join()
+
+        it = list(range(100))
+        numbers_dp = IterableWrapper(it)
+        (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp)
+
+        process.start()
+        local_datapipe = communication.iter.QueueWrapper(
+            communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
+
+        actual = list(local_datapipe)
+        clean_me(process, req_queue, res_queue)
 
+        self.assertEqual(list(range(100)), actual)
 
 class StringDataset(Dataset):
     def __init__(self):
index 0af9e61..ac0c763 100644 (file)
@@ -35,7 +35,7 @@ from torch.utils.data._decorator import (
     runtime_validation_disabled,
 )
 from torch.utils.data.dataloader_experimental import DataLoader2
-
+from torch.utils.data import communication
 
 __all__ = ['BatchSampler',
            'ChainDataset',
@@ -56,6 +56,7 @@ __all__ = ['BatchSampler',
            'WeightedRandomSampler',
            '_DatasetKind',
            'argument_validation',
+           'communication',
            'functional_datapipe',
            'get_worker_info',
            'guaranteed_datapipes_determinism',
diff --git a/torch/utils/data/communication/__init__.py b/torch/utils/data/communication/__init__.py
new file mode 100644 (file)
index 0000000..88a395e
--- /dev/null
@@ -0,0 +1,5 @@
+from . import eventloop
+from . import iter
+from . import messages
+from . import protocol
+from . import queue
diff --git a/torch/utils/data/communication/eventloop.py b/torch/utils/data/communication/eventloop.py
new file mode 100644 (file)
index 0000000..75c44c5
--- /dev/null
@@ -0,0 +1,41 @@
+import torch
+import threading
+import pickle
+
+from torch.utils.data import IterDataPipe, communication
+
+
+def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue):
+    if isinstance(source_datapipe, IterDataPipe):
+        pipe_type = communication.iter
+        protocol_type = communication.protocol.IterDataPipeQueueProtocolServer
+    else:
+        raise Exception('Only supports IterDataPipe, got', source_datapipe)
+        # pipe_type = communication.map
+        # protocol_type = communication.protocol.MapDataPipeQueueProtocolServer
+
+    torch.set_num_threads(1)
+    for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), blocking_request_get=True):
+        pass
+
+
+def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe):
+    req_queue = multiprocessing_ctx.Queue()
+    res_queue = multiprocessing_ctx.Queue()
+    process = multiprocessing_ctx.Process(
+        target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue))
+    return process, req_queue, res_queue
+
+
+def SpawnThreadForDataPipeline(datapipe):
+    req_queue = communication.queue.ThreadingQueue()
+    res_queue = communication.queue.ThreadingQueue()
+
+    try:
+        new_datapipe = pickle.loads(pickle.dumps(datapipe))
+    except Exception as e:
+        raise Exception('Unable to pickle DataPipe to make thread local copy', e)
+
+    process = threading.Thread(target=DataPipeToQueuesLoop, args=(
+        new_datapipe, req_queue, res_queue), daemon=True)
+    return process, req_queue, res_queue, new_datapipe
diff --git a/torch/utils/data/communication/iter.py b/torch/utils/data/communication/iter.py
new file mode 100644 (file)
index 0000000..594a466
--- /dev/null
@@ -0,0 +1,173 @@
+import time
+import types
+
+from torch.utils.data import IterDataPipe, communication
+
+DEFAULT_NON_BLOCKING_SLEEP = 0.001
+
+
+def default_not_available_hook():
+    time.sleep(DEFAULT_NON_BLOCKING_SLEEP)
+
+
+class NotAvailable(Exception):
+    pass
+
+
+class InvalidStateResetRequired(Exception):
+    """
+        Returned by DataPipe when it is expecting to get reset request,
+        for example RouterDataPipe expecting all workers to request reset'
+    """
+    pass
+
+
+class NonBlocking(IterDataPipe):
+    not_available_hook = default_not_available_hook
+
+    def __iter__(self):
+        self.reset_iterator()
+        return self
+
+    def __next__(self):
+        while True:
+            try:
+                return self.nonblocking_next()
+            except StopIteration:
+                raise StopIteration
+            except NotAvailable:
+                if NonBlocking.not_available_hook is not None:
+                    NonBlocking.not_available_hook()
+
+    def nonblocking_next(self):
+        raise NotImplementedError(
+            "nonblocking_next is not implemented for %s" % self.__class__)
+
+    def reset_iterator(self):
+        raise NotImplementedError(
+            "reset_iterator is not implemented for %s" % self.__class__)
+
+    @staticmethod
+    def register_not_available_hook(hook_function):
+        NonBlocking.not_available_hook = hook_function
+
+
+def EnsureNonBlockingDataPipe(validated_datapipe):
+    if not isinstance(validated_datapipe, IterDataPipe):
+        raise Exception('Not Iterable DataPipe ' +
+                        str(validated_datapipe.__class__))
+    if isinstance(validated_datapipe, NonBlocking):
+        return validated_datapipe
+    if not hasattr(validated_datapipe, '_as_iterator'):
+        validated_datapipe._as_iterator = None  # type: ignore[attr-defined]
+    if not hasattr(validated_datapipe, 'nonblocking_next'):
+        def nonblocking_next(self):
+            if self._as_iterator is None:
+                self._as_iterator = iter(self)
+            return next(self._as_iterator)
+        validated_datapipe.nonblocking_next = types.MethodType(  # type: ignore[attr-defined]
+            nonblocking_next, validated_datapipe)
+    if not hasattr(validated_datapipe, 'reset_iterator'):
+        def reset_iterator(self):
+            self._as_iterator = None
+        validated_datapipe.reset_iterator = types.MethodType(  # type: ignore[attr-defined]
+            reset_iterator, validated_datapipe)
+    return validated_datapipe
+
+
+def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False):
+    """
+        Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue
+        If raise_stop is true, raises exception when StopIteration received from the source_datapipe
+    """
+    if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer):
+        raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol)
+    source_datapipe = EnsureNonBlockingDataPipe(source_datapipe)
+    forever = True
+    while forever:
+
+        try:
+            # Non-blocking call is Extremely slow here for python.mp, need to figureout good workaround
+            request = protocol.get_new_request(block=blocking_request_get)
+        except communication.protocol.EmptyQueue:
+            yield True
+            continue
+
+        if isinstance(request, communication.messages.ResetIteratorRequest):
+            source_datapipe.reset_iterator()
+            protocol.response_reset()
+
+        elif isinstance(request, communication.messages.TerminateRequest):
+            forever = False
+            protocol.response_terminate()
+
+        elif isinstance(request, communication.messages.GetNextRequest):
+            while forever:
+                try:
+                    value = source_datapipe.nonblocking_next()
+                except NotAvailable:
+                    yield True
+                    continue
+                except StopIteration:
+                    protocol.response_stop()
+                    if full_stop:
+                        forever = False
+                    else:
+                        yield True
+                    break
+                except InvalidStateResetRequired:
+                    protocol.response_invalid()
+                    if full_stop:
+                        forever = False
+                    else:
+                        yield True
+                    break
+                protocol.response_next(value)
+                yield True  # Returns control
+                break
+        else:
+            raise Exception('Unrecognized type of request received', request)
+
+
+class QueueWrapper(NonBlocking):
+    """
+        Creates iter.DataPipe which reads data from the DataLoader.Queue
+    """
+
+    def __init__(self, protocol, response_wait_time=0.00001):
+        if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient):
+            raise Exception('Got', protocol)
+
+        self.protocol = protocol
+        self.counter = 0
+        self._stop_iteration = False
+        self._response_wait_time = response_wait_time
+
+    def reset_iterator(self):
+        self._stop_iteration = False
+        self.counter = 0
+        self.protocol.request_reset()
+        while True:
+            try:
+                self.protocol.get_response_reset()
+                break
+            except communication.protocol.EmptyQueue:
+                if NonBlocking.not_available_hook is not None:
+                    NonBlocking.not_available_hook()
+
+    def nonblocking_next(self):
+        if self._stop_iteration:
+            raise Exception(
+                '`next` or `nonblocking_next` called after receiving StopIteration')
+        if self.protocol.can_take_request():
+            self.protocol.request_next()
+        try:
+            response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time)
+        except communication.protocol.EmptyQueue:
+            raise NotAvailable
+        if isinstance(response, communication.messages.StopIterationResponse):
+            self._stop_iteration = True
+            raise StopIteration
+        if isinstance(response, communication.messages.InvalidStateResponse):
+            raise NotAvailable
+        return response.value
diff --git a/torch/utils/data/communication/messages.py b/torch/utils/data/communication/messages.py
new file mode 100644 (file)
index 0000000..449cf23
--- /dev/null
@@ -0,0 +1,75 @@
+class DataLoaderQueueMessage(object):
+    pass
+
+
+class Request(DataLoaderQueueMessage):
+    pass
+
+
+class Response(DataLoaderQueueMessage):
+    pass
+
+
+class ResetIteratorRequest(Request):
+    pass
+
+
+class ResetIteratorResponse(Response):
+    pass
+
+
+class TerminateRequest(Request):
+    pass
+
+
+class TerminateResponse(Response):
+    pass
+
+
+class LenRequest(Request):
+    pass
+
+
+class LenResponse(Response):
+    __slots__ = ('len')
+
+    def __init__(self, len):
+        self.len = len
+
+
+class GetItemRequest(Request):
+    __slots__ = ('key')
+
+    def __init__(self, key):
+        self.key = key
+
+
+class GetItemResponse(Response):
+    __slots__ = ('key', 'value')
+
+    def __init__(self, key, value):
+        self.key = key
+        self.value = value
+
+
+class GetNextRequest(Request):
+    pass
+
+
+class GetNextResponse(Response):
+    __slots__ = ('value')
+
+    def __init__(self, value):
+        self.value = value
+
+
+class StopIterationResponse(Response):
+    pass
+
+
+class InvalidStateResponse(Response):
+    """
+        Returned by DataPipe when it is expecting to get reset request,
+        for example RouterDataPipe expecting all workers to request reset'
+    """
+    pass
diff --git a/torch/utils/data/communication/protocol.py b/torch/utils/data/communication/protocol.py
new file mode 100644 (file)
index 0000000..68ff335
--- /dev/null
@@ -0,0 +1,159 @@
+from torch.utils.data import communication
+
+
+class Protocol(object):
+    __slots__ = ('request_queue', 'response_queue')
+
+    def __init__(self, request_queue, response_queue):
+        self.request_queue = request_queue
+        self.response_queue = response_queue
+
+
+class ProtocolClient(Protocol):
+    """
+        ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue.
+    """
+    _req_sent = None
+
+    def __init__(self, request_queue, response_queue):
+        self.request_queue = request_queue
+        self.response_queue = response_queue
+        self._req_sent = None
+
+    def can_take_request(self):
+        return self._req_sent is None
+
+    def waiting_for_response(self):
+        return self._req_sent is not None
+
+    def request_sent(self, request=True):
+        if not self.can_take_request():
+            raise Exception('Protocol only supports one request in the Queue')
+        self._req_sent = request
+
+    def request_served(self, result=None):
+        if not self.waiting_for_response():
+            raise Exception(
+                'Expected no peding requests, but something got served', result)
+        self._req_sent = None
+
+
+class ProtocolServer(Protocol):
+    """
+        ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe.
+    """
+    _req_received = None
+
+    def __init__(self, request_queue, response_queue):
+        self.request_queue = request_queue
+        self.response_queue = response_queue
+        self._req_received = None
+
+    def have_pending_request(self):
+        return self._req_received is not None
+
+    def get_new_request(self, block=False):
+        if self.have_pending_request():
+            raise Exception(
+                'Trying to get next request, while having one unserved')
+        try:
+            response = self.request_queue.get(block=block)
+        except Exception as e:  # TODO: Catch only timeout exceptions
+            raise EmptyQueue('queue is empty')
+        self._req_received = response
+        return response
+
+        # TODO: Validate supported requests
+
+    def response_reset(self):
+        if not self.have_pending_request():
+            raise Exception("Attempting to reply with pending request")
+        if not isinstance(self._req_received, communication.messages.ResetIteratorRequest):
+            raise Exception(
+                "Replaying with reset status to other type of message")
+        self.response_queue.put(communication.messages.ResetIteratorResponse())
+        self._req_received = None
+
+    def response_next(self, value):
+        if not self.have_pending_request():
+            raise Exception("Attempting to reply with pending request")
+        self.response_queue.put(communication.messages.GetNextResponse(value))
+        self._req_received = None
+
+    def response_stop(self):
+        if not self.have_pending_request():
+            raise Exception("Attempting to reply with pending request")
+        self.response_queue.put(communication.messages.StopIterationResponse())
+        self._req_received = None
+
+    def response_invalid(self):
+        if not self.have_pending_request():
+            raise Exception("Attempting to reply with pending request")
+        self.response_queue.put(communication.messages.InvalidStateResponse())
+        self._req_received = None
+
+    def response_terminate(self):
+        if not self.have_pending_request():
+            raise Exception("Attempting to reply with pending request")
+        if not isinstance(self._req_received, communication.messages.TerminateRequest):
+            raise Exception(
+                "Replaying with terminate status to other type of message")
+        self.response_queue.put(communication.messages.TerminateResponse())
+        self._req_received = None
+
+
+class MapDataPipeQueueProtocolClient(ProtocolClient):
+    pass
+
+
+class MapDataPipeQueueProtocolServer(ProtocolServer):
+    pass
+
+
+class EmptyQueue(Exception):
+    pass
+
+
+class IterDataPipeQueueProtocolServer(ProtocolServer):
+    pass
+
+
+class IterDataPipeQueueProtocolClient(ProtocolClient):
+    def request_reset(self):
+        if not self.can_take_request():
+            raise Exception(
+                'Can not reset while we are still waiting response for previous request')
+        request = communication.messages.ResetIteratorRequest()
+        self.request_queue.put(request)
+        self.request_sent(request)
+
+    def request_next(self):
+        if not self.can_take_request():
+            raise Exception(
+                'Can not request next item while we are still waiting response for previous request')
+        request = communication.messages.GetNextRequest()
+        self.request_queue.put(request)
+        self.request_sent(request)
+
+    def get_response_reset(self, block=False):
+        try:
+            response = self.response_queue.get(block=block)
+        except Exception as e:  # TODO: Catch only timeout exceptions
+            raise EmptyQueue('queue is empty')
+        self.request_served(response)
+
+        if not isinstance(response, communication.messages.ResetIteratorResponse):
+            raise Exception('Invalid response received')
+
+    def get_response_next(self, block=False, timeout=None):
+        if not self.waiting_for_response():
+            raise Exception(
+                'Can not expect any response without submitted request')
+        try:
+            response = self.response_queue.get(block=block, timeout=timeout)
+        except Exception as e:  # TODO: Catch only timeout exceptions
+            raise EmptyQueue('queue is empty')
+        self.request_served(response)
+
+        # TODO(VitalyFedyunin): Add possible response types validation here
+        return response
diff --git a/torch/utils/data/communication/queue.py b/torch/utils/data/communication/queue.py
new file mode 100644 (file)
index 0000000..7717697
--- /dev/null
@@ -0,0 +1,50 @@
+import threading
+import time
+
+class LocalQueue():
+    ops = 0
+    stored = 0
+    uid = 0
+    empty = 0
+
+    def __init__(self, name='unnamed'):
+        self.items = []
+        self.name = name
+        self.uid = LocalQueue.uid
+        LocalQueue.uid += 1
+
+    def put(self, item, block=True):
+        LocalQueue.ops += 1
+        LocalQueue.stored += 1
+        self.items.append(item)
+
+    def get(self, block=True, timeout=0):
+        # TODO(VitalyFedyunin): Add support of block and timeout arguments
+        LocalQueue.ops += 1
+        if not len(self.items):
+            LocalQueue.empty += 1
+            raise Exception('LocalQueue is empty')
+        LocalQueue.stored -= 1
+        return self.items.pop()
+
+
+class ThreadingQueue():
+    def __init__(self, name='unnamed'):
+        self.lock = threading.Lock()
+        self.items = []
+        self.name = name
+
+    def put(self, item, block=True):
+        with self.lock:
+            self.items.append(item)
+
+    def get(self, block=True, timeout=0):
+        # TODO(VitalyFedyunin): Add support of block and timeout arguments
+        while True:
+            with self.lock:
+                if len(self.items) > 0:
+                    return self.items.pop()
+            if not block:
+                raise Exception("Not available")
+            # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue
+            time.sleep(0.000001)
index ea08529..a74c75c 100644 (file)
@@ -1,10 +1,60 @@
 
 import functools
+import time
+
+from typing import Any, List
 
 import torch.utils.data.backward_compatibility
-from torch.utils.data import DataLoader, IterDataPipe
+
+import torch.utils.data.sharding
+from torch.utils.data import DataLoader, IterDataPipe, communication
 from torch.utils.data.datapipes.iter import IterableWrapper
 
+class _ThreadingDataLoader2:
+
+    def __init__(self, datapipe, num_workers=0, collate_fn=None):
+        self.threads = []
+        self.datapipes = []
+        self.collate_fn = collate_fn
+        for worker_id in range(num_workers):
+            (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe)
+            torch.utils.data.sharding.apply_sharding(thread_localdatapipe, num_workers, worker_id)
+            thread.start()
+            self.threads.append((thread, req_queue, res_queue))
+            local_datapipe = communication.iter.QueueWrapper(
+                communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue))
+            self.datapipes.append(local_datapipe)
+
+    def __iter__(self):
+        not_available = False
+        forever = True
+        exclude_datapipes: List[Any] = []
+        while len(exclude_datapipes) < len(self.datapipes):
+            for dp in self.datapipes:
+                if dp not in exclude_datapipes:
+                    try:
+                        value = dp.nonblocking_next()
+                        yield value
+                    except StopIteration:
+                        exclude_datapipes.append(dp)
+                    except communication.iter.NotAvailable:
+                        not_available = True
+            if not_available:
+                time.sleep(0.001)
+
+    def __del__(self):
+        self._cleanup_all_threads()
+
+    def _cleanup_all_threads(self):
+        def clean_me(thread, req_queue, res_queue):
+            req_queue.put(communication.messages.TerminateRequest())
+            _ = res_queue.get()
+            thread.join()
+
+        for thread, req_queue, res_queue in self.threads:
+            clean_me(thread, req_queue, res_queue)
+
+
 class DataLoader2:
     def __new__(cls,
                 dataset,
@@ -21,15 +71,17 @@ class DataLoader2:
                 *,
                 prefetch_factor=2,
                 persistent_workers=False,
-                batch_outside_worker=False):
+                batch_outside_worker=False,
+                parallelism_mode='mp'):
         if isinstance(dataset, IterDataPipe):
-            datapipe = dataset
+            data_loader: Any = None
             if batch_sampler is not None:
                 raise Exception(
-                    'batch_sampler is not yet supported for DataPipes')
+                    'batch_sampler is not yet supported by DataPipes')
             if sampler is not None:
                 raise Exception(
-                    'sampler is not yet supported for DataPipes')
+                    'sampler is not yet supported by DataPipes')
+            datapipe = dataset
             if shuffle:
                 datapipe = datapipe.shuffle()
             if batch_outside_worker and pin_memory:
@@ -40,30 +92,43 @@ class DataLoader2:
                     datapipe = datapipe.batch(batch_size, drop_last=drop_last)
                     if collate_fn is None:
                         collate_fn = torch.utils.data._utils.collate.default_collate
+            if parallelism_mode == 'mp' or num_workers == 0:
+                def sharding_worker_init_fn(worker_init_fn, worker_id):
+                    if worker_init_fn is not None:
+                        worker_init_fn(worker_id)
+                    torch.utils.data.backward_compatibility.worker_init_fn(
+                        worker_id)
 
-            def sharding_worker_init_fn(worker_init_fn, worker_id):
-                if worker_init_fn is not None:
-                    worker_init_fn(worker_id)
-                torch.utils.data.backward_compatibility.worker_init_fn(
-                    worker_id)
-
-            my_worker_init_fn = functools.partial(
-                sharding_worker_init_fn, worker_init_fn)
-
-            data_loader = DataLoader(datapipe,
-                                     batch_size=None,  # Replaced by .batch DataPipe
-                                     shuffle=False,  # Replaced by .shuffle DataPipe
-                                     sampler=None,
-                                     batch_sampler=None,
-                                     num_workers=num_workers,
-                                     collate_fn=collate_fn,
-                                     pin_memory=pin_memory,
-                                     drop_last=False,  # Replaced by .batch DataPipe
-                                     timeout=timeout,
-                                     worker_init_fn=my_worker_init_fn,
-                                     prefetch_factor=prefetch_factor,
-                                     persistent_workers=persistent_workers)
+                my_worker_init_fn = functools.partial(
+                    sharding_worker_init_fn, worker_init_fn)
 
+                data_loader = DataLoader(datapipe,
+                                         batch_size=None,  # Replaced by .batch DataPipe
+                                         shuffle=False,  # Replaced by .shuffle DataPipe
+                                         sampler=None,
+                                         batch_sampler=None,
+                                         num_workers=num_workers,
+                                         collate_fn=collate_fn,
+                                         pin_memory=pin_memory,
+                                         drop_last=False,  # Replaced by .batch DataPipe
+                                         timeout=timeout,
+                                         worker_init_fn=my_worker_init_fn,
+                                         prefetch_factor=prefetch_factor,
+                                         persistent_workers=persistent_workers)
+            elif parallelism_mode == 'thread':
+                if collate_fn is not None and not batch_outside_worker:
+                    datapipe = datapipe.map(collate_fn)
+                if pin_memory:
+                    raise Exception(
+                        'pin_memory is not yet supported by DataPipes with Threading')
+                if worker_init_fn is not None:
+                    raise Exception(
+                        'worker_init_fn is not yet supported by DataPipes with Threading')
+                data_loader = _ThreadingDataLoader2(datapipe,
+                                                    num_workers=num_workers,
+                                                    collate_fn=collate_fn)
+            else:
+                raise Exception('Unsupported parallelism mode', parallelism_mode)
             if not batch_outside_worker:
                 return data_loader
             else:
@@ -72,8 +137,11 @@ class DataLoader2:
                 datapipe = IterableWrapper(data_loader).batch(
                     batch_size, drop_last=drop_last).map(collate_fn)
                 return datapipe
-
         else:
+            if parallelism_mode != 'thread':
+                raise Exception(
+                    'thread parallelism mode is not supported for old DataSets')
+
             return DataLoader(dataset,
                               batch_size=batch_size,
                               shuffle=shuffle,