from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
get_function_arglist, load_tests
-from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
+from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_GEQ4GPU, TEST_CUDNN, \
TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, loss_reference_fns, get_reduction, \
module = nn.Linear(10, 5).float().cuda()
input = Variable(torch.randn(2, 10).float().cuda())
expected_output = module(input).data
- replicas = dp.replicate(module, (0, 1))
- for i, replica in enumerate(replicas):
- for p in replica.parameters():
- self.assertEqual(p.get_device(), i)
- replica_input = input.cuda(i)
- self.assertEqual(replica(replica_input).data, expected_output)
+ for devices in [(0, 1), [[0], [1]]]:
+ replicas = dp.replicate(module, devices)
+ for i, replica in enumerate(replicas):
+ for p in replica.parameters():
+ self.assertEqual(p.get_device(), i)
+ replica_input = input.cuda(i)
+ self.assertEqual(replica(replica_input).data, expected_output)
+
+ @unittest.skipIf(not TEST_GEQ4GPU, "less than 4 GPUs")
+ def test_replicate_multi_gpu_module(self):
+ class MultiGpuModule(nn.Module):
+ def __init__(self):
+ super(MultiGpuModule, self).__init__()
+ self.net1 = torch.nn.Linear(10, 5).cuda(0)
+ self.net2 = torch.nn.Linear(5, 5).cuda(1)
+ self.bn = nn.BatchNorm2d(10).cuda(0)
+
+ def forward(self, x):
+ out = self.net1(x.cuda(self.net1.weight.get_device()))
+ return self.net2(out.cuda(self.net2.weight.get_device()))
+
+ module = MultiGpuModule()
+
+ input = torch.rand(2, 10).cuda(0)
+ expected_output = module(input).cpu()
+
+ for devices in ([[0, 1], [2, 3]], [[1, 0], [3, 2]]):
+ replicas = dp.replicate(module, devices)
+ for i, replica in enumerate(replicas):
+ self.assertEqual(replica.net1.weight.get_device(), 2 * i)
+ self.assertEqual(replica.net1.bias.get_device(), 2 * i)
+ self.assertEqual(replica.net2.weight.get_device(), 2 * i + 1)
+ self.assertEqual(replica.net2.bias.get_device(), 2 * i + 1)
+ self.assertEqual(replica.bn.running_mean.get_device(), 2 * i)
+ self.assertEqual(replica.bn.running_var.get_device(), 2 * i)
+ self.assertEqual(
+ replica.bn.num_batches_tracked.get_device(), 2 * i)
+
+ replica_input = input.cuda(2 * i)
+ replica_output = replica(replica_input)
+ self.assertEqual(replica_output.get_device(), 2 * i + 1)
+ self.assertEqual(replica_output.cpu(), expected_output)
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_replicate_device_indices(self):
+ from torch.nn.parallel.replicate import _to_device_index as f
+
+ self.assertEqual(
+ f([['cuda:0', 'cuda:1', 'cuda:2'],
+ ['cuda:4', 'cuda:3', 'cuda:6']]),
+ [[0, 1, 2], [4, 3, 6]])
+
+ self.assertEqual(f(('cuda:0', 'cuda:1', 'cuda:2')), [0, 1, 2])
+
+ self.assertEqual(
+ len(set([0, 1, 2]).intersection(f({'cuda:0', 'cuda:1', 'cuda:2'}))),
+ 3)
+ self.assertEqual(
+ f([['cuda:0'], ['cuda:1'], ['cuda:2']]), [[0], [1], [2]])
+
+ msg = "empty device list"
+ for devices in (None, (), [], [[]]):
+ with self.assertRaisesRegex(RuntimeError, msg):
+ f(devices)
+
+ msg = "unidentical number of devices"
+ for devices in ([[0, 1], [2]], [[0], [1, 2]]):
+ with self.assertRaisesRegex(AssertionError, msg):
+ f(devices)
+
+ msg = "shared by multiple replicas"
+ for devices in ([[0, 1], [1, 2]], [[0], [1], [0]]):
+ with self.assertRaisesRegex(AssertionError, msg):
+ f(devices)
+
+ msg = "Duplicated device ids"
+ for devices in ([[0, 1, 2, 1]], [0, 1, 1], [0, 0]):
+ with self.assertRaisesRegex(AssertionError, msg):
+ f(devices)
+
+ @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+ def test_replicate_tensor_grouping_multi_gpu(self):
+ from torch.nn.parallel.replicate import _group_by_device
+
+ a = torch.Tensor(1).cuda(0)
+ b = torch.Tensor(2).cuda(0)
+ c = torch.Tensor(3).cuda(1)
+ d = torch.Tensor(4).cuda(0)
+ e = torch.Tensor(5).cuda(1)
+
+ tensors = [a, b, c, d, e]
+ for devices in ([[0, 1], [2, 3]], [[1, 4, 0], [3, 5, 2]]):
+ grouped_tensors, grouped_devices, original_index = \
+ _group_by_device(tensors, devices)
+
+ self.assertEqual(grouped_tensors, [[a, b, d], [c, e]])
+ self.assertEqual(grouped_devices, [[0, 2], [1, 3]])
+ self.assertEqual(original_index, [[0, 1, 3], [2, 4]])
+
+ msg = "missing from devices"
+ for devices in ([[0, 2], [1, 3]], [[1, 2], [0, 3]], [[2, 3], [0, 1]]):
+ with self.assertRaisesRegex(AssertionError, msg):
+ grouped_tensors, grouped_devices, original_index = \
+ _group_by_device(tensors, devices)
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_replicate_tensor_grouping(self):
+ from torch.nn.parallel.replicate import _group_by_device
+
+ a = torch.Tensor(1).cuda(0)
+ b = torch.Tensor(2).cuda(0)
+ c = torch.Tensor(3).cuda(0)
+
+ tensors = [a, b, c]
+
+ grouped_tensors, grouped_devices, original_index = \
+ _group_by_device(tensors, [0, 1])
+
+ self.assertEqual(grouped_tensors, [[a, b, c]])
+ self.assertEqual(grouped_devices, [[0, 1]])
+ self.assertEqual(original_index, [[0, 1, 2]])
+
+ @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+ def test_replicate_reshape(self):
+ from torch.nn.parallel.replicate import _broadcast_coalesced_reshape
+
+ a = torch.Tensor(1).cuda(0)
+ b = torch.Tensor(2).cuda(0)
+ c = torch.Tensor(3).cuda(1)
+ d = torch.Tensor(4).cuda(0)
+ e = torch.Tensor(5).cuda(1)
+
+ tensors = [a, b, c, d, e]
+ outputs = _broadcast_coalesced_reshape(tensors, [[0, 1], [1, 0]])
+
+ self.assertEqual(len(outputs), 2)
+ self.assertEqual(outputs[0], [a, b, c, d, e])
+ self.assertEqual(
+ outputs[1], [a.cuda(1), b.cuda(1), c.cuda(0), d.cuda(1), e.cuda(0)])
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_buffers(self):
net = nn.Module()
net.bn = nn.BatchNorm2d(10)
net.cuda()
- replicas = dp.replicate(net, (0, 1))
- for i, replica in enumerate(replicas):
- self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
- self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
- self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
+ for devices in [(0, 1), [[0], [1]]]:
+ replicas = dp.replicate(net, devices)
+ for i, replica in enumerate(replicas):
+ self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
+ self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
+ self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
return True
+def _to_device_index(devices):
+ if not devices:
+ raise RuntimeError("Cannot replicate using an empty device list.")
+
+ if isinstance(devices, list) and isinstance(devices[0], list):
+ device_ids = []
+ seen = set()
+ for i, replica_devs in enumerate(devices):
+ assert len(replica_devs) == len(devices[0]), (
+ "Cannot replicate to unidentical number of devices, but got "
+ "device list {} and {} for replica {} and {}."
+ ).format(devices[0], devices[i], 0, i)
+
+ assert len(seen.intersection(replica_devs)) == 0, (
+ "Devices {} are shared by multiple replicas."
+ ).format(seen.intersection(replica_devs))
+ seen.update(replica_devs)
+
+ device_ids.append(_to_device_index(replica_devs))
+ return device_ids
+ else:
+ assert len(devices) == len(set(devices)), (
+ "Duplicated device ids {}."
+ ).format(devices)
+
+ return list(map(lambda x: _get_device_index(x, True), devices))
+
+
def _build_param_dict(modules, module_copies, module_indices):
param_dict = {}
for module in modules:
replica._copy_method(method_name, param_list, module)
+# Group tensors on the same device together, which can later be broadcast to
+# a list of devices. For example,consider 5 tensors on 2 devices
+# a = torch.Tensor(0).cuda(0)
+# b = torch.Tensor(0).cuda(0)
+# c = torch.Tensor(0).cuda(1)
+# d = torch.Tensor(0).cuda(0)
+# e = torch.Tensor(0).cuda(1).
+# Let inputs be
+# tensors = [a, b, c, d, e] and
+# devices = [[0, 1], [2, 3]].
+# Then, outputs will be:
+# grouped_tensors = [[a, b, d], [c, e]],
+# grouped_devices = [[0, 2], [1, 3]],
+# original_index = [[0, 1, 3], [2, 4]],
+# meaning that grouped_tensors[i] will be broadcast to grouped_devices[i].
+def _group_by_device(tensors, devices):
+ if isinstance(devices[0], list):
+ # all tensor devices must appear in devices[0]
+ missing_devs = [t.device.index for t in tensors
+ if t.device.index not in devices[0]]
+ assert not missing_devs, (
+ "tensor devices {} are missing from devices[0] {}."
+ ).format(missing_devs, devices[0])
+
+ # device id to output group index, this is necessary when `tensors` only
+ # use a subset of devices in `devices[0]`
+ dev_to_group_idx = {}
+ for t in tensors:
+ if t.device.index not in dev_to_group_idx:
+ dev_to_group_idx[t.device.index] = len(dev_to_group_idx)
+
+ # Group tensors by devices and remember each tensor's original index.
+ # The original_index helps to recover the original input tensor order
+ # from grouped tensors.
+ grouped_tensors = [[] for _ in range(len(dev_to_group_idx))]
+ original_index = [[] for _ in range(len(dev_to_group_idx))]
+ for i, t in enumerate(tensors):
+ group_id = dev_to_group_idx[t.device.index]
+ original_index[group_id].append(i)
+ grouped_tensors[group_id].append(t)
+
+ # group devices together if they should be in the same broadcast call
+ grouped_devices = [[] for _ in range(len(dev_to_group_idx))]
+ transpose = list(zip(*devices))
+ for row in transpose:
+ if row[0] in dev_to_group_idx:
+ grouped_devices[dev_to_group_idx[row[0]]] = list(row)
+
+ return grouped_tensors, grouped_devices, original_index
+ else:
+ return [tensors], [devices], [list(range(len(tensors)))]
+
+
+# Return len(devices) replicas of input tensors. If input tensors reside on
+# multiple GPUs, devices must be a 2D list with devices[0] matching input
+# tensors' devices. For example,consider 5 tensors on 2 devices
+# a = torch.Tensor(0).cuda(0)
+# b = torch.Tensor(0).cuda(0)
+# c = torch.Tensor(0).cuda(1)
+# d = torch.Tensor(0).cuda(0)
+# e = torch.Tensor(0).cuda(1).
+# Let inputs be
+# tensors = [a, b, c, d, e] and
+# devices = [[0, 1], [2, 3]].
+#
+# The output will be a 2D list of tensors:
+# [[a0, b0, c0, d0, e0],
+# [a1, b1, c1, d1, e1]], where
+# a0, b0, d0 are on device 0
+# a1, b1, d1 are on device 2
+# c0, e0 are on device 1
+# c1, e1 are on device 3
+#
+# This example will be used throughout the implementation of this function.
def _broadcast_coalesced_reshape(tensors, devices, detach=False):
from ._functions import Broadcast
- if detach:
- return comm.broadcast_coalesced(tensors, devices)
- else:
- # Use the autograd function to broadcast if not detach
- if len(tensors) > 0:
- tensor_copies = Broadcast.apply(devices, *tensors)
- return [tensor_copies[i:i + len(tensors)]
- for i in range(0, len(tensor_copies), len(tensors))]
+
+ # a triply-nested list of 1) broadcast group, 2) tensor list replica,
+ # 3) tensors on the same device.
+ grouped_replicas = []
+ grouped_tensors, grouped_devices, original_index = \
+ _group_by_device(tensors, devices)
+ # For the example input described above, we have
+ # grouped_tensors =[[a, b, d], [c, e]]
+ # grouped_devices = [[0, 2], [1, 3]]
+ # original_index = [[0, 1, 3], [2, 4]]
+ for tensor_group, device_group in zip(grouped_tensors, grouped_devices):
+ if detach:
+ grouped_replicas.append(
+ comm.broadcast_coalesced(tensor_group, device_group))
else:
- return []
+ if len(tensor_group) > 0:
+ # Use the autograd function to broadcast if not detach
+ tensor_copies = Broadcast.apply(device_group, *tensor_group)
+ grouped_replicas.append(
+ [tensor_copies[i:i + len(tensor_group)]
+ for i in range(
+ 0, len(tensor_copies), len(tensor_group))])
+ else:
+ grouped_replicas.append([])
+
+ if isinstance(devices[0], list):
+ # convert the triply-nested list into a doubly-nested list of 1) replica
+ # 2) tensors in the same replica (can be on different devices)
+ #
+ # For the example input described above, we have
+ # grouped_replicas = [
+ # [[a0, b0, d0], # on device 0
+ # [a1, b1, d1]], # on device 2
+ # [[c0, e0], # on device 1
+ # [c1, e1]] # on device 3
+ # ]
+ #
+ # The code below re-organize elements in grouped_replicas to the
+ # expected form:
+ # [[a0, b0, c0, d0, e0],
+ # [a1, b1, c1, d1, e1]].
+ transpose = [0 for _ in tensors]
+ for g_idx in range(len(original_index)):
+ for t_idx in range(len(original_index[g_idx])):
+ # g_idx is the broadcast group index.
+ # t_idx is the tensor's index in a replica within a group.
+ # Tensors in grouped_replicas[g_idx, :, t_idx] are replicas of
+ # input tensor[original_index[g_idx][t_idx]]. Retrieve the
+ # column and add it as the original_index[g_idx][t_idx]'s row in
+ # transpose.
+ transpose[original_index[g_idx][t_idx]] = \
+ [replica[t_idx] for replica in grouped_replicas[g_idx]]
+
+ # transpose the result to stay consistent with the 1D devices case.
+ return list(zip(*transpose))
+ else:
+ return grouped_replicas[0]
def replicate(network, devices, detach=False):
+ r"""Replicate the input :attr:`network` to given :attr:`devices`. If
+ :attr:`network` resides on CPU or a single GPU, :attr:`devices` must be a 1D
+ list of destination devices. If :attr:`network` resides on multiple GPUs,
+ :attr:`devices` must be satisfy the following conditions:
+
+ 1. :attr:`devices` must be a 2D list,
+ 2. ``devices[0]`` must match the :attr:`network`'s devices, in any order.
+ 3. All ``devices[i]`` must have the same length.
+
+ For example, :attr:`network` is a ``Sequential`` module with two ``Linear``
+ layers stored on ``cuda:0`` and ``cuda:1`` respectively. Setting
+ :attr:`devices` to ``[[0, 1], [2, 3], [4, 5]]`` will replicate
+ :attr:`network` three times with replicas stored on devices
+ ``[cuda:0, cuda:1]``, ``[cuda:2, cuda:3]``, and ``[cuda:4, cuda:5]``
+ respectively.
+
+
+ Args:
+ network (Module): modules to be replicate
+ devices (1D or 2D list of int or torch.device): CUDA devices
+ detach (bool, optional): detached replicas from the current graph.
+ """
if not _replicatable_module(network):
raise RuntimeError("Cannot replicate network where python modules are "
"childrens of ScriptModule")
- devices = list(map(lambda x: _get_device_index(x, True), devices))
+ devices = _to_device_index(devices)
num_replicas = len(devices)
params = list(network.parameters())