Fixed DistributedDataParallel cannot kick off all-reduce in a corner case (#14675)
authorTeng Li <tengli@fb.com>
Mon, 3 Dec 2018 01:10:48 +0000 (17:10 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 01:13:07 +0000 (17:13 -0800)
Summary:
Ok, this corner happens for translation guys, and it only happens in the following corner case:

(1) when the module is registered a parameter that does not requires grad

and

(2) this registered parameter has a unique type (say, double, or half) and it's the only unique type such that itself alone will be put into a separate bucket.

and

(3) it is the last parameter that got registered in the module, such that its bucket reduction is the first to be kicked off.

Once this corner case happens, since it does not require grad, the backward hook won't be kicked off. Now that all other buckets are waiting for its bucket to be kicked off, in this case, no bucket will be kicked off since it's blocked by the first bucket (the unique type parameter).

This PR fixes two things:
(1) Make sure that we will only bucket parameters that requires_grad
(2) Make all-reduction checks in the next iteration. As long as we detect the previous iteration's all-reduction has not been fully kicked off, we will issue an error in the next iteration.
(3) Also removed some unused variables

With this bug fixed, the only case when this error can happen is when the user changed parameters later after wrapping up the module with DDP, like the case in:
https://github.com/pytorch/pytorch/issues/12603

Test covered as well

Without the first fix, I varied that the repro in fbcode hit this error message:

```
result = self.forward(*input, **kwargs)
  File "/data/users/tengli/fbsource/fbcode/buck-out/dev/gen/language_technology/neural_mt/os/pytorch_translate/train#link-tree/torch/nn/parallel/distributed.py", line 312, in forward
    raise RuntimeError("Not all gradients are all-reduced from "
RuntimeError: Not all gradients are all-reduced from the backward of the previous iteration. This is unexpected and fatal error. Please check and ensure that the model's parameters are not changed after you wrap up the model with DistributedDataParallel.

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14675

Differential Revision: D13291083

Pulled By: teng-li

fbshipit-source-id: 2539b699fae843f104b4b8d22721ae82502ba684

test/test_distributed.py
torch/nn/parallel/distributed.py

index e7b63c6..072c2bc 100644 (file)
@@ -52,6 +52,8 @@ class Net(nn.Module):
         self.fc2 = _FC2()
         self.fc3 = nn.Linear(50, 4, bias=False)
         self.relu = nn.ReLU()
+        self.no_grad_param = nn.Parameter(torch.Tensor([2, 2]).long(),
+                                          requires_grad=False)
 
     def forward(self, x):
         x = self.relu(self.fc1(x))
@@ -1227,10 +1229,10 @@ class _DistTestBase(object):
         for p_gpu, p_DDP in zip(param_gpu, param_DDP):
             self.assertEqual(p_gpu, p_DDP)
 
-    def _test_DDP_2iter(
+    def _test_DDP_5iter(
         self, model_base, model_DDP, input, target, loss, local_bs, rank, batch_size
     ):
-        for _ in range(2):
+        for _ in range(5):
             # single cpu/gpu training
             self._test_DDP_helper(model_base, input, target, loss)
 
@@ -1292,8 +1294,8 @@ class _DistTestBase(object):
         local_bs = len(gpu_subset)
         global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
 
-        # check two model parameters over 2 iterations
-        self._test_DDP_2iter(
+        # check two model parameters over 5 iterations
+        self._test_DDP_5iter(
             model_gpu,
             model_DDP,
             input_cpu.cuda(gpu_subset[0]),
@@ -1324,8 +1326,8 @@ class _DistTestBase(object):
         local_bs = 2
         global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
 
-        # check two model parameters over 2 iterations
-        self._test_DDP_2iter(
+        # check two model parameters over 5 iterations
+        self._test_DDP_5iter(
             model_base, model_DDP, input_cpu, target, loss, local_bs, rank, global_bs
         )
         self._barrier()
index 9746d9f..5120238 100644 (file)
@@ -1,8 +1,6 @@
 import copy
 
 import torch
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors, \
-    _take_tensors
 
 from torch.cuda.comm import broadcast_coalesced
 from torch.cuda import nccl
@@ -156,6 +154,15 @@ class DistributedDataParallel(Module):
                        bucket can potentially overlap with backward computation.
                        bucket_cap_mb controls the bucket size in MegaBytes (MB)
                        (default: 25)
+        check_reduction: when setting to True, it enables DistributedDataParallel
+                         to automatically check if the previous iteration's
+                         backward reductions were successfully issued at the
+                         beginning of every iteration's forward function.
+                         You normally don't need this option enabled unless you
+                         are observing weird behaviors such as different ranks
+                         are getting different gradients, which should not
+                         happen if DistributedDataParallel is corrected used.
+                         (default: False)
 
     Attributes:
         module (Module): the module to be parallelized
@@ -166,7 +173,8 @@ class DistributedDataParallel(Module):
     """
     def __init__(self, module, device_ids=None,
                  output_device=None, dim=0, broadcast_buffers=True,
-                 process_group=None, bucket_cap_mb=25):
+                 process_group=None, bucket_cap_mb=25,
+                 check_reduction=False):
 
         super(DistributedDataParallel, self).__init__()
 
@@ -187,6 +195,7 @@ class DistributedDataParallel(Module):
         self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
         self.output_device = _get_device_index(output_device, True)
         self.broadcast_buffers = broadcast_buffers
+        self.check_reduction = check_reduction
 
         MB = 1024 * 1024
 
@@ -224,11 +233,22 @@ class DistributedDataParallel(Module):
 
         # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
         param_buckets = []
+
         # Split the parameters into buckets and by types as well
-        param_buckets = [dist._dist_bucket_tensors(list(m.parameters()),
+        # We only need to bucket and reduce parameters that require grad and
+        # this is also true for backward since only the backward hooks for
+        # parameters that require grad will be registered with gradient
+        # reduction functions
+        params_to_bucket = [[] for _ in self._module_copies]
+        for dev_idx, m in enumerate(self._module_copies):
+            for p in m.parameters():
+                if p.requires_grad:
+                    params_to_bucket[dev_idx].append(p)
+
+        param_buckets = [dist._dist_bucket_tensors(dev_params_to_bucket,
                                                    int(bucket_bytes_cap),
                                                    fine_grained=False)
-                         for m in self._module_copies]
+                         for dev_params_to_bucket in params_to_bucket]
 
         self.bucket_sizes = []
         self.bucket_map = {}
@@ -256,6 +276,11 @@ class DistributedDataParallel(Module):
         # We will always reduce the bucket following the reverse order
         # that is, alway reduces following the order of: n - 1, n - 2, ..., 0
         self.next_bucket = len(self.bucket_sizes) - 1
+        # When all buckets are reduced, this will be set to True. This flag is
+        # useful for sanity checks to ensure that each iteration's backward has
+        # always reduced all buckets
+        self.all_buckets_reduced = False
+        self.check_previous_reduction = False
         self.ready_buckets_not_reduced = set()
         self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
         self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
@@ -272,6 +297,7 @@ class DistributedDataParallel(Module):
     def __setstate__(self, state):
         # If serializable, then the process group should be the default one
         self.process_group = dist.get_default_group()
+        self.check_previous_reduction = False
         super(DistributedDataParallel, self).__setstate__(state)
         self._register_grad_hooks()
 
@@ -290,7 +316,26 @@ class DistributedDataParallel(Module):
                                "init_process_group and have not passed "
                                "process_group argument to DDP constructor")
 
+    def _check_previous_reduction(self):
+        if not self.training:
+            return
+        # self.check_previous_reduction will be False in the first iteration
+        # and is then toggled to True for all future iterations.
+        if self.check_previous_reduction is False:
+            self.check_previous_reduction = True
+        else:
+            if not self.all_buckets_reduced:
+                raise RuntimeError("Not all gradients have been reduced from "
+                                   "the backward of the previous iteration. "
+                                   "This is unexpected and fatal error. Please "
+                                   "check and ensure that the model's "
+                                   "parameters are not changed after you wrap "
+                                   "up the model with DistributedDataParallel.")
+        self.all_buckets_reduced = False
+
     def forward(self, *inputs, **kwargs):
+        if self.check_reduction:
+            self._check_previous_reduction()
         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
         self._sync_params()
         if len(self.device_ids) == 1:
@@ -308,6 +353,7 @@ class DistributedDataParallel(Module):
         return gather(outputs, output_device, dim=self.dim)
 
     def train(self, mode=True):
+        self.check_previous_reduction = False
         super(DistributedDataParallel, self).train(mode)
         for module in self._module_copies[1:]:
             module.train(mode)
@@ -401,6 +447,7 @@ class DistributedDataParallel(Module):
                 if self.next_bucket == -1:
                     # A final sync for all the reduction works
                     self._sync_reduction_works()
+                    self.all_buckets_reduced = True
 
         return distributed_data_parallel_hook