[Pipe] Add a `WithDevice` wrapper to specify device execution for a module. (#65190)
authorPritam Damania <pritam.damania@fb.com>
Mon, 20 Sep 2021 17:39:08 +0000 (10:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 20 Sep 2021 18:27:27 +0000 (11:27 -0700)
commit3e64c9e17665fc1b0ce32fef85310304c60dc5bc
tree7601c92e2fc16a7194ac93cf86eedda2a23c3253
parent0a3cf8886a68e9a965f669823e640cd95ca2692e
[Pipe] Add a `WithDevice` wrapper to specify device execution for a module. (#65190)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65190

As described in https://github.com/pytorch/pytorch/issues/65093, there
could be modules which don't have any parameters/buffers. In this case, Pipe
determines that the module should be executed on CPU. However this might result
in unnecessary GPU to CPU transfers whereas the user expected the module to be
executed on the GPU itself by keeping its inputs and outputs on GPU.

For this use case, we introduce a `WithDevice` wrapper which can be used to
override which device a particular module should be executed on as part of the
pipeline.

#Closes: https://github.com/pytorch/pytorch/issues/65093
ghstack-source-id: 138376272

Test Plan:
1) waitforbuildbot
2) unit tests

Reviewed By: SciPioneer

Differential Revision: D31010027

fbshipit-source-id: 4c1c61d3c6feeef341e002e5f7e83dd33ff3a516
test/distributed/pipeline/sync/test_pipe.py
torch/distributed/pipeline/sync/__init__.py
torch/distributed/pipeline/sync/pipe.py