Enforce module device at DataParallel construction time (#17129)
authorShen Li <shenli@fb.com>
Fri, 15 Feb 2019 19:09:12 +0000 (11:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Feb 2019 19:14:46 +0000 (11:14 -0800)
Summary:
closes #17065

CC douwekiela
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17129

Differential Revision: D14093353

Pulled By: mrshenli

fbshipit-source-id: 9a5a10f16e392337a7f7073223541cf69b402f82

test/test_nn.py
torch/nn/parallel/data_parallel.py

index 5b23b86..0d4ff86 100644 (file)
@@ -3351,6 +3351,43 @@ class TestNN(NNTestCase):
         out = dp.data_parallel(l, i, (0, 1))
         self.assertEqual(out, l(i))
 
+    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+    @skipIfRocm
+    def test_data_parallel_model_device(self):
+        l = nn.Linear(2, 2)
+        error_msg = "module must have its parameters and buffers on device %d "
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (0), lambda: nn.DataParallel(l))
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (0), lambda: nn.DataParallel(l.cuda(1)))
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (1),
+            lambda: nn.DataParallel(l.cuda(), device_ids=[1, 0]))
+
+        nn.DataParallel(l.cuda())
+        nn.DataParallel(l.cuda(1), device_ids=[1, 0])
+
+        s = nn.Sequential(l.cpu())
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+
+        s = nn.Sequential(deepcopy(l), l.cuda())
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+
+        s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+        self.assertRaisesRegex(
+            RuntimeError, error_msg % (1),
+            lambda: nn.DataParallel(s, device_ids=[1, 0]))
+
+        s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
+        nn.DataParallel(s)
+
+        s = nn.Sequential(l.cuda(1), deepcopy(l).cuda(1))
+        nn.DataParallel(s, device_ids=[1, 0])
+
     @unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
     @skipIfRocm
     def test_data_parallel_model_no_refcycles(self):
index 9e04dbe..c543fc2 100644 (file)
@@ -1,6 +1,7 @@
 import operator
 import torch
 import warnings
+from itertools import chain
 from ..modules import Module
 from .scatter_gather import scatter_kwargs, gather
 from .replicate import replicate
@@ -105,7 +106,7 @@ class DataParallel(Module):
     Example::
 
         >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
-        >>> output = net(input_var)
+        >>> output = net(input_var)  # input_var can be on any device, including CPU
     """
 
     # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
@@ -123,6 +124,11 @@ class DataParallel(Module):
         if output_device is None:
             output_device = device_ids[0]
 
+        if not all(t.is_cuda and t.device.index == device_ids[0]
+                   for t in chain(module.parameters(), module.buffers())):
+            raise RuntimeError("module must have its parameters and buffers "
+                               "on device %d (device_ids[0])" % device_ids[0])
+
         self.dim = dim
         self.module = module
         self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))