Stashing checkpointing RNG states based on devices of arg tensors (#14518)
authorMichael Carilli <mcarilli@nvidia.com>
Tue, 11 Dec 2018 17:46:25 +0000 (09:46 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 17:48:45 +0000 (09:48 -0800)
commit5d3a347685d826867973fd5f36d6f4c99f6b544b
treeef7a2bab75ba19a96cd7d2d5535c2d198943620c
parent25ddd659c9c4742fbaf2c77c155d359b6aae22fe
Stashing checkpointing RNG states based on devices of arg tensors (#14518)

Summary:
This PR intends to address apaszke's concerns in https://github.com/pytorch/pytorch/pull/14253#issuecomment-441740016.  Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way.

Additionally, the checkpointing function stashes and restores the RNG states of
1. devices associated with all input tensor args to run_fn as well as
2. the current device.

I could easily change this to only save and restore the RNG states associated 1. alone.  This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active.

I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py).  I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14518

Differential Revision: D13356210

Pulled By: ezyang

fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
docs/source/checkpoint.rst
torch/utils/checkpoint.py