Adding DataLoader2 class as future replacement of DataLoader (#63523)
authorVitaly Fedyunin <vitaly.fedyunin@gmail.com>
Fri, 20 Aug 2021 16:00:23 +0000 (09:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 16:01:55 +0000 (09:01 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63523

Supports sharding and batching on loader level**
* #63522 Adding IterableAsDataPipe IterDataPipe
usefull for tests and simple cases

Supports sharding and batching on loader level

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30426527

Pulled By: VitalyFedyunin

fbshipit-source-id: e5905d3364c4880e720dd62fb066f08881c71a6e

test/test_dataloader.py
torch/utils/data/__init__.py
torch/utils/data/dataloader_experimental.py [new file with mode: 0644]

index c68d7e2..71230cf 100644 (file)
@@ -13,9 +13,20 @@ import itertools
 import warnings
 import tempfile
 from torch import multiprocessing as mp
-from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset
+from torch.utils.data import (
+    ChainDataset,
+    ConcatDataset,
+    DataLoader,
+    DataLoader2,
+    Dataset,
+    IterableDataset,
+    Subset,
+    TensorDataset,
+    _utils
+)
 from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
 from torch.utils.data.dataset import random_split
+from torch.utils.data.datapipes.iter import IterableAsDataPipe
 from torch._utils import ExceptionWrapper
 from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
                                                   IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
@@ -1934,6 +1945,18 @@ except RuntimeError as e:
             dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
 
 
+@unittest.skipIf(
+    TEST_WITH_TSAN,
+    "Fails with TSAN with the following error: starting new threads after multi-threaded "
+    "fork is not supported. Dying (set die_after_fork=0 to override)")
+class TestDataLoader2(TestCase):
+    def test_basics(self):
+        dp = IterableAsDataPipe(list(range(10)))
+        dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
+        dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
+        self.assertEquals(list(dl), list(dl2))
+
+
 class StringDataset(Dataset):
     def __init__(self):
         self.s = '12345'
index 1d18b7b..0af9e61 100644 (file)
@@ -11,9 +11,9 @@ from torch.utils.data.sampler import (
 from torch.utils.data.dataset import (
     ChainDataset,
     ConcatDataset,
+    DataChunk,
     Dataset,
     Dataset as MapDataPipe,
-    DataChunk,
     IterableDataset,
     IterableDataset as IterDataPipe,
     Subset,
@@ -34,11 +34,14 @@ from torch.utils.data._decorator import (
     runtime_validation,
     runtime_validation_disabled,
 )
+from torch.utils.data.dataloader_experimental import DataLoader2
+
 
 __all__ = ['BatchSampler',
            'ChainDataset',
            'ConcatDataset',
            'DataLoader',
+           'DataLoader2',
            'Dataset',
            'DistributedSampler',
            'IterDataPipe',
@@ -68,4 +71,3 @@ assert __all__ == sorted(__all__)
 ################################################################################
 # import subpackage
 ################################################################################
-from torch.utils.data import datapipes
diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py
new file mode 100644 (file)
index 0000000..85028af
--- /dev/null
@@ -0,0 +1,89 @@
+
+import functools
+
+import torch.utils.data.backward_compatibility
+from torch.utils.data import DataLoader, IterDataPipe
+from torch.utils.data.datapipes.iter import IterableAsDataPipe
+
+class DataLoader2:
+    def __new__(cls,
+                dataset,
+                batch_size=1,
+                shuffle=False,
+                sampler=None,
+                batch_sampler=None,
+                num_workers=0,
+                collate_fn=None,
+                pin_memory=False,
+                drop_last=False,
+                timeout=0,
+                worker_init_fn=None,
+                *,
+                prefetch_factor=2,
+                persistent_workers=False,
+                batch_outside_worker=False):
+        if isinstance(dataset, IterDataPipe):
+            datapipe = dataset
+            if batch_sampler is not None:
+                raise Exception(
+                    'batch_sampler is not yet supported for DataPipes')
+            if sampler is not None:
+                raise Exception(
+                    'sampler is not yet supported for DataPipes')
+            if shuffle:
+                datapipe = datapipe.shuffle()
+            if batch_outside_worker and pin_memory:
+                raise Exception(
+                    'pin_memory is not yet compatible with batch_outside_worker')
+            if not batch_outside_worker:
+                if batch_size is not None:
+                    datapipe = datapipe.batch(batch_size, drop_last=drop_last)
+                    if collate_fn is None:
+                        collate_fn = torch.utils.data._utils.collate.default_collate
+
+            def sharding_worker_init_fn(worker_init_fn, worker_id):
+                if worker_init_fn is not None:
+                    worker_init_fn(worker_id)
+                torch.utils.data.backward_compatibility.worker_init_fn(
+                    worker_id)
+
+            my_worker_init_fn = functools.partial(
+                sharding_worker_init_fn, worker_init_fn)
+
+            data_loader = DataLoader(datapipe,
+                                     batch_size=None,  # Replaced by .batch DataPipe
+                                     shuffle=False,  # Replaced by .shuffle DataPipe
+                                     sampler=None,
+                                     batch_sampler=None,
+                                     num_workers=num_workers,
+                                     collate_fn=collate_fn,
+                                     pin_memory=pin_memory,
+                                     drop_last=False,  # Replaced by .batch DataPipe
+                                     timeout=timeout,
+                                     worker_init_fn=my_worker_init_fn,
+                                     prefetch_factor=prefetch_factor,
+                                     persistent_workers=persistent_workers)
+
+            if not batch_outside_worker:
+                return data_loader
+            else:
+                if collate_fn is None:
+                    collate_fn = torch.utils.data._utils.collate.default_collate
+                datapipe = IterableAsDataPipe(data_loader).batch(
+                    batch_size, drop_last=drop_last).map(collate_fn)
+                return datapipe
+
+        else:
+            return DataLoader(dataset,
+                              batch_size=batch_size,
+                              shuffle=shuffle,
+                              sampler=sampler,
+                              batch_sampler=batch_sampler,
+                              num_workers=num_workers,
+                              collate_fn=collate_fn,
+                              pin_memory=pin_memory,
+                              drop_last=drop_last,
+                              timeout=timeout,
+                              worker_init_fn=worker_init_fn,
+                              prefetch_factor=prefetch_factor,
+                              persistent_workers=persistent_workers)