Make checkpoint_sequential work with multiple arguments (#14278)
authorAndy Chen <andersenchen@fb.com>
Wed, 5 Dec 2018 02:45:45 +0000 (18:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 02:47:43 +0000 (18:47 -0800)
commit33ea7eafefb6a74d4d87b4e02f8d182640051ffc
treebc774a0f61e3f986370a0a62b8166ff3f8836d70
parent3237103624f776016c5445c0b957f0ea6e9a02bd
Make checkpoint_sequential work with multiple arguments (#14278)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278

In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.

We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.

Reviewed By: ezyang

Differential Revision: D13144672

fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7
test/test_utils.py
torch/utils/checkpoint.py