Refactor BucketBatch (#63185)
authorErjia Guan <erjia@fb.com>
Mon, 16 Aug 2021 13:39:56 +0000 (06:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 13:42:56 +0000 (06:42 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63185

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D30288893

Pulled By: ejguan

fbshipit-source-id: b88b792d12a83c99d8ea9e516e3b4c54a82100f6

test/test_datapipe.py
torch/utils/data/datapipes/iter/__init__.py
torch/utils/data/datapipes/iter/grouping.py

index 80fe758..9a7876e 100644 (file)
@@ -709,52 +709,48 @@ class TestFunctionalIterDataPipe(TestCase):
     def test_bucket_batch_datapipe(self):
         input_dp = IDP(range(20))
         with self.assertRaises(AssertionError):
-            input_dp.bucket_batch(batch_size=0)
+            dp.iter.BucketBatcher(input_dp, batch_size=0)
 
         input_dp_nl = IDP_NoLen(range(20))
-        bucket_dp_nl = input_dp_nl.bucket_batch(batch_size=7)
+        bucket_dp_nl = dp.iter.BucketBatcher(input_dp_nl, batch_size=7)
         with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
             len(bucket_dp_nl)
 
-        # Test Bucket Batch without sort_key
         def _helper(**kwargs):
-            arrs = list(range(100))
+            data_len = 100
+            arrs = list(range(data_len))
             random.shuffle(arrs)
             input_dp = IDP(arrs)
-            bucket_dp = input_dp.bucket_batch(**kwargs)
-            if kwargs["sort_key"] is None:
-                # BatchDataset as reference
-                ref_dp = input_dp.batch(batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last'])
-                for batch, rbatch in zip(bucket_dp, ref_dp):
-                    self.assertEqual(batch, rbatch)
-            else:
-                bucket_size = bucket_dp.bucket_size
-                bucket_num = (len(input_dp) - 1) // bucket_size + 1
-                it = iter(bucket_dp)
-                for i in range(bucket_num):
-                    ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size])
-                    bucket: List = []
-                    while len(bucket) < len(ref):
-                        try:
-                            batch = next(it)
-                            bucket += batch
-                        # If drop last, stop in advance
-                        except StopIteration:
-                            break
-                    if len(bucket) != len(ref):
-                        ref = ref[:len(bucket)]
-                    # Sorted bucket
-                    self.assertEqual(bucket, ref)
-
-        _helper(batch_size=7, drop_last=False, sort_key=None)
-        _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None)
-
-        # Test Bucket Batch with sort_key
+            bucket_dp = dp.iter.BucketBatcher(input_dp, **kwargs)
+
+            self.assertEqual(len(bucket_dp), data_len // 3 if kwargs['drop_last'] else data_len // 3 + 1)
+
+            def _verify_bucket_sorted(bucket):
+                # Sort batch in a bucket
+                bucket = sorted(bucket, key=lambda x: x[0])
+                flat = [item for batch in bucket for item in batch]
+                # Elements in the bucket should be sorted
+                self.assertEqual(flat, sorted(flat))
+
+            batch_num = kwargs['batch_num'] if 'batch_num' in kwargs else 100
+            bucket = []
+            for idx, d in enumerate(bucket_dp):
+                self.assertEqual(d, sorted(d))
+                bucket.append(d)
+                if idx % batch_num == batch_num - 1:
+                    _verify_bucket_sorted(bucket)
+                    bucket = []
+            _verify_bucket_sorted(bucket)
+
         def _sort_fn(data):
-            return data
+            return sorted(data)
+
+        # In-batch shuffle
+        _helper(batch_size=3, drop_last=False, batch_num=5, sort_key=_sort_fn)
+        _helper(batch_size=3, drop_last=False, batch_num=2, bucket_num=2, sort_key=_sort_fn)
+        _helper(batch_size=3, drop_last=True, batch_num=2, sort_key=_sort_fn)
+        _helper(batch_size=3, drop_last=True, batch_num=2, bucket_num=2, sort_key=_sort_fn)
 
-        _helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn)
-        _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn)
 
     def test_filter_datapipe(self):
         input_ds = IDP(range(10))
index b7718f8..0bcfdc4 100644 (file)
@@ -13,7 +13,7 @@ from torch.utils.data.datapipes.iter.combining import (
 )
 from torch.utils.data.datapipes.iter.grouping import (
     BatchIterDataPipe as Batch,
-    BucketBatchIterDataPipe as BucketBatch,
+    BucketBatcherIterDataPipe as BucketBatcher,
     GroupByKeyIterDataPipe as GroupByKey,
 )
 from torch.utils.data.datapipes.iter.httpreader import (
@@ -45,7 +45,7 @@ from torch.utils.data.datapipes.iter.tobytes import (
 )
 
 __all__ = ['Batch',
-           'BucketBatch',
+           'BucketBatcher',
            'Collate',
            'Concat',
            'Filter',
index 83fe797..1bd8c4c 100644 (file)
@@ -1,5 +1,6 @@
 import functools
 import os
+import random
 import warnings
 
 from collections import defaultdict
@@ -132,10 +133,14 @@ class UnBatchIterDataPipe(IterDataPipe):
             else:
                 raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
 
+# TODO(ejguan): https://github.com/pytorch/pytorch/issues/63095
+def _in_batch_shuffle_fn(data: DataChunk):
+    d = list(data)
+    random.shuffle(d)
+    return DataChunk(d)
 
-@functional_datapipe('bucket_batch')
-class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
-    r""" :class:`BucketBatchIterDataPipe`.
+class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
+    r""" :class:`BucketBatcherIterDataPipe`.
 
     Iterable DataPipe to create mini-batches of data from sorted bucket. An outer
     dimension will be added as `batch_size` if `drop_last` is set to `True`,
@@ -144,64 +149,78 @@ class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
         datapipe: Iterable DataPipe being batched
         batch_size: The size of each batch
         drop_last: Option to drop the last batch if it's not full
-        bucket_size_mul: The multiplier to specify the size of bucket
+        batch_num: Number of batches to consist a bucket
+        bucket_num: Number of buckets to consist a pool for shuffling
         sort_key: Callable to specify the comparison key for sorting within bucket
+        in_batch_shuffle: Option to do in-batch shuffle or buffer shuffle
     """
     datapipe: IterDataPipe[T_co]
     batch_size: int
     drop_last: bool
-    bucket_size_mul: int
+    batch_num: int
+    bucket_num: int
     sort_key: Optional[Callable]
+    in_batch_shuffle: bool
     length: Optional[int]
 
     def __init__(self,
                  datapipe: IterDataPipe[T_co],
                  batch_size: int,
                  drop_last: bool = False,
-                 bucket_size_mul: int = 100,
+                 batch_num: int = 100,
+                 bucket_num: int = 1,
                  sort_key: Optional[Callable] = None,
+                 in_batch_shuffle: bool = True
                  ) -> None:
         assert batch_size > 0, "Batch size is required to be larger than 0!"
+        assert batch_num > 0, "Number of batches is required to be larger than 0!"
+        assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
+
+        warnings.warn("`BucketBatcher` is going to be removed from PyTorch Core")
         super().__init__()
-        self.datapipe = datapipe
+
+        # TODO: Verify _datapippe is not going to be serialized twice
+        # and be able to reconstruct
+        self._datapipe = datapipe
         self.batch_size = batch_size
         self.drop_last = drop_last
-        self.bucket_size = batch_size * bucket_size_mul
+        self.batch_num = batch_num
+        self.bucket_num = bucket_num
         self.sort_key = sort_key
-        if sort_key is not None and sort_key.__name__ == '<lambda>':
-            warnings.warn("Lambda function is not supported for pickle, "
-                          "please use regular python function instead.")
-        self.bucket_ds = BatchIterDataPipe(datapipe, batch_size=self.bucket_size, drop_last=False)
+        self.in_batch_shuffle = in_batch_shuffle
+
+        self.bucket_size = batch_size * batch_num
+        self.pool_size = self.bucket_size * bucket_num
+
+        if bucket_num > 1 or sort_key is None:
+            if in_batch_shuffle:
+                datapipe = datapipe.batch(batch_size=self.pool_size, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
+            else:
+                datapipe = datapipe.shuffle(buffer_size=self.pool_size)
+        if sort_key is not None:
+            datapipe = datapipe.batch(self.bucket_size).map(fn=sort_key).unbatch()
+        datapipe = datapipe.batch(batch_size, drop_last=drop_last)
+        if sort_key is not None:
+            # In-batch shuffle each bucket seems not that useful
+            if in_batch_shuffle:
+                datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
+            else:
+                datapipe = datapipe.shuffle(buffer_size=self.bucket_size)
+        self.datapipe = datapipe
+
         self.length = None
 
-    def __iter__(self) -> Iterator[List[T_co]]:
-        # Bucket without sorting remains same order, directly returns BatchDataset
-        if self.sort_key is None:
-            for element in BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last):
-                if isinstance(element, DataChunk):
-                    yield list(element.raw_iterator())
-                else:
-                    yield element
-        else:
-            bucket: List[T_co]
-            batch: List[T_co] = []
-            for bucket_or_chunk in self.bucket_ds:
-                bucket = list(bucket_or_chunk)
-                # In-place sort within bucket
-                bucket.sort(key=self.sort_key)
-                for start in range(0, len(bucket), self.batch_size):
-                    batch = bucket[start: start + self.batch_size]
-                    if len(batch) == self.batch_size or not self.drop_last:
-                        yield batch
+    def __iter__(self) -> Iterator:
+        yield from self.datapipe
 
     def __len__(self) -> int:
         if self.length is not None:
             return self.length
-        if isinstance(self.datapipe, Sized):
+        if isinstance(self._datapipe, Sized):
             if self.drop_last:
-                self.length = len(self.datapipe) // self.batch_size
+                self.length = len(self._datapipe) // self.batch_size
             else:
-                self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+                self.length = (len(self._datapipe) + self.batch_size - 1) // self.batch_size
             return self.length
         raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))