Revert D30426527: Adding DataLoader2 class as future replacement of DataLoader
authorAlban Desmaison <albandes@fb.com>
Fri, 20 Aug 2021 19:05:32 +0000 (12:05 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 19:06:52 +0000 (12:06 -0700)
Test Plan: revert-hammer

Differential Revision:
D30426527 (https://github.com/pytorch/pytorch/commit/5a7133b87fe2fd7d025d36855ed4cc06539a9299)

Original commit changeset: e5905d3364c4

fbshipit-source-id: 794d8a4e9256ccff8cf894aee10eff6adc30d502

test/test_dataloader.py
torch/utils/data/__init__.py
torch/utils/data/dataloader_experimental.py [deleted file]

index 71230cf..c68d7e2 100644 (file)
@@ -13,20 +13,9 @@ import itertools
 import warnings
 import tempfile
 from torch import multiprocessing as mp
-from torch.utils.data import (
-    ChainDataset,
-    ConcatDataset,
-    DataLoader,
-    DataLoader2,
-    Dataset,
-    IterableDataset,
-    Subset,
-    TensorDataset,
-    _utils
-)
+from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset
 from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
 from torch.utils.data.dataset import random_split
-from torch.utils.data.datapipes.iter import IterableAsDataPipe
 from torch._utils import ExceptionWrapper
 from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
                                                   IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
@@ -1945,18 +1934,6 @@ except RuntimeError as e:
             dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
 
 
-@unittest.skipIf(
-    TEST_WITH_TSAN,
-    "Fails with TSAN with the following error: starting new threads after multi-threaded "
-    "fork is not supported. Dying (set die_after_fork=0 to override)")
-class TestDataLoader2(TestCase):
-    def test_basics(self):
-        dp = IterableAsDataPipe(list(range(10)))
-        dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
-        dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
-        self.assertEquals(list(dl), list(dl2))
-
-
 class StringDataset(Dataset):
     def __init__(self):
         self.s = '12345'
index 0af9e61..1d18b7b 100644 (file)
@@ -11,9 +11,9 @@ from torch.utils.data.sampler import (
 from torch.utils.data.dataset import (
     ChainDataset,
     ConcatDataset,
-    DataChunk,
     Dataset,
     Dataset as MapDataPipe,
+    DataChunk,
     IterableDataset,
     IterableDataset as IterDataPipe,
     Subset,
@@ -34,14 +34,11 @@ from torch.utils.data._decorator import (
     runtime_validation,
     runtime_validation_disabled,
 )
-from torch.utils.data.dataloader_experimental import DataLoader2
-
 
 __all__ = ['BatchSampler',
            'ChainDataset',
            'ConcatDataset',
            'DataLoader',
-           'DataLoader2',
            'Dataset',
            'DistributedSampler',
            'IterDataPipe',
@@ -71,3 +68,4 @@ assert __all__ == sorted(__all__)
 ################################################################################
 # import subpackage
 ################################################################################
+from torch.utils.data import datapipes
diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py
deleted file mode 100644 (file)
index 85028af..0000000
+++ /dev/null
@@ -1,89 +0,0 @@
-
-import functools
-
-import torch.utils.data.backward_compatibility
-from torch.utils.data import DataLoader, IterDataPipe
-from torch.utils.data.datapipes.iter import IterableAsDataPipe
-
-class DataLoader2:
-    def __new__(cls,
-                dataset,
-                batch_size=1,
-                shuffle=False,
-                sampler=None,
-                batch_sampler=None,
-                num_workers=0,
-                collate_fn=None,
-                pin_memory=False,
-                drop_last=False,
-                timeout=0,
-                worker_init_fn=None,
-                *,
-                prefetch_factor=2,
-                persistent_workers=False,
-                batch_outside_worker=False):
-        if isinstance(dataset, IterDataPipe):
-            datapipe = dataset
-            if batch_sampler is not None:
-                raise Exception(
-                    'batch_sampler is not yet supported for DataPipes')
-            if sampler is not None:
-                raise Exception(
-                    'sampler is not yet supported for DataPipes')
-            if shuffle:
-                datapipe = datapipe.shuffle()
-            if batch_outside_worker and pin_memory:
-                raise Exception(
-                    'pin_memory is not yet compatible with batch_outside_worker')
-            if not batch_outside_worker:
-                if batch_size is not None:
-                    datapipe = datapipe.batch(batch_size, drop_last=drop_last)
-                    if collate_fn is None:
-                        collate_fn = torch.utils.data._utils.collate.default_collate
-
-            def sharding_worker_init_fn(worker_init_fn, worker_id):
-                if worker_init_fn is not None:
-                    worker_init_fn(worker_id)
-                torch.utils.data.backward_compatibility.worker_init_fn(
-                    worker_id)
-
-            my_worker_init_fn = functools.partial(
-                sharding_worker_init_fn, worker_init_fn)
-
-            data_loader = DataLoader(datapipe,
-                                     batch_size=None,  # Replaced by .batch DataPipe
-                                     shuffle=False,  # Replaced by .shuffle DataPipe
-                                     sampler=None,
-                                     batch_sampler=None,
-                                     num_workers=num_workers,
-                                     collate_fn=collate_fn,
-                                     pin_memory=pin_memory,
-                                     drop_last=False,  # Replaced by .batch DataPipe
-                                     timeout=timeout,
-                                     worker_init_fn=my_worker_init_fn,
-                                     prefetch_factor=prefetch_factor,
-                                     persistent_workers=persistent_workers)
-
-            if not batch_outside_worker:
-                return data_loader
-            else:
-                if collate_fn is None:
-                    collate_fn = torch.utils.data._utils.collate.default_collate
-                datapipe = IterableAsDataPipe(data_loader).batch(
-                    batch_size, drop_last=drop_last).map(collate_fn)
-                return datapipe
-
-        else:
-            return DataLoader(dataset,
-                              batch_size=batch_size,
-                              shuffle=shuffle,
-                              sampler=sampler,
-                              batch_sampler=batch_sampler,
-                              num_workers=num_workers,
-                              collate_fn=collate_fn,
-                              pin_memory=pin_memory,
-                              drop_last=drop_last,
-                              timeout=timeout,
-                              worker_init_fn=worker_init_fn,
-                              prefetch_factor=prefetch_factor,
-                              persistent_workers=persistent_workers)