[DataPipe] implementing demux() (#63650)
authorKevin Tse <ktse@fb.com>
Tue, 31 Aug 2021 15:07:23 +0000 (08:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 15:32:29 +0000 (08:32 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63650

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30493944

Pulled By: NivekT

fbshipit-source-id: 0aa06dee8c7fb1744975b8f6a0694b90c11ef80d

test/test_datapipe.py
torch/utils/data/datapipes/iter/__init__.py
torch/utils/data/datapipes/iter/combining.py

index 842e442..b6e3513 100644 (file)
@@ -690,6 +690,105 @@ class TestFunctionalIterDataPipe(TestCase):
         self.assertEqual(list(range(10)), list(dp3))  # dp3 has to read from the start again
 
 
+    def test_demux_datapipe(self):
+        input_dp = IDP(range(10))
+
+        # Test Case: split into 2 DataPipes and output them one at a time
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+        output1, output2 = list(dp1), list(dp2)
+        self.assertEqual(list(range(0, 10, 2)), output1)
+        self.assertEqual(list(range(1, 10, 2)), output2)
+
+        # Test Case: split into 2 DataPipes and output them together
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+        output = []
+        for n1, n2 in zip(dp1, dp2):
+            output.append((n1, n2))
+        self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output)
+
+        # Test Case: values of the same classification are lumped together, and buffer_size = 3 being too small
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4)
+        it1 = iter(dp1)
+        with self.assertRaises(BufferError):
+            next(it1)  # Buffer raises because first 5 elements all belong to the a different child
+
+        # Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5)
+        output1, output2 = list(dp1), list(dp2)
+        self.assertEqual(list(range(5, 10)), output1)
+        self.assertEqual(list(range(0, 5)), output2)
+
+        # Test Case: classifer returns a value outside of [0, num_instance - 1]
+        dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
+        it = iter(dp[0])
+        with self.assertRaises(ValueError):
+            next(it)
+            next(it)
+
+        # Test Case: DataPipe doesn't reset when it has not been read
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+        i1 = iter(dp1)
+        output2 = []
+        i = 0
+        for i, n2 in enumerate(dp2):
+            output2.append(n2)
+            if i == 4:
+                i1 = iter(dp1)
+        self.assertEqual(list(range(1, 10, 2)), output2)
+
+        # Test Case: DataPipe reset when some of it has been read
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+        output1, output2 = [], []
+        for n1, n2 in zip(dp1, dp2):
+            output1.append(n1)
+            output2.append(n2)
+            if n1 == 4:
+                break
+        with warnings.catch_warnings(record=True) as wa:
+            i1 = iter(dp1)  # Reset all child DataPipes
+            self.assertEqual(len(wa), 1)
+            self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+        for n1, n2 in zip(dp1, dp2):
+            output1.append(n1)
+            output2.append(n2)
+        self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1)
+        self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2)
+
+        # Test Case: DataPipe reset, even when not all child DataPipes are exhausted
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+        output1 = list(dp1)
+        self.assertEqual(list(range(0, 10, 2)), output1)
+        with warnings.catch_warnings(record=True) as wa:
+            self.assertEqual(list(range(0, 10, 2)), list(dp1))  # Reset even when dp2 is not read
+            self.assertEqual(len(wa), 1)
+            self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+        output2 = []
+        for i, n2 in enumerate(dp2):
+            output2.append(n2)
+            if i == 1:
+                self.assertEqual(list(range(1, 5, 2)), output2)
+                with warnings.catch_warnings(record=True) as wa:
+                    self.assertEqual(list(range(0, 10, 2)), list(dp1))  # Can reset even when dp2 is partially read
+                    self.assertEqual(len(wa), 1)
+                    self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+                break
+        output2 = list(dp2)  # output2 has to read from beginning again
+        self.assertEqual(list(range(1, 10, 2)), output2)
+
+        # Test Case: drop_none = True
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
+                                  drop_none=True)
+        self.assertEqual([2, 4, 6, 8], list(dp1))
+        self.assertEqual([1, 3, 7, 9], list(dp2))
+
+        # Test Case: drop_none = False
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
+                                  drop_none=False)
+        it1 = iter(dp1)
+        with self.assertRaises(ValueError):
+            next(it1)
+
+
     def test_map_datapipe(self):
         input_dp = IDP(range(10))
 
index b460d4d..d4baef7 100644 (file)
@@ -8,7 +8,9 @@ from torch.utils.data.datapipes.iter.combinatorics import (
 )
 from torch.utils.data.datapipes.iter.combining import (
     ConcaterIterDataPipe as Concater,
+    DemultiplexerIterDataPipe as Demultiplexer,
     ForkerIterDataPipe as Forker,
+    MultiplexerIterDataPipe as Multiplexer,
     ZipperIterDataPipe as Zipper,
 )
 from torch.utils.data.datapipes.iter.filelister import (
index 85b3732..f44db96 100644 (file)
@@ -1,8 +1,7 @@
-import functools
 import warnings
 
 from torch.utils.data import IterDataPipe, functional_datapipe
-from typing import Any, Iterator, Optional, Sized, Tuple, TypeVar, Deque
+from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVar, Deque
 from collections import deque
 
 T_co = TypeVar('T_co', covariant=True)
@@ -163,14 +162,93 @@ class _ChildDataPipe(IterDataPipe):
 
 @functional_datapipe('demux')
 class DemultiplexerIterDataPipe(IterDataPipe):
+    r""" :class:`DemultiplexerIterDataPipe`.
 
-    def __new__(cls, datapipe, instances, classifier_fn):
-        result = []
-        buffer = list(datapipe)
+        Iterable DataPipe to split the input DataPipe into multiple child DataPipes, using the given
+        classification function. A list of the child DataPipes is returned from this operation.
+        args:
+            datapipe: Iterable DataPipe being filtered
+            num_instances: number of instances of the DataPipe to create
+            classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None
+            drop_none: defaults to False, if True, the function will skip over elements classified as None
+            buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
+                DataPipes while waiting for their values to be yielded
+    """
+    def __new__(cls, datapipe: IterDataPipe, num_instances: int,
+                classifier_fn: Callable[[T_co], int], drop_none: bool = False, buffer_size: int = 1000):
+        container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)
+        return [_ChildDataPipe(container, i) for i in range(num_instances)]
+
+
+class _DemultiplexerIterDataPipe(IterDataPipe):
+    r""" :class:`_DemultiplexerIterDataPipe`.
+
+        Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. It tracks
+        the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value
+        as requested by the child DataPipes.
+    """
+
+    def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int,
+                 classifier_fn: Callable[[T_co], int], drop_none: bool, buffer_size: int):
+        self.main_datapipe = datapipe
+        self._datapipe_iterator: Optional[Iterator[Any]] = None
+        self.num_instances = num_instances
+        self.max_buffer_size = buffer_size
+        self.current_buffer_usage = 0
+        self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)]
+        self.instance_started: List[bool] = [False] * num_instances
+        self.classifier_fn = classifier_fn
+        self.drop_none = drop_none
+        self.main_datapipe_exhausted = False
+
+    def _find_next(self, instance_id: int) -> T_co:
+        while True:
+            if self._datapipe_iterator is None:
+                raise ValueError("_datapipe_iterator has not been set, likely because this private method is called directly "
+                                 "without invoking get_next_element_by_instance() first.")
+            value = next(self._datapipe_iterator)
+            classification = self.classifier_fn(value)
+            if classification is None and self.drop_none:
+                continue
+            if classification is None or classification >= self.num_instances or classification < 0:
+                raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " +
+                                 f"{classification} is returned.")
+            if classification == instance_id:
+                return value
+            self.child_buffers[classification].append(value)
+            self.current_buffer_usage += 1
+            if self.current_buffer_usage > self.max_buffer_size:
+                raise BufferError(
+                    f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.max_buffer_size} is insufficient.")
 
-        def filter_fn(classifier_fn, i, x):
-            return classifier_fn(x) == i
-        return [IterateBuffer(buffer).filter(functools.partial(filter_fn, classifier_fn, i)) for i in range(instances)]
+    def get_next_element_by_instance(self, instance_id: int):
+        if self._datapipe_iterator is None:
+            self._datapipe_iterator = iter(self.main_datapipe)
+        stop = False
+        self.instance_started[instance_id] = True
+        while not stop:
+            if self.child_buffers[instance_id]:
+                self.current_buffer_usage -= 1
+                yield self.child_buffers[instance_id].popleft()
+            else:
+                try:
+                    yield self._find_next(instance_id)
+                except StopIteration:
+                    stop = True
+                    self.main_datapipe_exhausted = True
+
+    def is_instance_started(self, instance_id: int) -> bool:
+        return self.instance_started[instance_id]
+
+    def is_every_instance_exhausted(self) -> bool:
+        return self.main_datapipe_exhausted and all(not child_buffer for child_buffer in self.child_buffers)
+
+    def reset(self):
+        self._datapipe_iterator = iter(self.main_datapipe)
+        self.current_buffer_usage = 0
+        self.child_buffers = [deque() for _ in range(self.num_instances)]
+        self.instance_started = [False] * self.num_instances
+        self.main_datapipe_exhausted = False
 
 @functional_datapipe('mux')
 class MultiplexerIterDataPipe(IterDataPipe):