From af85bc5ffd1d4ad52e0fed255aa7afe2fdfbc5e2 Mon Sep 17 00:00:00 2001 From: Erjia Guan Date: Mon, 30 Aug 2021 18:41:08 -0700 Subject: [PATCH] Replace group_by_key by group_by IterDataPipe (#64220) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64220 Remove `ByKeyGrouperIterDataPipe` due to duplicated functionality. Fix a bug in `GrouperIterDataPipe` using the existing test. Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D30650542 Pulled By: ejguan fbshipit-source-id: 666b4d28282fb4f49f3ff101b8d08be16a50d836 --- test/test_datapipe.py | 22 +++-- torch/utils/data/datapipes/iter/__init__.py | 4 +- torch/utils/data/datapipes/iter/grouping.py | 121 +--------------------------- 3 files changed, 23 insertions(+), 124 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 86e53fa..c35698e 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -299,7 +299,7 @@ class TestIterableDataPipeBasic(TestCase): _helper(cached, datapipe4, channel_first=True) # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate - def test_groupbykey_iterable_datapipe(self): + def test_groupby_iterable_datapipe(self): temp_dir = self.temp_dir.name temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar") file_list = [ @@ -316,13 +316,25 @@ class TestIterableDataPipeBasic(TestCase): datapipe1 = dp.iter.FileLister(temp_dir, '*.tar') datapipe2 = dp.iter.FileLoader(datapipe1) datapipe3 = dp.iter.TarArchiveReader(datapipe2) - datapipe4 = dp.iter.ByKeyGrouper(datapipe3, group_size=2) - expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), ( - "f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")] + def group_fn(data): + filepath, _ = data + return os.path.basename(filepath).split(".")[0] + + datapipe4 = dp.iter.Grouper(datapipe3, group_key_fn=group_fn, group_size=2) + + def order_fn(data): + data.sort(key=lambda f: f[0], reverse=True) + return data + + datapipe5 = dp.iter.Mapper(datapipe4, fn=order_fn) # type: ignore[var-annotated] + + expected_result = [ + ("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), + ("f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.txt", "h.json")] count = 0 - for rec, expected in zip(datapipe4, expected_result): + for rec, expected in zip(datapipe5, expected_result): count = count + 1 self.assertEqual(os.path.basename(rec[0][0]), expected[0]) self.assertEqual(os.path.basename(rec[1][0]), expected[1]) diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py index f302fd3..8478577 100644 --- a/torch/utils/data/datapipes/iter/__init__.py +++ b/torch/utils/data/datapipes/iter/__init__.py @@ -19,7 +19,7 @@ from torch.utils.data.datapipes.iter.fileloader import ( from torch.utils.data.datapipes.iter.grouping import ( BatcherIterDataPipe as Batcher, BucketBatcherIterDataPipe as BucketBatcher, - ByKeyGrouperIterDataPipe as ByKeyGrouper, + GrouperIterDataPipe as Grouper, ) from torch.utils.data.datapipes.iter.httpreader import ( HTTPReaderIterDataPipe as HttpReader, @@ -48,12 +48,12 @@ from torch.utils.data.datapipes.iter.utils import ( __all__ = ['Batcher', 'BucketBatcher', - 'ByKeyGrouper', 'Collator', 'Concater', 'FileLister', 'FileLoader', 'Filter', + 'Grouper', 'HttpReader', 'IterableWrapper', 'LineReader', diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 5f44948..f47299c 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,12 +1,10 @@ -import functools -import os import random import warnings from collections import defaultdict from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk -from typing import Any, Callable, Dict, Iterator, List, Optional, Sized, Tuple, TypeVar, DefaultDict +from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar T_co = TypeVar('T_co', covariant=True) @@ -225,35 +223,6 @@ class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) -# defaut group key is the file pathname without the extension. -# Assuming the passed in data is a tuple and 1st item is file's pathname. -def default_group_key_fn(dataitem: Tuple[str, Any]): - return os.path.splitext(dataitem[0])[0] - - -def default_sort_data_fn(datalist: List[Tuple[str, Any]]): - txt_ext = ['.json', '.jsn', '.txt', '.text'] - - def cmp_fn(a: Tuple[str, Any], b: Tuple[str, Any]): - a_is_txt = os.path.splitext(a[0])[1] in txt_ext - b_is_txt = os.path.splitext(b[0])[1] in txt_ext - - # if a is txt but b is not, b go front - if a_is_txt and not b_is_txt: - return 1 - # if a is not txt but b is txt, a go front - if not a_is_txt and b_is_txt: - return -1 - # if a and b both are or are not txt, sort in alphabetic order - if a[0] < b[0]: - return -1 - elif a[0] > b[0]: - return 1 - return 0 - - return sorted(datalist, key=functools.cmp_to_key(cmp_fn)) - - @functional_datapipe('groupby') class GrouperIterDataPipe(IterDataPipe): # TODO(VtalyFedyunin): Add inline docs and tests (they are partially available in notebooks) @@ -309,6 +278,9 @@ class GrouperIterDataPipe(IterDataPipe): for x in self.datapipe: key = self.group_key_fn(x) + buffer_elements[key].append(x) + buffer_size += 1 + if self.group_size is not None and self.group_size == len(buffer_elements[key]): yield self.wrapper_class(buffer_elements[key]) buffer_size -= len(buffer_elements[key]) @@ -319,92 +291,7 @@ class GrouperIterDataPipe(IterDataPipe): if result_to_yield is not None: yield self.wrapper_class(result_to_yield) - buffer_elements[key].append(x) - buffer_size += 1 - while buffer_size: (result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size) if result_to_yield is not None: yield self.wrapper_class(result_to_yield) - - -@functional_datapipe('group_by_key') -class ByKeyGrouperIterDataPipe(IterDataPipe[list]): - r""" :class:`GroupByKeyIterDataPipe`. - - Iterable datapipe to group data from input iterable by keys which are generated from `group_key_fn`, - yields a list with `group_size` items in it, each item in the list is a tuple of key and data - - args: - datapipe: Iterable datapipe that provides data. (typically str key (eg. pathname) and data stream in tuples) - group_size: the size of group - max_buffer_size: the max size of stream buffer which is used to store not yet grouped but iterated data - group_key_fn: a function which is used to generate group key from the data in the input datapipe - sort_data_fn: a function which is used to sort the grouped data before yielding back - length: a nominal length of the datapipe - """ - datapipe: IterDataPipe[Tuple[str, Any]] - group_size: int - max_buffer_size: int - group_key_fn: Callable - sort_data_fn: Callable - curr_buffer_size: int - stream_buffer: Dict[str, List[Tuple[str, Any]]] - length: int - - def __init__( - self, - datapipe: IterDataPipe[Tuple[str, Any]], - *, - group_size: int, - max_buffer_size: Optional[int] = None, - group_key_fn: Callable = default_group_key_fn, - sort_data_fn: Callable = default_sort_data_fn, - length: int = -1): - super().__init__() - - assert group_size > 0 - self.datapipe = datapipe - self.group_size = group_size - - # default max buffer size is group_size * 10 - self.max_buffer_size = max_buffer_size if max_buffer_size is not None else group_size * 10 - assert self.max_buffer_size >= self.group_size - - self.group_key_fn = group_key_fn # type: ignore[assignment] - self.sort_data_fn = sort_data_fn # type: ignore[assignment] - self.curr_buffer_size = 0 - self.stream_buffer = {} - self.length = length - - def __iter__(self) -> Iterator[list]: - if self.group_size == 1: - for data in self.datapipe: - yield [data] - else: - for data in self.datapipe: - key = self.group_key_fn(data) - if key not in self.stream_buffer: - self.stream_buffer[key] = [] - res = self.stream_buffer[key] - res.append(data) - if len(res) == self.group_size: - yield self.sort_data_fn(res) - del self.stream_buffer[key] - self.curr_buffer_size = self.curr_buffer_size - self.group_size + 1 - else: - if self.curr_buffer_size == self.max_buffer_size: - raise OverflowError( - "stream_buffer is overflow, please adjust the order of data " - "in the input datapipe or increase the buffer size!") - self.curr_buffer_size = self.curr_buffer_size + 1 - - if self.curr_buffer_size > 0: - msg = "Not able to group [{}] with group size {}.".format( - ','.join([v[0] for _, vs in self.stream_buffer.items() for v in vs]), str(self.group_size)) - raise RuntimeError(msg) - - def __len__(self) -> int: - if self.length == -1: - raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) - return self.length -- 2.7.4