[DataPipe] implementing __len__ for fork (no valid length for demux) (#64215)
authorKevin Tse <ktse@fb.com>
Tue, 31 Aug 2021 15:07:23 +0000 (08:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 15:32:31 +0000 (08:32 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64215

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30648672

Pulled By: NivekT

fbshipit-source-id: 4780f2f6a79ae15a4009092475e7d92f96dd09a2

test/test_datapipe.py
torch/utils/data/datapipes/iter/combining.py

index b6e3513..4e37f41 100644 (file)
@@ -689,6 +689,12 @@ class TestFunctionalIterDataPipe(TestCase):
                 break
         self.assertEqual(list(range(10)), list(dp3))  # dp3 has to read from the start again
 
+        # Test Case: Each DataPipe inherits the source datapipe's length
+        dp1, dp2, dp3 = input_dp.fork(num_instances=3)
+        self.assertEqual(len(input_dp), len(dp1))
+        self.assertEqual(len(input_dp), len(dp2))
+        self.assertEqual(len(input_dp), len(dp3))
+
 
     def test_demux_datapipe(self):
         input_dp = IDP(range(10))
@@ -788,6 +794,13 @@ class TestFunctionalIterDataPipe(TestCase):
         with self.assertRaises(ValueError):
             next(it1)
 
+        # Test Case: __len__ not implemented
+        dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
+        with self.assertRaises(TypeError):
+            len(dp1)  # It is not implemented as we do not know length for each child in advance
+        with self.assertRaises(TypeError):
+            len(dp2)
+
 
     def test_map_datapipe(self):
         input_dp = IDP(range(10))
index f44db96..a837c5b 100644 (file)
@@ -61,7 +61,8 @@ class ForkerIterDataPipe(IterDataPipe):
     r""" :class:`ForkerIterDataPipe`.
 
         Iterable DataPipe to create multiple instances of the same Iterable DataPipe.
-        args:
+
+        Args:
             datapipe: Iterable DataPipe being copied
             num_instances: number of instances of the datapipe to create
             buffer_size: this restricts how far ahead the leading child DataPipe
@@ -90,6 +91,9 @@ class _ForkerIterDataPipe(IterDataPipe):
         self.leading_ptr = 0
         self.end_ptr: Optional[int] = None
 
+    def __len__(self):
+        return len(self.main_datapipe)
+
     def get_next_element_by_instance(self, instance_id: int):
         if self._datapipe_iterator is None:
             self._datapipe_iterator = iter(self.main_datapipe)
@@ -135,7 +139,8 @@ class _ChildDataPipe(IterDataPipe):
 
         Iteratable Datapipe that is a child of a main DataPipe. The instance of this class
         will pass its instance_id to get the next value from its main DataPipe.
-        args:
+
+        Args:
             main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)'
             instance_id: integer identifier of this instance
     """
@@ -156,6 +161,9 @@ class _ChildDataPipe(IterDataPipe):
         # We want to separate the code for reset and yield, so that 'reset' exeutes before __next__ is called
         return self.get_generator_by_instance(self.instance_id)
 
+    def __len__(self):
+        return len(self.main_datapipe)
+
     def get_generator_by_instance(self, instance_id: int):
         yield from self.main_datapipe.get_next_element_by_instance(self.instance_id)
 
@@ -166,7 +174,8 @@ class DemultiplexerIterDataPipe(IterDataPipe):
 
         Iterable DataPipe to split the input DataPipe into multiple child DataPipes, using the given
         classification function. A list of the child DataPipes is returned from this operation.
-        args:
+
+        Args:
             datapipe: Iterable DataPipe being filtered
             num_instances: number of instances of the DataPipe to create
             classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None