[DataPipe] Make TarArchiveReader and ZipArchiveReader accepts FileSream with attempt...
authorKevin Tse <ktse@fb.com>
Wed, 15 Sep 2021 14:32:45 +0000 (07:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 14:34:29 +0000 (07:34 -0700)
Summary:
ghstack is not working for the second commit so I'm manually creating this PR for now. Please only look at changes related to the second commit in this PR (there is a PR for the first commit).

This PR removes TarArchiveReader's dependency on FileLoader DataPipe, by allowing it to use a IterDataPipe of path names as input rather than a tuple of path name and a stream.

It also adds additional tests to ensure that the DataPipe is functioning properly when it is read multiple times or reset half way through reading.

The whole stack fixes https://github.com/pytorch/pytorch/issues/64281 - issues related to unclosed buffer stream.

Stack:
* __->__ https://github.com/pytorch/pytorch/issues/64788
* https://github.com/pytorch/pytorch/issues/64786

cc VitalyFedyunin ejguan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64788

Reviewed By: jbschlosser, ejguan

Differential Revision: D30901176

Pulled By: NivekT

fbshipit-source-id: 59746a8d0144fc6d3ce0feb2d76445b82e6d414e

test/test_datapipe.py
torch/utils/data/datapipes/iter/tararchivereader.py
torch/utils/data/datapipes/iter/ziparchivereader.py
torch/utils/data/datapipes/utils/common.py

index f19a2e0..b0e1d06 100644 (file)
@@ -222,7 +222,6 @@ class TestIterableDataPipeBasic(TestCase):
                 rec[1].close()
         self.assertEqual(count, len(self.temp_files))
 
-    # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
     def test_readfilesfromtar_iterable_datapipe(self):
         temp_dir = self.temp_dir.name
         temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
@@ -233,14 +232,17 @@ class TestIterableDataPipeBasic(TestCase):
         datapipe1 = dp.iter.FileLister(temp_dir, '*.tar')
         datapipe2 = dp.iter.FileLoader(datapipe1)
         datapipe3 = dp.iter.TarArchiveReader(datapipe2)
-        # read extracted files before reaching the end of the tarfile
+
+        # Test Case: Read extracted files before reaching the end of the tarfile
         for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files):
             self.assertTrue(rec is not None and temp_file is not None)
             self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
             with open(temp_file, 'rb') as f:
                 self.assertEqual(rec[1].read(), f.read())
             rec[1].close()
-        # read extracted files after reaching the end of the tarfile
+
+
+        # Test Case: Read extracted files after reaching the end of the tarfile
         data_refs = list(datapipe3)
         self.assertEqual(len(data_refs), len(self.temp_files))
         for data_ref, temp_file in zip(data_refs, self.temp_files):
@@ -249,7 +251,26 @@ class TestIterableDataPipeBasic(TestCase):
                 self.assertEqual(data_ref[1].read(), f.read())
             data_ref[1].close()
 
+        # Test Case: reset the DataPipe after reading part of it
+        n_elements_before_reset = 1
+        res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe3, n_elements_before_reset)
+        # Check result accumulated before reset
+        self.assertEqual(len(res_before_reset), n_elements_before_reset)
+        for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files):
+            self.assertEqual(os.path.basename(ele_before_reset[0]), os.path.basename(temp_file))
+            with open(temp_file, 'rb') as f:
+                self.assertEqual(ele_before_reset[1].read(), f.read())
+            ele_before_reset[1].close()
+        # Check result accumulated after reset
+        self.assertEqual(len(res_after_reset), len(self.temp_files))
+        for ele_after_reset, temp_file in zip(res_after_reset, self.temp_files):
+            self.assertEqual(os.path.basename(ele_after_reset[0]), os.path.basename(temp_file))
+            with open(temp_file, 'rb') as f:
+                self.assertEqual(ele_after_reset[1].read(), f.read())
+            ele_after_reset[1].close()
 
+    # This test throws a warning because data_stream in side ZipArchiveReader cannot be closed
+    # due to the way zipfiles.open() is implemented
     def test_readfilesfromzip_iterable_datapipe(self):
         temp_dir = self.temp_dir.name
         temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip")
@@ -258,17 +279,18 @@ class TestIterableDataPipeBasic(TestCase):
             myzip.write(self.temp_files[1])
             myzip.write(self.temp_files[2])
         datapipe1 = dp.iter.FileLister(temp_dir, '*.zip')
-        datapipe2 = dp.iter.ZipArchiveReader(datapipe1)
+        datapipe2 = dp.iter.FileLoader(datapipe1)
+        datapipe3 = dp.iter.ZipArchiveReader(datapipe2)
 
         # Test Case: read extracted files before reaching the end of the zipfile
-        for rec, temp_file in itertools.zip_longest(datapipe2, self.temp_files):
+        for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files):
             self.assertTrue(rec is not None and temp_file is not None)
             self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
             with open(temp_file, 'rb') as f:
                 self.assertEqual(rec[1].read(), f.read())
             rec[1].close()
         # Test Case: read extracted files after reaching the end of the zipile
-        data_refs = list(datapipe2)
+        data_refs = list(datapipe3)
         self.assertEqual(len(data_refs), len(self.temp_files))
         for data_ref, temp_file in zip(data_refs, self.temp_files):
             self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file))
@@ -278,7 +300,7 @@ class TestIterableDataPipeBasic(TestCase):
 
         # Test Case: reset the DataPipe after reading part of it
         n_elements_before_reset = 1
-        res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe2, n_elements_before_reset)
+        res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe3, n_elements_before_reset)
         # Check the results accumulated before reset
         self.assertEqual(len(res_before_reset), n_elements_before_reset)
         for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files):
@@ -334,7 +356,7 @@ class TestIterableDataPipeBasic(TestCase):
         datapipe4.add_handler(_png_decoder)
         _helper(cached, datapipe4, channel_first=True)
 
-    # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
+
     def test_groupby_iterable_datapipe(self):
         temp_dir = self.temp_dir.name
         temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
index c34583a..95c3331 100644 (file)
@@ -10,11 +10,11 @@ import warnings
 class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     r""" :class:`TarArchiveReaderIterDataPipe`.
 
-    Iterable datapipe to extract tar binary streams from input iterable which contains tuples of
-    pathname and tar binary stream, yields pathname and extracted binary stream in a tuple.
+    Iterable datapipe to extract tar binary streams from input iterable which contains tuples of pathnames and
+    tar binary stream. This yields a tuple of pathname and extracted binary stream.
 
     Args:
-        datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples
+        datapipe: Iterable datapipe that provides tuples of pathname and tar binary stream
         mode: File mode used by `tarfile.open` to read file object.
             Mode has to be a string of the form 'filemode[:compression]'
         length: a nominal length of the datapipe
@@ -22,7 +22,7 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     Note:
         The opened file handles will be closed automatically if the default DecoderDataPipe
         is attached. Otherwise, user should be responsible to close file handles explicitly
-        or let Python's GC close them periodly.
+        or let Python's GC close them periodically.
     """
     def __init__(
         self,
@@ -36,8 +36,6 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
         self.length: int = length
 
     def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
-        if not isinstance(self.datapipe, Iterable):
-            raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
         for data in self.datapipe:
             validate_pathname_binary_tuple(data)
             pathname, data_stream = data
@@ -52,15 +50,13 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
                         warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
                         raise tarfile.ExtractError
                     inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
-                    # Add a reference of the source tarfile into extracted_fobj, so the source
-                    # tarfile handle won't be released until all the extracted file objs are destroyed.
-                    extracted_fobj.source_ref = tar  # type: ignore[attr-defined]
-                    # typing.cast is used here to silence mypy's type checker
-                    yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
+                    yield (inner_pathname, extracted_fobj)  # type: ignore[misc]
             except Exception as e:
                 warnings.warn(
                     "Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
                 raise e
+            finally:
+                data_stream.close()
 
     def __len__(self):
         if self.length == -1:
index b0ac4a0..66af28e 100644 (file)
@@ -1,5 +1,6 @@
 from torch.utils.data import IterDataPipe
-from typing import Iterable, Iterator, Tuple
+from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
+from typing import Iterable, Iterator, Tuple, IO, cast
 from io import BufferedIOBase
 
 import os
@@ -10,32 +11,34 @@ import warnings
 class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     r""" :class:`ZipArchiveReaderIterDataPipe`.
 
-    Iterable data pipe to extract zip binary streams from input iterable which contains
-    pathnames, yields a tuple of pathname and extracted binary stream.
+    Iterable data pipe to extract zip binary streams from input iterable which contains a tuple of pathname and
+    zip binary stream. This yields a tuple of pathname and extracted binary stream.
 
     Args:
-        datapipe: Iterable datapipe that provides pathnames
+        datapipe: Iterable datapipe that provides tuples of pathname and zip binary stream
         length: Nominal length of the datapipe
 
     Note:
         The opened file handles will be closed automatically if the default DecoderDataPipe
         is attached. Otherwise, user should be responsible to close file handles explicitly
-        or let Python's GC close them periodly.
+        or let Python's GC close them periodically. Due to how zipfiles implements its open() method,
+        the data_stream variable below cannot be closed within the scope of this function.
     """
     def __init__(
             self,
-            datapipe: Iterable[str],
+            datapipe: Iterable[Tuple[str, BufferedIOBase]],
             length: int = -1):
         super().__init__()
-        self.datapipe: Iterable[str] = datapipe
+        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
         self.length: int = length
 
     def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
-        for pathname in self.datapipe:
-            if not isinstance(pathname, str):
-                raise TypeError(f"pathname should be of string type, but is type {type(pathname)}")
+        for data in self.datapipe:
+            validate_pathname_binary_tuple(data)
+            pathname, data_stream = data
             try:
-                zips = zipfile.ZipFile(pathname)
+                # typing.cast is used here to silence mypy's type checker
+                zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
                 for zipinfo in zips.infolist():
                     # major version should always be 3 here.
                     if sys.version_info[1] >= 6:
@@ -50,6 +53,7 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
                 warnings.warn(
                     f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
                 raise e
+            # We are unable to close 'data_stream' here, because it needs to be available to use later
 
     def __len__(self):
         if self.length == -1:
index ba3f6cb..bdb5a26 100644 (file)
@@ -54,7 +54,6 @@ def get_file_binaries_from_pathnames(pathnames: Iterable, mode: str):
                             .format(type(pathname)))
         yield (pathname, open(pathname, mode))
 
-
 def validate_pathname_binary_tuple(data):
     if not isinstance(data, tuple):
         raise TypeError("pathname binary data should be tuple type, but got {}".format(type(data)))