[DataPipe] Remove ZipArchiveReader's dependency on FileLoader (#64786)
authorKevin Tse <ktse@fb.com>
Fri, 10 Sep 2021 21:22:36 +0000 (14:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 23:49:17 +0000 (16:49 -0700)
Summary:
Stack from [ghstack](https://github.com/ezyang/ghstack):
* https://github.com/pytorch/pytorch/issues/64788
* __->__ https://github.com/pytorch/pytorch/issues/64786

This PR removes ZipArchiveReader's dependency on FileLoader DataPipe, by allowing it to use a IterDataPipe of path names as input rather than a tuple of path name and a stream.

It also adds additional tests to ensure that the DataPipe is functioning properly when it is read multiple times or reset half way through reading.

The whole stack fixes issues related to unclosed buffer stream (see https://github.com/pytorch/pytorch/issues/64281).

cc VitalyFedyunin ejguan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64786

Reviewed By: ngimel

Differential Revision: D30870968

Pulled By: NivekT

fbshipit-source-id: 64b04d1697b99772f2fa20fc141668e6b8e18c41

test/test_datapipe.py
torch/utils/data/datapipes/iter/ziparchivereader.py

index b77d0a1..c82f81b 100644 (file)
@@ -100,6 +100,17 @@ def create_temp_dir_and_files():
     return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name),
             (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)]
 
+# Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list
+# Then, reset the DataPipe and return a tuple of two lists
+# 1. A list of elements yielded before the reset
+# 2. A list of all elements of the DataPipe after the reset
+def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]:
+    it = iter(datapipe)
+    res_before_reset = []
+    for _ in range(n):
+        res_before_reset.append(next(it))
+    return res_before_reset, list(datapipe)
+
 
 class TestDataChunk(TestCase):
     def setUp(self):
@@ -231,7 +242,7 @@ class TestIterableDataPipeBasic(TestCase):
                 self.assertEqual(data_ref[1].read(), f.read())
             data_ref[1].close()
 
-    # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
+
     def test_readfilesfromzip_iterable_datapipe(self):
         temp_dir = self.temp_dir.name
         temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip")
@@ -240,17 +251,17 @@ class TestIterableDataPipeBasic(TestCase):
             myzip.write(self.temp_files[1])
             myzip.write(self.temp_files[2])
         datapipe1 = dp.iter.FileLister(temp_dir, '*.zip')
-        datapipe2 = dp.iter.FileLoader(datapipe1)
-        datapipe3 = dp.iter.ZipArchiveReader(datapipe2)
-        # read extracted files before reaching the end of the zipfile
-        for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files):
+        datapipe2 = dp.iter.ZipArchiveReader(datapipe1)
+
+        # Test Case: read extracted files before reaching the end of the zipfile
+        for rec, temp_file in itertools.zip_longest(datapipe2, self.temp_files):
             self.assertTrue(rec is not None and temp_file is not None)
             self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
             with open(temp_file, 'rb') as f:
                 self.assertEqual(rec[1].read(), f.read())
             rec[1].close()
-        # read extracted files before reaching the end of the zipile
-        data_refs = list(datapipe3)
+        # Test Case: read extracted files after reaching the end of the zipile
+        data_refs = list(datapipe2)
         self.assertEqual(len(data_refs), len(self.temp_files))
         for data_ref, temp_file in zip(data_refs, self.temp_files):
             self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file))
@@ -258,6 +269,24 @@ class TestIterableDataPipeBasic(TestCase):
                 self.assertEqual(data_ref[1].read(), f.read())
             data_ref[1].close()
 
+        # Test Case: reset the DataPipe after reading part of it
+        n_elements_before_reset = 1
+        res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe2, n_elements_before_reset)
+        # Check the results accumulated before reset
+        self.assertEqual(len(res_before_reset), n_elements_before_reset)
+        for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files):
+            self.assertEqual(os.path.basename(ele_before_reset[0]), os.path.basename(temp_file))
+            with open(temp_file, 'rb') as f:
+                self.assertEqual(ele_before_reset[1].read(), f.read())
+            ele_before_reset[1].close()
+        # Check the results accumulated after reset
+        self.assertEqual(len(res_after_reset), len(self.temp_files))
+        for ele_after_reset, temp_file in zip(res_after_reset, self.temp_files):
+            self.assertEqual(os.path.basename(ele_after_reset[0]), os.path.basename(temp_file))
+            with open(temp_file, 'rb') as f:
+                self.assertEqual(ele_after_reset[1].read(), f.read())
+            ele_after_reset[1].close()
+
     def test_routeddecoder_iterable_datapipe(self):
         temp_dir = self.temp_dir.name
         temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
index 881d005..b0ac4a0 100644 (file)
@@ -1,6 +1,5 @@
 from torch.utils.data import IterDataPipe
-from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
-from typing import Iterable, Iterator, Tuple, IO, cast
+from typing import Iterable, Iterator, Tuple
 from io import BufferedIOBase
 
 import os
@@ -11,11 +10,11 @@ import warnings
 class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     r""" :class:`ZipArchiveReaderIterDataPipe`.
 
-    Iterable data pipe to extract zip binary streams from input iterable which contains tuples of
-    pathname and zip binary stream, yields pathname and extracted binary stream in a tuple.
+    Iterable data pipe to extract zip binary streams from input iterable which contains
+    pathnames, yields a tuple of pathname and extracted binary stream.
 
     Args:
-        datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples
+        datapipe: Iterable datapipe that provides pathnames
         length: Nominal length of the datapipe
 
     Note:
@@ -25,21 +24,18 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     """
     def __init__(
             self,
-            datapipe: Iterable[Tuple[str, BufferedIOBase]],
+            datapipe: Iterable[str],
             length: int = -1):
         super().__init__()
-        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
+        self.datapipe: Iterable[str] = datapipe
         self.length: int = length
 
     def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
-        if not isinstance(self.datapipe, Iterable):
-            raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
-        for data in self.datapipe:
-            validate_pathname_binary_tuple(data)
-            pathname, data_stream = data
+        for pathname in self.datapipe:
+            if not isinstance(pathname, str):
+                raise TypeError(f"pathname should be of string type, but is type {type(pathname)}")
             try:
-                # typing.cast is used here to silence mypy's type checker
-                zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
+                zips = zipfile.ZipFile(pathname)
                 for zipinfo in zips.infolist():
                     # major version should always be 3 here.
                     if sys.version_info[1] >= 6:
@@ -47,20 +43,15 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
                             continue
                     elif zipinfo.filename.endswith('/'):
                         continue
-
                     extracted_fobj = zips.open(zipinfo)
                     inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename))
-                    # Add a reference of the source zipfile into extracted_fobj, so the source
-                    # zipfile handle won't be released until all the extracted file objs are destroyed.
-                    extracted_fobj.source_ref = zips  # type: ignore[attr-defined]
-                    # typing.cast is used here to silence mypy's type checker
-                    yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
+                    yield (inner_pathname, extracted_fobj)  # type: ignore[misc]
             except Exception as e:
                 warnings.warn(
-                    "Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e))
+                    f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
                 raise e
 
     def __len__(self):
         if self.length == -1:
-            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
         return self.length