From a0263ec04765fb9b20decc178241d2668ea58cee Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Mon, 15 Apr 2019 12:24:43 -0700 Subject: [PATCH] Make DistributedDataParallel use new reducer (#18953) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18953 This removes Python side bucketing code from DistributedDataParallel and replaces it with calls to the new C++ based bucketing and reducing code. To confirm this is working well, we ran a test with both the previous implementation and the new implementation, and confirmed they are numerically equivalent. Performance is improved by a couple percent or more, including the single machine multiple GPU runs. Closes #13273. Reviewed By: mrshenli Differential Revision: D14580911 fbshipit-source-id: 44e76f8b0b7e58dd6c91644e3df4660ca2ee4ae2 --- test/test_c10d.py | 46 +++++++ torch/csrc/distributed/c10d/init.cpp | 7 + torch/csrc/distributed/c10d/reducer.cpp | 110 +++++++++++++++ torch/csrc/distributed/c10d/reducer.h | 4 + torch/nn/parallel/distributed.py | 234 ++++++++------------------------ 5 files changed, 221 insertions(+), 180 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index f06dd88..0d1bc22 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -1934,6 +1934,52 @@ class ReducerTest(TestCase): optimizer.step() +class ComputeBucketAssignmentTest(TestCase): + def test_single_limit_single_dtype(self): + tensors = [ + torch.empty([100], dtype=torch.float), + torch.empty([200], dtype=torch.float), + torch.empty([100], dtype=torch.float), + torch.empty([50], dtype=torch.float), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [400]) + self.assertEqual([[0], [1], [2], [3]], result) + + def test_single_limit_multi_dtype(self): + tensors = [ + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [400]) + self.assertEqual([[0, 2], [1, 3], [4], [5]], result) + + def test_multi_limit_single_dtype(self): + tensors = [ + torch.empty([10], dtype=torch.float), + torch.empty([10], dtype=torch.float), + torch.empty([10], dtype=torch.float), + torch.empty([10], dtype=torch.float), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [40, 80]) + self.assertEqual([[0], [1, 2], [3]], result) + + def test_multi_limit_multi_dtype(self): + tensors = [ + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + torch.empty([50], dtype=torch.float), + torch.empty([25], dtype=torch.double), + ] + result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) + self.assertEqual([[0], [1], [2, 4], [3, 5]], result) + + if __name__ == '__main__': assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process" diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a610fc7..e505dba 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -500,6 +500,13 @@ They are used in specifying strategies for reduction collectives, e.g., py::call_guard()); #endif + module.def( + "_compute_bucket_assignment_by_size", + &::c10d::compute_bucket_assignment_by_size, + py::arg("tensors"), + py::arg("bucket_size"), + py::call_guard()); + Py_RETURN_TRUE; } diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 166090a..a7b666c 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace c10d { @@ -361,6 +362,13 @@ void Reducer::prepare_for_backward( bucket.pending = bucket.replicas.size(); } + // If no outputs are specified, we assume that autograd hooks for ALL + // variables will be called, and we don't have to search the autograd graph + // for presence of these hooks. + if (outputs.empty()) { + return; + } + // Seed queue with the grad functions of all outputs. for (const auto& output : outputs) { const auto& grad_fn = output.grad_fn(); @@ -433,4 +441,106 @@ void Reducer::finalize_backward() { } } +namespace { + +// Tensors may be coalesced into buckets. Buckets must contain tensors of +// the same type, on the same device, so a bucket can identified by a +// composite key of a tensor's type identifier and its device. +struct BucketKey { + BucketKey(c10::ScalarType type, c10::Device device) + : type(std::move(type)), device(std::move(device)) {} + + const c10::ScalarType type; + const c10::Device device; + + // See torch/csrc/utils/hash.h for dispatch code. + static size_t hash(const BucketKey& key) { + return torch::get_hash(key.type, key.device); + } +}; + +inline bool operator==(const BucketKey& lhs, const BucketKey& rhs) { + return lhs.type == rhs.type && lhs.device == rhs.device; +} + +} // namespace + +// This is equivalent to take_tensors but returns indices into the +// tensor list argument for bucket assignment. Also, it is aware +// of device placement and will not allow buckets to span devices. +std::vector> compute_bucket_assignment_by_size( + const std::vector& tensors, + std::vector bucket_size_limits) { + std::vector> result; + result.reserve(tensors.size()); + + // Keep iterator into the size_limit vector by tensor type and device. + // This is done so that we can use the consecutive bucket limits per type. + std::unordered_map< + BucketKey, + std::vector::iterator, + torch::hash> + bucket_size_limit_iterators; + + // Local accumulator type for a single bucket. + struct BucketAccumulator { + std::vector indices; + size_t size = 0; + }; + + // Keep vector of indices and size accumulator by tensor type and device. + std::unordered_map> + buckets; + + for (size_t i = 0; i < tensors.size(); i++) { + const auto& tensor = tensors[i]; + AT_ASSERTM(!tensor.is_sparse(), "No support for sparse tensors."); + auto key = BucketKey(tensor.scalar_type(), tensor.device()); + auto& bucket = buckets[key]; + bucket.indices.push_back(i); + bucket.size += tensor.numel() * tensor.element_size(); + + // Initialize bucket size limit iterator if necessary. + if (bucket_size_limit_iterators.count(key) == 0) { + bucket_size_limit_iterators[key] = bucket_size_limits.begin(); + } + + auto& bucket_size_limit_iterator = bucket_size_limit_iterators[key]; + const auto bucket_size_limit = *bucket_size_limit_iterator; + if (bucket.size >= bucket_size_limit) { + result.emplace_back(std::move(bucket.indices)); + bucket = BucketAccumulator(); + + // Advance to the next bucket size limit for this type/device. + auto next = bucket_size_limit_iterator + 1; + if (next != bucket_size_limits.end()) { + bucket_size_limit_iterator = next; + } + } + } + + // Add remaining buckets. + for (auto& it : buckets) { + auto& bucket = it.second; + if (!bucket.indices.empty()) { + result.emplace_back(std::move(bucket.indices)); + } + } + + // Sort resulting buckets by the minimum tensor index they include. + // We assume that the order of the tensors is the order in which they are + // used (or the reverse order in which their gradients are produced). + // This sorting step ensures that the buckets are ready in consecutive order. + std::sort( + result.begin(), + result.end(), + [](const std::vector& a, const std::vector& b) { + const auto amin = std::min_element(a.begin(), a.end()); + const auto bmin = std::min_element(b.begin(), b.end()); + return *amin < *bmin; + }); + + return result; +} + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h index f1d5ba3..4010582 100644 --- a/torch/csrc/distributed/c10d/reducer.h +++ b/torch/csrc/distributed/c10d/reducer.h @@ -139,4 +139,8 @@ class Reducer { std::vector> backward_stats_; }; +std::vector> compute_bucket_assignment_by_size( + const std::vector& tensors, + std::vector bucket_size); + } // namespace c10d diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index e4878f9..ddca4a6 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -209,15 +209,19 @@ class DistributedDataParallel(Module): self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers - self.check_reduction = check_reduction + if check_reduction: + # This argument is no longer used since the reducer + # will ensure reduction completes even if some parameters + # do not receive gradients. + pass MB = 1024 * 1024 # used for intra-node param sync and inter-node sync as well - self.broadcast_bucket_size = 250 * MB + self.broadcast_bucket_size = int(250 * MB) # reduction bucket size - self.bucket_bytes_cap = bucket_cap_mb * MB + self.bucket_bytes_cap = int(bucket_cap_mb * MB) # Sync params and buffers module_states = list(self.module.state_dict().values()) @@ -254,60 +258,26 @@ class DistributedDataParallel(Module): self.modules_params = [list(m.parameters()) for m in self._module_copies] self.modules_buffers = [list(m.buffers()) for m in self._module_copies] - # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems - param_buckets = [] - - # Split the parameters into buckets and by types as well - # We only need to bucket and reduce parameters that require grad and - # this is also true for backward since only the backward hooks for - # parameters that require grad will be registered with gradient - # reduction functions - params_to_bucket = [[] for _ in self._module_copies] - for dev_idx, m in enumerate(self._module_copies): - for p in m.parameters(): - if p.requires_grad: - params_to_bucket[dev_idx].append(p) - - param_buckets = [dist._dist_bucket_tensors(dev_params_to_bucket, - int(self.bucket_bytes_cap), - fine_grained=False) - for dev_params_to_bucket in params_to_bucket] - - self.bucket_sizes = [] - self.bucket_map = {} - - # We transpose param_buckets, so the loop is over buckets. - # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems - for bucket_idx, param_buckets_tuple in enumerate(zip(*param_buckets)): - self.bucket_sizes.append(0) - # Now, we transpose again, so we iterate over bucket_elems, but getting tuples - # of params from each device. - for param_tuple in zip(*param_buckets_tuple): - if not param_tuple[0].requires_grad: - continue - for p in param_tuple: - self.bucket_map[p] = (bucket_idx, self.bucket_sizes[bucket_idx]) - self.bucket_sizes[bucket_idx] += 1 - - self.buckets = [[[None for _ in range(self.bucket_sizes[i])] - for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] - # The number of params ready in each bucket - self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] - - # coalesced bucket for only device 0 - self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))] - # We will always reduce the bucket following the reverse order - # that is, alway reduces following the order of: n - 1, n - 2, ..., 0 - self.next_bucket = len(self.bucket_sizes) - 1 - # When all buckets are reduced, this will be set to True. This flag is - # useful for sanity checks to ensure that each iteration's backward has - # always reduced all buckets - self.all_buckets_reduced = False - self.check_previous_reduction = False - self.ready_buckets_not_reduced = set() - self.reduction_works = [None for _ in range(len(self.bucket_sizes))] - self.devs_ready = [0 for _ in range(len(self.bucket_sizes))] - self._register_grad_hooks() + param_list = [ + list(filter(lambda p: p.requires_grad, module.parameters())) + for module in self._module_copies] + + # The bucket size limit is specified in the constructor. + # Additionally, we allow for a single small bucket for parameters + # that are defined first, such that their gradients don't spill into + # a much larger bucket, adding unnecessary latency after gradient + # computation finishes. Experiments showed 1MB is a reasonable value. + bucket_indices = dist._compute_bucket_assignment_by_size( + param_list[0], + [1024 * 1024, self.bucket_bytes_cap]) + + # Note: reverse list of buckets because we want to approximate the + # order in which their gradients are produced, and assume they + # are used in the forward pass in the order they are defined. + self.reducer = dist.Reducer( + param_list, + list(reversed(bucket_indices)), + self.process_group) # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self._module_copies) @@ -315,15 +285,13 @@ class DistributedDataParallel(Module): def __getstate__(self): self._check_default_group() attrs = copy.copy(self.__dict__) - del attrs['process_group'], \ - attrs['default_streams'], \ - attrs['_grad_accs'] + del attrs['process_group'] + del attrs['reducer'] return attrs def __setstate__(self, state): # If serializable, then the process group should be the default one self.process_group = _get_default_group() - self.check_previous_reduction = False super(DistributedDataParallel, self).__setstate__(state) self._ddp_init_helper() @@ -342,32 +310,28 @@ class DistributedDataParallel(Module): "init_process_group and have not passed " "process_group argument to DDP constructor") - def _check_previous_reduction(self): - if not self.training: - return - # self.check_previous_reduction will be False in the first iteration - # and is then toggled to True for all future iterations. - if self.check_previous_reduction is False: - self.check_previous_reduction = True - else: - if not self.all_buckets_reduced: - raise RuntimeError("Not all gradients have been reduced from " - "the backward of the previous iteration. " - "This is an unexpected and fatal error. " - "Please check and ensure that the model's " - "parameters are not changed after you wrap " - "up the model with DistributedDataParallel.") - self.all_buckets_reduced = False - def forward(self, *inputs, **kwargs): - if self.check_reduction: - self._check_previous_reduction() inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) self._sync_params() if len(self.device_ids) == 1: - return self.module(*inputs[0], **kwargs[0]) - outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) - return self.gather(outputs, self.output_device) + output = self.module(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + + # We'll return the output object verbatim since it is a freeform object. + # We need to find any tensors in this object, though, because we need to + # figure out which parameters were used during this forward pass, + # to ensure we short circuit reduction for any unused parameters. + output_tensors = [] + if isinstance(output, torch.Tensor): + output_tensors = [output] + if isinstance(output, (list, tuple)): + def istensor(obj): + return isinstance(obj, torch.Tensor) + output_tensors = list(filter(istensor, output)) + self.reducer.prepare_for_backward(output_tensors) + return output def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) @@ -379,7 +343,6 @@ class DistributedDataParallel(Module): return gather(outputs, output_device, dim=self.dim) def train(self, mode=True): - self.check_previous_reduction = False super(DistributedDataParallel, self).train(mode) for module in self._module_copies[1:]: module.train(mode) @@ -398,6 +361,13 @@ class DistributedDataParallel(Module): self.modules_params[1:]): for tensor, param in zip(tensors, module_params): param.set_(tensor) + # Assume we have just run the optimizer and zeroed the + # grads of the parameters on the root model. We need + # to zero the grads on all model replicas as well. + # This snippet is copied from torch.optim.Optimizer. + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() # module buffer sync if self.broadcast_buffers and len(self.modules_buffers[0]) > 0: @@ -419,99 +389,3 @@ class DistributedDataParallel(Module): for layer in module.modules(): if isinstance(layer, torch.nn.modules.SyncBatchNorm): layer._specify_ddp_gpu_num(len(self.device_ids)) - - def _register_grad_hooks(self): - self._grad_accs = [] # need to keep them in scope - - # default stream tracking to launch nccl reduce kernels - self.default_streams = [] - for dev_id in self.device_ids: - with torch.cuda.device(dev_id): - self.default_streams.append(torch.cuda.current_stream()) - - for device_idx, module in enumerate(self._module_copies): - for p in module.parameters(): - if p.requires_grad: - p_tmp = p.expand_as(p) - grad_acc = p_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(p, device_idx)) - self._grad_accs.append(grad_acc) - - def _make_param_hook(self, param, device_idx): - bucket_idx, bucket_offset = self.bucket_map[param] - - def distributed_data_parallel_hook(*unused): - if param.grad.requires_grad: - raise RuntimeError("DistributedDataParallel only works " - "with gradients that don't require grad") - bucket = self.buckets[bucket_idx][device_idx] - bucket[bucket_offset] = param.grad.data - self.buckets_ready_size[bucket_idx][device_idx] += 1 - - # We can flush these and save memory for replicas - if device_idx > 0: - param.grad = None - with torch.no_grad(): - param.set_() - - # Current device's bucket is full - if self.buckets_ready_size[bucket_idx][device_idx] == self.bucket_sizes[bucket_idx]: - self.devs_ready[bucket_idx] += 1 - if self.devs_ready[bucket_idx] < len(self.device_ids): - return - - # Now all devices's buckets with index: bucket_idx are ready - if bucket_idx == self.next_bucket: - self._queue_reduction(bucket_idx) - self.next_bucket -= 1 - # Now reduce anything that is ready but not yet reduced - if len(self.ready_buckets_not_reduced) > 0: - sorted_todo = sorted(self.ready_buckets_not_reduced, reverse=True) - for i in sorted_todo: - # Nothing can be reduced now - if i < self.next_bucket: - break - self._queue_reduction(i) - self.ready_buckets_not_reduced.remove(i) - if i == self.next_bucket: - self.next_bucket -= 1 - else: - self.ready_buckets_not_reduced.add(bucket_idx) - - # When all devices' buckets - if self.next_bucket == -1: - # A final sync for all the reduction works - self._sync_reduction_works() - self.all_buckets_reduced = True - - return distributed_data_parallel_hook - - def _queue_reduction(self, bucket_idx): - # _queue_reduction will use a seperate CUDA stream to coalesce - # the small tensors to achieve more parallelisms, before passing the - # coalesced tensor into the c10d CUDA stream for reduction - result = dist._queue_reduction(self.process_group, - self.buckets[bucket_idx], - self.device_ids) - self.reduction_works[bucket_idx] = result[0] - self.buckets_coalesced[bucket_idx] = result[1] - - def _sync_reduction_works(self): - # Now only work on the first GPU of self.device_ids - # _sync_reduction will use a seperate CUDA stream to uncoalesce - # the coalesced tensors to achieve more parallelisms - for bucket_idx, grads_batch in enumerate(self.buckets): - dist._sync_reduction(self.reduction_works[bucket_idx], - grads_batch[0], - self.buckets_coalesced[bucket_idx]) - - # Reset the module states - self.next_bucket = len(self.bucket_sizes) - 1 - self.ready_buckets_not_reduced = set() - self.reduction_works = [None for _ in range(len(self.bucket_sizes))] - self.devs_ready = [0 for _ in range(len(self.bucket_sizes))] - - self.buckets = [[[None for _ in range(self.bucket_sizes[i])] - for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] - self.buckets_coalesced = [[] for _ in range(len(self.bucket_sizes))] - self.buckets_ready_size = [[0 for _ in range(len(self.device_ids))] for i in range(len(self.bucket_sizes))] -- 2.7.4