A re-open PR: Avoid re-creating the random number generator in RandomSampler (#63026)
authorMY_ <lartpang@163.com>
Mon, 16 Aug 2021 21:07:06 +0000 (14:07 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 21:08:37 +0000 (14:08 -0700)
Summary:
More details can be found in the old pr: https://github.com/pytorch/pytorch/pull/53085

ejguan  Thanks for your guidance. I tried to reopen this PR following your instructions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63026

Reviewed By: anjali411

Differential Revision: D30224920

Pulled By: ejguan

fbshipit-source-id: 2fa83bd4a2661485e553447fe3e57ce723f2716d

torch/utils/data/sampler.py

index 894f775..7903347 100644 (file)
@@ -112,16 +112,15 @@ class RandomSampler(Sampler[int]):
     def __iter__(self) -> Iterator[int]:
         n = len(self.data_source)
         if self.generator is None:
-            generator = torch.Generator()
-            generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
-        else:
-            generator = self.generator
+            self.generator = torch.Generator()
+            self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
+
         if self.replacement:
             for _ in range(self.num_samples // 32):
-                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
-            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
+                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
+            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
         else:
-            yield from torch.randperm(n, generator=generator).tolist()
+            yield from torch.randperm(n, generator=self.generator).tolist()
 
     def __len__(self) -> int:
         return self.num_samples