From: Kevin Tse Date: Tue, 31 Aug 2021 15:07:23 +0000 (-0700) Subject: [DataPipe] implementing __len__ for fork (no valid length for demux) (#64215) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~557 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0ef8760bf6b3e8098ef42df60f1e451234151f32;p=platform%2Fupstream%2Fpytorch.git [DataPipe] implementing __len__ for fork (no valid length for demux) (#64215) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64215 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30648672 Pulled By: NivekT fbshipit-source-id: 4780f2f6a79ae15a4009092475e7d92f96dd09a2 --- diff --git a/test/test_datapipe.py b/test/test_datapipe.py index b6e3513..4e37f41 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -689,6 +689,12 @@ class TestFunctionalIterDataPipe(TestCase): break self.assertEqual(list(range(10)), list(dp3)) # dp3 has to read from the start again + # Test Case: Each DataPipe inherits the source datapipe's length + dp1, dp2, dp3 = input_dp.fork(num_instances=3) + self.assertEqual(len(input_dp), len(dp1)) + self.assertEqual(len(input_dp), len(dp2)) + self.assertEqual(len(input_dp), len(dp3)) + def test_demux_datapipe(self): input_dp = IDP(range(10)) @@ -788,6 +794,13 @@ class TestFunctionalIterDataPipe(TestCase): with self.assertRaises(ValueError): next(it1) + # Test Case: __len__ not implemented + dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) + with self.assertRaises(TypeError): + len(dp1) # It is not implemented as we do not know length for each child in advance + with self.assertRaises(TypeError): + len(dp2) + def test_map_datapipe(self): input_dp = IDP(range(10)) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index f44db96..a837c5b 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -61,7 +61,8 @@ class ForkerIterDataPipe(IterDataPipe): r""" :class:`ForkerIterDataPipe`. Iterable DataPipe to create multiple instances of the same Iterable DataPipe. - args: + + Args: datapipe: Iterable DataPipe being copied num_instances: number of instances of the datapipe to create buffer_size: this restricts how far ahead the leading child DataPipe @@ -90,6 +91,9 @@ class _ForkerIterDataPipe(IterDataPipe): self.leading_ptr = 0 self.end_ptr: Optional[int] = None + def __len__(self): + return len(self.main_datapipe) + def get_next_element_by_instance(self, instance_id: int): if self._datapipe_iterator is None: self._datapipe_iterator = iter(self.main_datapipe) @@ -135,7 +139,8 @@ class _ChildDataPipe(IterDataPipe): Iteratable Datapipe that is a child of a main DataPipe. The instance of this class will pass its instance_id to get the next value from its main DataPipe. - args: + + Args: main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)' instance_id: integer identifier of this instance """ @@ -156,6 +161,9 @@ class _ChildDataPipe(IterDataPipe): # We want to separate the code for reset and yield, so that 'reset' exeutes before __next__ is called return self.get_generator_by_instance(self.instance_id) + def __len__(self): + return len(self.main_datapipe) + def get_generator_by_instance(self, instance_id: int): yield from self.main_datapipe.get_next_element_by_instance(self.instance_id) @@ -166,7 +174,8 @@ class DemultiplexerIterDataPipe(IterDataPipe): Iterable DataPipe to split the input DataPipe into multiple child DataPipes, using the given classification function. A list of the child DataPipes is returned from this operation. - args: + + Args: datapipe: Iterable DataPipe being filtered num_instances: number of instances of the DataPipe to create classifier_fn: a function that maps values to an integer within the range [0, num_instances - 1] or None