From: Vitaly Fedyunin Date: Fri, 20 Aug 2021 16:00:23 +0000 (-0700) Subject: Adding DataLoader2 class as future replacement of DataLoader (#63523) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~849 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5a7133b87fe2fd7d025d36855ed4cc06539a9299;p=platform%2Fupstream%2Fpytorch.git Adding DataLoader2 class as future replacement of DataLoader (#63523) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63523 Supports sharding and batching on loader level** * #63522 Adding IterableAsDataPipe IterDataPipe usefull for tests and simple cases Supports sharding and batching on loader level Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30426527 Pulled By: VitalyFedyunin fbshipit-source-id: e5905d3364c4880e720dd62fb066f08881c71a6e --- diff --git a/test/test_dataloader.py b/test/test_dataloader.py index c68d7e2..71230cf 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -13,9 +13,20 @@ import itertools import warnings import tempfile from torch import multiprocessing as mp -from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset, Subset +from torch.utils.data import ( + ChainDataset, + ConcatDataset, + DataLoader, + DataLoader2, + Dataset, + IterableDataset, + Subset, + TensorDataset, + _utils +) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split +from torch.utils.data.datapipes.iter import IterableAsDataPipe from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, @@ -1934,6 +1945,18 @@ except RuntimeError as e: dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) +@unittest.skipIf( + TEST_WITH_TSAN, + "Fails with TSAN with the following error: starting new threads after multi-threaded " + "fork is not supported. Dying (set die_after_fork=0 to override)") +class TestDataLoader2(TestCase): + def test_basics(self): + dp = IterableAsDataPipe(list(range(10))) + dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2) + dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2) + self.assertEquals(list(dl), list(dl2)) + + class StringDataset(Dataset): def __init__(self): self.s = '12345' diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 1d18b7b..0af9e61 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -11,9 +11,9 @@ from torch.utils.data.sampler import ( from torch.utils.data.dataset import ( ChainDataset, ConcatDataset, + DataChunk, Dataset, Dataset as MapDataPipe, - DataChunk, IterableDataset, IterableDataset as IterDataPipe, Subset, @@ -34,11 +34,14 @@ from torch.utils.data._decorator import ( runtime_validation, runtime_validation_disabled, ) +from torch.utils.data.dataloader_experimental import DataLoader2 + __all__ = ['BatchSampler', 'ChainDataset', 'ConcatDataset', 'DataLoader', + 'DataLoader2', 'Dataset', 'DistributedSampler', 'IterDataPipe', @@ -68,4 +71,3 @@ assert __all__ == sorted(__all__) ################################################################################ # import subpackage ################################################################################ -from torch.utils.data import datapipes diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py new file mode 100644 index 0000000..85028af --- /dev/null +++ b/torch/utils/data/dataloader_experimental.py @@ -0,0 +1,89 @@ + +import functools + +import torch.utils.data.backward_compatibility +from torch.utils.data import DataLoader, IterDataPipe +from torch.utils.data.datapipes.iter import IterableAsDataPipe + +class DataLoader2: + def __new__(cls, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + *, + prefetch_factor=2, + persistent_workers=False, + batch_outside_worker=False): + if isinstance(dataset, IterDataPipe): + datapipe = dataset + if batch_sampler is not None: + raise Exception( + 'batch_sampler is not yet supported for DataPipes') + if sampler is not None: + raise Exception( + 'sampler is not yet supported for DataPipes') + if shuffle: + datapipe = datapipe.shuffle() + if batch_outside_worker and pin_memory: + raise Exception( + 'pin_memory is not yet compatible with batch_outside_worker') + if not batch_outside_worker: + if batch_size is not None: + datapipe = datapipe.batch(batch_size, drop_last=drop_last) + if collate_fn is None: + collate_fn = torch.utils.data._utils.collate.default_collate + + def sharding_worker_init_fn(worker_init_fn, worker_id): + if worker_init_fn is not None: + worker_init_fn(worker_id) + torch.utils.data.backward_compatibility.worker_init_fn( + worker_id) + + my_worker_init_fn = functools.partial( + sharding_worker_init_fn, worker_init_fn) + + data_loader = DataLoader(datapipe, + batch_size=None, # Replaced by .batch DataPipe + shuffle=False, # Replaced by .shuffle DataPipe + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=False, # Replaced by .batch DataPipe + timeout=timeout, + worker_init_fn=my_worker_init_fn, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + + if not batch_outside_worker: + return data_loader + else: + if collate_fn is None: + collate_fn = torch.utils.data._utils.collate.default_collate + datapipe = IterableAsDataPipe(data_loader).batch( + batch_size, drop_last=drop_last).map(collate_fn) + return datapipe + + else: + return DataLoader(dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, + worker_init_fn=worker_init_fn, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers)