From: Shen Li Date: Thu, 18 Apr 2019 04:18:49 +0000 (-0700) Subject: Allow DDP to wrap multi-GPU modules (#19271) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~182 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6732358bf92939789978abac29d18aab6ebc8c2d;p=platform%2Fupstream%2Fpytorch.git Allow DDP to wrap multi-GPU modules (#19271) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19271 allow DDP to take multi-gpu models Reviewed By: pietern Differential Revision: D14822375 fbshipit-source-id: 1eebfaa33371766d3129f0ac6f63a573332b2f1c --- diff --git a/test/test_c10d.py b/test/test_c10d.py index 0d1bc22..3f9dcf6 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -57,6 +57,18 @@ def skip_if_not_multigpu(func): return wrapper +def skip_if_lt_x_gpu(x): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if torch.cuda.is_available() and torch.cuda.device_count() >= x: + return func(*args, **kwargs) + sys.exit(TEST_SKIPS['multi-gpu'].exit_code) + return wrapper + + return decorator + + def skip_if_not_nccl(func): """Skips a test if NCCL is not available (for c10d).""" @wraps(func) @@ -1502,6 +1514,48 @@ class Net(nn.Module): return F.softmax(x, dim=1) +class DoubleGpuNet(nn.Module): + def __init__(self, gpus): + super(DoubleGpuNet, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0]) + self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1]) + self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1]) + self.relu = nn.ReLU() + self.no_grad_param = nn.Parameter(torch.Tensor([2, 2]).long(), + requires_grad=False).to(gpus[0]) + + def forward(self, x): + dev0 = self.fc1.weight.device + dev1 = self.fc2.weight.device + x = self.relu(self.fc1(x.to(dev0))) + x = self.relu(self.fc2(x.to(dev1))) + x = self.fc3(x) + return F.softmax(x, dim=1).to(dev0) + + +class QuadraGpuNet(nn.Module): + def __init__(self, gpus): + super(QuadraGpuNet, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0]) + self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1]) + self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2]) + self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3]) + self.relu = nn.ReLU() + self.no_grad_param = nn.Parameter(torch.Tensor([2, 2]).long(), + requires_grad=False).to(gpus[0]) + + def forward(self, x): + dev0 = self.fc1.weight.device + dev1 = self.fc2.weight.device + dev2 = self.fc3.weight.device + dev3 = self.fc4.weight.device + x = self.relu(self.fc1(x.to(dev0))) + x = self.relu(self.fc2(x.to(dev1))) + x = self.relu(self.fc3(x.to(dev2))) + x = self.fc4(x.to(dev3)) + return F.softmax(x, dim=1).to(dev0) + + class DistributedDataParallelTest(MultiProcessTestCase): def tearDown(self): @@ -1517,7 +1571,7 @@ class DistributedDataParallelTest(MultiProcessTestCase): def world_size(self): return 2 - def _test_ddp_with_process_group(self, process_group, gpus): + def _prepare_single_device_module(self, process_group, gpus, global_batch_size): model = Net() ddp_model = DistributedDataParallel( copy.deepcopy(model).cuda(gpus[0]), @@ -1527,15 +1581,47 @@ class DistributedDataParallelTest(MultiProcessTestCase): model.cuda(gpus[0]) - local_batch_size = len(gpus) - global_batch_size = self.world_size * local_batch_size input = torch.randn(global_batch_size, 2).cuda(gpus[0]) target = torch.randn(global_batch_size, 4).cuda(gpus[0]) + return model, ddp_model, input, target + + def _prepare_multi_device_module(self, process_group, gpus, global_batch_size): + self.assertTrue( + len(gpus) == 2 or len(gpus) == 4, + "unexpected devices for ddp tests {}".format(gpus)) + if len(gpus) == 2: + model = DoubleGpuNet(gpus) + elif len(gpus) == 4: + model = QuadraGpuNet(gpus) + + ddp_model = DistributedDataParallel( + copy.deepcopy(model), + process_group=process_group, + bucket_cap_mb=0.001) + + input = torch.randn(global_batch_size, 2).to(gpus[0]) + target = torch.randn(global_batch_size, 4) + + return model, ddp_model, input, target + + def _test_ddp_with_process_group(self, process_group, gpus, multi_gpu=False): + local_batch_size = len(gpus) + global_batch_size = self.world_size * local_batch_size + + if multi_gpu: + model, ddp_model, input, target = \ + self._prepare_multi_device_module( + process_group, gpus, global_batch_size) + else: + model, ddp_model, input, target = \ + self._prepare_single_device_module( + process_group, gpus, global_batch_size) + def step_model(model, input, target): model.train() output = model(input) - loss = F.mse_loss(output, target) + loss = F.mse_loss(output, target.to(output.device)) loss.backward() def update_parameters(model): @@ -1564,24 +1650,117 @@ class DistributedDataParallelTest(MultiProcessTestCase): torch.manual_seed(1337 + iteration) input = input[torch.randperm(global_batch_size)] - @skip_if_not_multigpu - def test_gloo_backend(self): + def _test_gloo_backend(self, gpus, multi_gpu=False, use_str=False): + if use_str: + gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus)) store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) + self._test_ddp_with_process_group(process_group, gpus, multi_gpu) + + @skip_if_not_multigpu + def test_gloo_backend(self): gpus = gpus_for_rank(self.world_size)[self.rank] - self._test_ddp_with_process_group(process_group, gpus) - self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus))) + self._test_gloo_backend(gpus) + + @skip_if_not_multigpu + def test_gloo_backend_str(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_gloo_backend(gpus, use_str=True) + + @skip_if_lt_x_gpu(4) + def test_gloo_backend_2gpu_module(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_gloo_backend(gpus[:2], multi_gpu=True) + + @skip_if_lt_x_gpu(4) + def test_gloo_backend_2gpu_module_str(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_gloo_backend(gpus[:2], multi_gpu=True, use_str=True) + + @skip_if_lt_x_gpu(8) + def test_gloo_backend_4gpu_module(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_gloo_backend(gpus[:4], multi_gpu=True) + + @skip_if_lt_x_gpu(8) + def test_gloo_backend_4gpu_module_str(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_gloo_backend(gpus[:4], multi_gpu=True, use_str=True) + + def _test_nccl_backend(self, gpus, multi_gpu=False, use_str=False): + if use_str: + gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus)) + store = c10d.FileStore(self.file.name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + self._test_ddp_with_process_group(process_group, gpus, multi_gpu) @skip_if_not_multigpu @skip_if_not_nccl def test_nccl_backend(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_nccl_backend(gpus) + + @skip_if_not_multigpu + @skip_if_not_nccl + def test_nccl_backend_str(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_nccl_backend(gpus, use_str=True) + + @skip_if_lt_x_gpu(4) + @skip_if_not_nccl + def test_nccl_backend_2gpu_module(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_nccl_backend(gpus[:2], multi_gpu=True) + + @skip_if_lt_x_gpu(4) + @skip_if_not_nccl + def test_nccl_backend_2gpu_module_str(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_nccl_backend(gpus[:2], multi_gpu=True, use_str=True) + + @skip_if_lt_x_gpu(8) + @skip_if_not_nccl + def test_nccl_backend_4gpu_module(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_nccl_backend(gpus[:4], multi_gpu=True) + + @skip_if_lt_x_gpu(8) + @skip_if_not_nccl + def test_nccl_backend_4gpu_module_str(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_nccl_backend(gpus[:4], multi_gpu=True, use_str=True) + + @skip_if_lt_x_gpu(4) + @skip_if_not_nccl + def test_ddp_multi_device_module_config(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + + self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process") + store = c10d.FileStore(self.file.name, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) - gpus = gpus_for_rank(self.world_size)[self.rank] - self._test_ddp_with_process_group(process_group, gpus) - self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus))) + + gpus = gpus[:2] + model = DoubleGpuNet(gpus) + + with self.assertRaisesRegex(AssertionError, "output_device .* single-device CUDA"): + ddp_model = DistributedDataParallel( + model, output_device=gpus[1], process_group=process_group) + + with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"): + ddp_model = DistributedDataParallel( + model, device_ids=gpus, process_group=process_group) + + with self.assertRaisesRegex(AssertionError, "only works with CUDA devices"): + model.fc1 = model.fc1.cpu() + ddp_model = DistributedDataParallel(model, process_group=process_group) + + model = model.cpu() + with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"): + ddp_model = DistributedDataParallel( + model, device_ids=gpus, process_group=process_group) @skip_if_not_multigpu @skip_if_not_nccl diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index ddca4a6..d2b4fe9 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -153,11 +153,22 @@ class DistributedDataParallel(Module): Args: module (Module): module to be parallelized - device_ids (list of int or torch.device): CUDA devices (default: all devices) - output_device (int or torch.device): device location of output (default: device_ids[0]) + device_ids (list of int or torch.device): CUDA devices. This should + only be provided when the input module resides on a single + CUDA device. For single-device modules, the ``i``th + :attr:`module` replica is placed on ``device_ids[i]``. For + multi-device modules and CPU modules, device_ids must be None + or an empty list, and input data for the forward pass must be + placed on the correct device. (default: all devices for + single-device modules) + output_device (int or torch.device): device location of output for + single-device CUDA modules. For multi-device modules and + CPU modules, it must be None, and the module itself + dictates the output location. (default: device_ids[0] for + single-device modules) broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of - the module at beginning of the forward function. - (default: ``True``) + the module at beginning of the forward function. + (default: ``True``) process_group: the process group to be used for distributed data all-reduction. If ``None``, the default process group, which is created by ```torch.distributed.init_process_group```, @@ -192,12 +203,35 @@ class DistributedDataParallel(Module): super(DistributedDataParallel, self).__init__() - # Use all devices by default - if device_ids is None: - device_ids = list(range(torch.cuda.device_count())) + self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1 + self.is_cuda = all([p.device.type == 'cuda' for p in module.parameters()]) - if output_device is None: - output_device = device_ids[0] + if not self.is_cuda or self.is_multi_device_module: + assert not device_ids and not output_device, ( + "DistributedDataParallel device_ids and output_device arguments " + "only work with single-device CUDA modules, but got " + "device_ids {}, output_device {}, and module parameters {}." + ).format(device_ids, output_device, {p.device for p in module.parameters()}) + + self.device_ids = None + self.output_device = None + else: + # Use all devices by default for single-device CUDA modules + if device_ids is None: + device_ids = list(range(torch.cuda.device_count())) + + self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + + if output_device is None: + output_device = device_ids[0] + + self.output_device = _get_device_index(output_device, True) + + if self.is_multi_device_module: + assert self.is_cuda, ( + "DistributedDataParallel with multi-device module only works " + "with CUDA devices, but module parameters locate in {}." + ).format({p.device for p in module.parameters()}) if process_group is None: self.process_group = _get_default_group() @@ -206,9 +240,8 @@ class DistributedDataParallel(Module): self.dim = dim self.module = module - self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) - self.output_device = _get_device_index(output_device, True) self.broadcast_buffers = broadcast_buffers + if check_reduction: # This argument is no longer used since the reducer # will ensure reduction completes even if some parameters @@ -241,7 +274,9 @@ class DistributedDataParallel(Module): (4) registering the grad hooks (5) passing a handle of DDP to SyncBatchNorm Layer """ - if len(self.device_ids) > 1: + if self.device_ids and len(self.device_ids) > 1: + # only create replicas for single-device CUDA modules + # # TODO: we don't need to replicate params in here. they're always going to # be broadcasted using larger blocks in broadcast_coalesced, so it might be # better to not pollute the caches with these small blocks @@ -311,13 +346,16 @@ class DistributedDataParallel(Module): "process_group argument to DDP constructor") def forward(self, *inputs, **kwargs): - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) self._sync_params() - if len(self.device_ids) == 1: - output = self.module(*inputs[0], **kwargs[0]) + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) else: - outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) - output = self.gather(outputs, self.output_device) + output = self.module(*inputs, **kwargs) # We'll return the output object verbatim since it is a freeform object. # We need to find any tensors in this object, though, because we need to @@ -352,7 +390,9 @@ class DistributedDataParallel(Module): def _sync_params(self): with torch.no_grad(): - if len(self.device_ids) > 1: + # only do intra-node parameters sync for replicated single-device + # CUDA modules + if self.device_ids and len(self.device_ids) > 1: # intra-node parameter sync result = broadcast_coalesced(self.modules_params[0], self.device_ids, @@ -374,7 +414,9 @@ class DistributedDataParallel(Module): # cross-node buffer sync self._dist_broadcast_coalesced(self.modules_buffers[0], self.broadcast_bucket_size) - if len(self.device_ids) > 1: + # only do intra-node buffer sync for replicated single-device + # CUDA modules + if self.device_ids and len(self.device_ids) > 1: # intra-node buffer sync result = broadcast_coalesced(self.modules_buffers[0], self.device_ids, @@ -388,4 +430,6 @@ class DistributedDataParallel(Module): for dev_idx, module in enumerate(module_copies): for layer in module.modules(): if isinstance(layer, torch.nn.modules.SyncBatchNorm): - layer._specify_ddp_gpu_num(len(self.device_ids)) + assert self.is_cuda, "SyncBatchNorm layers only work with CUDA modules" + layer._specify_ddp_gpu_num( + len(self.device_ids) if self.device_ids else 1)