fix RandomSampler length (#15991)
authorkyryl <truskovskiyk@gmail.com>
Mon, 14 Jan 2019 07:07:16 +0000 (23:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 07:09:51 +0000 (23:09 -0800)
Summary:
Hi!

This PR addresses #15537  issue.
Please review.

Thanks!

Differential Revision: D13649890

Pulled By: soumith

fbshipit-source-id: 166212ae383331345423236dfc4fa2ea907d265d

test/test_dataloader.py
torch/utils/data/sampler.py

index 291b9c0..6d2701d 100644 (file)
@@ -585,6 +585,33 @@ class TestDataLoader(TestCase):
 
         self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
 
+    def test_random_sampler_len_with_replacement(self):
+        from torch.utils.data import RandomSampler
+        # add 5 extra samples
+        num_samples = len(self.dataset) + 5
+        sampler = RandomSampler(self.dataset,
+                                replacement=True,
+                                num_samples=num_samples)
+        # test len method
+        self.assertEqual(num_samples, len(sampler))
+
+        # test with iteration
+        count_num_samples = sum(1 for _ in sampler)
+        self.assertEqual(num_samples, count_num_samples)
+
+        # test with dataloader, batch_size = 1
+        batch_size = 1
+        count_num_samples_in_data_loader = len(DataLoader(
+            self.dataset, batch_size=batch_size, sampler=sampler))
+        self.assertEqual(num_samples, count_num_samples_in_data_loader)
+
+        # test with dataloader, batch_size = 6
+        batch_size = 6
+        count_num_samples_in_data_loader = len(DataLoader(
+            self.dataset, batch_size=batch_size, sampler=sampler))
+        self.assertEqual(int(math.ceil(float(num_samples) / batch_size)),
+                         count_num_samples_in_data_loader)
+
     def test_duplicating_data_with_drop_last(self):
 
         from torch.utils.data.distributed import DistributedSampler
index 37303f0..d0ecdb4 100644 (file)
@@ -50,15 +50,12 @@ class RandomSampler(Sampler):
     def __init__(self, data_source, replacement=False, num_samples=None):
         self.data_source = data_source
         self.replacement = replacement
-        self.num_samples = num_samples
+        self._num_samples = num_samples
 
-        if self.num_samples is not None and replacement is False:
+        if self._num_samples is not None and replacement is False:
             raise ValueError("With replacement=False, num_samples should not be specified, "
                              "since a random permute will be performed.")
 
-        if self.num_samples is None:
-            self.num_samples = len(self.data_source)
-
         if not isinstance(self.num_samples, int) or self.num_samples <= 0:
             raise ValueError("num_samples should be a positive integeral "
                              "value, but got num_samples={}".format(self.num_samples))
@@ -66,6 +63,13 @@ class RandomSampler(Sampler):
             raise ValueError("replacement should be a boolean value, but got "
                              "replacement={}".format(self.replacement))
 
+    @property
+    def num_samples(self):
+        # dataset size might change at runtime
+        if self._num_samples is None:
+            return len(self.data_source)
+        return self._num_samples
+
     def __iter__(self):
         n = len(self.data_source)
         if self.replacement:
@@ -73,7 +77,7 @@ class RandomSampler(Sampler):
         return iter(torch.randperm(n).tolist())
 
     def __len__(self):
-        return len(self.data_source)
+        return self.num_samples
 
 
 class SubsetRandomSampler(Sampler):