[DataPipe] Unlimited buffer for Forker and Demultiplexer (#64994)
authorErjia Guan <erjia@fb.com>
Mon, 20 Sep 2021 15:54:36 +0000 (08:54 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 20 Sep 2021 16:30:39 +0000 (09:30 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64994

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D30934362

Pulled By: ejguan

fbshipit-source-id: d3b774d7e28c0b9659e999511e5a68c3929857d4

test/test_datapipe.py
torch/utils/data/datapipes/iter/combining.py

index b0e1d06..0da0c0a 100644 (file)
@@ -720,6 +720,12 @@ class TestFunctionalIterDataPipe(TestCase):
     def test_fork_datapipe(self):
         input_dp = IDP(range(10))
 
+        with self.assertRaises(ValueError):
+            input_dp.fork(num_instances=0)
+
+        dp1 = input_dp.fork(num_instances=1)
+        self.assertEqual(dp1, input_dp)
+
         # Test Case: making sure all child DataPipe shares the same reference
         dp1, dp2, dp3 = input_dp.fork(num_instances=3)
         self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))
@@ -744,6 +750,17 @@ class TestFunctionalIterDataPipe(TestCase):
             next(it1)
         with self.assertRaises(BufferError):
             next(it1)
+        with self.assertRaises(BufferError):
+            list(dp2)
+
+        # Test Case: one child DataPipe yields all value first with unlimited buffer
+        with warnings.catch_warnings(record=True) as wa:
+            dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1)
+            self.assertEqual(len(wa), 1)
+            self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
+        l1, l2 = list(dp1), list(dp2)
+        for d1, d2 in zip(l1, l2):
+            self.assertEqual(d1, d2)
 
         # Test Case: two child DataPipes yield value together with buffer size 1
         dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
@@ -818,10 +835,12 @@ class TestFunctionalIterDataPipe(TestCase):
         self.assertEqual(len(input_dp), len(dp2))
         self.assertEqual(len(input_dp), len(dp3))
 
-
     def test_demux_datapipe(self):
         input_dp = IDP(range(10))
 
+        with self.assertRaises(ValueError):
+            input_dp.demux(num_instances=0, classifier_fn=lambda x: 0)
+
         # Test Case: split into 2 DataPipes and output them one at a time
         dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
         output1, output2 = list(dp1), list(dp2)
@@ -840,6 +859,8 @@ class TestFunctionalIterDataPipe(TestCase):
         it1 = iter(dp1)
         with self.assertRaises(BufferError):
             next(it1)  # Buffer raises because first 5 elements all belong to the a different child
+        with self.assertRaises(BufferError):
+            list(dp2)
 
         # Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough
         dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5)
@@ -847,6 +868,19 @@ class TestFunctionalIterDataPipe(TestCase):
         self.assertEqual(list(range(5, 10)), output1)
         self.assertEqual(list(range(0, 5)), output2)
 
+        # Test Case: values of the same classification are lumped together, and unlimited buffer
+        with warnings.catch_warnings(record=True) as wa:
+            dp1, dp2 = input_dp.demux(
+                num_instances=2,
+                classifier_fn=lambda x: 0 if x >= 5 else 1,
+                buffer_size=-1
+            )
+            self.assertEqual(len(wa), 1)
+            self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set")
+        output1, output2 = list(dp1), list(dp2)
+        self.assertEqual(list(range(5, 10)), output1)
+        self.assertEqual(list(range(0, 5)), output2)
+
         # Test Case: classifer returns a value outside of [0, num_instance - 1]
         dp = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
         it = iter(dp[0])
index ed1256f..fdb731c 100644 (file)
@@ -66,9 +66,14 @@ class ForkerIterDataPipe(IterDataPipe):
             datapipe: Iterable DataPipe being copied
             num_instances: number of instances of the datapipe to create
             buffer_size: this restricts how far ahead the leading child DataPipe
-             can read relative to the slowest child DataPipe
+             can read relative to the slowest child DataPipe.
+             Use -1 for the unlmited buffer
     """
     def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
+        if num_instances < 1:
+            raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
+        if num_instances == 1:
+            return datapipe
         container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size)
         return [_ChildDataPipe(container, i) for i in range(num_instances)]
 
@@ -86,6 +91,12 @@ class _ForkerIterDataPipe(IterDataPipe):
         self.num_instances = num_instances
         self.buffer: Deque = deque()
         self.buffer_size = buffer_size
+        if self.buffer_size < 0:
+            warnings.warn(
+                "Unlimited buffer size is set for `fork`, "
+                "please be aware of OOM at random places",
+                UserWarning
+            )
         self.child_pointers = [0] * num_instances  # Indicate the indices of the next element to get
         self.slowest_ptr = 0
         self.leading_ptr = 0
@@ -100,7 +111,7 @@ class _ForkerIterDataPipe(IterDataPipe):
         while self.end_ptr is None or self.child_pointers[instance_id] < self.end_ptr:
             if not self.buffer or self.child_pointers[instance_id] > self.leading_ptr:
                 self.leading_ptr = self.child_pointers[instance_id]
-                if self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size:
+                if self.buffer_size >= 0 and self.leading_ptr - self.slowest_ptr + 1 > self.buffer_size:
                     raise BufferError("ForkerIterDataPipe buffer overflow," +
                                       f"buffer size {self.buffer_size} is insufficient.")
                 try:
@@ -181,10 +192,16 @@ class DemultiplexerIterDataPipe(IterDataPipe):
             classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None
             drop_none: defaults to False, if True, the function will skip over elements classified as None
             buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
-                DataPipes while waiting for their values to be yielded
+                DataPipes while waiting for their values to be yielded.
+                Use -1 for the unlimited buffer
     """
     def __new__(cls, datapipe: IterDataPipe, num_instances: int,
                 classifier_fn: Callable[[T_co], int], drop_none: bool = False, buffer_size: int = 1000):
+        if num_instances < 1:
+            raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
+        # When num_instances == 1, demux can be replaced by filter,
+        # but keep it as Demultiplexer for the sake of consistency
+        # like throwing Error when classification result is out of o range
         container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)
         return [_ChildDataPipe(container, i) for i in range(num_instances)]
 
@@ -202,7 +219,13 @@ class _DemultiplexerIterDataPipe(IterDataPipe):
         self.main_datapipe = datapipe
         self._datapipe_iterator: Optional[Iterator[Any]] = None
         self.num_instances = num_instances
-        self.max_buffer_size = buffer_size
+        self.buffer_size = buffer_size
+        if self.buffer_size < 0:
+            warnings.warn(
+                "Unlimited buffer size is set for `demux`, "
+                "please be aware of OOM at random places",
+                UserWarning
+            )
         self.current_buffer_usage = 0
         self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)]
         self.instance_started: List[bool] = [False] * num_instances
@@ -226,9 +249,9 @@ class _DemultiplexerIterDataPipe(IterDataPipe):
                 return value
             self.child_buffers[classification].append(value)
             self.current_buffer_usage += 1
-            if self.current_buffer_usage > self.max_buffer_size:
+            if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
                 raise BufferError(
-                    f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.max_buffer_size} is insufficient.")
+                    f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient.")
 
     def get_next_element_by_instance(self, instance_id: int):
         if self._datapipe_iterator is None: