Add mode to TarArchiveReader (#63332)
authorErjia Guan <erjia@fb.com>
Tue, 17 Aug 2021 14:26:08 +0000 (07:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 14:28:37 +0000 (07:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63332

Add a corresponding PR from [torchdata](https://github.com/facebookexternal/torchdata/pull/101)

Test Plan: Imported from OSS

Reviewed By: astaff

Differential Revision: D30350151

Pulled By: ejguan

fbshipit-source-id: bced4a1ee1ce89d4e91e678327342e1c095dbb9e

torch/utils/data/datapipes/iter/readfilesfromtar.py

index 58f3ec5..f456602 100644 (file)
@@ -8,12 +8,13 @@ import tarfile
 import warnings
 
 class ReadFilesFromTarIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
-    r""" :class:`ReadFilesFromTarIDP`.
+    r""":class:`ReadFilesFromTarIterDataPipe`.
 
     Iterable datapipe to extract tar binary streams from input iterable which contains tuples of
     pathname and tar binary stream, yields pathname and extracted binary stream in a tuple.
     args:
         datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples
+        mode: File mode used by `tarfile.open` to read file object. Mode has to be a string of the form 'filemode[:compression]'
         length: a nominal length of the datapipe
 
     Note:
@@ -22,13 +23,15 @@ class ReadFilesFromTarIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
         or let Python's GC close them periodly.
     """
     def __init__(
-            self,
-            datapipe : Iterable[Tuple[str, BufferedIOBase]],
-            length : int = -1):
+        self,
+        datapipe : Iterable[Tuple[str, BufferedIOBase]],
+        mode: str = "r:*",
+        length : int = -1
+    ):
         super().__init__()
-        self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
-        self.length : int = length
-
+        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
+        self.mode = mode
+        self.length: int = length
 
     def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
         if not isinstance(self.datapipe, Iterable):
@@ -38,7 +41,7 @@ class ReadFilesFromTarIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
             pathname, data_stream = data
             try:
                 # typing.cast is used here to silence mypy's type checker
-                tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode="r:*")
+                tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=self.mode)
                 for tarinfo in tar:
                     if not tarinfo.isfile():
                         continue
@@ -57,7 +60,6 @@ class ReadFilesFromTarIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
                     "Unable to extract files from corrupted tarfile stream {} due to: {}, abort!".format(pathname, e))
                 raise e
 
-
     def __len__(self):
         if self.length == -1:
             raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))