Allow DDP to wrap multi-GPU modules (#19271)
authorShen Li <shenli@fb.com>
Thu, 18 Apr 2019 04:18:49 +0000 (21:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 04:21:54 +0000 (21:21 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19271

allow DDP to take multi-gpu models

Reviewed By: pietern

Differential Revision: D14822375

fbshipit-source-id: 1eebfaa33371766d3129f0ac6f63a573332b2f1c

test/test_c10d.py
torch/nn/parallel/distributed.py

index 0d1bc22..3f9dcf6 100644 (file)
@@ -57,6 +57,18 @@ def skip_if_not_multigpu(func):
     return wrapper
 
 
+def skip_if_lt_x_gpu(x):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if torch.cuda.is_available() and torch.cuda.device_count() >= x:
+                return func(*args, **kwargs)
+            sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
+        return wrapper
+
+    return decorator
+
+
 def skip_if_not_nccl(func):
     """Skips a test if NCCL is not available (for c10d)."""
     @wraps(func)
@@ -1502,6 +1514,48 @@ class Net(nn.Module):
         return F.softmax(x, dim=1)
 
 
+class DoubleGpuNet(nn.Module):
+    def __init__(self, gpus):
+        super(DoubleGpuNet, self).__init__()
+        self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
+        self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
+        self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1])
+        self.relu = nn.ReLU()
+        self.no_grad_param = nn.Parameter(torch.Tensor([2, 2]).long(),
+                                          requires_grad=False).to(gpus[0])
+
+    def forward(self, x):
+        dev0 = self.fc1.weight.device
+        dev1 = self.fc2.weight.device
+        x = self.relu(self.fc1(x.to(dev0)))
+        x = self.relu(self.fc2(x.to(dev1)))
+        x = self.fc3(x)
+        return F.softmax(x, dim=1).to(dev0)
+
+
+class QuadraGpuNet(nn.Module):
+    def __init__(self, gpus):
+        super(QuadraGpuNet, self).__init__()
+        self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
+        self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
+        self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2])
+        self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3])
+        self.relu = nn.ReLU()
+        self.no_grad_param = nn.Parameter(torch.Tensor([2, 2]).long(),
+                                          requires_grad=False).to(gpus[0])
+
+    def forward(self, x):
+        dev0 = self.fc1.weight.device
+        dev1 = self.fc2.weight.device
+        dev2 = self.fc3.weight.device
+        dev3 = self.fc4.weight.device
+        x = self.relu(self.fc1(x.to(dev0)))
+        x = self.relu(self.fc2(x.to(dev1)))
+        x = self.relu(self.fc3(x.to(dev2)))
+        x = self.fc4(x.to(dev3))
+        return F.softmax(x, dim=1).to(dev0)
+
+
 class DistributedDataParallelTest(MultiProcessTestCase):
 
     def tearDown(self):
@@ -1517,7 +1571,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
     def world_size(self):
         return 2
 
-    def _test_ddp_with_process_group(self, process_group, gpus):
+    def _prepare_single_device_module(self, process_group, gpus, global_batch_size):
         model = Net()
         ddp_model = DistributedDataParallel(
             copy.deepcopy(model).cuda(gpus[0]),
@@ -1527,15 +1581,47 @@ class DistributedDataParallelTest(MultiProcessTestCase):
 
         model.cuda(gpus[0])
 
-        local_batch_size = len(gpus)
-        global_batch_size = self.world_size * local_batch_size
         input = torch.randn(global_batch_size, 2).cuda(gpus[0])
         target = torch.randn(global_batch_size, 4).cuda(gpus[0])
 
+        return model, ddp_model, input, target
+
+    def _prepare_multi_device_module(self, process_group, gpus, global_batch_size):
+        self.assertTrue(
+            len(gpus) == 2 or len(gpus) == 4,
+            "unexpected devices for ddp tests {}".format(gpus))
+        if len(gpus) == 2:
+            model = DoubleGpuNet(gpus)
+        elif len(gpus) == 4:
+            model = QuadraGpuNet(gpus)
+
+        ddp_model = DistributedDataParallel(
+            copy.deepcopy(model),
+            process_group=process_group,
+            bucket_cap_mb=0.001)
+
+        input = torch.randn(global_batch_size, 2).to(gpus[0])
+        target = torch.randn(global_batch_size, 4)
+
+        return model, ddp_model, input, target
+
+    def _test_ddp_with_process_group(self, process_group, gpus, multi_gpu=False):
+        local_batch_size = len(gpus)
+        global_batch_size = self.world_size * local_batch_size
+
+        if multi_gpu:
+            model, ddp_model, input, target = \
+                self._prepare_multi_device_module(
+                    process_group, gpus, global_batch_size)
+        else:
+            model, ddp_model, input, target = \
+                self._prepare_single_device_module(
+                    process_group, gpus, global_batch_size)
+
         def step_model(model, input, target):
             model.train()
             output = model(input)
-            loss = F.mse_loss(output, target)
+            loss = F.mse_loss(output, target.to(output.device))
             loss.backward()
 
         def update_parameters(model):
@@ -1564,24 +1650,117 @@ class DistributedDataParallelTest(MultiProcessTestCase):
             torch.manual_seed(1337 + iteration)
             input = input[torch.randperm(global_batch_size)]
 
-    @skip_if_not_multigpu
-    def test_gloo_backend(self):
+    def _test_gloo_backend(self, gpus, multi_gpu=False, use_str=False):
+        if use_str:
+            gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
         store = c10d.FileStore(self.file.name, self.world_size)
         options = c10d.ProcessGroupGloo.Options()
         options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
         process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
+        self._test_ddp_with_process_group(process_group, gpus, multi_gpu)
+
+    @skip_if_not_multigpu
+    def test_gloo_backend(self):
         gpus = gpus_for_rank(self.world_size)[self.rank]
-        self._test_ddp_with_process_group(process_group, gpus)
-        self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
+        self._test_gloo_backend(gpus)
+
+    @skip_if_not_multigpu
+    def test_gloo_backend_str(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_gloo_backend(gpus, use_str=True)
+
+    @skip_if_lt_x_gpu(4)
+    def test_gloo_backend_2gpu_module(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_gloo_backend(gpus[:2], multi_gpu=True)
+
+    @skip_if_lt_x_gpu(4)
+    def test_gloo_backend_2gpu_module_str(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_gloo_backend(gpus[:2], multi_gpu=True, use_str=True)
+
+    @skip_if_lt_x_gpu(8)
+    def test_gloo_backend_4gpu_module(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_gloo_backend(gpus[:4], multi_gpu=True)
+
+    @skip_if_lt_x_gpu(8)
+    def test_gloo_backend_4gpu_module_str(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_gloo_backend(gpus[:4], multi_gpu=True, use_str=True)
+
+    def _test_nccl_backend(self, gpus, multi_gpu=False, use_str=False):
+        if use_str:
+            gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
+        store = c10d.FileStore(self.file.name, self.world_size)
+        process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
+        self._test_ddp_with_process_group(process_group, gpus, multi_gpu)
 
     @skip_if_not_multigpu
     @skip_if_not_nccl
     def test_nccl_backend(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_nccl_backend(gpus)
+
+    @skip_if_not_multigpu
+    @skip_if_not_nccl
+    def test_nccl_backend_str(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_nccl_backend(gpus, use_str=True)
+
+    @skip_if_lt_x_gpu(4)
+    @skip_if_not_nccl
+    def test_nccl_backend_2gpu_module(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_nccl_backend(gpus[:2], multi_gpu=True)
+
+    @skip_if_lt_x_gpu(4)
+    @skip_if_not_nccl
+    def test_nccl_backend_2gpu_module_str(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_nccl_backend(gpus[:2], multi_gpu=True, use_str=True)
+
+    @skip_if_lt_x_gpu(8)
+    @skip_if_not_nccl
+    def test_nccl_backend_4gpu_module(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_nccl_backend(gpus[:4], multi_gpu=True)
+
+    @skip_if_lt_x_gpu(8)
+    @skip_if_not_nccl
+    def test_nccl_backend_4gpu_module_str(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+        self._test_nccl_backend(gpus[:4], multi_gpu=True, use_str=True)
+
+    @skip_if_lt_x_gpu(4)
+    @skip_if_not_nccl
+    def test_ddp_multi_device_module_config(self):
+        gpus = gpus_for_rank(self.world_size)[self.rank]
+
+        self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process")
+
         store = c10d.FileStore(self.file.name, self.world_size)
         process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
-        gpus = gpus_for_rank(self.world_size)[self.rank]
-        self._test_ddp_with_process_group(process_group, gpus)
-        self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
+
+        gpus = gpus[:2]
+        model = DoubleGpuNet(gpus)
+
+        with self.assertRaisesRegex(AssertionError, "output_device .* single-device CUDA"):
+            ddp_model = DistributedDataParallel(
+                model, output_device=gpus[1], process_group=process_group)
+
+        with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"):
+            ddp_model = DistributedDataParallel(
+                model, device_ids=gpus, process_group=process_group)
+
+        with self.assertRaisesRegex(AssertionError, "only works with CUDA devices"):
+            model.fc1 = model.fc1.cpu()
+            ddp_model = DistributedDataParallel(model, process_group=process_group)
+
+        model = model.cpu()
+        with self.assertRaisesRegex(AssertionError, "device_ids .* single-device CUDA"):
+            ddp_model = DistributedDataParallel(
+                model, device_ids=gpus, process_group=process_group)
 
     @skip_if_not_multigpu
     @skip_if_not_nccl
index ddca4a6..d2b4fe9 100644 (file)
@@ -153,11 +153,22 @@ class DistributedDataParallel(Module):
 
     Args:
         module (Module): module to be parallelized
-        device_ids (list of int or torch.device): CUDA devices (default: all devices)
-        output_device (int or torch.device): device location of output (default: device_ids[0])
+        device_ids (list of int or torch.device): CUDA devices. This should
+                   only be provided when the input module resides on a single
+                   CUDA device. For single-device modules, the ``i``th
+                   :attr:`module` replica is placed on ``device_ids[i]``. For
+                   multi-device modules and CPU modules, device_ids must be None
+                   or an empty list, and input data for the forward pass must be
+                   placed on the correct device. (default: all devices for
+                   single-device modules)
+        output_device (int or torch.device): device location of output for
+                      single-device CUDA modules. For multi-device modules and
+                      CPU modules, it must be None, and the module itself
+                      dictates the output location. (default: device_ids[0] for
+                      single-device modules)
         broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
-                           the module at beginning of the forward function.
-                           (default: ``True``)
+                          the module at beginning of the forward function.
+                          (default: ``True``)
         process_group: the process group to be used for distributed data
                        all-reduction. If ``None``, the default process group, which
                        is created by ```torch.distributed.init_process_group```,
@@ -192,12 +203,35 @@ class DistributedDataParallel(Module):
 
         super(DistributedDataParallel, self).__init__()
 
-        # Use all devices by default
-        if device_ids is None:
-            device_ids = list(range(torch.cuda.device_count()))
+        self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1
+        self.is_cuda = all([p.device.type == 'cuda' for p in module.parameters()])
 
-        if output_device is None:
-            output_device = device_ids[0]
+        if not self.is_cuda or self.is_multi_device_module:
+            assert not device_ids and not output_device, (
+                "DistributedDataParallel device_ids and output_device arguments "
+                "only work with single-device CUDA modules, but got "
+                "device_ids {}, output_device {}, and module parameters {}."
+            ).format(device_ids, output_device, {p.device for p in module.parameters()})
+
+            self.device_ids = None
+            self.output_device = None
+        else:
+            # Use all devices by default for single-device CUDA modules
+            if device_ids is None:
+                device_ids = list(range(torch.cuda.device_count()))
+
+            self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+
+            if output_device is None:
+                output_device = device_ids[0]
+
+            self.output_device = _get_device_index(output_device, True)
+
+        if self.is_multi_device_module:
+            assert self.is_cuda, (
+                "DistributedDataParallel with multi-device module only works "
+                "with CUDA devices, but module parameters locate in {}."
+            ).format({p.device for p in module.parameters()})
 
         if process_group is None:
             self.process_group = _get_default_group()
@@ -206,9 +240,8 @@ class DistributedDataParallel(Module):
 
         self.dim = dim
         self.module = module
-        self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
-        self.output_device = _get_device_index(output_device, True)
         self.broadcast_buffers = broadcast_buffers
+
         if check_reduction:
             # This argument is no longer used since the reducer
             # will ensure reduction completes even if some parameters
@@ -241,7 +274,9 @@ class DistributedDataParallel(Module):
         (4) registering the grad hooks
         (5) passing a handle of DDP to SyncBatchNorm Layer
         """
-        if len(self.device_ids) > 1:
+        if self.device_ids and len(self.device_ids) > 1:
+            # only create replicas for single-device CUDA modules
+            #
             # TODO: we don't need to replicate params in here. they're always going to
             # be broadcasted using larger blocks in broadcast_coalesced, so it might be
             # better to not pollute the caches with these small blocks
@@ -311,13 +346,16 @@ class DistributedDataParallel(Module):
                                "process_group argument to DDP constructor")
 
     def forward(self, *inputs, **kwargs):
-        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
         self._sync_params()
-        if len(self.device_ids) == 1:
-            output = self.module(*inputs[0], **kwargs[0])
+        if self.device_ids:
+            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+            if len(self.device_ids) == 1:
+                output = self.module(*inputs[0], **kwargs[0])
+            else:
+                outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
+                output = self.gather(outputs, self.output_device)
         else:
-            outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
-            output = self.gather(outputs, self.output_device)
+            output = self.module(*inputs, **kwargs)
 
         # We'll return the output object verbatim since it is a freeform object.
         # We need to find any tensors in this object, though, because we need to
@@ -352,7 +390,9 @@ class DistributedDataParallel(Module):
 
     def _sync_params(self):
         with torch.no_grad():
-            if len(self.device_ids) > 1:
+            # only do intra-node parameters sync for replicated single-device
+            # CUDA modules
+            if self.device_ids and len(self.device_ids) > 1:
                 # intra-node parameter sync
                 result = broadcast_coalesced(self.modules_params[0],
                                              self.device_ids,
@@ -374,7 +414,9 @@ class DistributedDataParallel(Module):
                 # cross-node buffer sync
                 self._dist_broadcast_coalesced(self.modules_buffers[0],
                                                self.broadcast_bucket_size)
-                if len(self.device_ids) > 1:
+                # only do intra-node buffer sync for replicated single-device
+                # CUDA modules
+                if self.device_ids and len(self.device_ids) > 1:
                     # intra-node buffer sync
                     result = broadcast_coalesced(self.modules_buffers[0],
                                                  self.device_ids,
@@ -388,4 +430,6 @@ class DistributedDataParallel(Module):
         for dev_idx, module in enumerate(module_copies):
             for layer in module.modules():
                 if isinstance(layer, torch.nn.modules.SyncBatchNorm):
-                    layer._specify_ddp_gpu_num(len(self.device_ids))
+                    assert self.is_cuda, "SyncBatchNorm layers only work with CUDA modules"
+                    layer._specify_ddp_gpu_num(
+                        len(self.device_ids) if self.device_ids else 1)