Update cuda.get/set_rng_state doc (#14324)
authorSsnL <tongzhou.wang.1994@gmail.com>
Thu, 27 Dec 2018 22:06:23 +0000 (14:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 27 Dec 2018 22:09:25 +0000 (14:09 -0800)
Summary:
Now that `cuda.get/set_rng_state` accept `device` objects, the default value should be an device object, and doc should mention so.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14324

Reviewed By: ezyang

Differential Revision: D13528707

Pulled By: soumith

fbshipit-source-id: 32fdac467dfea6d5b96b7e2a42dc8cfd42ba11ee

docs/source/cuda.rst
torch/_torch_docs.py
torch/cuda/random.py

index b65c64f..6da20ce 100644 (file)
@@ -9,7 +9,9 @@ torch.cuda
 Random Number Generator
 -------------------------
 .. autofunction:: get_rng_state
+.. autofunction:: get_rng_state_all
 .. autofunction:: set_rng_state
+.. autofunction:: set_rng_state_all
 .. autofunction:: manual_seed
 .. autofunction:: manual_seed_all
 .. autofunction:: seed
index 47bb873..cd5c4a8 100644 (file)
@@ -5522,7 +5522,7 @@ full_like(input, fill_value, out=None, dtype=None, layout=torch.strided, device=
 
 Returns a tensor with the same size as :attr:`input` filled with :attr:`fill_value`.
 ``torch.full_like(input, fill_value)`` is equivalent to
-``torch.full_like(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``.
+``torch.full(input.size(), fill_value, dtype=input.dtype, layout=input.layout, device=input.device)``.
 
 Args:
     {input}
index aece89c..c9a2a63 100644 (file)
@@ -1,14 +1,19 @@
-from torch import _C
+from torch import _C, device
 from . import _lazy_init, _lazy_call, device_count, device as device_ctx_manager
 
+__all__ = ['get_rng_state', 'get_rng_state_all',
+           'set_rng_state', 'set_rng_state_all',
+           'manual_seed', 'manual_seed_all',
+           'seed', 'seed_all', 'initial_seed']
 
-def get_rng_state(device=-1):
+
+def get_rng_state(device=device('cuda')):
     r"""Returns the random number generator state of the current
     GPU as a ByteTensor.
 
     Args:
-        device (int, optional): The device to return the RNG state of.
-            Default: -1 (i.e., use the current device).
+        device (torch.device or int, optional): The device to return the RNG state of.
+            Default: ``torch.device('cuda')`` (i.e., the current CUDA device).
 
     .. warning::
         This function eagerly initializes CUDA.
@@ -28,11 +33,13 @@ def get_rng_state_all():
     return results
 
 
-def set_rng_state(new_state, device=-1):
+def set_rng_state(new_state, device=device('cuda')):
     r"""Sets the random number generator state of the current GPU.
 
     Args:
         new_state (torch.ByteTensor): The desired state
+        device (torch.device or int, optional): The device to set the RNG state.
+            Default: ``torch.device('cuda')`` (i.e., the current CUDA device).
     """
     new_state_copy = new_state.clone()