def test_fork_datapipe(self):
input_dp = IDP(range(10))
+ with self.assertRaises(ValueError):
+ input_dp.fork(num_instances=0)
+
+ dp1 = input_dp.fork(num_instances=1)
+ self.assertEqual(dp1, input_dp)
+
# Test Case: making sure all child DataPipe shares the same reference
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))
next(it1)
with self.assertRaises(BufferError):
next(it1)
+ with self.assertRaises(BufferError):
+ list(dp2)
+
+ # Test Case: one child DataPipe yields all value first with unlimited buffer
+ with warnings.catch_warnings(record=True) as wa:
+ dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1)
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
+ l1, l2 = list(dp1), list(dp2)
+ for d1, d2 in zip(l1, l2):
+ self.assertEqual(d1, d2)
# Test Case: two child DataPipes yield value together with buffer size 1
dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
self.assertEqual(len(input_dp), len(dp2))
self.assertEqual(len(input_dp), len(dp3))
-
def test_demux_datapipe(self):
input_dp = IDP(range(10))
+ with self.assertRaises(ValueError):
+ input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
+
# Test Case: split into 2 DataPipes and output them one at a time
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
output1, output2 = list(dp1), list(dp2)
it1 = iter(dp1)
with self.assertRaises(BufferError):
next(it1) # Buffer raises because first 5 elements all belong to the a different child
+ with self.assertRaises(BufferError):
+ list(dp2)
# Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5)
self.assertEqual(list(range(5, 10)), output1)
self.assertEqual(list(range(0, 5)), output2)
+ # Test Case: values of the same classification are lumped together, and unlimited buffer
+ with warnings.catch_warnings(record=True) as wa:
+ dp1, dp2 = input_dp.demux(
+ num_instances=2,
+ classifier_fn=lambda x: 0 if x >= 5 else 1,
+ buffer_size=-1
+ )
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
+ output1, output2 = list(dp1), list(dp2)
+ self.assertEqual(list(range(5, 10)), output1)
+ self.assertEqual(list(range(0, 5)), output2)
+
# Test Case: classifer returns a value outside of [0, num_instance - 1]
dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
it = iter(dp[0])
datapipe: Iterable DataPipe being copied
num_instances: number of instances of the datapipe to create
buffer_size: this restricts how far ahead the leading child DataPipe
- can read relative to the slowest child DataPipe
+ can read relative to the slowest child DataPipe.
+ Use -1 for the unlmited buffer
"""
def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
+ if num_instances < 1:
+ raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
+ if num_instances == 1:
+ return datapipe
container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size)
return [_ChildDataPipe(container, i) for i in range(num_instances)]
self.num_instances = num_instances
self.buffer: Deque = deque()
self.buffer_size = buffer_size
+ if self.buffer_size < 0:
+ warnings.warn(
+ "Unlimited buffer size is set for `fork`, "
+ "please be aware of OOM at random places",
+ UserWarning
+ )
self.child_pointers = [0] * num_instances # Indicate the indices of the next element to get
self.slowest_ptr = 0
self.leading_ptr = 0
while self.end_ptr is None or self.child_pointers[instance_id] < self.end_ptr:
if not self.buffer or self.child_pointers[instance_id] > self.leading_ptr:
self.leading_ptr = self.child_pointers[instance_id]
- if self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size:
+ if self.buffer_size >= 0 and self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size:
raise BufferError("ForkerIterDataPipe buffer overflow," +
f"buffer size {self.buffer_size} is insufficient.")
try:
classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None
drop_none: defaults to False, if True, the function will skip over elements classified as None
buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
- DataPipes while waiting for their values to be yielded
+ DataPipes while waiting for their values to be yielded.
+ Use -1 for the unlimited buffer
"""
def __new__(cls, datapipe: IterDataPipe, num_instances: int,
classifier_fn: Callable[[T_co], int], drop_none: bool = False, buffer_size: int = 1000):
+ if num_instances < 1:
+ raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
+ # When num_instances == 1, demux can be replaced by filter,
+ # but keep it as Demultiplexer for the sake of consistency
+ # like throwing Error when classification result is out of o range
container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)
return [_ChildDataPipe(container, i) for i in range(num_instances)]
self.main_datapipe = datapipe
self._datapipe_iterator: Optional[Iterator[Any]] = None
self.num_instances = num_instances
- self.max_buffer_size = buffer_size
+ self.buffer_size = buffer_size
+ if self.buffer_size < 0:
+ warnings.warn(
+ "Unlimited buffer size is set for `demux`, "
+ "please be aware of OOM at random places",
+ UserWarning
+ )
self.current_buffer_usage = 0
self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)]
self.instance_started: List[bool] = [False] * num_instances
return value
self.child_buffers[classification].append(value)
self.current_buffer_usage += 1
- if self.current_buffer_usage > self.max_buffer_size:
+ if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
raise BufferError(
- f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.max_buffer_size} is insufficient.")
+ f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient.")
def get_next_element_by_instance(self, instance_id: int):
if self._datapipe_iterator is None: