def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
- generator = torch.Generator()
- generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
- else:
- generator = self.generator
+ self.generator = torch.Generator()
+ self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
+
if self.replacement:
for _ in range(self.num_samples // 32):
- yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
- yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
else:
- yield from torch.randperm(n, generator=generator).tolist()
+ yield from torch.randperm(n, generator=self.generator).tolist()
def __len__(self) -> int:
return self.num_samples