Fixed DistributedDataParallel cannot kick off all-reduce in a corner case (#14675)
Summary:
Ok, this corner happens for translation guys, and it only happens in the following corner case:
(1) when the module is registered a parameter that does not requires grad
and
(2) this registered parameter has a unique type (say, double, or half) and it's the only unique type such that itself alone will be put into a separate bucket.
and
(3) it is the last parameter that got registered in the module, such that its bucket reduction is the first to be kicked off.
Once this corner case happens, since it does not require grad, the backward hook won't be kicked off. Now that all other buckets are waiting for its bucket to be kicked off, in this case, no bucket will be kicked off since it's blocked by the first bucket (the unique type parameter).
This PR fixes two things:
(1) Make sure that we will only bucket parameters that requires_grad
(2) Make all-reduction checks in the next iteration. As long as we detect the previous iteration's all-reduction has not been fully kicked off, we will issue an error in the next iteration.
(3) Also removed some unused variables
With this bug fixed, the only case when this error can happen is when the user changed parameters later after wrapping up the module with DDP, like the case in:
https://github.com/pytorch/pytorch/issues/12603
Test covered as well
Without the first fix, I varied that the repro in fbcode hit this error message:
```
result = self.forward(*input, **kwargs)
File "/data/users/tengli/fbsource/fbcode/buck-out/dev/gen/language_technology/neural_mt/os/pytorch_translate/train#link-tree/torch/nn/parallel/distributed.py", line 312, in forward
raise RuntimeError("Not all gradients are all-reduced from "
RuntimeError: Not all gradients are all-reduced from the backward of the previous iteration. This is unexpected and fatal error. Please check and ensure that the model's parameters are not changed after you wrap up the model with DistributedDataParallel.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14675
Differential Revision:
D13291083
Pulled By: teng-li
fbshipit-source-id:
2539b699fae843f104b4b8d22721ae82502ba684