import warnings
import tempfile
from torch import multiprocessing as mp
-from torch.utils.data import (
- ChainDataset,
- ConcatDataset,
- DataLoader,
- DataLoader2,
- Dataset,
- IterableDataset,
- Subset,
- TensorDataset,
- _utils
-)
+from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataset import random_split
-from torch.utils.data.datapipes.iter import IterableAsDataPipe
from torch._utils import ExceptionWrapper
from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
-@unittest.skipIf(
- TEST_WITH_TSAN,
- "Fails with TSAN with the following error: starting new threads after multi-threaded "
- "fork is not supported. Dying (set die_after_fork=0 to override)")
-class TestDataLoader2(TestCase):
- def test_basics(self):
- dp = IterableAsDataPipe(list(range(10)))
- dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
- dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
- self.assertEquals(list(dl), list(dl2))
-
-
class StringDataset(Dataset):
def __init__(self):
self.s = '12345'
from torch.utils.data.dataset import (
ChainDataset,
ConcatDataset,
- DataChunk,
Dataset,
Dataset as MapDataPipe,
+ DataChunk,
IterableDataset,
IterableDataset as IterDataPipe,
Subset,
runtime_validation,
runtime_validation_disabled,
)
-from torch.utils.data.dataloader_experimental import DataLoader2
-
__all__ = ['BatchSampler',
'ChainDataset',
'ConcatDataset',
'DataLoader',
- 'DataLoader2',
'Dataset',
'DistributedSampler',
'IterDataPipe',
################################################################################
# import subpackage
################################################################################
+from torch.utils.data import datapipes
+++ /dev/null
-
-import functools
-
-import torch.utils.data.backward_compatibility
-from torch.utils.data import DataLoader, IterDataPipe
-from torch.utils.data.datapipes.iter import IterableAsDataPipe
-
-class DataLoader2:
- def __new__(cls,
- dataset,
- batch_size=1,
- shuffle=False,
- sampler=None,
- batch_sampler=None,
- num_workers=0,
- collate_fn=None,
- pin_memory=False,
- drop_last=False,
- timeout=0,
- worker_init_fn=None,
- *,
- prefetch_factor=2,
- persistent_workers=False,
- batch_outside_worker=False):
- if isinstance(dataset, IterDataPipe):
- datapipe = dataset
- if batch_sampler is not None:
- raise Exception(
- 'batch_sampler is not yet supported for DataPipes')
- if sampler is not None:
- raise Exception(
- 'sampler is not yet supported for DataPipes')
- if shuffle:
- datapipe = datapipe.shuffle()
- if batch_outside_worker and pin_memory:
- raise Exception(
- 'pin_memory is not yet compatible with batch_outside_worker')
- if not batch_outside_worker:
- if batch_size is not None:
- datapipe = datapipe.batch(batch_size, drop_last=drop_last)
- if collate_fn is None:
- collate_fn = torch.utils.data._utils.collate.default_collate
-
- def sharding_worker_init_fn(worker_init_fn, worker_id):
- if worker_init_fn is not None:
- worker_init_fn(worker_id)
- torch.utils.data.backward_compatibility.worker_init_fn(
- worker_id)
-
- my_worker_init_fn = functools.partial(
- sharding_worker_init_fn, worker_init_fn)
-
- data_loader = DataLoader(datapipe,
- batch_size=None, # Replaced by .batch DataPipe
- shuffle=False, # Replaced by .shuffle DataPipe
- sampler=None,
- batch_sampler=None,
- num_workers=num_workers,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=False, # Replaced by .batch DataPipe
- timeout=timeout,
- worker_init_fn=my_worker_init_fn,
- prefetch_factor=prefetch_factor,
- persistent_workers=persistent_workers)
-
- if not batch_outside_worker:
- return data_loader
- else:
- if collate_fn is None:
- collate_fn = torch.utils.data._utils.collate.default_collate
- datapipe = IterableAsDataPipe(data_loader).batch(
- batch_size, drop_last=drop_last).map(collate_fn)
- return datapipe
-
- else:
- return DataLoader(dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- sampler=sampler,
- batch_sampler=batch_sampler,
- num_workers=num_workers,
- collate_fn=collate_fn,
- pin_memory=pin_memory,
- drop_last=drop_last,
- timeout=timeout,
- worker_init_fn=worker_init_fn,
- prefetch_factor=prefetch_factor,
- persistent_workers=persistent_workers)