From dc5ce22a1a8bd075141f7310fa6cb503db4d6e08 Mon Sep 17 00:00:00 2001 From: MY_ Date: Mon, 16 Aug 2021 14:07:06 -0700 Subject: [PATCH] A re-open PR: Avoid re-creating the random number generator in RandomSampler (#63026) Summary: More details can be found in the old pr: https://github.com/pytorch/pytorch/pull/53085 ejguan Thanks for your guidance. I tried to reopen this PR following your instructions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63026 Reviewed By: anjali411 Differential Revision: D30224920 Pulled By: ejguan fbshipit-source-id: 2fa83bd4a2661485e553447fe3e57ce723f2716d --- torch/utils/data/sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 894f775..7903347 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -112,16 +112,15 @@ class RandomSampler(Sampler[int]): def __iter__(self) -> Iterator[int]: n = len(self.data_source) if self.generator is None: - generator = torch.Generator() - generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) - else: - generator = self.generator + self.generator = torch.Generator() + self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist() else: - yield from torch.randperm(n, generator=generator).tolist() + yield from torch.randperm(n, generator=self.generator).tolist() def __len__(self) -> int: return self.num_samples -- 2.7.4