Replace group_by_key by group_by IterDataPipe (#64220)
authorErjia Guan <erjia@fb.com>
Tue, 31 Aug 2021 01:41:08 +0000 (18:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 01:45:44 +0000 (18:45 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64220

Remove `ByKeyGrouperIterDataPipe` due to duplicated functionality.
Fix a bug in `GrouperIterDataPipe` using the existing test.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30650542

Pulled By: ejguan

fbshipit-source-id: 666b4d28282fb4f49f3ff101b8d08be16a50d836

test/test_datapipe.py
torch/utils/data/datapipes/iter/__init__.py
torch/utils/data/datapipes/iter/grouping.py

index 86e53fa..c35698e 100644 (file)
@@ -299,7 +299,7 @@ class TestIterableDataPipeBasic(TestCase):
         _helper(cached, datapipe4, channel_first=True)
 
     # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate
-    def test_groupbykey_iterable_datapipe(self):
+    def test_groupby_iterable_datapipe(self):
         temp_dir = self.temp_dir.name
         temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar")
         file_list = [
@@ -316,13 +316,25 @@ class TestIterableDataPipeBasic(TestCase):
         datapipe1 = dp.iter.FileLister(temp_dir, '*.tar')
         datapipe2 = dp.iter.FileLoader(datapipe1)
         datapipe3 = dp.iter.TarArchiveReader(datapipe2)
-        datapipe4 = dp.iter.ByKeyGrouper(datapipe3, group_size=2)
 
-        expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), (
-            "f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")]
+        def group_fn(data):
+            filepath, _ = data
+            return os.path.basename(filepath).split(".")[0]
+
+        datapipe4 = dp.iter.Grouper(datapipe3, group_key_fn=group_fn, group_size=2)
+
+        def order_fn(data):
+            data.sort(key=lambda f: f[0], reverse=True)
+            return data
+
+        datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn)  # type: ignore[var-annotated]
+
+        expected_result = [
+            ("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"),
+            ("f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.txt", "h.json")]
 
         count = 0
-        for rec, expected in zip(datapipe4, expected_result):
+        for rec, expected in zip(datapipe5, expected_result):
             count = count + 1
             self.assertEqual(os.path.basename(rec[0][0]), expected[0])
             self.assertEqual(os.path.basename(rec[1][0]), expected[1])
index f302fd3..8478577 100644 (file)
@@ -19,7 +19,7 @@ from torch.utils.data.datapipes.iter.fileloader import (
 from torch.utils.data.datapipes.iter.grouping import (
     BatcherIterDataPipe as Batcher,
     BucketBatcherIterDataPipe as BucketBatcher,
-    ByKeyGrouperIterDataPipe as ByKeyGrouper,
+    GrouperIterDataPipe as Grouper,
 )
 from torch.utils.data.datapipes.iter.httpreader import (
     HTTPReaderIterDataPipe as HttpReader,
@@ -48,12 +48,12 @@ from torch.utils.data.datapipes.iter.utils import (
 
 __all__ = ['Batcher',
            'BucketBatcher',
-           'ByKeyGrouper',
            'Collator',
            'Concater',
            'FileLister',
            'FileLoader',
            'Filter',
+           'Grouper',
            'HttpReader',
            'IterableWrapper',
            'LineReader',
index 5f44948..f47299c 100644 (file)
@@ -1,12 +1,10 @@
-import functools
-import os
 import random
 import warnings
 
 from collections import defaultdict
 
 from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
-from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, Tuple, TypeVar, DefaultDict
+from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
 
 T_co = TypeVar('T_co', covariant=True)
 
@@ -225,35 +223,6 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
         raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
 
 
-# defaut group key is the file pathname without the extension.
-# Assuming the passed in data is a tuple and 1st item is file's pathname.
-def default_group_key_fn(dataitem: Tuple[str, Any]):
-    return os.path.splitext(dataitem[0])[0]
-
-
-def default_sort_data_fn(datalist: List[Tuple[str, Any]]):
-    txt_ext = ['.json', '.jsn', '.txt', '.text']
-
-    def cmp_fn(a: Tuple[str, Any], b: Tuple[str, Any]):
-        a_is_txt = os.path.splitext(a[0])[1] in txt_ext
-        b_is_txt = os.path.splitext(b[0])[1] in txt_ext
-
-        # if a is txt but b is not, b go front
-        if a_is_txt and not b_is_txt:
-            return 1
-        # if a is not txt but b is txt, a go front
-        if not a_is_txt and b_is_txt:
-            return -1
-        # if a and b both are or are not txt, sort in alphabetic order
-        if a[0] < b[0]:
-            return -1
-        elif a[0] > b[0]:
-            return 1
-        return 0
-
-    return sorted(datalist, key=functools.cmp_to_key(cmp_fn))
-
-
 @functional_datapipe('groupby')
 class GrouperIterDataPipe(IterDataPipe):
     # TODO(VtalyFedyunin): Add inline docs and tests (they are partially available in notebooks)
@@ -309,6 +278,9 @@ class GrouperIterDataPipe(IterDataPipe):
         for x in self.datapipe:
             key = self.group_key_fn(x)
 
+            buffer_elements[key].append(x)
+            buffer_size += 1
+
             if self.group_size is not None and self.group_size == len(buffer_elements[key]):
                 yield self.wrapper_class(buffer_elements[key])
                 buffer_size -= len(buffer_elements[key])
@@ -319,92 +291,7 @@ class GrouperIterDataPipe(IterDataPipe):
                 if result_to_yield is not None:
                     yield self.wrapper_class(result_to_yield)
 
-            buffer_elements[key].append(x)
-            buffer_size += 1
-
         while buffer_size:
             (result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
             if result_to_yield is not None:
                 yield self.wrapper_class(result_to_yield)
-
-
-@functional_datapipe('group_by_key')
-class ByKeyGrouperIterDataPipe(IterDataPipe[list]):
-    r""" :class:`GroupByKeyIterDataPipe`.
-
-    Iterable datapipe to group data from input iterable by keys which are generated from `group_key_fn`,
-    yields a list with `group_size` items in it, each item in the list is a tuple of key and data
-
-    args:
-        datapipe: Iterable datapipe that provides data. (typically str key (eg. pathname) and data stream in tuples)
-        group_size: the size of group
-        max_buffer_size: the max size of stream buffer which is used to store not yet grouped but iterated data
-        group_key_fn: a function which is used to generate group key from the data in the input datapipe
-        sort_data_fn: a function which is used to sort the grouped data before yielding back
-        length: a nominal length of the datapipe
-    """
-    datapipe: IterDataPipe[Tuple[str, Any]]
-    group_size: int
-    max_buffer_size: int
-    group_key_fn: Callable
-    sort_data_fn: Callable
-    curr_buffer_size: int
-    stream_buffer: Dict[str, List[Tuple[str, Any]]]
-    length: int
-
-    def __init__(
-            self,
-            datapipe: IterDataPipe[Tuple[str, Any]],
-            *,
-            group_size: int,
-            max_buffer_size: Optional[int] = None,
-            group_key_fn: Callable = default_group_key_fn,
-            sort_data_fn: Callable = default_sort_data_fn,
-            length: int = -1):
-        super().__init__()
-
-        assert group_size > 0
-        self.datapipe = datapipe
-        self.group_size = group_size
-
-        # default max buffer size is group_size * 10
-        self.max_buffer_size = max_buffer_size if max_buffer_size is not None else group_size * 10
-        assert self.max_buffer_size >= self.group_size
-
-        self.group_key_fn = group_key_fn  # type: ignore[assignment]
-        self.sort_data_fn = sort_data_fn  # type: ignore[assignment]
-        self.curr_buffer_size = 0
-        self.stream_buffer = {}
-        self.length = length
-
-    def __iter__(self) -> Iterator[list]:
-        if self.group_size == 1:
-            for data in self.datapipe:
-                yield [data]
-        else:
-            for data in self.datapipe:
-                key = self.group_key_fn(data)
-                if key not in self.stream_buffer:
-                    self.stream_buffer[key] = []
-                res = self.stream_buffer[key]
-                res.append(data)
-                if len(res) == self.group_size:
-                    yield self.sort_data_fn(res)
-                    del self.stream_buffer[key]
-                    self.curr_buffer_size = self.curr_buffer_size - self.group_size + 1
-                else:
-                    if self.curr_buffer_size == self.max_buffer_size:
-                        raise OverflowError(
-                            "stream_buffer is overflow, please adjust the order of data "
-                            "in the input datapipe or increase the buffer size!")
-                    self.curr_buffer_size = self.curr_buffer_size + 1
-
-            if self.curr_buffer_size > 0:
-                msg = "Not able to group [{}] with group size {}.".format(
-                    ','.join([v[0] for _, vs in self.stream_buffer.items() for v in vs]), str(self.group_size))
-                raise RuntimeError(msg)
-
-    def __len__(self) -> int:
-        if self.length == -1:
-            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
-        return self.length