From: Tongzhou Wang Date: Fri, 22 Feb 2019 16:27:04 +0000 (-0800) Subject: Fix DataParallel(cpu_m).cuda() not working by checking at forward (#17363) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1143 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3d5968d3661fae5baa52b56af4535a648ca44b27;p=platform%2Fupstream%2Fpytorch.git Fix DataParallel(cpu_m).cuda() not working by checking at forward (#17363) Summary: Fixes #17362 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17363 Differential Revision: D14175151 Pulled By: soumith fbshipit-source-id: 7b7e2335d553ed2133287deeaca3f6b6254aea4a --- diff --git a/test/test_nn.py b/test/test_nn.py index 2640666..022a5a5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3377,39 +3377,76 @@ class TestNN(NNTestCase): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") @skipIfRocm def test_data_parallel_model_device(self): + r"""Test device[0] check at forward time. + """ l = nn.Linear(2, 2) - error_msg = "module must have its parameters and buffers on device %d " - self.assertRaisesRegex( - RuntimeError, error_msg % (0), lambda: nn.DataParallel(l)) - self.assertRaisesRegex( - RuntimeError, error_msg % (0), lambda: nn.DataParallel(l.cuda(1))) - self.assertRaisesRegex( - RuntimeError, error_msg % (1), - lambda: nn.DataParallel(l.cuda(), device_ids=[1, 0])) - - nn.DataParallel(l.cuda()) - nn.DataParallel(l.cuda(1), device_ids=[1, 0]) + inp = torch.randn(2, 2) + inp_cuda0 = inp.cuda(0) + inp_cuda1 = inp.cuda(1) + + error_msg = "module must have its parameters and buffers on device {}" + + @contextlib.contextmanager + def dummy_ctx_manager(): + yield + + def test(inner_m, dp_device, inp, device_ids, should_fail): + if device_ids is None: + device_ids = list(range(torch.cuda.device_count())) + + if isinstance(device_ids[0], torch.device): + expect_device = device_ids[0] + else: + expect_device = torch.device("cuda:{}".format(device_ids[0])) + + if should_fail: + def assert_correct(): + return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device)) + else: + assert_correct = dummy_ctx_manager + + # test DataParallel module + dpm = nn.DataParallel(inner_m, device_ids) + if dp_device is not None: + dpm = dpm.to(dp_device) + + with assert_correct(): + dpm(inp) + + # test functional + with assert_correct(): + nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids) + + test(l.to('cpu'), None, inp, None, should_fail=True) + test(l.cuda(1), None, inp_cuda0, None, should_fail=True) + test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True) + + test(l.cuda(), None, inp_cuda0, None, should_fail=False) + test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False) + test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False) + test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False) s = nn.Sequential(l.cpu()) - self.assertRaisesRegex( - RuntimeError, error_msg % (0), lambda: nn.DataParallel(s)) + test(s, None, inp, None, should_fail=True) + test(s, None, inp, [0, 1], should_fail=True) + test(s, None, inp, [1, 0], should_fail=True) - s = nn.Sequential(deepcopy(l), l.cuda()) - self.assertRaisesRegex( - RuntimeError, error_msg % (0), lambda: nn.DataParallel(s)) + s = nn.Sequential(deepcopy(l).cpu(), l.cuda()) + test(s, None, inp, None, should_fail=True) + test(s, None, inp, [0, 1], should_fail=True) + test(s, None, inp, [1, 0], should_fail=True) s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1)) - self.assertRaisesRegex( - RuntimeError, error_msg % (0), lambda: nn.DataParallel(s)) - self.assertRaisesRegex( - RuntimeError, error_msg % (1), - lambda: nn.DataParallel(s, device_ids=[1, 0])) + test(s, None, inp, None, should_fail=True) + test(s, None, inp, [0, 1], should_fail=True) + test(s, None, inp, [1, 0], should_fail=True) s = nn.Sequential(l.cuda(), deepcopy(l).cuda()) - nn.DataParallel(s) - - s = nn.Sequential(l.cuda(1), deepcopy(l).cuda(1)) - nn.DataParallel(s, device_ids=[1, 0]) + test(s, None, inp, None, should_fail=False) + test(s, None, inp, [0, 1], should_fail=False) + test(s, None, inp, [1, 0], should_fail=True) + test(s.cpu(), None, inp, [1, 0], should_fail=True) + test(s.cuda(1), None, inp, [1, 0], should_fail=False) @unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported") @skipIfRocm diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index c543fc2..40c9b1b 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -124,15 +124,11 @@ class DataParallel(Module): if output_device is None: output_device = device_ids[0] - if not all(t.is_cuda and t.device.index == device_ids[0] - for t in chain(module.parameters(), module.buffers())): - raise RuntimeError("module must have its parameters and buffers " - "on device %d (device_ids[0])" % device_ids[0]) - self.dim = dim self.module = module self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) + self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0])) _check_balance(self.device_ids) @@ -142,6 +138,13 @@ class DataParallel(Module): def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(self.src_device_obj, t.device)) + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) @@ -186,6 +189,16 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo if output_device is None: output_device = device_ids[0] + device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) + output_device = _get_device_index(output_device, True) + src_device_obj = torch.device("cuda:{}".format(device_ids[0])) + + for t in chain(module.parameters(), module.buffers()): + if t.device != src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(src_device_obj, t.device)) + inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) if len(device_ids) == 1: return module(*inputs[0], **module_kwargs[0])