Adds return type annotation for fork_rng function (#63724)
authorAswin Murali <aswinmurali.co@gmail.com>
Fri, 27 Aug 2021 16:02:22 +0000 (09:02 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 16:03:40 +0000 (09:03 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63723

Since it's a generator function the type annotation shall be `Generator`.
![image](https://user-images.githubusercontent.com/47299190/130318830-29ef9529-0daa-463c-90b2-1b11f63ade8a.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63724

Reviewed By: iramazanli

Differential Revision: D30543098

Pulled By: heitorschueroff

fbshipit-source-id: ebdd34749defe1e26c899146786a0357ab4b4b9b

torch/random.py

index d774634..f5156bf 100644 (file)
@@ -1,4 +1,5 @@
 import contextlib
+from typing import Generator
 import warnings
 
 from torch._C import default_generator
@@ -65,7 +66,7 @@ _fork_rng_warned_already = False
 
 
 @contextlib.contextmanager
-def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
+def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices") -> Generator:
     """
     Forks the RNG, so that when you return, the RNG is reset
     to the state that it was previously in.