From f3f410880a71068bf2649efd977167972a85274b Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Fri, 10 Sep 2021 14:22:36 -0700 Subject: [PATCH] [DataPipe] Remove ZipArchiveReader's dependency on FileLoader (#64786) Summary: Stack from [ghstack](https://github.com/ezyang/ghstack): * https://github.com/pytorch/pytorch/issues/64788 * __->__ https://github.com/pytorch/pytorch/issues/64786 This PR removes ZipArchiveReader's dependency on FileLoader DataPipe, by allowing it to use a IterDataPipe of path names as input rather than a tuple of path name and a stream. It also adds additional tests to ensure that the DataPipe is functioning properly when it is read multiple times or reset half way through reading. The whole stack fixes issues related to unclosed buffer stream (see https://github.com/pytorch/pytorch/issues/64281). cc VitalyFedyunin ejguan Pull Request resolved: https://github.com/pytorch/pytorch/pull/64786 Reviewed By: ngimel Differential Revision: D30870968 Pulled By: NivekT fbshipit-source-id: 64b04d1697b99772f2fa20fc141668e6b8e18c41 --- test/test_datapipe.py | 43 ++++++++++++++++++---- .../utils/data/datapipes/iter/ziparchivereader.py | 35 +++++++----------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index b77d0a1..c82f81b 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -100,6 +100,17 @@ def create_temp_dir_and_files(): return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name), (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)] +# Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list +# Then, reset the DataPipe and return a tuple of two lists +# 1. A list of elements yielded before the reset +# 2. A list of all elements of the DataPipe after the reset +def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]: + it = iter(datapipe) + res_before_reset = [] + for _ in range(n): + res_before_reset.append(next(it)) + return res_before_reset, list(datapipe) + class TestDataChunk(TestCase): def setUp(self): @@ -231,7 +242,7 @@ class TestIterableDataPipeBasic(TestCase): self.assertEqual(data_ref[1].read(), f.read()) data_ref[1].close() - # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate + def test_readfilesfromzip_iterable_datapipe(self): temp_dir = self.temp_dir.name temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip") @@ -240,17 +251,17 @@ class TestIterableDataPipeBasic(TestCase): myzip.write(self.temp_files[1]) myzip.write(self.temp_files[2]) datapipe1 = dp.iter.FileLister(temp_dir, '*.zip') - datapipe2 = dp.iter.FileLoader(datapipe1) - datapipe3 = dp.iter.ZipArchiveReader(datapipe2) - # read extracted files before reaching the end of the zipfile - for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files): + datapipe2 = dp.iter.ZipArchiveReader(datapipe1) + + # Test Case: read extracted files before reaching the end of the zipfile + for rec, temp_file in itertools.zip_longest(datapipe2, self.temp_files): self.assertTrue(rec is not None and temp_file is not None) self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file)) with open(temp_file, 'rb') as f: self.assertEqual(rec[1].read(), f.read()) rec[1].close() - # read extracted files before reaching the end of the zipile - data_refs = list(datapipe3) + # Test Case: read extracted files after reaching the end of the zipile + data_refs = list(datapipe2) self.assertEqual(len(data_refs), len(self.temp_files)) for data_ref, temp_file in zip(data_refs, self.temp_files): self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file)) @@ -258,6 +269,24 @@ class TestIterableDataPipeBasic(TestCase): self.assertEqual(data_ref[1].read(), f.read()) data_ref[1].close() + # Test Case: reset the DataPipe after reading part of it + n_elements_before_reset = 1 + res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe2, n_elements_before_reset) + # Check the results accumulated before reset + self.assertEqual(len(res_before_reset), n_elements_before_reset) + for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files): + self.assertEqual(os.path.basename(ele_before_reset[0]), os.path.basename(temp_file)) + with open(temp_file, 'rb') as f: + self.assertEqual(ele_before_reset[1].read(), f.read()) + ele_before_reset[1].close() + # Check the results accumulated after reset + self.assertEqual(len(res_after_reset), len(self.temp_files)) + for ele_after_reset, temp_file in zip(res_after_reset, self.temp_files): + self.assertEqual(os.path.basename(ele_after_reset[0]), os.path.basename(temp_file)) + with open(temp_file, 'rb') as f: + self.assertEqual(ele_after_reset[1].read(), f.read()) + ele_after_reset[1].close() + def test_routeddecoder_iterable_datapipe(self): temp_dir = self.temp_dir.name temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png") diff --git a/torch/utils/data/datapipes/iter/ziparchivereader.py b/torch/utils/data/datapipes/iter/ziparchivereader.py index 881d005..b0ac4a0 100644 --- a/torch/utils/data/datapipes/iter/ziparchivereader.py +++ b/torch/utils/data/datapipes/iter/ziparchivereader.py @@ -1,6 +1,5 @@ from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple -from typing import Iterable, Iterator, Tuple, IO, cast +from typing import Iterable, Iterator, Tuple from io import BufferedIOBase import os @@ -11,11 +10,11 @@ import warnings class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): r""" :class:`ZipArchiveReaderIterDataPipe`. - Iterable data pipe to extract zip binary streams from input iterable which contains tuples of - pathname and zip binary stream, yields pathname and extracted binary stream in a tuple. + Iterable data pipe to extract zip binary streams from input iterable which contains + pathnames, yields a tuple of pathname and extracted binary stream. Args: - datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples + datapipe: Iterable datapipe that provides pathnames length: Nominal length of the datapipe Note: @@ -25,21 +24,18 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): """ def __init__( self, - datapipe: Iterable[Tuple[str, BufferedIOBase]], + datapipe: Iterable[str], length: int = -1): super().__init__() - self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe + self.datapipe: Iterable[str] = datapipe self.length: int = length def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: - if not isinstance(self.datapipe, Iterable): - raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe))) - for data in self.datapipe: - validate_pathname_binary_tuple(data) - pathname, data_stream = data + for pathname in self.datapipe: + if not isinstance(pathname, str): + raise TypeError(f"pathname should be of string type, but is type {type(pathname)}") try: - # typing.cast is used here to silence mypy's type checker - zips = zipfile.ZipFile(cast(IO[bytes], data_stream)) + zips = zipfile.ZipFile(pathname) for zipinfo in zips.infolist(): # major version should always be 3 here. if sys.version_info[1] >= 6: @@ -47,20 +43,15 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): continue elif zipinfo.filename.endswith('/'): continue - extracted_fobj = zips.open(zipinfo) inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename)) - # Add a reference of the source zipfile into extracted_fobj, so the source - # zipfile handle won't be released until all the extracted file objs are destroyed. - extracted_fobj.source_ref = zips # type: ignore[attr-defined] - # typing.cast is used here to silence mypy's type checker - yield (inner_pathname, cast(BufferedIOBase, extracted_fobj)) + yield (inner_pathname, extracted_fobj) # type: ignore[misc] except Exception as e: warnings.warn( - "Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e)) + f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!") raise e def __len__(self): if self.length == -1: - raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) + raise TypeError(f"{type(self).__name__} instance doesn't have valid length") return self.length -- 2.7.4