def test_bucket_batch_datapipe(self):
input_dp = IDP(range(20))
with self.assertRaises(AssertionError):
- input_dp.bucket_batch(batch_size=0)
+ dp.iter.BucketBatcher(input_dp, batch_size=0)
input_dp_nl = IDP_NoLen(range(20))
- bucket_dp_nl = input_dp_nl.bucket_batch(batch_size=7)
+ bucket_dp_nl = dp.iter.BucketBatcher(input_dp_nl, batch_size=7)
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(bucket_dp_nl)
- # Test Bucket Batch without sort_key
def _helper(**kwargs):
- arrs = list(range(100))
+ data_len = 100
+ arrs = list(range(data_len))
random.shuffle(arrs)
input_dp = IDP(arrs)
- bucket_dp = input_dp.bucket_batch(**kwargs)
- if kwargs["sort_key"] is None:
- # BatchDataset as reference
- ref_dp = input_dp.batch(batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last'])
- for batch, rbatch in zip(bucket_dp, ref_dp):
- self.assertEqual(batch, rbatch)
- else:
- bucket_size = bucket_dp.bucket_size
- bucket_num = (len(input_dp) - 1) // bucket_size + 1
- it = iter(bucket_dp)
- for i in range(bucket_num):
- ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size])
- bucket: List = []
- while len(bucket) < len(ref):
- try:
- batch = next(it)
- bucket += batch
- # If drop last, stop in advance
- except StopIteration:
- break
- if len(bucket) != len(ref):
- ref = ref[:len(bucket)]
- # Sorted bucket
- self.assertEqual(bucket, ref)
-
- _helper(batch_size=7, drop_last=False, sort_key=None)
- _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None)
-
- # Test Bucket Batch with sort_key
+ bucket_dp = dp.iter.BucketBatcher(input_dp, **kwargs)
+
+ self.assertEqual(len(bucket_dp), data_len // 3 if kwargs['drop_last'] else data_len // 3 + 1)
+
+ def _verify_bucket_sorted(bucket):
+ # Sort batch in a bucket
+ bucket = sorted(bucket, key=lambda x: x[0])
+ flat = [item for batch in bucket for item in batch]
+ # Elements in the bucket should be sorted
+ self.assertEqual(flat, sorted(flat))
+
+ batch_num = kwargs['batch_num'] if 'batch_num' in kwargs else 100
+ bucket = []
+ for idx, d in enumerate(bucket_dp):
+ self.assertEqual(d, sorted(d))
+ bucket.append(d)
+ if idx % batch_num == batch_num - 1:
+ _verify_bucket_sorted(bucket)
+ bucket = []
+ _verify_bucket_sorted(bucket)
+
def _sort_fn(data):
- return data
+ return sorted(data)
+
+ # In-batch shuffle
+ _helper(batch_size=3, drop_last=False, batch_num=5, sort_key=_sort_fn)
+ _helper(batch_size=3, drop_last=False, batch_num=2, bucket_num=2, sort_key=_sort_fn)
+ _helper(batch_size=3, drop_last=True, batch_num=2, sort_key=_sort_fn)
+ _helper(batch_size=3, drop_last=True, batch_num=2, bucket_num=2, sort_key=_sort_fn)
- _helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn)
- _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn)
def test_filter_datapipe(self):
input_ds = IDP(range(10))
import functools
import os
+import random
import warnings
from collections import defaultdict
else:
raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
+# TODO(ejguan): https://github.com/pytorch/pytorch/issues/63095
+def _in_batch_shuffle_fn(data: DataChunk):
+ d = list(data)
+ random.shuffle(d)
+ return DataChunk(d)
-@functional_datapipe('bucket_batch')
-class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
- r""" :class:`BucketBatchIterDataPipe`.
+class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]):
+ r""" :class:`BucketBatcherIterDataPipe`.
Iterable DataPipe to create mini-batches of data from sorted bucket. An outer
dimension will be added as `batch_size` if `drop_last` is set to `True`,
datapipe: Iterable DataPipe being batched
batch_size: The size of each batch
drop_last: Option to drop the last batch if it's not full
- bucket_size_mul: The multiplier to specify the size of bucket
+ batch_num: Number of batches to consist a bucket
+ bucket_num: Number of buckets to consist a pool for shuffling
sort_key: Callable to specify the comparison key for sorting within bucket
+ in_batch_shuffle: Option to do in-batch shuffle or buffer shuffle
"""
datapipe: IterDataPipe[T_co]
batch_size: int
drop_last: bool
- bucket_size_mul: int
+ batch_num: int
+ bucket_num: int
sort_key: Optional[Callable]
+ in_batch_shuffle: bool
length: Optional[int]
def __init__(self,
datapipe: IterDataPipe[T_co],
batch_size: int,
drop_last: bool = False,
- bucket_size_mul: int = 100,
+ batch_num: int = 100,
+ bucket_num: int = 1,
sort_key: Optional[Callable] = None,
+ in_batch_shuffle: bool = True
) -> None:
assert batch_size > 0, "Batch size is required to be larger than 0!"
+ assert batch_num > 0, "Number of batches is required to be larger than 0!"
+ assert bucket_num > 0, "Number of buckets is required to be larger than 0!"
+
+ warnings.warn("`BucketBatcher` is going to be removed from PyTorch Core")
super().__init__()
- self.datapipe = datapipe
+
+ # TODO: Verify _datapippe is not going to be serialized twice
+ # and be able to reconstruct
+ self._datapipe = datapipe
self.batch_size = batch_size
self.drop_last = drop_last
- self.bucket_size = batch_size * bucket_size_mul
+ self.batch_num = batch_num
+ self.bucket_num = bucket_num
self.sort_key = sort_key
- if sort_key is not None and sort_key.__name__ == '<lambda>':
- warnings.warn("Lambda function is not supported for pickle, "
- "please use regular python function instead.")
- self.bucket_ds = BatchIterDataPipe(datapipe, batch_size=self.bucket_size, drop_last=False)
+ self.in_batch_shuffle = in_batch_shuffle
+
+ self.bucket_size = batch_size * batch_num
+ self.pool_size = self.bucket_size * bucket_num
+
+ if bucket_num > 1 or sort_key is None:
+ if in_batch_shuffle:
+ datapipe = datapipe.batch(batch_size=self.pool_size, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
+ else:
+ datapipe = datapipe.shuffle(buffer_size=self.pool_size)
+ if sort_key is not None:
+ datapipe = datapipe.batch(self.bucket_size).map(fn=sort_key).unbatch()
+ datapipe = datapipe.batch(batch_size, drop_last=drop_last)
+ if sort_key is not None:
+ # In-batch shuffle each bucket seems not that useful
+ if in_batch_shuffle:
+ datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).map(fn=_in_batch_shuffle_fn).unbatch()
+ else:
+ datapipe = datapipe.shuffle(buffer_size=self.bucket_size)
+ self.datapipe = datapipe
+
self.length = None
- def __iter__(self) -> Iterator[List[T_co]]:
- # Bucket without sorting remains same order, directly returns BatchDataset
- if self.sort_key is None:
- for element in BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last):
- if isinstance(element, DataChunk):
- yield list(element.raw_iterator())
- else:
- yield element
- else:
- bucket: List[T_co]
- batch: List[T_co] = []
- for bucket_or_chunk in self.bucket_ds:
- bucket = list(bucket_or_chunk)
- # In-place sort within bucket
- bucket.sort(key=self.sort_key)
- for start in range(0, len(bucket), self.batch_size):
- batch = bucket[start: start + self.batch_size]
- if len(batch) == self.batch_size or not self.drop_last:
- yield batch
+ def __iter__(self) -> Iterator:
+ yield from self.datapipe
def __len__(self) -> int:
if self.length is not None:
return self.length
- if isinstance(self.datapipe, Sized):
+ if isinstance(self._datapipe, Sized):
if self.drop_last:
- self.length = len(self.datapipe) // self.batch_size
+ self.length = len(self._datapipe) // self.batch_size
else:
- self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+ self.length = (len(self._datapipe) + self.batch_size - 1) // self.batch_size
return self.length
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))