From: MY_ Date: Mon, 16 Aug 2021 21:07:06 +0000 (-0700) Subject: A re-open PR: Avoid re-creating the random number generator in RandomSampler (#63026) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~990 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dc5ce22a1a8bd075141f7310fa6cb503db4d6e08;p=platform%2Fupstream%2Fpytorch.git A re-open PR: Avoid re-creating the random number generator in RandomSampler (#63026) Summary: More details can be found in the old pr: https://github.com/pytorch/pytorch/pull/53085 ejguan Thanks for your guidance. I tried to reopen this PR following your instructions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63026 Reviewed By: anjali411 Differential Revision: D30224920 Pulled By: ejguan fbshipit-source-id: 2fa83bd4a2661485e553447fe3e57ce723f2716d --- diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 894f775..7903347 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -112,16 +112,15 @@ class RandomSampler(Sampler[int]): def __iter__(self) -> Iterator[int]: n = len(self.data_source) if self.generator is None: - generator = torch.Generator() - generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) - else: - generator = self.generator + self.generator = torch.Generator() + self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist() else: - yield from torch.randperm(n, generator=generator).tolist() + yield from torch.randperm(n, generator=self.generator).tolist() def __len__(self) -> int: return self.num_samples