Modify inline doc for DataPipe (#64221)
authorErjia Guan <erjia@fb.com>
Tue, 31 Aug 2021 01:41:08 +0000 (18:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 01:45:46 +0000 (18:45 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64221

List of tasks in this PR
- [x]  Add inline doc for DataPipe
- [x] Improve the inline doc
- [x] Expose DataPipe to `datapipes.iter` (`UnBatcher`) Note: `Forker`, `Demux`, `Mux` are exposed in another PR authored by Kevin
- [x] Add correct typing to DataPipe
- [x] Unify the argument to `datapipe` rather than `source_datapipe`

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30650541

Pulled By: ejguan

fbshipit-source-id: c09d1b9742b8097d8e645c15947cef80c876877b

15 files changed:
torch/utils/data/datapipes/iter/__init__.py
torch/utils/data/datapipes/iter/callable.py
torch/utils/data/datapipes/iter/combinatorics.py
torch/utils/data/datapipes/iter/combining.py
torch/utils/data/datapipes/iter/filelister.py
torch/utils/data/datapipes/iter/fileloader.py
torch/utils/data/datapipes/iter/grouping.py
torch/utils/data/datapipes/iter/httpreader.py
torch/utils/data/datapipes/iter/linereader.py
torch/utils/data/datapipes/iter/routeddecoder.py
torch/utils/data/datapipes/iter/selecting.py
torch/utils/data/datapipes/iter/streamreader.py
torch/utils/data/datapipes/iter/tararchivereader.py
torch/utils/data/datapipes/iter/utils.py
torch/utils/data/datapipes/iter/ziparchivereader.py

index 8478577..b55bbf6 100644 (file)
@@ -20,6 +20,7 @@ from torch.utils.data.datapipes.iter.grouping import (
     BatcherIterDataPipe as Batcher,
     BucketBatcherIterDataPipe as BucketBatcher,
     GrouperIterDataPipe as Grouper,
+    UnBatcherIterDataPipe as UnBatcher,
 )
 from torch.utils.data.datapipes.iter.httpreader import (
     HTTPReaderIterDataPipe as HttpReader,
@@ -63,6 +64,7 @@ __all__ = ['Batcher',
            'Shuffler',
            'StreamReader',
            'TarArchiveReader',
+           'UnBatcher',
            'ZipArchiveReader',
            'Zipper']
 
index 18f6f17..2c5ca3d 100644 (file)
@@ -31,14 +31,15 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
     Iterable DataPipe to run a function over each item from the source DataPipe.
     The function can be any regular python function or partial object. Lambda
     function is not recommended as it is not supported by pickle.
-    args:
+
+    Args:
         datapipe: Source Iterable DataPipe
         fn: Function called over each item
         fn_args: Positional arguments for `fn`
         fn_kwargs: Keyword arguments for `fn`
-        nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0)
-        This also accepts -1 as input to apply the function to the lowest nesting level. It currently doesn't support
-        argument < -1.
+        nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
+            This also accepts -1 as input to apply the function to the lowest nesting level. It currently doesn't support
+            argument < -1.
     """
     datapipe: IterDataPipe
     fn: Callable
@@ -112,10 +113,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
 
     Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
     or customized Data Structure by collate_fn.
-    args:
+
+    Args:
         datapipe: Iterable DataPipe being collated
         collate_fn: Customized collate function to collect and combine data or a batch of data.
-                    Default function collates to Tensor(s) based on data type.
+            Default function collates to Tensor(s) based on data type.
         fn_args: Positional arguments for `collate_fn`
         fn_kwargs: Keyword arguments for `collate_fn`
 
index d1a7dd0..4d6fac7 100644 (file)
@@ -10,10 +10,11 @@ class SamplerIterDataPipe(IterDataPipe[T_co]):
     r""" :class:`SamplerIterDataPipe`.
 
     Iterable DataPipe to generate sample elements.
-    args:
-        datapipe: IterDataPipe sampled from
+
+    Args:
+        datapipe: IterDataPipe to sample from
         sampler: Sampler class to genereate sample elements from input DataPipe.
-                    Default is :class:`SequentialSampler` for IterDataPipe
+            Default is :class:`SequentialSampler` for IterDataPipe
     """
     datapipe: IterDataPipe
     sampler: Sampler
@@ -63,7 +64,7 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
     mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
     for each worker process.
 
-    args:
+    Args:
         datapipe: The IterDataPipe being shuffled
         buffer_size: The buffer size for shuffling (default to 10000)
         unbatch_level: Specifies if it necessary to unbatch source data before
index 4b28e09..879e8be 100644 (file)
@@ -11,7 +11,8 @@ class ConcaterIterDataPipe(IterDataPipe):
     r""" :class:`ConcaterIterDataPipe`.
 
     Iterable DataPipe to concatenate multiple Iterable DataPipes.
-    args:
+
+    Args:
         datapipes: Iterable DataPipes being concatenated
     """
     datapipes: Tuple[IterDataPipe]
@@ -97,12 +98,13 @@ class MultiplexerIterDataPipe(IterDataPipe):
 
 @functional_datapipe('zip')
 class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):
-    r""" :class:`ZipIterDataPipe`.
+    r""" :class:`ZipperIterDataPipe`.
 
     Iterable DataPipe aggregates elements into a tuple from each of
     the input DataPipe. The output DataPipe is stopped when the
     shortest input DataPipe is exhausted.
-    args:
+
+    Args:
         *datapipes: Iterable DataPipes being aggregated
     """
     datapipes: Tuple[IterDataPipe]
index 48fdce9..aef147d 100644 (file)
@@ -6,11 +6,12 @@ class FileListerIterDataPipe(IterDataPipe[str]):
     r""" :class:`FileListerIterDataPipe`
 
     Iterable DataPipe to load file pathname(s) (path + filename), yield pathname from given disk root dir.
-    args:
-        root : root dir
-        mask : a unix style filter string or string list for filtering file name(s)
-        abspath : whether to return relative pathname or absolute pathname
-        length : a nominal length of the datapipe
+
+    Args:
+        root: Root directory
+        mask: Unix style filter string or string list for filtering file name(s)
+        abspath: Whether to return relative pathname or absolute pathname
+        length: Nominal length of the datapipe
     """
 
     def __init__(
@@ -22,11 +23,11 @@ class FileListerIterDataPipe(IterDataPipe[str]):
             abspath: bool = False,
             length: int = -1):
         super().__init__()
-        self.root : str = root
-        self.masks : Union[str, List[str]] = masks
-        self.recursive : bool = recursive
-        self.abspath : bool = abspath
-        self.length : int = length
+        self.root: str = root
+        self.masks: Union[str, List[str]] = masks
+        self.recursive: bool = recursive
+        self.abspath: bool = abspath
+        self.length: int = length
 
     def __iter__(self) -> Iterator[str] :
         yield from get_file_pathnames_from_root(self.root, self.masks, self.recursive, self.abspath)
index 2b73e4e..7c048fc 100644 (file)
@@ -10,13 +10,14 @@ class FileLoaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
 
     Iterable Datapipe to load file streams from given pathnames,
     yield pathname and file stream in a tuple.
-    args:
+
+    Args:
         datapipe: Iterable datapipe that provides pathnames
         mode: An optional string that specifies the mode in which
             the file is opened by `open()`. It defaults to 'b' which
             means open for reading in binary mode. Another option is
             't' for text mode
-        length: a nominal length of the datapipe
+        length: Nominal length of the datapipe
 
     Note:
         The opened file handles will be closed by Python's GC periodly. Users can choose
index f47299c..aece256 100644 (file)
@@ -30,26 +30,27 @@ class ShardingFilterIterDataPipe(IterDataPipe):
 
 
 @functional_datapipe('batch')
-class BatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
+class BatcherIterDataPipe(IterDataPipe[DataChunk]):
     r""" :class:`BatcherIterDataPipe`.
 
     Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
     `batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the
     last batch if `drop_last` is set to `False`.
-    args:
+
+    Args:
         datapipe: Iterable DataPipe being batched
         batch_size: The size of each batch
         drop_last: Option to drop the last batch if it's not full
         unbatch_level: Specifies if it necessary to unbatch source data before
             applying new batching rule
     """
-    datapipe: IterDataPipe[T_co]
+    datapipe: IterDataPipe
     batch_size: int
     drop_last: bool
     length: Optional[int]
 
     def __init__(self,
-                 datapipe: IterDataPipe[T_co],
+                 datapipe: IterDataPipe,
                  batch_size: int,
                  drop_last: bool = False,
                  unbatch_level: int = 0,
@@ -66,8 +67,8 @@ class BatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
         self.length = None
         self.wrapper_class = DataChunk
 
-    def __iter__(self) -> Iterator[DataChunk[T_co]]:
-        batch: List[T_co] = []
+    def __iter__(self) -> Iterator[DataChunk]:
+        batch: List = []
         for x in self.datapipe:
             batch.append(x)
             if len(batch) == self.batch_size:
@@ -96,13 +97,16 @@ class UnBatcherIterDataPipe(IterDataPipe):
 
     Iterable DataPipe to undo batching of data. In other words, it flattens the data up to the specified level
     within a batched DataPipe.
-    args:
+
+    Args:
         datapipe: Iterable DataPipe being un-batched
         unbatch_level: Defaults to `1` (only flattening the top level). If set to `2`, it will flatten the top 2 levels,
-        and `-1` will flatten the entire DataPipe.
+            and `-1` will flatten the entire DataPipe.
     """
 
-    def __init__(self, datapipe, unbatch_level: int = 1):
+    def __init__(self,
+                 datapipe: IterDataPipe,
+                 unbatch_level: int = 1):
         self.datapipe = datapipe
         self.unbatch_level = unbatch_level
 
@@ -143,7 +147,8 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
     Iterable DataPipe to create mini-batches of data from sorted bucket. An outer
     dimension will be added as `batch_size` if `drop_last` is set to `True`,
     or `length % batch_size` for the last batch if `drop_last` is set to `False`.
-        args:
+
+    Args:
         datapipe: Iterable DataPipe being batched
         batch_size: The size of each batch
         drop_last: Option to drop the last batch if it's not full
@@ -224,8 +229,21 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
 
 
 @functional_datapipe('groupby')
-class GrouperIterDataPipe(IterDataPipe):
-    # TODO(VtalyFedyunin): Add inline docs and tests (they are partially available in notebooks)
+class GrouperIterDataPipe(IterDataPipe[DataChunk]):
+    r""":class:`GrouperIterDataPipe`.
+
+    Iterable datapipe to group data from input IterDataPipe by keys which are generated from `group_key_fn`,
+    and yield a DataChunk with size ranging from `guaranteed_group_size` to `group_size`.
+
+    Args:
+        datapipe: Iterable datapipe to be grouped
+        group_key_fn: Function used to generate group key from the data of the source datapipe
+        buffer_size: The size of buffer for ungrouped data
+        group_size: The size of each group
+        unbatch_level: Specifies if it necessary to unbatch source data before grouping
+        guaranteed_group_size: The guaranteed minimum group size
+        drop_remaining: Specifies if the group smaller than `guaranteed_group_size` will be dropped from buffer
+    """
     def __init__(self,
                  datapipe: IterDataPipe[T_co],
                  group_key_fn: Callable,
index c663a18..747b5d5 100644 (file)
@@ -10,16 +10,18 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
 
     Iterable DataPipe to load file url(s) (http url(s) pointing to file(s)),
     yield file url and IO stream in a tuple
-    args:
-        timeout : timeout for http request
+
+    Args:
+        datapipe: Iterable DataPipe providing urls
+        timeout: Timeout for http request
     """
 
-    def __init__(self, source_datapipe, timeout=None):
-        self.source_datapipe = source_datapipe
+    def __init__(self, datapipe, timeout=None):
+        self.datapipe = datapipe
         self.timeout = timeout
 
     def __iter__(self):
-        for furl in self.source_datapipe:
+        for furl in self.datapipe:
             try:
                 if self.timeout is None:
                     r = urllib.urlopen(furl)
index 2b15b93..04b992d 100644 (file)
@@ -7,12 +7,15 @@ class LineReaderIterDataPipe(IterDataPipe[Tuple[str, str]]):
 
     Iterable DataPipe to load file name and stream as source IterDataPipe
     and yield filename and line(s).
+
+    Args:
+        datapipe: Iterable DataPipe providing file name and string file stream
     """
 
-    def __init__(self, source_datapipe):
-        self.source_datapipe = source_datapipe
+    def __init__(self, datapipe):
+        self.datapipe = datapipe
 
     def __iter__(self):
-        for file_name, stream in self.source_datapipe:
+        for file_name, stream in self.datapipe:
             for line in stream:
                 yield file_name, line
index f149c07..ea47742 100644 (file)
@@ -6,7 +6,8 @@ from torch.utils.data.datapipes.utils.decoder import (
     Decoder,
     basichandlers as decoder_basichandlers,
     imagehandler as decoder_imagehandler,
-    extension_extract_fn)
+    extension_extract_fn
+)
 
 
 @functional_datapipe('decode')
@@ -15,7 +16,8 @@ class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
 
     Iterable datapipe to decode binary streams from input DataPipe, yield pathname
     and decoded data in a tuple.
-    args:
+
+    Args:
         datapipe: Iterable datapipe that provides pathname and binary stream in tuples
         handlers: Optional user defined decoder handlers. If None, basic and image decoder
             handlers will be set as default. If multiple handles are provided, the priority
index 83872ce..a89bfdf 100644 (file)
@@ -11,15 +11,16 @@ class FilterIterDataPipe(MapperIterDataPipe):
     r""" :class:`FilterIterDataPipe`.
 
     Iterable DataPipe to filter elements from datapipe according to filter_fn.
-    args:
+
+    Args:
         datapipe: Iterable DataPipe being filtered
         filter_fn: Customized function mapping an element to a boolean.
         fn_args: Positional arguments for `filter_fn`
         fn_kwargs: Keyword arguments for `filter_fn`
         drop_empty_batches: By default, drops batch if it is empty after filtering instead of keeping an empty list
         nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
-        This also accepts -1 as input to apply filtering to the lowest nesting level. It currently doesn't support
-        argument < -1.
+            This also accepts -1 as input to apply filtering to the lowest nesting level.
+            It currently doesn't support argument < -1.
     """
     drop_empty_batches: bool
 
index f74efe7..197fb8e 100644 (file)
@@ -7,16 +7,18 @@ class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
 
     Iterable DataPipe to load IO stream with label name,
     and to yield bytes with label name in a tuple
-    args:
-        chunk : bytes to read from stream on each iteration.
-                If None, stream reads to the EOF.
+
+    Args:
+        datapipe: Iterable DataPipe provides url and byte stream
+        chunk: Number of bytes to be read from stream per iteration.
+            If None, all bytes will be read util the EOF.
     """
-    def __init__(self, source_datapipe, chunk=None):
-        self.source_datapipe = source_datapipe
+    def __init__(self, datapipe, chunk=None):
+        self.datapipe = datapipe
         self.chunk = chunk
 
     def __iter__(self):
-        for (furl, stream) in self.source_datapipe:
+        for furl, stream in self.datapipe:
             while True:
                 d = stream.read(self.chunk)
                 if not d:
index 9145f5f..c34583a 100644 (file)
@@ -12,9 +12,11 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
 
     Iterable datapipe to extract tar binary streams from input iterable which contains tuples of
     pathname and tar binary stream, yields pathname and extracted binary stream in a tuple.
-    args:
+
+    Args:
         datapipe: Iterable datapipe that provides pathname and tar binary stream in tuples
-        mode: File mode used by `tarfile.open` to read file object. Mode has to be a string of the form 'filemode[:compression]'
+        mode: File mode used by `tarfile.open` to read file object.
+            Mode has to be a string of the form 'filemode[:compression]'
         length: a nominal length of the datapipe
 
     Note:
@@ -24,13 +26,13 @@ class TarArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     """
     def __init__(
         self,
-        datapipe : Iterable[Tuple[str, BufferedIOBase]],
+        datapipe: Iterable[Tuple[str, BufferedIOBase]],
         mode: str = "r:*",
-        length : int = -1
+        length: int = -1
     ):
         super().__init__()
         self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
-        self.mode = mode
+        self.mode: str = mode
         self.length: int = length
 
     def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
index ee04abc..9ba80e3 100644 (file)
@@ -2,6 +2,13 @@ from torch.utils.data import IterDataPipe
 
 
 class IterableWrapperIterDataPipe(IterDataPipe):
+    r""":class:`IterableWrapperIterDataPipe`.
+
+    Iterable datapipe that wraps an iterable object.
+
+    Args:
+        iterable: Iterable object to be wrapped into an IterDataPipe
+    """
     def __init__(self, iterable):
         self.iterable = iterable
 
index e98bd17..881d005 100644 (file)
@@ -13,9 +13,10 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
 
     Iterable data pipe to extract zip binary streams from input iterable which contains tuples of
     pathname and zip binary stream, yields pathname and extracted binary stream in a tuple.
-    args:
+
+    Args:
         datapipe: Iterable datapipe that provides pathname and zip binary stream in tuples
-        length: a nominal length of the datapipe
+        length: Nominal length of the datapipe
 
     Note:
         The opened file handles will be closed automatically if the default DecoderDataPipe
@@ -24,12 +25,11 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
     """
     def __init__(
             self,
-            datapipe : Iterable[Tuple[str, BufferedIOBase]],
-            length : int = -1):
+            datapipe: Iterable[Tuple[str, BufferedIOBase]],
+            length: int = -1):
         super().__init__()
-        self.datapipe : Iterable[Tuple[str, BufferedIOBase]] = datapipe
-        self.length : int = length
-
+        self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
+        self.length: int = length
 
     def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]:
         if not isinstance(self.datapipe, Iterable):
@@ -60,7 +60,6 @@ class ZipArchiveReaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]):
                     "Unable to extract files from corrupted zipfile stream {} due to: {}, abort!".format(pathname, e))
                 raise e
 
-
     def __len__(self):
         if self.length == -1:
             raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))