self.assertEqual(list(range(10)), list(dp3)) # dp3 has to read from the start again
+ def test_demux_datapipe(self):
+ input_dp = IDP(range(10))
+
+ # Test Case: split into 2 DataPipes and output them one at a time
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+ output1, output2 = list(dp1), list(dp2)
+ self.assertEqual(list(range(0, 10, 2)), output1)
+ self.assertEqual(list(range(1, 10, 2)), output2)
+
+ # Test Case: split into 2 DataPipes and output them together
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+ output = []
+ for n1, n2 in zip(dp1, dp2):
+ output.append((n1, n2))
+ self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output)
+
+ # Test Case: values of the same classification are lumped together, and buffer_size = 3 being too small
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4)
+ it1 = iter(dp1)
+ with self.assertRaises(BufferError):
+ next(it1) # Buffer raises because first 5 elements all belong to the a different child
+
+ # Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5)
+ output1, output2 = list(dp1), list(dp2)
+ self.assertEqual(list(range(5, 10)), output1)
+ self.assertEqual(list(range(0, 5)), output2)
+
+ # Test Case: classifer returns a value outside of [0, num_instance - 1]
+ dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
+ it = iter(dp[0])
+ with self.assertRaises(ValueError):
+ next(it)
+ next(it)
+
+ # Test Case: DataPipe doesn't reset when it has not been read
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+ i1 = iter(dp1)
+ output2 = []
+ i = 0
+ for i, n2 in enumerate(dp2):
+ output2.append(n2)
+ if i == 4:
+ i1 = iter(dp1)
+ self.assertEqual(list(range(1, 10, 2)), output2)
+
+ # Test Case: DataPipe reset when some of it has been read
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+ output1, output2 = [], []
+ for n1, n2 in zip(dp1, dp2):
+ output1.append(n1)
+ output2.append(n2)
+ if n1 == 4:
+ break
+ with warnings.catch_warnings(record=True) as wa:
+ i1 = iter(dp1) # Reset all child DataPipes
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+ for n1, n2 in zip(dp1, dp2):
+ output1.append(n1)
+ output2.append(n2)
+ self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1)
+ self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2)
+
+ # Test Case: DataPipe reset, even when not all child DataPipes are exhausted
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+ output1 = list(dp1)
+ self.assertEqual(list(range(0, 10, 2)), output1)
+ with warnings.catch_warnings(record=True) as wa:
+ self.assertEqual(list(range(0, 10, 2)), list(dp1)) # Reset even when dp2 is not read
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+ output2 = []
+ for i, n2 in enumerate(dp2):
+ output2.append(n2)
+ if i == 1:
+ self.assertEqual(list(range(1, 5, 2)), output2)
+ with warnings.catch_warnings(record=True) as wa:
+ self.assertEqual(list(range(0, 10, 2)), list(dp1)) # Can reset even when dp2 is partially read
+ self.assertEqual(len(wa), 1)
+ self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+ break
+ output2 = list(dp2) # output2 has to read from beginning again
+ self.assertEqual(list(range(1, 10, 2)), output2)
+
+ # Test Case: drop_none = True
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
+ drop_none=True)
+ self.assertEqual([2, 4, 6, 8], list(dp1))
+ self.assertEqual([1, 3, 7, 9], list(dp2))
+
+ # Test Case: drop_none = False
+ dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
+ drop_none=False)
+ it1 = iter(dp1)
+ with self.assertRaises(ValueError):
+ next(it1)
+
+
def test_map_datapipe(self):
input_dp = IDP(range(10))
)
from torch.utils.data.datapipes.iter.combining import (
ConcaterIterDataPipe as Concater,
+ DemultiplexerIterDataPipe as Demultiplexer,
ForkerIterDataPipe as Forker,
+ MultiplexerIterDataPipe as Multiplexer,
ZipperIterDataPipe as Zipper,
)
from torch.utils.data.datapipes.iter.filelister import (
-import functools
import warnings
from torch.utils.data import IterDataPipe, functional_datapipe
-from typing import Any, Iterator, Optional, Sized, Tuple, TypeVar, Deque
+from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVar, Deque
from collections import deque
T_co = TypeVar('T_co', covariant=True)
@functional_datapipe('demux')
class DemultiplexerIterDataPipe(IterDataPipe):
+ r""" :class:`DemultiplexerIterDataPipe`.
- def __new__(cls, datapipe, instances, classifier_fn):
- result = []
- buffer = list(datapipe)
+ Iterable DataPipe to split the input DataPipe into multiple child DataPipes, using the given
+ classification function. A list of the child DataPipes is returned from this operation.
+ args:
+ datapipe: Iterable DataPipe being filtered
+ num_instances: number of instances of the DataPipe to create
+ classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None
+ drop_none: defaults to False, if True, the function will skip over elements classified as None
+ buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
+ DataPipes while waiting for their values to be yielded
+ """
+ def __new__(cls, datapipe: IterDataPipe, num_instances: int,
+ classifier_fn: Callable[[T_co], int], drop_none: bool = False, buffer_size: int = 1000):
+ container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)
+ return [_ChildDataPipe(container, i) for i in range(num_instances)]
+
+
+class _DemultiplexerIterDataPipe(IterDataPipe):
+ r""" :class:`_DemultiplexerIterDataPipe`.
+
+ Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. It tracks
+ the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value
+ as requested by the child DataPipes.
+ """
+
+ def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int,
+ classifier_fn: Callable[[T_co], int], drop_none: bool, buffer_size: int):
+ self.main_datapipe = datapipe
+ self._datapipe_iterator: Optional[Iterator[Any]] = None
+ self.num_instances = num_instances
+ self.max_buffer_size = buffer_size
+ self.current_buffer_usage = 0
+ self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)]
+ self.instance_started: List[bool] = [False] * num_instances
+ self.classifier_fn = classifier_fn
+ self.drop_none = drop_none
+ self.main_datapipe_exhausted = False
+
+ def _find_next(self, instance_id: int) -> T_co:
+ while True:
+ if self._datapipe_iterator is None:
+ raise ValueError("_datapipe_iterator has not been set, likely because this private method is called directly "
+ "without invoking get_next_element_by_instance() first.")
+ value = next(self._datapipe_iterator)
+ classification = self.classifier_fn(value)
+ if classification is None and self.drop_none:
+ continue
+ if classification is None or classification >= self.num_instances or classification < 0:
+ raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " +
+ f"{classification} is returned.")
+ if classification == instance_id:
+ return value
+ self.child_buffers[classification].append(value)
+ self.current_buffer_usage += 1
+ if self.current_buffer_usage > self.max_buffer_size:
+ raise BufferError(
+ f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.max_buffer_size} is insufficient.")
- def filter_fn(classifier_fn, i, x):
- return classifier_fn(x) == i
- return [IterateBuffer(buffer).filter(functools.partial(filter_fn, classifier_fn, i)) for i in range(instances)]
+ def get_next_element_by_instance(self, instance_id: int):
+ if self._datapipe_iterator is None:
+ self._datapipe_iterator = iter(self.main_datapipe)
+ stop = False
+ self.instance_started[instance_id] = True
+ while not stop:
+ if self.child_buffers[instance_id]:
+ self.current_buffer_usage -= 1
+ yield self.child_buffers[instance_id].popleft()
+ else:
+ try:
+ yield self._find_next(instance_id)
+ except StopIteration:
+ stop = True
+ self.main_datapipe_exhausted = True
+
+ def is_instance_started(self, instance_id: int) -> bool:
+ return self.instance_started[instance_id]
+
+ def is_every_instance_exhausted(self) -> bool:
+ return self.main_datapipe_exhausted and all(not child_buffer for child_buffer in self.child_buffers)
+
+ def reset(self):
+ self._datapipe_iterator = iter(self.main_datapipe)
+ self.current_buffer_usage = 0
+ self.child_buffers = [deque() for _ in range(self.num_instances)]
+ self.instance_started = [False] * self.num_instances
+ self.main_datapipe_exhausted = False
@functional_datapipe('mux')
class MultiplexerIterDataPipe(IterDataPipe):