warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
-# Global switch to toggle whether or not checkpointed passes stash and restore
-# the RNG state. If True, any checkpoints making use of RNG should achieve deterministic
-# output compared to non-checkpointed passes.
-preserve_rng_state = True
+# We can't know if the run_fn will internally move some args to different devices,
+# which would require logic to preserve rng states for those devices as well.
+# We could paranoically stash and restore ALL the rng states for all visible devices,
+# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
+# the device of all Tensor args.
+#
+# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
+def get_device_states(*args):
+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
+ # the conditionals short-circuit.
+ fwd_gpu_devices = list(set(arg.get_device() for arg in args
+ if isinstance(arg, torch.Tensor) and arg.is_cuda))
+
+ fwd_gpu_states = []
+ for device in fwd_gpu_devices:
+ with torch.cuda.device(device):
+ fwd_gpu_states.append(torch.cuda.get_rng_state())
+
+ return fwd_gpu_devices, fwd_gpu_states
+
+
+def set_device_states(devices, states):
+ for device, state in zip(devices, states):
+ with torch.cuda.device(device):
+ torch.cuda.set_rng_state(state)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
- def forward(ctx, run_function, *args):
+ def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
+ ctx.preserve_rng_state = preserve_rng_state
if preserve_rng_state:
- # We can't know if the user will transfer some args from the host
- # to the device during their run_fn. Therefore, we stash both
- # the cpu and cuda rng states unconditionally.
- #
- # TODO:
- # We also can't know if the run_fn will internally move some args to a device
- # other than the current device, which would require logic to preserve
- # rng states for those devices as well. We could paranoically stash and restore
- # ALL the rng states for all visible devices, but that seems very wasteful for
- # most cases.
- ctx.fwd_cpu_rng_state = torch.get_rng_state()
+ ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
- ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
+ ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrouding state
# when we're done.
- rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
- with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
- if preserve_rng_state:
- torch.set_rng_state(ctx.fwd_cpu_rng_state)
+ rng_devices = []
+ if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
+ rng_devices = ctx.fwd_gpu_devices
+ with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
+ if ctx.preserve_rng_state:
+ torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
- torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
+ set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
- return (None,) + tuple(inp.grad for inp in detached_inputs)
+ return (None, None) + tuple(inp.grad for inp in detached_inputs)
-def checkpoint(function, *args):
+def checkpoint(function, *args, **kwargs):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
+ the RNG state during each checkpoint.
args: tuple containing inputs to the :attr:`function`
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
- return CheckpointFunction.apply(function, *args)
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
+ preserve = kwargs.pop('preserve_rng_state', True)
+ if kwargs:
+ raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
+
+ return CheckpointFunction.apply(function, preserve, *args)
-def checkpoint_sequential(functions, segments, *inputs):
+def checkpoint_sequential(functions, segments, *inputs, **kwargs):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
functions (comprising the model) to run sequentially.
segments: Number of chunks to create in the model
inputs: tuple of Tensors that are inputs to :attr:`functions`
+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
+ the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
"""
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
+ preserve = kwargs.pop('preserve_rng_state', True)
+ if kwargs:
+ raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
def run_function(start, end, functions):
def forward(*inputs):
end = -1
for start in range(0, segment_size * (segments - 1), segment_size):
end = start + segment_size - 1
- inputs = checkpoint(run_function(start, end, functions), *inputs)
+ inputs = checkpoint(run_function(start, end, functions), *inputs,
+ preserve_rng_state=preserve)
if not isinstance(inputs, tuple):
inputs = (inputs,)
return run_function(end + 1, len(functions) - 1, functions)(*inputs)