[DataPipe] adding/removing __len__ for different DataPipe (#64398)
authorKevin Tse <ktse@fb.com>
Thu, 2 Sep 2021 20:06:18 +0000 (13:06 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 20:08:32 +0000 (13:08 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64398

cc VitalyFedyunin ejguan

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30710437

Pulled By: NivekT

fbshipit-source-id: 524eda43a2faa0db0c1a662bf9bb4283f0ade83c

test/test_datapipe.py
torch/utils/data/datapipes/iter/grouping.py
torch/utils/data/datapipes/iter/httpreader.py
torch/utils/data/datapipes/iter/selecting.py

index 24d0ce2..f09583b 100644 (file)
@@ -1626,6 +1626,27 @@ class TestSharding(TestCase):
 
         self.assertEqual(sorted(all_items), sorted(items))
 
+    def test_sharding_length(self):
+        numbers_dp = IDP(range(13))
+        sharded_dp0 = numbers_dp.sharding_filter()
+        torch.utils.data.sharding.apply_sharding(sharded_dp0, 3, 0)
+        sharded_dp1 = numbers_dp.sharding_filter()
+        torch.utils.data.sharding.apply_sharding(sharded_dp1, 3, 1)
+        sharded_dp2 = numbers_dp.sharding_filter()
+        torch.utils.data.sharding.apply_sharding(sharded_dp2, 3, 2)
+        self.assertEqual(13, len(numbers_dp))
+        self.assertEqual(5, len(sharded_dp0))
+        self.assertEqual(4, len(sharded_dp1))
+        self.assertEqual(4, len(sharded_dp2))
+
+        numbers_dp = IDP(range(1))
+        sharded_dp0 = numbers_dp.sharding_filter()
+        torch.utils.data.sharding.apply_sharding(sharded_dp0, 2, 0)
+        sharded_dp1 = numbers_dp.sharding_filter()
+        torch.utils.data.sharding.apply_sharding(sharded_dp1, 2, 1)
+        self.assertEqual(1, len(sharded_dp0))
+        self.assertEqual(0, len(sharded_dp1))
+
     @skipIfNoDill
     def test_old_dataloader(self):
         dp = self._get_pipeline()
index aece256..d90ad08 100644 (file)
@@ -28,6 +28,12 @@ class ShardingFilterIterDataPipe(IterDataPipe):
             if i % self.num_of_instances == self.instance_id:
                 yield item
 
+    def __len__(self):
+        if isinstance(self.source_datapipe, Sized):
+            return len(self.source_datapipe) // self.num_of_instances +\
+                (1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
+        raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+
 
 @functional_datapipe('batch')
 class BatcherIterDataPipe(IterDataPipe[DataChunk]):
index 747b5d5..0c8e2fc 100644 (file)
@@ -1,5 +1,5 @@
 from io import IOBase
-from typing import Tuple
+from typing import Sized, Tuple
 from urllib.error import HTTPError, URLError
 import urllib.request as urllib
 from torch.utils.data import IterDataPipe
@@ -39,3 +39,8 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
                                 .format(reason=e.reason, url=furl))
             except Exception:
                 raise
+
+    def __len__(self) -> int:
+        if isinstance(self.datapipe, Sized):
+            return len(self.datapipe)
+        raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
index a89bfdf..4e8703c 100644 (file)
@@ -77,6 +77,5 @@ class FilterIterDataPipe(MapperIterDataPipe):
             not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches)
         return r
 
-
     def __len__(self):
         raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))