)
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataset import random_split
-from torch.utils.data.datapipes.iter import IterableAsDataPipe
+from torch.utils.data.datapipes.iter import IterableWrapper
from torch._utils import ExceptionWrapper
from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS,
IS_IN_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest,
class TestDataLoader2(TestCase):
@skipIfNoDill
def test_basics(self):
- dp = IterableAsDataPipe(list(range(10)))
+ dp = IterableWrapper(list(range(10)))
dl = DataLoader(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
dl2 = DataLoader2(dp, batch_size=3, collate_fn=lambda x: x, num_workers=2)
self.assertEquals(list(dl), list(dl2))
import torch.utils.data.backward_compatibility
from torch.utils.data import DataLoader, IterDataPipe
-from torch.utils.data.datapipes.iter import IterableAsDataPipe
+from torch.utils.data.datapipes.iter import IterableWrapper
class DataLoader2:
def __new__(cls,
else:
if collate_fn is None:
collate_fn = torch.utils.data._utils.collate.default_collate
- datapipe = IterableAsDataPipe(data_loader).batch(
+ datapipe = IterableWrapper(data_loader).batch(
batch_size, drop_last=drop_last).map(collate_fn)
return datapipe
ZipArchiveReaderIterDataPipe as ZipArchiveReader,
)
from torch.utils.data.datapipes.iter.utils import (
- IterableAsDataPipeIterDataPipe as IterableAsDataPipe,
+ IterableWrapperIterDataPipe as IterableWrapper,
)
__all__ = ['Batcher',
'FileLoader',
'Filter',
'HttpReader',
- 'IterableAsDataPipe',
+ 'IterableWrapper',
'LineReader',
'Mapper',
'RoutedDecoder',
from torch.utils.data import IterDataPipe
-class IterableAsDataPipeIterDataPipe(IterDataPipe):
+class IterableWrapperIterDataPipe(IterDataPipe):
def __init__(self, iterable):
self.iterable = iterable
def __iter__(self):
for data in self.iterable:
yield data
+
+ def __len__(self):
+ return len(self.iterable)