Stashing checkpointing RNG states based on devices of arg tensors (#14518)
authorMichael Carilli <mcarilli@nvidia.com>
Tue, 11 Dec 2018 17:46:25 +0000 (09:46 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 17:48:45 +0000 (09:48 -0800)
Summary:
This PR intends to address apaszke's concerns in https://github.com/pytorch/pytorch/pull/14253#issuecomment-441740016.  Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way.

Additionally, the checkpointing function stashes and restores the RNG states of
1. devices associated with all input tensor args to run_fn as well as
2. the current device.

I could easily change this to only save and restore the RNG states associated 1. alone.  This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active.

I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py).  I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14518

Differential Revision: D13356210

Pulled By: ezyang

fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039

docs/source/checkpoint.rst
torch/utils/checkpoint.py

index 4fd4f02..3affd71 100644 (file)
@@ -11,10 +11,18 @@ torch.utils.checkpoint
     compared to non-checkpointed passes.  The logic to stash and restore
     RNG states can incur a moderate performance hit depending on the runtime
     of checkpointed operations.  If deterministic output compared to
-    non-checkpointed passes is not required, set the global flag
-    ``torch.utils.checkpoint.preserve_rng_state=False`` to omit stashing and
+    non-checkpointed passes is not required, supply ``preserve_rng_state=False``
+    to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and
     restoring the RNG state during each checkpoint.
 
+    The stashing logic saves and restores the RNG state for the current device
+    and the device of all cuda Tensor arguments to the ``run_fn``.
+    However, the logic has no way to anticipate if the user will move
+    Tensors to a new device within the ``run_fn`` itself.  Therefore, if you move
+    Tensors to a new device ("new" meaning not belonging to the set of
+    [current device + devices of Tensor arguments]) within ``run_fn``, deterministic
+    output compared to non-checkpointed passes is never guaranteed.
+
 .. currentmodule:: torch.utils.checkpoint
 .. autofunction:: checkpoint
 .. autofunction:: checkpoint_sequential
index 557eabd..18df734 100644 (file)
@@ -21,30 +21,42 @@ def check_backward_validity(inputs):
         warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
 
 
-# Global switch to toggle whether or not checkpointed passes stash and restore
-# the RNG state.  If True, any checkpoints making use of RNG should achieve deterministic
-# output compared to non-checkpointed passes.
-preserve_rng_state = True
+# We can't know if the run_fn will internally move some args to different devices,
+# which would require logic to preserve rng states for those devices as well.
+# We could paranoically stash and restore ALL the rng states for all visible devices,
+# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
+# the device of all Tensor args.
+#
+# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
+def get_device_states(*args):
+    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
+    # the conditionals short-circuit.
+    fwd_gpu_devices = list(set(arg.get_device() for arg in args
+                               if isinstance(arg, torch.Tensor) and arg.is_cuda))
+
+    fwd_gpu_states = []
+    for device in fwd_gpu_devices:
+        with torch.cuda.device(device):
+            fwd_gpu_states.append(torch.cuda.get_rng_state())
+
+    return fwd_gpu_devices, fwd_gpu_states
+
+
+def set_device_states(devices, states):
+    for device, state in zip(devices, states):
+        with torch.cuda.device(device):
+            torch.cuda.set_rng_state(state)
 
 
 class CheckpointFunction(torch.autograd.Function):
 
     @staticmethod
-    def forward(ctx, run_function, *args):
+    def forward(ctx, run_function, preserve_rng_state, *args):
         check_backward_validity(args)
         ctx.run_function = run_function
+        ctx.preserve_rng_state = preserve_rng_state
         if preserve_rng_state:
-            # We can't know if the user will transfer some args from the host
-            # to the device during their run_fn.  Therefore, we stash both
-            # the cpu and cuda rng states unconditionally.
-            #
-            # TODO:
-            # We also can't know if the run_fn will internally move some args to a device
-            # other than the current device, which would require logic to preserve
-            # rng states for those devices as well.  We could paranoically stash and restore
-            # ALL the rng states for all visible devices, but that seems very wasteful for
-            # most cases.
-            ctx.fwd_cpu_rng_state = torch.get_rng_state()
+            ctx.fwd_cpu_state = torch.get_rng_state()
             # Don't eagerly initialize the cuda context by accident.
             # (If the user intends that the context is initialized later, within their
             # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
@@ -52,7 +64,7 @@ class CheckpointFunction(torch.autograd.Function):
             ctx.had_cuda_in_fwd = False
             if torch.cuda._initialized:
                 ctx.had_cuda_in_fwd = True
-                ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
+                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
         ctx.save_for_backward(*args)
         with torch.no_grad():
             outputs = run_function(*args)
@@ -66,12 +78,14 @@ class CheckpointFunction(torch.autograd.Function):
         # Stash the surrounding rng state, and mimic the state that was
         # present at this time during forward.  Restore the surrouding state
         # when we're done.
-        rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
-        with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
-            if preserve_rng_state:
-                torch.set_rng_state(ctx.fwd_cpu_rng_state)
+        rng_devices = []
+        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
+            rng_devices = ctx.fwd_gpu_devices
+        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
+            if ctx.preserve_rng_state:
+                torch.set_rng_state(ctx.fwd_cpu_state)
                 if ctx.had_cuda_in_fwd:
-                    torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
+                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
             detached_inputs = detach_variable(inputs)
             with torch.enable_grad():
                 outputs = ctx.run_function(*detached_inputs)
@@ -79,10 +93,10 @@ class CheckpointFunction(torch.autograd.Function):
         if isinstance(outputs, torch.Tensor):
             outputs = (outputs,)
         torch.autograd.backward(outputs, args)
-        return (None,) + tuple(inp.grad for inp in detached_inputs)
+        return (None, None) + tuple(inp.grad for inp in detached_inputs)
 
 
-def checkpoint(function, *args):
+def checkpoint(function, *args, **kwargs):
     r"""Checkpoint a model or part of the model
 
     Checkpointing works by trading compute for memory. Rather than storing all
@@ -120,15 +134,22 @@ def checkpoint(function, *args):
             passed as the tuple. For example, in LSTM, if user passes
             ``(activation, hidden)``, :attr:`function` should correctly use the
             first input as ``activation`` and the second input as ``hidden``
+        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
+            the RNG state during each checkpoint.
         args: tuple containing inputs to the :attr:`function`
 
     Returns:
         Output of running :attr:`function` on :attr:`*args`
     """
-    return CheckpointFunction.apply(function, *args)
+    # Hack to mix *args with **kwargs in a python 2.7-compliant way
+    preserve = kwargs.pop('preserve_rng_state', True)
+    if kwargs:
+        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
+
+    return CheckpointFunction.apply(function, preserve, *args)
 
 
-def checkpoint_sequential(functions, segments, *inputs):
+def checkpoint_sequential(functions, segments, *inputs, **kwargs):
     r"""A helper function for checkpointing sequential models.
 
     Sequential models execute a list of modules/functions in order
@@ -154,6 +175,8 @@ def checkpoint_sequential(functions, segments, *inputs):
             functions (comprising the model) to run sequentially.
         segments: Number of chunks to create in the model
         inputs: tuple of Tensors that are inputs to :attr:`functions`
+        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
+            the RNG state during each checkpoint.
 
     Returns:
         Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -162,6 +185,10 @@ def checkpoint_sequential(functions, segments, *inputs):
         >>> model = nn.Sequential(...)
         >>> input_var = checkpoint_sequential(model, chunks, input_var)
     """
+    # Hack to mix *args with **kwargs in a python 2.7-compliant way
+    preserve = kwargs.pop('preserve_rng_state', True)
+    if kwargs:
+        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
 
     def run_function(start, end, functions):
         def forward(*inputs):
@@ -181,7 +208,8 @@ def checkpoint_sequential(functions, segments, *inputs):
     end = -1
     for start in range(0, segment_size * (segments - 1), segment_size):
         end = start + segment_size - 1
-        inputs = checkpoint(run_function(start, end, functions), *inputs)
+        inputs = checkpoint(run_function(start, end, functions), *inputs,
+                            preserve_rng_state=preserve)
         if not isinstance(inputs, tuple):
             inputs = (inputs,)
     return run_function(end + 1, len(functions) - 1, functions)(*inputs)