self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
+ def test_random_sampler_len_with_replacement(self):
+ from torch.utils.data import RandomSampler
+ # add 5 extra samples
+ num_samples = len(self.dataset) + 5
+ sampler = RandomSampler(self.dataset,
+ replacement=True,
+ num_samples=num_samples)
+ # test len method
+ self.assertEqual(num_samples, len(sampler))
+
+ # test with iteration
+ count_num_samples = sum(1 for _ in sampler)
+ self.assertEqual(num_samples, count_num_samples)
+
+ # test with dataloader, batch_size = 1
+ batch_size = 1
+ count_num_samples_in_data_loader = len(DataLoader(
+ self.dataset, batch_size=batch_size, sampler=sampler))
+ self.assertEqual(num_samples, count_num_samples_in_data_loader)
+
+ # test with dataloader, batch_size = 6
+ batch_size = 6
+ count_num_samples_in_data_loader = len(DataLoader(
+ self.dataset, batch_size=batch_size, sampler=sampler))
+ self.assertEqual(int(math.ceil(float(num_samples) / batch_size)),
+ count_num_samples_in_data_loader)
+
def test_duplicating_data_with_drop_last(self):
from torch.utils.data.distributed import DistributedSampler
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
- self.num_samples = num_samples
+ self._num_samples = num_samples
- if self.num_samples is not None and replacement is False:
+ if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
- if self.num_samples is None:
- self.num_samples = len(self.data_source)
-
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integeral "
"value, but got num_samples={}".format(self.num_samples))
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
+ @property
+ def num_samples(self):
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randperm(n).tolist())
def __len__(self):
- return len(self.data_source)
+ return self.num_samples
class SubsetRandomSampler(Sampler):