self.fc2 = _FC2()
self.fc3 = nn.Linear(50, 4, bias=False)
self.relu = nn.ReLU()
+ self.no_grad_param = nn.Parameter(torch.Tensor([2, 2]).long(),
+ requires_grad=False)
def forward(self, x):
x = self.relu(self.fc1(x))
for p_gpu, p_DDP in zip(param_gpu, param_DDP):
self.assertEqual(p_gpu, p_DDP)
- def _test_DDP_2iter(
+ def _test_DDP_5iter(
self, model_base, model_DDP, input, target, loss, local_bs, rank, batch_size
):
- for _ in range(2):
+ for _ in range(5):
# single cpu/gpu training
self._test_DDP_helper(model_base, input, target, loss)
local_bs = len(gpu_subset)
global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
- # check two model parameters over 2 iterations
- self._test_DDP_2iter(
+ # check two model parameters over 5 iterations
+ self._test_DDP_5iter(
model_gpu,
model_DDP,
input_cpu.cuda(gpu_subset[0]),
local_bs = 2
global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
- # check two model parameters over 2 iterations
- self._test_DDP_2iter(
+ # check two model parameters over 5 iterations
+ self._test_DDP_5iter(
model_base, model_DDP, input_cpu, target, loss, local_bs, rank, global_bs
)
self._barrier()
import copy
import torch
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors, \
- _take_tensors
from torch.cuda.comm import broadcast_coalesced
from torch.cuda import nccl
bucket can potentially overlap with backward computation.
bucket_cap_mb controls the bucket size in MegaBytes (MB)
(default: 25)
+ check_reduction: when setting to True, it enables DistributedDataParallel
+ to automatically check if the previous iteration's
+ backward reductions were successfully issued at the
+ beginning of every iteration's forward function.
+ You normally don't need this option enabled unless you
+ are observing weird behaviors such as different ranks
+ are getting different gradients, which should not
+ happen if DistributedDataParallel is corrected used.
+ (default: False)
Attributes:
module (Module): the module to be parallelized
"""
def __init__(self, module, device_ids=None,
output_device=None, dim=0, broadcast_buffers=True,
- process_group=None, bucket_cap_mb=25):
+ process_group=None, bucket_cap_mb=25,
+ check_reduction=False):
super(DistributedDataParallel, self).__init__()
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.broadcast_buffers = broadcast_buffers
+ self.check_reduction = check_reduction
MB = 1024 * 1024
# This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
param_buckets = []
+
# Split the parameters into buckets and by types as well
- param_buckets = [dist._dist_bucket_tensors(list(m.parameters()),
+ # We only need to bucket and reduce parameters that require grad and
+ # this is also true for backward since only the backward hooks for
+ # parameters that require grad will be registered with gradient
+ # reduction functions
+ params_to_bucket = [[] for _ in self._module_copies]
+ for dev_idx, m in enumerate(self._module_copies):
+ for p in m.parameters():
+ if p.requires_grad:
+ params_to_bucket[dev_idx].append(p)
+
+ param_buckets = [dist._dist_bucket_tensors(dev_params_to_bucket,
int(bucket_bytes_cap),
fine_grained=False)
- for m in self._module_copies]
+ for dev_params_to_bucket in params_to_bucket]
self.bucket_sizes = []
self.bucket_map = {}
# We will always reduce the bucket following the reverse order
# that is, alway reduces following the order of: n - 1, n - 2, ..., 0
self.next_bucket = len(self.bucket_sizes) - 1
+ # When all buckets are reduced, this will be set to True. This flag is
+ # useful for sanity checks to ensure that each iteration's backward has
+ # always reduced all buckets
+ self.all_buckets_reduced = False
+ self.check_previous_reduction = False
self.ready_buckets_not_reduced = set()
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
def __setstate__(self, state):
# If serializable, then the process group should be the default one
self.process_group = dist.get_default_group()
+ self.check_previous_reduction = False
super(DistributedDataParallel, self).__setstate__(state)
self._register_grad_hooks()
"init_process_group and have not passed "
"process_group argument to DDP constructor")
+ def _check_previous_reduction(self):
+ if not self.training:
+ return
+ # self.check_previous_reduction will be False in the first iteration
+ # and is then toggled to True for all future iterations.
+ if self.check_previous_reduction is False:
+ self.check_previous_reduction = True
+ else:
+ if not self.all_buckets_reduced:
+ raise RuntimeError("Not all gradients have been reduced from "
+ "the backward of the previous iteration. "
+ "This is unexpected and fatal error. Please "
+ "check and ensure that the model's "
+ "parameters are not changed after you wrap "
+ "up the model with DistributedDataParallel.")
+ self.all_buckets_reduced = False
+
def forward(self, *inputs, **kwargs):
+ if self.check_reduction:
+ self._check_previous_reduction()
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
self._sync_params()
if len(self.device_ids) == 1:
return gather(outputs, output_device, dim=self.dim)
def train(self, mode=True):
+ self.check_previous_reduction = False
super(DistributedDataParallel, self).train(mode)
for module in self._module_copies[1:]:
module.train(mode)
if self.next_bucket == -1:
# A final sync for all the reduction works
self._sync_reduction_works()
+ self.all_buckets_reduced = True
return distributed_data_parallel_hook