From 344acaa0caa91d3acb90b2e1f89be53d153dadc1 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Tue, 16 Apr 2019 09:35:36 -0700 Subject: [PATCH] Revert replicate.py to disallow replicating multi-device modules (#19278) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19278 Based on discussion in https://github.com/pytorch/pytorch/pull/19278 and https://github.com/pytorch/pytorch/pull/18687, changes to replicate.py will be reverted to disallow replicating multi-device modules. Reviewed By: pietern Differential Revision: D14940018 fbshipit-source-id: 7504c0f4325c2639264c52dcbb499e61c9ad2c26 --- test/common_cuda.py | 1 - test/test_nn.py | 139 +----------------------------- torch/nn/parallel/replicate.py | 191 +++-------------------------------------- 3 files changed, 13 insertions(+), 318 deletions(-) diff --git a/test/common_cuda.py b/test/common_cuda.py index 7a7110e..19750df 100644 --- a/test/common_cuda.py +++ b/test/common_cuda.py @@ -7,7 +7,6 @@ from common_utils import TEST_WITH_ROCM, TEST_NUMBA TEST_CUDA = torch.cuda.is_available() TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 -TEST_GEQ4GPU = TEST_CUDA and torch.cuda.device_count() >= 4 CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0") # note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) diff --git a/test/test_nn.py b/test/test_nn.py index d489e50..2366806 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -30,8 +30,7 @@ from torch.nn.parallel._functions import Broadcast from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \ get_function_arglist, load_tests -from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_GEQ4GPU, TEST_CUDNN, \ - TEST_CUDNN_VERSION +from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \ module_tests, criterion_tests, loss_reference_fns, get_reduction, \ get_weight, smoothl1loss_reference, kldivloss_reference, \ @@ -3709,7 +3708,7 @@ class TestNN(NNTestCase): module = nn.Linear(10, 5).float().cuda() input = Variable(torch.randn(2, 10).float().cuda()) expected_output = module(input).data - for devices in [(0, 1), [[0], [1]]]: + for devices in [(0, 1), [0, 1]]: replicas = dp.replicate(module, devices) for i, replica in enumerate(replicas): for p in replica.parameters(): @@ -3717,144 +3716,12 @@ class TestNN(NNTestCase): replica_input = input.cuda(i) self.assertEqual(replica(replica_input).data, expected_output) - @unittest.skipIf(not TEST_GEQ4GPU, "less than 4 GPUs") - def test_replicate_multi_gpu_module(self): - class MultiGpuModule(nn.Module): - def __init__(self): - super(MultiGpuModule, self).__init__() - self.net1 = torch.nn.Linear(10, 5).cuda(0) - self.net2 = torch.nn.Linear(5, 5).cuda(1) - self.bn = nn.BatchNorm2d(10).cuda(0) - - def forward(self, x): - out = self.net1(x.cuda(self.net1.weight.get_device())) - return self.net2(out.cuda(self.net2.weight.get_device())) - - module = MultiGpuModule() - - input = torch.rand(2, 10).cuda(0) - expected_output = module(input).cpu() - - for devices in ([[0, 1], [2, 3]], [[1, 0], [3, 2]]): - replicas = dp.replicate(module, devices) - for i, replica in enumerate(replicas): - self.assertEqual(replica.net1.weight.get_device(), 2 * i) - self.assertEqual(replica.net1.bias.get_device(), 2 * i) - self.assertEqual(replica.net2.weight.get_device(), 2 * i + 1) - self.assertEqual(replica.net2.bias.get_device(), 2 * i + 1) - self.assertEqual(replica.bn.running_mean.get_device(), 2 * i) - self.assertEqual(replica.bn.running_var.get_device(), 2 * i) - self.assertEqual( - replica.bn.num_batches_tracked.get_device(), 2 * i) - - replica_input = input.cuda(2 * i) - replica_output = replica(replica_input) - self.assertEqual(replica_output.get_device(), 2 * i + 1) - self.assertEqual(replica_output.cpu(), expected_output) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_replicate_device_indices(self): - from torch.nn.parallel.replicate import _to_device_index as f - - self.assertEqual( - f([['cuda:0', 'cuda:1', 'cuda:2'], - ['cuda:4', 'cuda:3', 'cuda:6']]), - [[0, 1, 2], [4, 3, 6]]) - - self.assertEqual(f(('cuda:0', 'cuda:1', 'cuda:2')), [0, 1, 2]) - - self.assertEqual( - len(set([0, 1, 2]).intersection(f({'cuda:0', 'cuda:1', 'cuda:2'}))), - 3) - self.assertEqual( - f([['cuda:0'], ['cuda:1'], ['cuda:2']]), [[0], [1], [2]]) - - msg = "empty device list" - for devices in (None, (), [], [[]]): - with self.assertRaisesRegex(RuntimeError, msg): - f(devices) - - msg = "unidentical number of devices" - for devices in ([[0, 1], [2]], [[0], [1, 2]]): - with self.assertRaisesRegex(AssertionError, msg): - f(devices) - - msg = "shared by multiple replicas" - for devices in ([[0, 1], [1, 2]], [[0], [1], [0]]): - with self.assertRaisesRegex(AssertionError, msg): - f(devices) - - msg = "Duplicated device ids" - for devices in ([[0, 1, 2, 1]], [0, 1, 1], [0, 0]): - with self.assertRaisesRegex(AssertionError, msg): - f(devices) - - @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - def test_replicate_tensor_grouping_multi_gpu(self): - from torch.nn.parallel.replicate import _group_by_device - - a = torch.Tensor(1).cuda(0) - b = torch.Tensor(2).cuda(0) - c = torch.Tensor(3).cuda(1) - d = torch.Tensor(4).cuda(0) - e = torch.Tensor(5).cuda(1) - - tensors = [a, b, c, d, e] - for devices in ([[0, 1], [2, 3]], [[1, 4, 0], [3, 5, 2]]): - grouped_tensors, grouped_devices, original_index = \ - _group_by_device(tensors, devices) - - self.assertEqual(grouped_tensors, [[a, b, d], [c, e]]) - self.assertEqual(grouped_devices, [[0, 2], [1, 3]]) - self.assertEqual(original_index, [[0, 1, 3], [2, 4]]) - - msg = "missing from devices" - for devices in ([[0, 2], [1, 3]], [[1, 2], [0, 3]], [[2, 3], [0, 1]]): - with self.assertRaisesRegex(AssertionError, msg): - grouped_tensors, grouped_devices, original_index = \ - _group_by_device(tensors, devices) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_replicate_tensor_grouping(self): - from torch.nn.parallel.replicate import _group_by_device - - a = torch.Tensor(1).cuda(0) - b = torch.Tensor(2).cuda(0) - c = torch.Tensor(3).cuda(0) - - tensors = [a, b, c] - - grouped_tensors, grouped_devices, original_index = \ - _group_by_device(tensors, [0, 1]) - - self.assertEqual(grouped_tensors, [[a, b, c]]) - self.assertEqual(grouped_devices, [[0, 1]]) - self.assertEqual(original_index, [[0, 1, 2]]) - - @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - def test_replicate_reshape(self): - from torch.nn.parallel.replicate import _broadcast_coalesced_reshape - - a = torch.Tensor(1).cuda(0) - b = torch.Tensor(2).cuda(0) - c = torch.Tensor(3).cuda(1) - d = torch.Tensor(4).cuda(0) - e = torch.Tensor(5).cuda(1) - - tensors = [a, b, c, d, e] - outputs = _broadcast_coalesced_reshape(tensors, [[0, 1], [1, 0]]) - - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0], [a, b, c, d, e]) - self.assertEqual( - outputs[1], [a.cuda(1), b.cuda(1), c.cuda(0), d.cuda(1), e.cuda(0)]) - @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_replicate_buffers(self): net = nn.Module() net.bn = nn.BatchNorm2d(10) net.cuda() - for devices in [(0, 1), [[0], [1]]]: + for devices in [(0, 1), [0, 1]]: replicas = dp.replicate(net, devices) for i, replica in enumerate(replicas): self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device') diff --git a/torch/nn/parallel/replicate.py b/torch/nn/parallel/replicate.py index a68ade4..cc270a1 100644 --- a/torch/nn/parallel/replicate.py +++ b/torch/nn/parallel/replicate.py @@ -56,34 +56,6 @@ def _replicatable_module(module, memo=None): return True -def _to_device_index(devices): - if not devices: - raise RuntimeError("Cannot replicate using an empty device list.") - - if isinstance(devices, list) and isinstance(devices[0], list): - device_ids = [] - seen = set() - for i, replica_devs in enumerate(devices): - assert len(replica_devs) == len(devices[0]), ( - "Cannot replicate to unidentical number of devices, but got " - "device list {} and {} for replica {} and {}." - ).format(devices[0], devices[i], 0, i) - - assert len(seen.intersection(replica_devs)) == 0, ( - "Devices {} are shared by multiple replicas." - ).format(seen.intersection(replica_devs)) - seen.update(replica_devs) - - device_ids.append(_to_device_index(replica_devs)) - return device_ids - else: - assert len(devices) == len(set(devices)), ( - "Duplicated device ids {}." - ).format(devices) - - return list(map(lambda x: _get_device_index(x, True), devices)) - - def _build_param_dict(modules, module_copies, module_indices): param_dict = {} for module in modules: @@ -111,169 +83,26 @@ def _copy_scriptmodule_methods(modules, module_copies, module_indices): replica._copy_method(method_name, param_list, module) -# Group tensors on the same device together, which can later be broadcast to -# a list of devices. For example,consider 5 tensors on 2 devices -# a = torch.Tensor(0).cuda(0) -# b = torch.Tensor(0).cuda(0) -# c = torch.Tensor(0).cuda(1) -# d = torch.Tensor(0).cuda(0) -# e = torch.Tensor(0).cuda(1). -# Let inputs be -# tensors = [a, b, c, d, e] and -# devices = [[0, 1], [2, 3]]. -# Then, outputs will be: -# grouped_tensors = [[a, b, d], [c, e]], -# grouped_devices = [[0, 2], [1, 3]], -# original_index = [[0, 1, 3], [2, 4]], -# meaning that grouped_tensors[i] will be broadcast to grouped_devices[i]. -def _group_by_device(tensors, devices): - if isinstance(devices[0], list): - # all tensor devices must appear in devices[0] - missing_devs = [t.device.index for t in tensors - if t.device.index not in devices[0]] - assert not missing_devs, ( - "tensor devices {} are missing from devices[0] {}." - ).format(missing_devs, devices[0]) - - # device id to output group index, this is necessary when `tensors` only - # use a subset of devices in `devices[0]` - dev_to_group_idx = {} - for t in tensors: - if t.device.index not in dev_to_group_idx: - dev_to_group_idx[t.device.index] = len(dev_to_group_idx) - - # Group tensors by devices and remember each tensor's original index. - # The original_index helps to recover the original input tensor order - # from grouped tensors. - grouped_tensors = [[] for _ in range(len(dev_to_group_idx))] - original_index = [[] for _ in range(len(dev_to_group_idx))] - for i, t in enumerate(tensors): - group_id = dev_to_group_idx[t.device.index] - original_index[group_id].append(i) - grouped_tensors[group_id].append(t) - - # group devices together if they should be in the same broadcast call - grouped_devices = [[] for _ in range(len(dev_to_group_idx))] - transpose = list(zip(*devices)) - for row in transpose: - if row[0] in dev_to_group_idx: - grouped_devices[dev_to_group_idx[row[0]]] = list(row) - - return grouped_tensors, grouped_devices, original_index - else: - return [tensors], [devices], [list(range(len(tensors)))] - - -# Return len(devices) replicas of input tensors. If input tensors reside on -# multiple GPUs, devices must be a 2D list with devices[0] matching input -# tensors' devices. For example,consider 5 tensors on 2 devices -# a = torch.Tensor(0).cuda(0) -# b = torch.Tensor(0).cuda(0) -# c = torch.Tensor(0).cuda(1) -# d = torch.Tensor(0).cuda(0) -# e = torch.Tensor(0).cuda(1). -# Let inputs be -# tensors = [a, b, c, d, e] and -# devices = [[0, 1], [2, 3]]. -# -# The output will be a 2D list of tensors: -# [[a0, b0, c0, d0, e0], -# [a1, b1, c1, d1, e1]], where -# a0, b0, d0 are on device 0 -# a1, b1, d1 are on device 2 -# c0, e0 are on device 1 -# c1, e1 are on device 3 -# -# This example will be used throughout the implementation of this function. def _broadcast_coalesced_reshape(tensors, devices, detach=False): from ._functions import Broadcast - - # a triply-nested list of 1) broadcast group, 2) tensor list replica, - # 3) tensors on the same device. - grouped_replicas = [] - grouped_tensors, grouped_devices, original_index = \ - _group_by_device(tensors, devices) - # For the example input described above, we have - # grouped_tensors =[[a, b, d], [c, e]] - # grouped_devices = [[0, 2], [1, 3]] - # original_index = [[0, 1, 3], [2, 4]] - for tensor_group, device_group in zip(grouped_tensors, grouped_devices): - if detach: - grouped_replicas.append( - comm.broadcast_coalesced(tensor_group, device_group)) - else: - if len(tensor_group) > 0: - # Use the autograd function to broadcast if not detach - tensor_copies = Broadcast.apply(device_group, *tensor_group) - grouped_replicas.append( - [tensor_copies[i:i + len(tensor_group)] - for i in range( - 0, len(tensor_copies), len(tensor_group))]) - else: - grouped_replicas.append([]) - - if isinstance(devices[0], list): - # convert the triply-nested list into a doubly-nested list of 1) replica - # 2) tensors in the same replica (can be on different devices) - # - # For the example input described above, we have - # grouped_replicas = [ - # [[a0, b0, d0], # on device 0 - # [a1, b1, d1]], # on device 2 - # [[c0, e0], # on device 1 - # [c1, e1]] # on device 3 - # ] - # - # The code below re-organize elements in grouped_replicas to the - # expected form: - # [[a0, b0, c0, d0, e0], - # [a1, b1, c1, d1, e1]]. - transpose = [0 for _ in tensors] - for g_idx in range(len(original_index)): - for t_idx in range(len(original_index[g_idx])): - # g_idx is the broadcast group index. - # t_idx is the tensor's index in a replica within a group. - # Tensors in grouped_replicas[g_idx, :, t_idx] are replicas of - # input tensor[original_index[g_idx][t_idx]]. Retrieve the - # column and add it as the original_index[g_idx][t_idx]'s row in - # transpose. - transpose[original_index[g_idx][t_idx]] = \ - [replica[t_idx] for replica in grouped_replicas[g_idx]] - - # transpose the result to stay consistent with the 1D devices case. - return list(zip(*transpose)) + if detach: + return comm.broadcast_coalesced(tensors, devices) else: - return grouped_replicas[0] + # Use the autograd function to broadcast if not detach + if len(tensors) > 0: + tensor_copies = Broadcast.apply(devices, *tensors) + return [tensor_copies[i:i + len(tensors)] + for i in range(0, len(tensor_copies), len(tensors))] + else: + return [] def replicate(network, devices, detach=False): - r"""Replicate the input :attr:`network` to given :attr:`devices`. If - :attr:`network` resides on CPU or a single GPU, :attr:`devices` must be a 1D - list of destination devices. If :attr:`network` resides on multiple GPUs, - :attr:`devices` must be satisfy the following conditions: - - 1. :attr:`devices` must be a 2D list, - 2. ``devices[0]`` must match the :attr:`network`'s devices, in any order. - 3. All ``devices[i]`` must have the same length. - - For example, :attr:`network` is a ``Sequential`` module with two ``Linear`` - layers stored on ``cuda:0`` and ``cuda:1`` respectively. Setting - :attr:`devices` to ``[[0, 1], [2, 3], [4, 5]]`` will replicate - :attr:`network` three times with replicas stored on devices - ``[cuda:0, cuda:1]``, ``[cuda:2, cuda:3]``, and ``[cuda:4, cuda:5]`` - respectively. - - - Args: - network (Module): modules to be replicate - devices (1D or 2D list of int or torch.device): CUDA devices - detach (bool, optional): detached replicas from the current graph. - """ if not _replicatable_module(network): raise RuntimeError("Cannot replicate network where python modules are " "childrens of ScriptModule") - devices = _to_device_index(devices) + devices = list(map(lambda x: _get_device_index(x, True), devices)) num_replicas = len(devices) params = list(network.parameters()) -- 2.7.4