Make datasets in `ConcatDataset` not need to be sized (#64114)
authorSantiago Castro <sacastro@umich.edu>
Wed, 1 Sep 2021 22:18:14 +0000 (15:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 22:32:50 +0000 (15:32 -0700)
Summary:
`datasets` needs to be iterable, but also sized because the length is checked. But immediately after it's converted to a list. By changing the order of these 2 lines, it doesn't need to be sized anymore.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64114

Reviewed By: H-Huang

Differential Revision: D30641480

Pulled By: ejguan

fbshipit-source-id: 7e16548c2123afa65b83845f9929271fa07fe1e8

torch/utils/data/dataset.py

index 609e1a1..50488d1 100644 (file)
@@ -271,9 +271,8 @@ class ConcatDataset(Dataset[T_co]):
 
     def __init__(self, datasets: Iterable[Dataset]) -> None:
         super(ConcatDataset, self).__init__()
-        # Cannot verify that datasets is Sized
-        assert len(datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
         self.datasets = list(datasets)
+        assert len(self.datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
         for d in self.datasets:
             assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
         self.cumulative_sizes = self.cumsum(self.datasets)