Fix DataParallel(cpu_m).cuda() not working by checking at forward (#17363)
authorTongzhou Wang <tongzhou.wang.1994@gmail.com>
Fri, 22 Feb 2019 16:27:04 +0000 (08:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 16:31:36 +0000 (08:31 -0800)
Summary:
Fixes #17362
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17363

Differential Revision: D14175151

Pulled By: soumith

fbshipit-source-id: 7b7e2335d553ed2133287deeaca3f6b6254aea4a

test/test_nn.py
torch/nn/parallel/data_parallel.py

index 2640666..022a5a5 100644 (file)
@@ -3377,39 +3377,76 @@ class TestNN(NNTestCase):
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     @skipIfRocm
     def test_data_parallel_model_device(self):
+        r"""Test device[0] check at forward time.
+        """
         l = nn.Linear(2, 2)
-        error_msg = "module must have its parameters and buffers on device %d "
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (0), lambda: nn.DataParallel(l))
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (0), lambda: nn.DataParallel(l.cuda(1)))
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (1),
-            lambda: nn.DataParallel(l.cuda(), device_ids=[1, 0]))
-
-        nn.DataParallel(l.cuda())
-        nn.DataParallel(l.cuda(1), device_ids=[1, 0])
+        inp = torch.randn(2, 2)
+        inp_cuda0 = inp.cuda(0)
+        inp_cuda1 = inp.cuda(1)
+
+        error_msg = "module must have its parameters and buffers on device {}"
+
+        @contextlib.contextmanager
+        def dummy_ctx_manager():
+            yield
+
+        def test(inner_m, dp_device, inp, device_ids, should_fail):
+            if device_ids is None:
+                device_ids = list(range(torch.cuda.device_count()))
+
+            if isinstance(device_ids[0], torch.device):
+                expect_device = device_ids[0]
+            else:
+                expect_device = torch.device("cuda:{}".format(device_ids[0]))
+
+            if should_fail:
+                def assert_correct():
+                    return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device))
+            else:
+                assert_correct = dummy_ctx_manager
+
+            # test DataParallel module
+            dpm = nn.DataParallel(inner_m, device_ids)
+            if dp_device is not None:
+                dpm = dpm.to(dp_device)
+
+            with assert_correct():
+                dpm(inp)
+
+            # test functional
+            with assert_correct():
+                nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)
+
+        test(l.to('cpu'), None, inp, None, should_fail=True)
+        test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
+        test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)
+
+        test(l.cuda(), None, inp_cuda0, None, should_fail=False)
+        test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False)
+        test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
+        test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False)
 
         s = nn.Sequential(l.cpu())
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+        test(s, None, inp, None, should_fail=True)
+        test(s, None, inp, [0, 1], should_fail=True)
+        test(s, None, inp, [1, 0], should_fail=True)
 
-        s = nn.Sequential(deepcopy(l), l.cuda())
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+        s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
+        test(s, None, inp, None, should_fail=True)
+        test(s, None, inp, [0, 1], should_fail=True)
+        test(s, None, inp, [1, 0], should_fail=True)
 
         s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
-        self.assertRaisesRegex(
-            RuntimeError, error_msg % (1),
-            lambda: nn.DataParallel(s, device_ids=[1, 0]))
+        test(s, None, inp, None, should_fail=True)
+        test(s, None, inp, [0, 1], should_fail=True)
+        test(s, None, inp, [1, 0], should_fail=True)
 
         s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
-        nn.DataParallel(s)
-
-        s = nn.Sequential(l.cuda(1), deepcopy(l).cuda(1))
-        nn.DataParallel(s, device_ids=[1, 0])
+        test(s, None, inp, None, should_fail=False)
+        test(s, None, inp, [0, 1], should_fail=False)
+        test(s, None, inp, [1, 0], should_fail=True)
+        test(s.cpu(), None, inp, [1, 0], should_fail=True)
+        test(s.cuda(1), None, inp, [1, 0], should_fail=False)
 
     @unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
     @skipIfRocm
index c543fc2..40c9b1b 100644 (file)
@@ -124,15 +124,11 @@ class DataParallel(Module):
         if output_device is None:
             output_device = device_ids[0]
 
-        if not all(t.is_cuda and t.device.index == device_ids[0]
-                   for t in chain(module.parameters(), module.buffers())):
-            raise RuntimeError("module must have its parameters and buffers "
-                               "on device %d (device_ids[0])" % device_ids[0])
-
         self.dim = dim
         self.module = module
         self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
         self.output_device = _get_device_index(output_device, True)
+        self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
 
         _check_balance(self.device_ids)
 
@@ -142,6 +138,13 @@ class DataParallel(Module):
     def forward(self, *inputs, **kwargs):
         if not self.device_ids:
             return self.module(*inputs, **kwargs)
+
+        for t in chain(self.module.parameters(), self.module.buffers()):
+            if t.device != self.src_device_obj:
+                raise RuntimeError("module must have its parameters and buffers "
+                                   "on device {} (device_ids[0]) but found one of "
+                                   "them on device: {}".format(self.src_device_obj, t.device))
+
         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
         if len(self.device_ids) == 1:
             return self.module(*inputs[0], **kwargs[0])
@@ -186,6 +189,16 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
     if output_device is None:
         output_device = device_ids[0]
 
+    device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+    output_device = _get_device_index(output_device, True)
+    src_device_obj = torch.device("cuda:{}".format(device_ids[0]))
+
+    for t in chain(module.parameters(), module.buffers()):
+        if t.device != src_device_obj:
+            raise RuntimeError("module must have its parameters and buffers "
+                               "on device {} (device_ids[0]) but found one of "
+                               "them on device: {}".format(src_device_obj, t.device))
+
     inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
     if len(device_ids) == 1:
         return module(*inputs[0], **module_kwargs[0])