From: Erjia Guan Date: Mon, 16 Aug 2021 13:39:56 +0000 (-0700) Subject: Refactor BucketBatch (#63185) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~1001 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d1cbee7b2b1da94ff1a0fa880bd0173de27fb89f;p=platform%2Fupstream%2Fpytorch.git Refactor BucketBatch (#63185) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63185 Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D30288893 Pulled By: ejguan fbshipit-source-id: b88b792d12a83c99d8ea9e516e3b4c54a82100f6 --- diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 80fe758..9a7876e 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -709,52 +709,48 @@ class TestFunctionalIterDataPipe(TestCase): def test_bucket_batch_datapipe(self): input_dp = IDP(range(20)) with self.assertRaises(AssertionError): - input_dp.bucket_batch(batch_size=0) + dp.iter.BucketBatcher(input_dp, batch_size=0) input_dp_nl = IDP_NoLen(range(20)) - bucket_dp_nl = input_dp_nl.bucket_batch(batch_size=7) + bucket_dp_nl = dp.iter.BucketBatcher(input_dp_nl, batch_size=7) with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): len(bucket_dp_nl) - # Test Bucket Batch without sort_key def _helper(**kwargs): - arrs = list(range(100)) + data_len = 100 + arrs = list(range(data_len)) random.shuffle(arrs) input_dp = IDP(arrs) - bucket_dp = input_dp.bucket_batch(**kwargs) - if kwargs["sort_key"] is None: - # BatchDataset as reference - ref_dp = input_dp.batch(batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last']) - for batch, rbatch in zip(bucket_dp, ref_dp): - self.assertEqual(batch, rbatch) - else: - bucket_size = bucket_dp.bucket_size - bucket_num = (len(input_dp) - 1) // bucket_size + 1 - it = iter(bucket_dp) - for i in range(bucket_num): - ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size]) - bucket: List = [] - while len(bucket) < len(ref): - try: - batch = next(it) - bucket += batch - # If drop last, stop in advance - except StopIteration: - break - if len(bucket) != len(ref): - ref = ref[:len(bucket)] - # Sorted bucket - self.assertEqual(bucket, ref) - - _helper(batch_size=7, drop_last=False, sort_key=None) - _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None) - - # Test Bucket Batch with sort_key + bucket_dp = dp.iter.BucketBatcher(input_dp, **kwargs) + + self.assertEqual(len(bucket_dp), data_len // 3 if kwargs['drop_last'] else data_len // 3 + 1) + + def _verify_bucket_sorted(bucket): + # Sort batch in a bucket + bucket = sorted(bucket, key=lambda x: x[0]) + flat = [item for batch in bucket for item in batch] + # Elements in the bucket should be sorted + self.assertEqual(flat, sorted(flat)) + + batch_num = kwargs['batch_num'] if 'batch_num' in kwargs else 100 + bucket = [] + for idx, d in enumerate(bucket_dp): + self.assertEqual(d, sorted(d)) + bucket.append(d) + if idx % batch_num == batch_num - 1: + _verify_bucket_sorted(bucket) + bucket = [] + _verify_bucket_sorted(bucket) + def _sort_fn(data): - return data + return sorted(data) + + # In-batch shuffle + _helper(batch_size=3, drop_last=False, batch_num=5, sort_key=_sort_fn) + _helper(batch_size=3, drop_last=False, batch_num=2, bucket_num=2, sort_key=_sort_fn) + _helper(batch_size=3, drop_last=True, batch_num=2, sort_key=_sort_fn) + _helper(batch_size=3, drop_last=True, batch_num=2, bucket_num=2, sort_key=_sort_fn) - _helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn) - _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn) def test_filter_datapipe(self): input_ds = IDP(range(10)) diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py index b7718f8..0bcfdc4 100644 --- a/torch/utils/data/datapipes/iter/__init__.py +++ b/torch/utils/data/datapipes/iter/__init__.py @@ -13,7 +13,7 @@ from torch.utils.data.datapipes.iter.combining import ( ) from torch.utils.data.datapipes.iter.grouping import ( BatchIterDataPipe as Batch, - BucketBatchIterDataPipe as BucketBatch, + BucketBatcherIterDataPipe as BucketBatcher, GroupByKeyIterDataPipe as GroupByKey, ) from torch.utils.data.datapipes.iter.httpreader import ( @@ -45,7 +45,7 @@ from torch.utils.data.datapipes.iter.tobytes import ( ) __all__ = ['Batch', - 'BucketBatch', + 'BucketBatcher', 'Collate', 'Concat', 'Filter', diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 83fe797..1bd8c4c 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,5 +1,6 @@ import functools import os +import random import warnings from collections import defaultdict @@ -132,10 +133,14 @@ class UnBatchIterDataPipe(IterDataPipe): else: raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe") +# TODO(ejguan): https://github.com/pytorch/pytorch/issues/63095 +def _in_batch_shuffle_fn(data: DataChunk): + d = list(data) + random.shuffle(d) + return DataChunk(d) -@functional_datapipe('bucket_batch') -class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]): - r""" :class:`BucketBatchIterDataPipe`. +class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): + r""" :class:`BucketBatcherIterDataPipe`. Iterable DataPipe to create mini-batches of data from sorted bucket. An outer dimension will be added as `batch_size` if `drop_last` is set to `True`, @@ -144,64 +149,78 @@ class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]): datapipe: Iterable DataPipe being batched batch_size: The size of each batch drop_last: Option to drop the last batch if it's not full - bucket_size_mul: The multiplier to specify the size of bucket + batch_num: Number of batches to consist a bucket + bucket_num: Number of buckets to consist a pool for shuffling sort_key: Callable to specify the comparison key for sorting within bucket + in_batch_shuffle: Option to do in-batch shuffle or buffer shuffle """ datapipe: IterDataPipe[T_co] batch_size: int drop_last: bool - bucket_size_mul: int + batch_num: int + bucket_num: int sort_key: Optional[Callable] + in_batch_shuffle: bool length: Optional[int] def __init__(self, datapipe: IterDataPipe[T_co], batch_size: int, drop_last: bool = False, - bucket_size_mul: int = 100, + batch_num: int = 100, + bucket_num: int = 1, sort_key: Optional[Callable] = None, + in_batch_shuffle: bool = True ) -> None: assert batch_size > 0, "Batch size is required to be larger than 0!" + assert batch_num > 0, "Number of batches is required to be larger than 0!" + assert bucket_num > 0, "Number of buckets is required to be larger than 0!" + + warnings.warn("`BucketBatcher` is going to be removed from PyTorch Core") super().__init__() - self.datapipe = datapipe + + # TODO: Verify _datapippe is not going to be serialized twice + # and be able to reconstruct + self._datapipe = datapipe self.batch_size = batch_size self.drop_last = drop_last - self.bucket_size = batch_size * bucket_size_mul + self.batch_num = batch_num + self.bucket_num = bucket_num self.sort_key = sort_key - if sort_key is not None and sort_key.__name__ == '': - warnings.warn("Lambda function is not supported for pickle, " - "please use regular python function instead.") - self.bucket_ds = BatchIterDataPipe(datapipe, batch_size=self.bucket_size, drop_last=False) + self.in_batch_shuffle = in_batch_shuffle + + self.bucket_size = batch_size * batch_num + self.pool_size = self.bucket_size * bucket_num + + if bucket_num > 1 or sort_key is None: + if in_batch_shuffle: + datapipe = datapipe.batch(batch_size=self.pool_size, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch() + else: + datapipe = datapipe.shuffle(buffer_size=self.pool_size) + if sort_key is not None: + datapipe = datapipe.batch(self.bucket_size).map(fn=sort_key).unbatch() + datapipe = datapipe.batch(batch_size, drop_last=drop_last) + if sort_key is not None: + # In-batch shuffle each bucket seems not that useful + if in_batch_shuffle: + datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch() + else: + datapipe = datapipe.shuffle(buffer_size=self.bucket_size) + self.datapipe = datapipe + self.length = None - def __iter__(self) -> Iterator[List[T_co]]: - # Bucket without sorting remains same order, directly returns BatchDataset - if self.sort_key is None: - for element in BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last): - if isinstance(element, DataChunk): - yield list(element.raw_iterator()) - else: - yield element - else: - bucket: List[T_co] - batch: List[T_co] = [] - for bucket_or_chunk in self.bucket_ds: - bucket = list(bucket_or_chunk) - # In-place sort within bucket - bucket.sort(key=self.sort_key) - for start in range(0, len(bucket), self.batch_size): - batch = bucket[start: start + self.batch_size] - if len(batch) == self.batch_size or not self.drop_last: - yield batch + def __iter__(self) -> Iterator: + yield from self.datapipe def __len__(self) -> int: if self.length is not None: return self.length - if isinstance(self.datapipe, Sized): + if isinstance(self._datapipe, Sized): if self.drop_last: - self.length = len(self.datapipe) // self.batch_size + self.length = len(self._datapipe) // self.batch_size else: - self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size + self.length = (len(self._datapipe) + self.batch_size - 1) // self.batch_size return self.length raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))