[DataPipe] implementing fork() (#63649)
authorKevin Tse <ktse@fb.com>
Tue, 31 Aug 2021 15:07:23 +0000 (08:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 15:32:27 +0000 (08:32 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63649

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30493945

Pulled By: NivekT

fbshipit-source-id: 40db7d4134facd266d86bc0dc2edf2729c4e5842

test/test_datapipe.py
torch/utils/data/datapipes/iter/__init__.py
torch/utils/data/datapipes/iter/combining.py

index c35698e..842e442 100644 (file)
@@ -591,6 +591,105 @@ class TestFunctionalIterDataPipe(TestCase):
 
         self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
 
+
+    def test_fork_datapipe(self):
+        input_dp = IDP(range(10))
+
+        # Test Case: making sure all child DataPipe shares the same reference
+        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+        self.assertTrue(all(n1 is n2 for n1, n2 in zip(dp1, dp2)))
+        self.assertTrue(all(n1 is n3 for n1, n3 in zip(dp1, dp3)))
+
+        # Test Case: one child DataPipe yields all value at a time
+        output1, output2, output3 = list(dp1), list(dp2), list(dp3)
+        self.assertEqual(list(range(10)), output1)
+        self.assertEqual(list(range(10)), output2)
+        self.assertEqual(list(range(10)), output3)
+
+        # Test Case: two child DataPipes yield value together
+        dp1, dp2 = input_dp.fork(num_instances=2)
+        output = []
+        for n1, n2 in zip(dp1, dp2):
+            output.append((n1, n2))
+        self.assertEqual([(i, i) for i in range(10)], output)
+
+        # Test Case: one child DataPipe yields all value first, but buffer_size = 5 being too small
+        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5)
+        it1 = iter(dp1)
+        for _ in range(5):
+            next(it1)
+        with self.assertRaises(BufferError):
+            next(it1)
+
+        # Test Case: two child DataPipes yield value together with buffer size 1
+        dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
+        output = []
+        for n1, n2 in zip(dp1, dp2):
+            output.append((n1, n2))
+        self.assertEqual([(i, i) for i in range(10)], output)
+
+        # Test Case: make sure logic related to slowest_ptr is working properly
+        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+        output1, output2 , output3 = [], [], []
+        for i, (n1, n2) in enumerate(zip(dp1, dp2)):
+            output1.append(n1)
+            output2.append(n2)
+            if i == 4:  # yield all of dp3 when halfway through dp1, dp2
+                output3 = list(dp3)
+                break
+        self.assertEqual(list(range(5)), output1)
+        self.assertEqual(list(range(5)), output2)
+        self.assertEqual(list(range(10)), output3)
+
+        # Test Case: DataPipe doesn't reset if this pipe hasn't been read
+        dp1, dp2 = input_dp.fork(num_instances=2)
+        i1, i2 = iter(dp1), iter(dp2)
+        output2 = []
+        for i, n2 in enumerate(i2):
+            output2.append(n2)
+            if i == 4:
+                i1 = iter(dp1)  # Doesn't reset because i1 hasn't been read
+        self.assertEqual(list(range(10)), output2)
+
+        # Test Case: DataPipe reset when some of it have been read
+        dp1, dp2 = input_dp.fork(num_instances=2)
+        i1, i2 = iter(dp1), iter(dp2)
+        output1, output2 = [], []
+        for i, (n1, n2) in enumerate(zip(i1, i2)):
+            output1.append(n1)
+            output2.append(n2)
+            if i == 4:
+                with warnings.catch_warnings(record=True) as wa:
+                    i1 = iter(dp1)  # Reset both all child DataPipe
+                    self.assertEqual(len(wa), 1)
+                    self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+        self.assertEqual(list(range(5)) + list(range(10)), output1)
+        self.assertEqual(list(range(5)) + list(range(10)), output2)
+
+        # Test Case: DataPipe reset, even when some other child DataPipes are not read
+        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+        output1, output2 = list(dp1), list(dp2)
+        self.assertEqual(list(range(10)), output1)
+        self.assertEqual(list(range(10)), output2)
+        output1, output2 = list(dp1), list(dp2)
+        with warnings.catch_warnings(record=True) as wa:
+            self.assertEqual(list(range(10)), list(dp1))  # Resets even though dp3 has not been read
+            self.assertEqual(len(wa), 1)
+            self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+        output3 = []
+        for i, n3 in enumerate(dp3):
+            output3.append(n3)
+            if i == 4:
+                with warnings.catch_warnings(record=True) as wa:
+                    output1 = list(dp1)  # Resets even though dp3 is only partially read
+                    self.assertEqual(len(wa), 1)
+                    self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
+                self.assertEqual(list(range(5)), output3)
+                self.assertEqual(list(range(10)), output1)
+                break
+        self.assertEqual(list(range(10)), list(dp3))  # dp3 has to read from the start again
+
+
     def test_map_datapipe(self):
         input_dp = IDP(range(10))
 
@@ -1333,24 +1432,25 @@ class TestGraph(TestCase):
         expected: Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}}
         self.assertEqual(expected, graph)
 
-    # TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature
-    # of the fork fake implementation, update fork first and fix this test too
     @skipIfNoDill
     def test_traverse_forked(self):
         numbers_dp = NumbersDataset(size=50)
-        dp0, dp1, dp2 = numbers_dp.fork(3)
+        dp0, dp1, dp2 = numbers_dp.fork(num_instances=3)
         dp0_upd = dp0.map(lambda x: x * 10)
         dp1_upd = dp1.filter(lambda x: x % 3 == 1)
         combined_dp = dp0_upd.mux(dp1_upd, dp2)
         graph = torch.utils.data.graph.traverse(combined_dp)
-        expected = {combined_dp: {dp0_upd: {dp0: {}}, dp1_upd: {dp1: {}}, dp2: {}}}
+        expected = {combined_dp: {dp0_upd: {dp0: {dp0.main_datapipe: {dp0.main_datapipe.main_datapipe: {}}}},
+                                  dp1_upd: {dp1: {dp1.main_datapipe: {dp1.main_datapipe.main_datapipe: {}}}},
+                                  dp2: {dp2.main_datapipe: {dp2.main_datapipe.main_datapipe: {}}}}}
         self.assertEqual(expected, graph)
 
 
 class TestSharding(TestCase):
+
     def _get_pipeline(self):
         numbers_dp = NumbersDataset(size=10)
-        dp0, dp1 = numbers_dp.fork(2)
+        dp0, dp1 = numbers_dp.fork(num_instances=2)
         dp0_upd = dp0.map(lambda x: x * 10)
         dp1_upd = dp1.filter(lambda x: x % 3 == 1)
         combined_dp = dp0_upd.mux(dp1_upd)
index b55bbf6..b460d4d 100644 (file)
@@ -8,6 +8,7 @@ from torch.utils.data.datapipes.iter.combinatorics import (
 )
 from torch.utils.data.datapipes.iter.combining import (
     ConcaterIterDataPipe as Concater,
+    ForkerIterDataPipe as Forker,
     ZipperIterDataPipe as Zipper,
 )
 from torch.utils.data.datapipes.iter.filelister import (
index 879e8be..85b3732 100644 (file)
@@ -1,7 +1,9 @@
 import functools
+import warnings
 
 from torch.utils.data import IterDataPipe, functional_datapipe
-from typing import Iterator, Optional, Sized, Tuple, TypeVar
+from typing import Any, Iterator, Optional, Sized, Tuple, TypeVar, Deque
+from collections import deque
 
 T_co = TypeVar('T_co', covariant=True)
 
@@ -46,6 +48,7 @@ class ConcaterIterDataPipe(IterDataPipe):
 # This is fake class to show API, going to be replaced by the copy from torchdata
 # TODO(VitalyFedyunin): Replace with valid version, documentation and tests
 class IterateBuffer(IterDataPipe):
+
     def __init__(self, buffer):
         self.buffer = buffer
 
@@ -56,11 +59,106 @@ class IterateBuffer(IterDataPipe):
 
 @functional_datapipe('fork')
 class ForkerIterDataPipe(IterDataPipe):
+    r""" :class:`ForkerIterDataPipe`.
+
+        Iterable DataPipe to create multiple instances of the same Iterable DataPipe.
+        args:
+            datapipe: Iterable DataPipe being copied
+            num_instances: number of instances of the datapipe to create
+            buffer_size: this restricts how far ahead the leading child DataPipe
+             can read relative to the slowest child DataPipe
+    """
+    def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
+        container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size)
+        return [_ChildDataPipe(container, i) for i in range(num_instances)]
 
-    def __new__(cls, datapipe, instances):
-        result = []
-        buffer = list(datapipe)
-        return [IterateBuffer(buffer) for i in range(instances)]
+
+class _ForkerIterDataPipe(IterDataPipe):
+    r""" :class:`_ForkerIterDataPipe`.
+
+        Container to hold instance-specific information on behalf of ForkerIterDataPipe. It tracks
+        the state of its child DataPipes, maintains the buffer, and yields the next value
+        as requested by the child DataPipes.
+    """
+    def __init__(self, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
+        self.main_datapipe = datapipe
+        self._datapipe_iterator: Optional[Iterator[Any]] = None
+        self.num_instances = num_instances
+        self.buffer: Deque = deque()
+        self.buffer_size = buffer_size
+        self.child_pointers = [0] * num_instances  # Indicate the indices of the next element to get
+        self.slowest_ptr = 0
+        self.leading_ptr = 0
+        self.end_ptr: Optional[int] = None
+
+    def get_next_element_by_instance(self, instance_id: int):
+        if self._datapipe_iterator is None:
+            self._datapipe_iterator = iter(self.main_datapipe)
+        while self.end_ptr is None or self.child_pointers[instance_id] < self.end_ptr:
+            if not self.buffer or self.child_pointers[instance_id] > self.leading_ptr:
+                self.leading_ptr = self.child_pointers[instance_id]
+                if self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size:
+                    raise BufferError("ForkerIterDataPipe buffer overflow," +
+                                      f"buffer size {self.buffer_size} is insufficient.")
+                try:
+                    self.buffer.append(next(self._datapipe_iterator))
+                    self.child_pointers[instance_id] += 1
+                    yield self.buffer[-1]
+                except StopIteration:
+                    self.end_ptr = self.leading_ptr
+            else:  # Child pointer is slower than or equal to the leading_ptr
+                buffer_index = self.child_pointers[instance_id] - self.slowest_ptr
+                return_val = self.buffer[buffer_index]
+                self.child_pointers[instance_id] += 1
+                if self.child_pointers[instance_id] - 1 == self.slowest_ptr:
+                    new_min = min(self.child_pointers)  # Can optimize by avoiding the call to min()
+                    if self.slowest_ptr < new_min:
+                        self.slowest_ptr = new_min
+                        self.buffer.popleft()
+                yield return_val
+
+    def is_instance_started(self, instance_id: int) -> bool:
+        return self.child_pointers[instance_id] != 0
+
+    def is_every_instance_exhausted(self) -> bool:
+        return all(self.end_ptr == ptr for ptr in self.child_pointers)
+
+    def reset(self):
+        self._datapipe_iterator = iter(self.main_datapipe)
+        self.buffer = deque()
+        self.child_pointers = [0] * self.num_instances
+        self.slowest_ptr = 0
+        self.leading_ptr = 0
+        self.end_ptr = None
+
+class _ChildDataPipe(IterDataPipe):
+    r""" :class:`_ChildDataPipe`.
+
+        Iteratable Datapipe that is a child of a main DataPipe. The instance of this class
+        will pass its instance_id to get the next value from its main DataPipe.
+        args:
+            main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)'
+            instance_id: integer identifier of this instance
+    """
+    def __init__(self, main_datapipe, instance_id: int):
+        required_attrs = ["get_next_element_by_instance", "is_instance_started", "is_every_instance_exhausted", "reset"]
+        required_ops = [getattr(main_datapipe, attr) for attr in required_attrs]
+        if any(not callable(op) for op in required_ops):
+            raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.")
+        self.main_datapipe = main_datapipe
+        self.instance_id = instance_id
+
+    def __iter__(self):
+        if self.main_datapipe.is_instance_started(self.instance_id):  # Only reset if the DataPipe started to read
+            if not self.main_datapipe.is_every_instance_exhausted():
+                warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
+                              "the buffer and each child DataPipe will read from the start again.", UserWarning)
+            self.main_datapipe.reset()
+        # We want to separate the code for reset and yield, so that 'reset' exeutes before __next__ is called
+        return self.get_generator_by_instance(self.instance_id)
+
+    def get_generator_by_instance(self, instance_id: int):
+        yield from self.main_datapipe.get_next_element_by_instance(self.instance_id)
 
 
 @functional_datapipe('demux')