Recursively find tensors in DDP module output (#19360)
authorPieter Noordhuis <pietern@fb.com>
Thu, 18 Apr 2019 21:51:37 +0000 (14:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 21:57:09 +0000 (14:57 -0700)
commita5c4348d54cc7736e5574ac56d890695197339ed
treedbc70111fa462387f0f314050d96c6acceff8b06
parent17f05ad5e562830127dd06be6c11c10453a86d0e
Recursively find tensors in DDP module output (#19360)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19360

We'll return the output object verbatim since it is a freeform object.
We need to find any tensors in this object, though, because we need to
figure out which parameters were used during this forward pass, to
ensure we short circuit reduction for any unused parameters.

Before this commit only lists were handled and the functionality went
untested. This commit adds support for dicts and recursive structures,
and also adds a test case.

Closes #19354.

Reviewed By: mrshenli

Differential Revision: D14978016

fbshipit-source-id: 4bb6999520871fb6a9e4561608afa64d55f4f3a8
test/test_c10d.py
torch/csrc/distributed/c10d/reducer.cpp
torch/nn/parallel/distributed.py