return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
(temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
+# Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list
+# Then, reset the DataPipe and return a tuple of two lists
+# 1. A list of elements yielded before the reset
+# 2. A list of all elements of the DataPipe after the reset
+def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]:
+ it = iter(datapipe)
+ res_before_reset = []
+ for _ in range(n):
+ res_before_reset.append(next(it))
+ return res_before_reset, list(datapipe)
+
class TestDataChunk(TestCase):
def setUp(self):
self.assertEqual(data_ref[1].read(), f.read())
data_ref[1].close()
- # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
+
def test_readfilesfromzip_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip")
myzip.write(self.temp_files[1])
myzip.write(self.temp_files[2])
datapipe1 = dp.iter.FileLister(temp_dir, '*.zip')
- datapipe2 = dp.iter.FileLoader(datapipe1)
- datapipe3 = dp.iter.ZipArchiveReader(datapipe2)
- # read extracted files before reaching the end of the zipfile
- for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files):
+ datapipe2 = dp.iter.ZipArchiveReader(datapipe1)
+
+ # Test Case: read extracted files before reaching the end of the zipfile
+ for rec, temp_file in itertools.zip_longest(datapipe2, self.temp_files):
self.assertTrue(rec is not None and temp_file is not None)
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
with open(temp_file, 'rb') as f:
self.assertEqual(rec[1].read(), f.read())
rec[1].close()
- # read extracted files before reaching the end of the zipile
- data_refs = list(datapipe3)
+ # Test Case: read extracted files after reaching the end of the zipile
+ data_refs = list(datapipe2)
self.assertEqual(len(data_refs), len(self.temp_files))
for data_ref, temp_file in zip(data_refs, self.temp_files):
self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file))
self.assertEqual(data_ref[1].read(), f.read())
data_ref[1].close()
+ # Test Case: reset the DataPipe after reading part of it
+ n_elements_before_reset = 1
+ res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe2, n_elements_before_reset)
+ # Check the results accumulated before reset
+ self.assertEqual(len(res_before_reset), n_elements_before_reset)
+ for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files):
+ self.assertEqual(os.path.basename(ele_before_reset[0]), os.path.basename(temp_file))
+ with open(temp_file, 'rb') as f:
+ self.assertEqual(ele_before_reset[1].read(), f.read())
+ ele_before_reset[1].close()
+ # Check the results accumulated after reset
+ self.assertEqual(len(res_after_reset), len(self.temp_files))
+ for ele_after_reset, temp_file in zip(res_after_reset, self.temp_files):
+ self.assertEqual(os.path.basename(ele_after_reset[0]), os.path.basename(temp_file))
+ with open(temp_file, 'rb') as f:
+ self.assertEqual(ele_after_reset[1].read(), f.read())
+ ele_after_reset[1].close()
+
def test_routeddecoder_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
from torch.utils.data import IterDataPipe
-from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
-from typing import Iterable, Iterator, Tuple, IO, cast
+from typing import Iterable, Iterator, Tuple
from io import BufferedIOBase
import os
class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
r""" :class:`ZipArchiveReaderIterDataPipe`.
- Iterable data pipe to extract zip binary streams from input iterable which contains tuples of
- pathname and zip binary stream, yields pathname and extracted binary stream in a tuple.
+ Iterable data pipe to extract zip binary streams from input iterable which contains
+ pathnames, yields a tuple of pathname and extracted binary stream.
Args:
- datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples
+ datapipe: Iterable datapipe that provides pathnames
length: Nominal length of the datapipe
Note:
"""
def __init__(
self,
- datapipe: Iterable[Tuple[str, BufferedIOBase]],
+ datapipe: Iterable[str],
length: int = -1):
super().__init__()
- self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
+ self.datapipe: Iterable[str] = datapipe
self.length: int = length
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
- if not isinstance(self.datapipe, Iterable):
- raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
- for data in self.datapipe:
- validate_pathname_binary_tuple(data)
- pathname, data_stream = data
+ for pathname in self.datapipe:
+ if not isinstance(pathname, str):
+ raise TypeError(f"pathname should be of string type, but is type {type(pathname)}")
try:
- # typing.cast is used here to silence mypy's type checker
- zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
+ zips = zipfile.ZipFile(pathname)
for zipinfo in zips.infolist():
# major version should always be 3 here.
if sys.version_info[1] >= 6:
continue
elif zipinfo.filename.endswith('/'):
continue
-
extracted_fobj = zips.open(zipinfo)
inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
- # Add a reference of the source zipfile into extracted_fobj, so the source
- # zipfile handle won't be released until all the extracted file objs are destroyed.
- extracted_fobj.source_ref = zips # type: ignore[attr-defined]
- # typing.cast is used here to silence mypy's type checker
- yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
+ yield (inner_pathname, extracted_fobj) # type: ignore[misc]
except Exception as e:
warnings.warn(
- "Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e))
+ f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
raise e
def __len__(self):
if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
return self.length