Support replicating multi-GPU modules (#18687)
authorShen Li <shenli@fb.com>
Wed, 3 Apr 2019 21:37:54 +0000 (14:37 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 21:43:07 +0000 (14:43 -0700)
Summary:
If the input `network` resides on multiple GPUs, `devices` must be a 2D list with `devices[0]` matching `network`'s devices. See  #18591
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18687

Differential Revision: D14706162

Pulled By: mrshenli

fbshipit-source-id: dca630d3308f2dbcf8b75629c452d7a64092ba42

test/common_cuda.py
test/test_nn.py
torch/nn/parallel/replicate.py

index 19750df..7a7110e 100644 (file)
@@ -7,6 +7,7 @@ from common_utils import TEST_WITH_ROCM, TEST_NUMBA
 
 TEST_CUDA = torch.cuda.is_available()
 TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
+TEST_GEQ4GPU = TEST_CUDA and torch.cuda.device_count() >= 4
 CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
 # note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
 TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
index 5ba8e62..2ba7167 100644 (file)
@@ -30,7 +30,7 @@ from torch.nn.parallel._functions import Broadcast
 from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
     TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
     get_function_arglist, load_tests
-from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
+from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_GEQ4GPU, TEST_CUDNN, \
     TEST_CUDNN_VERSION
 from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
     module_tests, criterion_tests, loss_reference_fns, get_reduction, \
@@ -3336,23 +3336,157 @@ class TestNN(NNTestCase):
         module = nn.Linear(10, 5).float().cuda()
         input = Variable(torch.randn(2, 10).float().cuda())
         expected_output = module(input).data
-        replicas = dp.replicate(module, (0, 1))
-        for i, replica in enumerate(replicas):
-            for p in replica.parameters():
-                self.assertEqual(p.get_device(), i)
-            replica_input = input.cuda(i)
-            self.assertEqual(replica(replica_input).data, expected_output)
+        for devices in [(0, 1), [[0], [1]]]:
+            replicas = dp.replicate(module, devices)
+            for i, replica in enumerate(replicas):
+                for p in replica.parameters():
+                    self.assertEqual(p.get_device(), i)
+                replica_input = input.cuda(i)
+                self.assertEqual(replica(replica_input).data, expected_output)
+
+    @unittest.skipIf(not TEST_GEQ4GPU, "less than 4 GPUs")
+    def test_replicate_multi_gpu_module(self):
+        class MultiGpuModule(nn.Module):
+            def __init__(self):
+                super(MultiGpuModule, self).__init__()
+                self.net1 = torch.nn.Linear(10, 5).cuda(0)
+                self.net2 = torch.nn.Linear(5, 5).cuda(1)
+                self.bn = nn.BatchNorm2d(10).cuda(0)
+
+            def forward(self, x):
+                out = self.net1(x.cuda(self.net1.weight.get_device()))
+                return self.net2(out.cuda(self.net2.weight.get_device()))
+
+        module = MultiGpuModule()
+
+        input = torch.rand(2, 10).cuda(0)
+        expected_output = module(input).cpu()
+
+        for devices in ([[0, 1], [2, 3]], [[1, 0], [3, 2]]):
+            replicas = dp.replicate(module, devices)
+            for i, replica in enumerate(replicas):
+                self.assertEqual(replica.net1.weight.get_device(), 2 * i)
+                self.assertEqual(replica.net1.bias.get_device(), 2 * i)
+                self.assertEqual(replica.net2.weight.get_device(), 2 * i + 1)
+                self.assertEqual(replica.net2.bias.get_device(), 2 * i + 1)
+                self.assertEqual(replica.bn.running_mean.get_device(), 2 * i)
+                self.assertEqual(replica.bn.running_var.get_device(), 2 * i)
+                self.assertEqual(
+                    replica.bn.num_batches_tracked.get_device(), 2 * i)
+
+                replica_input = input.cuda(2 * i)
+                replica_output = replica(replica_input)
+                self.assertEqual(replica_output.get_device(), 2 * i + 1)
+                self.assertEqual(replica_output.cpu(), expected_output)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_replicate_device_indices(self):
+        from torch.nn.parallel.replicate import _to_device_index as f
+
+        self.assertEqual(
+            f([['cuda:0', 'cuda:1', 'cuda:2'],
+               ['cuda:4', 'cuda:3', 'cuda:6']]),
+            [[0, 1, 2], [4, 3, 6]])
+
+        self.assertEqual(f(('cuda:0', 'cuda:1', 'cuda:2')), [0, 1, 2])
+
+        self.assertEqual(
+            len(set([0, 1, 2]).intersection(f({'cuda:0', 'cuda:1', 'cuda:2'}))),
+            3)
+        self.assertEqual(
+            f([['cuda:0'], ['cuda:1'], ['cuda:2']]), [[0], [1], [2]])
+
+        msg = "empty device list"
+        for devices in (None, (), [], [[]]):
+            with self.assertRaisesRegex(RuntimeError, msg):
+                f(devices)
+
+        msg = "unidentical number of devices"
+        for devices in ([[0, 1], [2]], [[0], [1, 2]]):
+            with self.assertRaisesRegex(AssertionError, msg):
+                f(devices)
+
+        msg = "shared by multiple replicas"
+        for devices in ([[0, 1], [1, 2]], [[0], [1], [0]]):
+            with self.assertRaisesRegex(AssertionError, msg):
+                f(devices)
+
+        msg = "Duplicated device ids"
+        for devices in ([[0, 1, 2, 1]], [0, 1, 1], [0, 0]):
+            with self.assertRaisesRegex(AssertionError, msg):
+                f(devices)
+
+    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+    def test_replicate_tensor_grouping_multi_gpu(self):
+        from torch.nn.parallel.replicate import _group_by_device
+
+        a = torch.Tensor(1).cuda(0)
+        b = torch.Tensor(2).cuda(0)
+        c = torch.Tensor(3).cuda(1)
+        d = torch.Tensor(4).cuda(0)
+        e = torch.Tensor(5).cuda(1)
+
+        tensors = [a, b, c, d, e]
+        for devices in ([[0, 1], [2, 3]], [[1, 4, 0], [3, 5, 2]]):
+            grouped_tensors, grouped_devices, original_index = \
+                _group_by_device(tensors, devices)
+
+            self.assertEqual(grouped_tensors, [[a, b, d], [c, e]])
+            self.assertEqual(grouped_devices, [[0, 2], [1, 3]])
+            self.assertEqual(original_index, [[0, 1, 3], [2, 4]])
+
+        msg = "missing from devices"
+        for devices in ([[0, 2], [1, 3]], [[1, 2], [0, 3]], [[2, 3], [0, 1]]):
+            with self.assertRaisesRegex(AssertionError, msg):
+                grouped_tensors, grouped_devices, original_index = \
+                    _group_by_device(tensors, devices)
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_replicate_tensor_grouping(self):
+        from torch.nn.parallel.replicate import _group_by_device
+
+        a = torch.Tensor(1).cuda(0)
+        b = torch.Tensor(2).cuda(0)
+        c = torch.Tensor(3).cuda(0)
+
+        tensors = [a, b, c]
+
+        grouped_tensors, grouped_devices, original_index = \
+            _group_by_device(tensors, [0, 1])
+
+        self.assertEqual(grouped_tensors, [[a, b, c]])
+        self.assertEqual(grouped_devices, [[0, 1]])
+        self.assertEqual(original_index, [[0, 1, 2]])
+
+    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+    def test_replicate_reshape(self):
+        from torch.nn.parallel.replicate import _broadcast_coalesced_reshape
+
+        a = torch.Tensor(1).cuda(0)
+        b = torch.Tensor(2).cuda(0)
+        c = torch.Tensor(3).cuda(1)
+        d = torch.Tensor(4).cuda(0)
+        e = torch.Tensor(5).cuda(1)
+
+        tensors = [a, b, c, d, e]
+        outputs = _broadcast_coalesced_reshape(tensors, [[0, 1], [1, 0]])
+
+        self.assertEqual(len(outputs), 2)
+        self.assertEqual(outputs[0], [a, b, c, d, e])
+        self.assertEqual(
+            outputs[1], [a.cuda(1), b.cuda(1), c.cuda(0), d.cuda(1), e.cuda(0)])
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     def test_replicate_buffers(self):
         net = nn.Module()
         net.bn = nn.BatchNorm2d(10)
         net.cuda()
-        replicas = dp.replicate(net, (0, 1))
-        for i, replica in enumerate(replicas):
-            self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
-            self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
-            self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
+        for devices in [(0, 1), [[0], [1]]]:
+            replicas = dp.replicate(net, devices)
+            for i, replica in enumerate(replicas):
+                self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
+                self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
+                self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     @skipIfRocm
index cc270a1..a68ade4 100644 (file)
@@ -56,6 +56,34 @@ def _replicatable_module(module, memo=None):
     return True
 
 
+def _to_device_index(devices):
+    if not devices:
+        raise RuntimeError("Cannot replicate using an empty device list.")
+
+    if isinstance(devices, list) and isinstance(devices[0], list):
+        device_ids = []
+        seen = set()
+        for i, replica_devs in enumerate(devices):
+            assert len(replica_devs) == len(devices[0]), (
+                "Cannot replicate to unidentical number of devices, but got "
+                "device list {} and {} for replica {} and {}."
+            ).format(devices[0], devices[i], 0, i)
+
+            assert len(seen.intersection(replica_devs)) == 0, (
+                "Devices {} are shared by multiple replicas."
+            ).format(seen.intersection(replica_devs))
+            seen.update(replica_devs)
+
+            device_ids.append(_to_device_index(replica_devs))
+        return device_ids
+    else:
+        assert len(devices) == len(set(devices)), (
+            "Duplicated device ids {}."
+        ).format(devices)
+
+        return list(map(lambda x: _get_device_index(x, True), devices))
+
+
 def _build_param_dict(modules, module_copies, module_indices):
     param_dict = {}
     for module in modules:
@@ -83,26 +111,169 @@ def _copy_scriptmodule_methods(modules, module_copies, module_indices):
             replica._copy_method(method_name, param_list, module)
 
 
+# Group tensors on the same device together, which can later be broadcast to
+# a list of devices. For example,consider 5 tensors on 2 devices
+#   a = torch.Tensor(0).cuda(0)
+#   b = torch.Tensor(0).cuda(0)
+#   c = torch.Tensor(0).cuda(1)
+#   d = torch.Tensor(0).cuda(0)
+#   e = torch.Tensor(0).cuda(1).
+# Let inputs be
+#   tensors = [a, b, c, d, e] and
+#   devices = [[0, 1], [2, 3]].
+# Then, outputs will be:
+#   grouped_tensors = [[a, b, d], [c, e]],
+#   grouped_devices = [[0, 2], [1, 3]],
+#   original_index = [[0, 1, 3], [2, 4]],
+# meaning that grouped_tensors[i] will be broadcast to grouped_devices[i].
+def _group_by_device(tensors, devices):
+    if isinstance(devices[0], list):
+        # all tensor devices must appear in devices[0]
+        missing_devs = [t.device.index for t in tensors
+                        if t.device.index not in devices[0]]
+        assert not missing_devs, (
+            "tensor devices {} are missing from devices[0] {}."
+        ).format(missing_devs, devices[0])
+
+        # device id to output group index, this is necessary when `tensors` only
+        # use a subset of devices in `devices[0]`
+        dev_to_group_idx = {}
+        for t in tensors:
+            if t.device.index not in dev_to_group_idx:
+                dev_to_group_idx[t.device.index] = len(dev_to_group_idx)
+
+        # Group tensors by devices and remember each tensor's original index.
+        # The original_index helps to recover the original input tensor order
+        # from grouped tensors.
+        grouped_tensors = [[] for _ in range(len(dev_to_group_idx))]
+        original_index = [[] for _ in range(len(dev_to_group_idx))]
+        for i, t in enumerate(tensors):
+            group_id = dev_to_group_idx[t.device.index]
+            original_index[group_id].append(i)
+            grouped_tensors[group_id].append(t)
+
+        # group devices together if they should be in the same broadcast call
+        grouped_devices = [[] for _ in range(len(dev_to_group_idx))]
+        transpose = list(zip(*devices))
+        for row in transpose:
+            if row[0] in dev_to_group_idx:
+                grouped_devices[dev_to_group_idx[row[0]]] = list(row)
+
+        return grouped_tensors, grouped_devices, original_index
+    else:
+        return [tensors], [devices], [list(range(len(tensors)))]
+
+
+# Return len(devices) replicas of input tensors. If input tensors reside on
+# multiple GPUs, devices must be a 2D list with devices[0] matching input
+# tensors' devices. For example,consider 5 tensors on 2 devices
+#   a = torch.Tensor(0).cuda(0)
+#   b = torch.Tensor(0).cuda(0)
+#   c = torch.Tensor(0).cuda(1)
+#   d = torch.Tensor(0).cuda(0)
+#   e = torch.Tensor(0).cuda(1).
+# Let inputs be
+#   tensors = [a, b, c, d, e] and
+#   devices = [[0, 1], [2, 3]].
+#
+# The output will be a 2D list of tensors:
+#   [[a0, b0, c0, d0, e0],
+#    [a1, b1, c1, d1, e1]], where
+# a0, b0, d0 are on device 0
+# a1, b1, d1 are on device 2
+# c0, e0 are on device 1
+# c1, e1 are on device 3
+#
+# This example will be used throughout the implementation of this function.
 def _broadcast_coalesced_reshape(tensors, devices, detach=False):
     from ._functions import Broadcast
-    if detach:
-        return comm.broadcast_coalesced(tensors, devices)
-    else:
-        # Use the autograd function to broadcast if not detach
-        if len(tensors) > 0:
-            tensor_copies = Broadcast.apply(devices, *tensors)
-            return [tensor_copies[i:i + len(tensors)]
-                    for i in range(0, len(tensor_copies), len(tensors))]
+
+    # a triply-nested list of 1) broadcast group, 2) tensor list replica,
+    # 3) tensors on the same device.
+    grouped_replicas = []
+    grouped_tensors, grouped_devices, original_index = \
+        _group_by_device(tensors, devices)
+    # For the example input described above, we have
+    # grouped_tensors =[[a, b, d], [c, e]]
+    # grouped_devices = [[0, 2], [1, 3]]
+    # original_index = [[0, 1, 3], [2, 4]]
+    for tensor_group, device_group in zip(grouped_tensors, grouped_devices):
+        if detach:
+            grouped_replicas.append(
+                comm.broadcast_coalesced(tensor_group, device_group))
         else:
-            return []
+            if len(tensor_group) > 0:
+                # Use the autograd function to broadcast if not detach
+                tensor_copies = Broadcast.apply(device_group, *tensor_group)
+                grouped_replicas.append(
+                    [tensor_copies[i:i + len(tensor_group)]
+                        for i in range(
+                            0, len(tensor_copies), len(tensor_group))])
+            else:
+                grouped_replicas.append([])
+
+    if isinstance(devices[0], list):
+        # convert the triply-nested list into a doubly-nested list of 1) replica
+        # 2) tensors in the same replica (can be on different devices)
+        #
+        # For the example input described above, we have
+        #   grouped_replicas = [
+        #       [[a0, b0, d0],   # on device 0
+        #        [a1, b1, d1]],  # on device 2
+        #       [[c0, e0],       # on device 1
+        #        [c1, e1]]       # on device 3
+        #   ]
+        #
+        # The code below re-organize elements in grouped_replicas to the
+        # expected form:
+        #   [[a0, b0, c0, d0, e0],
+        #    [a1, b1, c1, d1, e1]].
+        transpose = [0 for _ in tensors]
+        for g_idx in range(len(original_index)):
+            for t_idx in range(len(original_index[g_idx])):
+                # g_idx is the broadcast group index.
+                # t_idx is the tensor's index in a replica within a group.
+                # Tensors in grouped_replicas[g_idx, :, t_idx] are replicas of
+                # input tensor[original_index[g_idx][t_idx]]. Retrieve the
+                # column and add it as the original_index[g_idx][t_idx]'s row in
+                # transpose.
+                transpose[original_index[g_idx][t_idx]] = \
+                    [replica[t_idx] for replica in grouped_replicas[g_idx]]
+
+        # transpose the result to stay consistent with the 1D devices case.
+        return list(zip(*transpose))
+    else:
+        return grouped_replicas[0]
 
 
 def replicate(network, devices, detach=False):
+    r"""Replicate the input :attr:`network` to given :attr:`devices`. If
+    :attr:`network` resides on CPU or a single GPU, :attr:`devices` must be a 1D
+    list of destination devices. If :attr:`network` resides on multiple GPUs,
+    :attr:`devices` must be satisfy the following conditions:
+
+    1. :attr:`devices` must be a 2D list,
+    2. ``devices[0]`` must match the :attr:`network`'s devices, in any order.
+    3. All ``devices[i]`` must have the same length.
+
+    For example, :attr:`network` is a ``Sequential`` module with two ``Linear``
+    layers stored on ``cuda:0`` and ``cuda:1`` respectively. Setting
+    :attr:`devices` to ``[[0, 1], [2, 3], [4, 5]]`` will replicate
+    :attr:`network` three times with replicas stored on devices
+    ``[cuda:0, cuda:1]``, ``[cuda:2, cuda:3]``, and ``[cuda:4, cuda:5]``
+    respectively.
+
+
+    Args:
+        network (Module): modules to be replicate
+        devices (1D or 2D list of int or torch.device): CUDA devices
+        detach (bool, optional): detached replicas from the current graph.
+    """
     if not _replicatable_module(network):
         raise RuntimeError("Cannot replicate network where python modules are "
                            "childrens of ScriptModule")
 
-    devices = list(map(lambda x: _get_device_index(x, True), devices))
+    devices = _to_device_index(devices)
     num_replicas = len(devices)
 
     params = list(network.parameters())