self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
+
+ def test_fork_datapipe(self):
+ input_dp = IDP(range(10))
+
+ # Test Case: making sure all child DataPipe shares the same reference
+ dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+ self.assertTrue(all(n1 is n2 for n1, n2 in zip(dp1, dp2)))
+ self.assertTrue(all(n1 is n3 for n1, n3 in zip(dp1, dp3)))
+
+ # Test Case: one child DataPipe yields all value at a time
+ output1, output2, output3 = list(dp1), list(dp2), list(dp3)
+ self.assertEqual(list(range(10)), output1)
+ self.assertEqual(list(range(10)), output2)
+ self.assertEqual(list(range(10)), output3)
+
+ # Test Case: two child DataPipes yield value together
+ dp1, dp2 = input_dp.fork(num_instances=2)
+ output = []
+ for n1, n2 in zip(dp1, dp2):
+ output.append((n1, n2))
+ self.assertEqual([(i, i) for i in range(10)], output)
+
+ # Test Case: one child DataPipe yields all value first, but buffer_size = 5 being too small
+ dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5)
+ it1 = iter(dp1)
+ for _ in range(5):
+ next(it1)
+ with self.assertRaises(BufferError):
+ next(it1)
+
+ # Test Case: two child DataPipes yield value together with buffer size 1
+ dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
+ output = []
+ for n1, n2 in zip(dp1, dp2):
+ output.append((n1, n2))
+ self.assertEqual([(i, i) for i in range(10)], output)
+
+ # Test Case: make sure logic related to slowest_ptr is working properly
+ dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+ output1, output2 , output3 = [], [], []
+ for i, (n1, n2) in enumerate(zip(dp1, dp2)):
+ output1.append(n1)
+ output2.append(n2)
+ if i == 4: # yield all of dp3 when halfway through dp1, dp2
+ output3 = list(dp3)
+ break
+ self.assertEqual(list(range(5)), output1)
+ self.assertEqual(list(range(5)), output2)
+ self.assertEqual(list(range(10)), output3)
+
+ # Test Case: DataPipe doesn't reset if this pipe hasn't been read
+ dp1, dp2 = input_dp.fork(num_instances=2)
+ i1, i2 = iter(dp1), iter(dp2)
+ output2 = []
+ for i, n2 in enumerate(i2):
+ output2.append(n2)
+ if i == 4:
+ i1 = iter(dp1) # Doesn't reset because i1 hasn't been read
+ self.assertEqual(list(range(10)), output2)
+
+ # Test Case: DataPipe reset when some of it have been read
+ dp1, dp2 = input_dp.fork(num_instances=2)
+ i1, i2 = iter(dp1), iter(dp2)
+ output1, output2 = [], []
+ for i, (n1, n2) in enumerate(zip(i1, i2)):
+ output1.append(n1)
+ output2.append(n2)
+ if i == 4:
+ with warnings.catch_warnings(record=True) as wa:
+ i1 = iter(dp1) # Reset both all child DataPipe
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+ self.assertEqual(list(range(5)) + list(range(10)), output1)
+ self.assertEqual(list(range(5)) + list(range(10)), output2)
+
+ # Test Case: DataPipe reset, even when some other child DataPipes are not read
+ dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+ output1, output2 = list(dp1), list(dp2)
+ self.assertEqual(list(range(10)), output1)
+ self.assertEqual(list(range(10)), output2)
+ output1, output2 = list(dp1), list(dp2)
+ with warnings.catch_warnings(record=True) as wa:
+ self.assertEqual(list(range(10)), list(dp1)) # Resets even though dp3 has not been read
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+ output3 = []
+ for i, n3 in enumerate(dp3):
+ output3.append(n3)
+ if i == 4:
+ with warnings.catch_warnings(record=True) as wa:
+ output1 = list(dp1) # Resets even though dp3 is only partially read
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+ self.assertEqual(list(range(5)), output3)
+ self.assertEqual(list(range(10)), output1)
+ break
+ self.assertEqual(list(range(10)), list(dp3)) # dp3 has to read from the start again
+
+
def test_map_datapipe(self):
input_dp = IDP(range(10))
expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
self.assertEqual(expected, graph)
- # TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature
- # of the fork fake implementation, update fork first and fix this test too
@skipIfNoDill
def test_traverse_forked(self):
numbers_dp = NumbersDataset(size=50)
- dp0, dp1, dp2 = numbers_dp.fork(3)
+ dp0, dp1, dp2 = numbers_dp.fork(num_instances=3)
dp0_upd = dp0.map(lambda x: x * 10)
dp1_upd = dp1.filter(lambda x: x % 3 == 1)
combined_dp = dp0_upd.mux(dp1_upd, dp2)
graph = torch.utils.data.graph.traverse(combined_dp)
- expected = {combined_dp: {dp0_upd: {dp0: {}}, dp1_upd: {dp1: {}}, dp2: {}}}
+ expected = {combined_dp: {dp0_upd: {dp0: {dp0.main_datapipe: {dp0.main_datapipe.main_datapipe: {}}}},
+ dp1_upd: {dp1: {dp1.main_datapipe: {dp1.main_datapipe.main_datapipe: {}}}},
+ dp2: {dp2.main_datapipe: {dp2.main_datapipe.main_datapipe: {}}}}}
self.assertEqual(expected, graph)
class TestSharding(TestCase):
+
def _get_pipeline(self):
numbers_dp = NumbersDataset(size=10)
- dp0, dp1 = numbers_dp.fork(2)
+ dp0, dp1 = numbers_dp.fork(num_instances=2)
dp0_upd = dp0.map(lambda x: x * 10)
dp1_upd = dp1.filter(lambda x: x % 3 == 1)
combined_dp = dp0_upd.mux(dp1_upd)
import functools
+import warnings
from torch.utils.data import IterDataPipe, functional_datapipe
-from typing import Iterator, Optional, Sized, Tuple, TypeVar
+from typing import Any, Iterator, Optional, Sized, Tuple, TypeVar, Deque
+from collections import deque
T_co = TypeVar('T_co', covariant=True)
# This is fake class to show API, going to be replaced by the copy from torchdata
# TODO(VitalyFedyunin): Replace with valid version, documentation and tests
class IterateBuffer(IterDataPipe):
+
def __init__(self, buffer):
self.buffer = buffer
@functional_datapipe('fork')
class ForkerIterDataPipe(IterDataPipe):
+ r""" :class:`ForkerIterDataPipe`.
+
+ Iterable DataPipe to create multiple instances of the same Iterable DataPipe.
+ args:
+ datapipe: Iterable DataPipe being copied
+ num_instances: number of instances of the datapipe to create
+ buffer_size: this restricts how far ahead the leading child DataPipe
+ can read relative to the slowest child DataPipe
+ """
+ def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
+ container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size)
+ return [_ChildDataPipe(container, i) for i in range(num_instances)]
- def __new__(cls, datapipe, instances):
- result = []
- buffer = list(datapipe)
- return [IterateBuffer(buffer) for i in range(instances)]
+
+class _ForkerIterDataPipe(IterDataPipe):
+ r""" :class:`_ForkerIterDataPipe`.
+
+ Container to hold instance-specific information on behalf of ForkerIterDataPipe. It tracks
+ the state of its child DataPipes, maintains the buffer, and yields the next value
+ as requested by the child DataPipes.
+ """
+ def __init__(self, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
+ self.main_datapipe = datapipe
+ self._datapipe_iterator: Optional[Iterator[Any]] = None
+ self.num_instances = num_instances
+ self.buffer: Deque = deque()
+ self.buffer_size = buffer_size
+ self.child_pointers = [0] * num_instances # Indicate the indices of the next element to get
+ self.slowest_ptr = 0
+ self.leading_ptr = 0
+ self.end_ptr: Optional[int] = None
+
+ def get_next_element_by_instance(self, instance_id: int):
+ if self._datapipe_iterator is None:
+ self._datapipe_iterator = iter(self.main_datapipe)
+ while self.end_ptr is None or self.child_pointers[instance_id] < self.end_ptr:
+ if not self.buffer or self.child_pointers[instance_id] > self.leading_ptr:
+ self.leading_ptr = self.child_pointers[instance_id]
+ if self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size:
+ raise BufferError("ForkerIterDataPipe buffer overflow," +
+ f"buffer size {self.buffer_size} is insufficient.")
+ try:
+ self.buffer.append(next(self._datapipe_iterator))
+ self.child_pointers[instance_id] += 1
+ yield self.buffer[-1]
+ except StopIteration:
+ self.end_ptr = self.leading_ptr
+ else: # Child pointer is slower than or equal to the leading_ptr
+ buffer_index = self.child_pointers[instance_id] - self.slowest_ptr
+ return_val = self.buffer[buffer_index]
+ self.child_pointers[instance_id] += 1
+ if self.child_pointers[instance_id] - 1 == self.slowest_ptr:
+ new_min = min(self.child_pointers) # Can optimize by avoiding the call to min()
+ if self.slowest_ptr < new_min:
+ self.slowest_ptr = new_min
+ self.buffer.popleft()
+ yield return_val
+
+ def is_instance_started(self, instance_id: int) -> bool:
+ return self.child_pointers[instance_id] != 0
+
+ def is_every_instance_exhausted(self) -> bool:
+ return all(self.end_ptr == ptr for ptr in self.child_pointers)
+
+ def reset(self):
+ self._datapipe_iterator = iter(self.main_datapipe)
+ self.buffer = deque()
+ self.child_pointers = [0] * self.num_instances
+ self.slowest_ptr = 0
+ self.leading_ptr = 0
+ self.end_ptr = None
+
+class _ChildDataPipe(IterDataPipe):
+ r""" :class:`_ChildDataPipe`.
+
+ Iteratable Datapipe that is a child of a main DataPipe. The instance of this class
+ will pass its instance_id to get the next value from its main DataPipe.
+ args:
+ main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)'
+ instance_id: integer identifier of this instance
+ """
+ def __init__(self, main_datapipe, instance_id: int):
+ required_attrs = ["get_next_element_by_instance", "is_instance_started", "is_every_instance_exhausted", "reset"]
+ required_ops = [getattr(main_datapipe, attr) for attr in required_attrs]
+ if any(not callable(op) for op in required_ops):
+ raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.")
+ self.main_datapipe = main_datapipe
+ self.instance_id = instance_id
+
+ def __iter__(self):
+ if self.main_datapipe.is_instance_started(self.instance_id): # Only reset if the DataPipe started to read
+ if not self.main_datapipe.is_every_instance_exhausted():
+ warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
+ "the buffer and each child DataPipe will read from the start again.", UserWarning)
+ self.main_datapipe.reset()
+ # We want to separate the code for reset and yield, so that 'reset' exeutes before __next__ is called
+ return self.get_generator_by_instance(self.instance_id)
+
+ def get_generator_by_instance(self, instance_id: int):
+ yield from self.main_datapipe.get_next_element_by_instance(self.instance_id)
@functional_datapipe('demux')