self.assertEqual(sorted(all_items), sorted(items))
+ def test_sharding_length(self):
+ numbers_dp = IDP(range(13))
+ sharded_dp0 = numbers_dp.sharding_filter()
+ torch.utils.data.sharding.apply_sharding(sharded_dp0, 3, 0)
+ sharded_dp1 = numbers_dp.sharding_filter()
+ torch.utils.data.sharding.apply_sharding(sharded_dp1, 3, 1)
+ sharded_dp2 = numbers_dp.sharding_filter()
+ torch.utils.data.sharding.apply_sharding(sharded_dp2, 3, 2)
+ self.assertEqual(13, len(numbers_dp))
+ self.assertEqual(5, len(sharded_dp0))
+ self.assertEqual(4, len(sharded_dp1))
+ self.assertEqual(4, len(sharded_dp2))
+
+ numbers_dp = IDP(range(1))
+ sharded_dp0 = numbers_dp.sharding_filter()
+ torch.utils.data.sharding.apply_sharding(sharded_dp0, 2, 0)
+ sharded_dp1 = numbers_dp.sharding_filter()
+ torch.utils.data.sharding.apply_sharding(sharded_dp1, 2, 1)
+ self.assertEqual(1, len(sharded_dp0))
+ self.assertEqual(0, len(sharded_dp1))
+
@skipIfNoDill
def test_old_dataloader(self):
dp = self._get_pipeline()
if i % self.num_of_instances == self.instance_id:
yield item
+ def __len__(self):
+ if isinstance(self.source_datapipe, Sized):
+ return len(self.source_datapipe) // self.num_of_instances +\
+ (1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+
@functional_datapipe('batch')
class BatcherIterDataPipe(IterDataPipe[DataChunk]):
from io import IOBase
-from typing import Tuple
+from typing import Sized, Tuple
from urllib.error import HTTPError, URLError
import urllib.request as urllib
from torch.utils.data import IterDataPipe
.format(reason=e.reason, url=furl))
except Exception:
raise
+
+ def __len__(self) -> int:
+ if isinstance(self.datapipe, Sized):
+ return len(self.datapipe)
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))