rec[1].close()
self.assertEqual(count, len(self.temp_files))
- # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
def test_readfilesfromtar_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
datapipe1 = dp.iter.FileLister(temp_dir, '*.tar')
datapipe2 = dp.iter.FileLoader(datapipe1)
datapipe3 = dp.iter.TarArchiveReader(datapipe2)
- # read extracted files before reaching the end of the tarfile
+
+ # Test Case: Read extracted files before reaching the end of the tarfile
for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files):
self.assertTrue(rec is not None and temp_file is not None)
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
with open(temp_file, 'rb') as f:
self.assertEqual(rec[1].read(), f.read())
rec[1].close()
- # read extracted files after reaching the end of the tarfile
+
+
+ # Test Case: Read extracted files after reaching the end of the tarfile
data_refs = list(datapipe3)
self.assertEqual(len(data_refs), len(self.temp_files))
for data_ref, temp_file in zip(data_refs, self.temp_files):
self.assertEqual(data_ref[1].read(), f.read())
data_ref[1].close()
+ # Test Case: reset the DataPipe after reading part of it
+ n_elements_before_reset = 1
+ res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe3, n_elements_before_reset)
+ # Check result accumulated before reset
+ self.assertEqual(len(res_before_reset), n_elements_before_reset)
+ for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files):
+ self.assertEqual(os.path.basename(ele_before_reset[0]), os.path.basename(temp_file))
+ with open(temp_file, 'rb') as f:
+ self.assertEqual(ele_before_reset[1].read(), f.read())
+ ele_before_reset[1].close()
+ # Check result accumulated after reset
+ self.assertEqual(len(res_after_reset), len(self.temp_files))
+ for ele_after_reset, temp_file in zip(res_after_reset, self.temp_files):
+ self.assertEqual(os.path.basename(ele_after_reset[0]), os.path.basename(temp_file))
+ with open(temp_file, 'rb') as f:
+ self.assertEqual(ele_after_reset[1].read(), f.read())
+ ele_after_reset[1].close()
+ # This test throws a warning because data_stream in side ZipArchiveReader cannot be closed
+ # due to the way zipfiles.open() is implemented
def test_readfilesfromzip_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip")
myzip.write(self.temp_files[1])
myzip.write(self.temp_files[2])
datapipe1 = dp.iter.FileLister(temp_dir, '*.zip')
- datapipe2 = dp.iter.ZipArchiveReader(datapipe1)
+ datapipe2 = dp.iter.FileLoader(datapipe1)
+ datapipe3 = dp.iter.ZipArchiveReader(datapipe2)
# Test Case: read extracted files before reaching the end of the zipfile
- for rec, temp_file in itertools.zip_longest(datapipe2, self.temp_files):
+ for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files):
self.assertTrue(rec is not None and temp_file is not None)
self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file))
with open(temp_file, 'rb') as f:
self.assertEqual(rec[1].read(), f.read())
rec[1].close()
# Test Case: read extracted files after reaching the end of the zipile
- data_refs = list(datapipe2)
+ data_refs = list(datapipe3)
self.assertEqual(len(data_refs), len(self.temp_files))
for data_ref, temp_file in zip(data_refs, self.temp_files):
self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file))
# Test Case: reset the DataPipe after reading part of it
n_elements_before_reset = 1
- res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe2, n_elements_before_reset)
+ res_before_reset, res_after_reset = reset_after_n_next_calls(datapipe3, n_elements_before_reset)
# Check the results accumulated before reset
self.assertEqual(len(res_before_reset), n_elements_before_reset)
for ele_before_reset, temp_file in zip(res_before_reset, self.temp_files):
datapipe4.add_handler(_png_decoder)
_helper(cached, datapipe4, channel_first=True)
- # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
+
def test_groupby_iterable_datapipe(self):
temp_dir = self.temp_dir.name
temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
r""" :class:`TarArchiveReaderIterDataPipe`.
- Iterable datapipe to extract tar binary streams from input iterable which contains tuples of
- pathname and tar binary stream, yields pathname and extracted binary stream in a tuple.
+ Iterable datapipe to extract tar binary streams from input iterable which contains tuples of pathnames and
+ tar binary stream. This yields a tuple of pathname and extracted binary stream.
Args:
- datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples
+ datapipe: Iterable datapipe that provides tuples of pathname and tar binary stream
mode: File mode used by `tarfile.open` to read file object.
Mode has to be a string of the form 'filemode[:compression]'
length: a nominal length of the datapipe
Note:
The opened file handles will be closed automatically if the default DecoderDataPipe
is attached. Otherwise, user should be responsible to close file handles explicitly
- or let Python's GC close them periodly.
+ or let Python's GC close them periodically.
"""
def __init__(
self,
self.length: int = length
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
- if not isinstance(self.datapipe, Iterable):
- raise TypeError("datapipe must be Iterable type but got {}".format(type(self.datapipe)))
for data in self.datapipe:
validate_pathname_binary_tuple(data)
pathname, data_stream = data
warnings.warn("failed to extract file {} from source tarfile {}".format(tarinfo.name, pathname))
raise tarfile.ExtractError
inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name))
- # Add a reference of the source tarfile into extracted_fobj, so the source
- # tarfile handle won't be released until all the extracted file objs are destroyed.
- extracted_fobj.source_ref = tar # type: ignore[attr-defined]
- # typing.cast is used here to silence mypy's type checker
- yield (inner_pathname, cast(BufferedIOBase, extracted_fobj))
+ yield (inner_pathname, extracted_fobj) # type: ignore[misc]
except Exception as e:
warnings.warn(
"Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
raise e
+ finally:
+ data_stream.close()
def __len__(self):
if self.length == -1:
from torch.utils.data import IterDataPipe
-from typing import Iterable, Iterator, Tuple
+from torch.utils.data.datapipes.utils.common import validate_pathname_binary_tuple
+from typing import Iterable, Iterator, Tuple, IO, cast
from io import BufferedIOBase
import os
class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
r""" :class:`ZipArchiveReaderIterDataPipe`.
- Iterable data pipe to extract zip binary streams from input iterable which contains
- pathnames, yields a tuple of pathname and extracted binary stream.
+ Iterable data pipe to extract zip binary streams from input iterable which contains a tuple of pathname and
+ zip binary stream. This yields a tuple of pathname and extracted binary stream.
Args:
- datapipe: Iterable datapipe that provides pathnames
+ datapipe: Iterable datapipe that provides tuples of pathname and zip binary stream
length: Nominal length of the datapipe
Note:
The opened file handles will be closed automatically if the default DecoderDataPipe
is attached. Otherwise, user should be responsible to close file handles explicitly
- or let Python's GC close them periodly.
+ or let Python's GC close them periodically. Due to how zipfiles implements its open() method,
+ the data_stream variable below cannot be closed within the scope of this function.
"""
def __init__(
self,
- datapipe: Iterable[str],
+ datapipe: Iterable[Tuple[str, BufferedIOBase]],
length: int = -1):
super().__init__()
- self.datapipe: Iterable[str] = datapipe
+ self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
self.length: int = length
def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
- for pathname in self.datapipe:
- if not isinstance(pathname, str):
- raise TypeError(f"pathname should be of string type, but is type {type(pathname)}")
+ for data in self.datapipe:
+ validate_pathname_binary_tuple(data)
+ pathname, data_stream = data
try:
- zips = zipfile.ZipFile(pathname)
+ # typing.cast is used here to silence mypy's type checker
+ zips = zipfile.ZipFile(cast(IO[bytes], data_stream))
for zipinfo in zips.infolist():
# major version should always be 3 here.
if sys.version_info[1] >= 6:
warnings.warn(
f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!")
raise e
+ # We are unable to close 'data_stream' here, because it needs to be available to use later
def __len__(self):
if self.length == -1: