@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel_model_device(self):
+ r"""Test device[0] check at forward time.
+ """
l = nn.Linear(2, 2)
- error_msg = "module must have its parameters and buffers on device %d "
- self.assertRaisesRegex(
- RuntimeError, error_msg % (0), lambda: nn.DataParallel(l))
- self.assertRaisesRegex(
- RuntimeError, error_msg % (0), lambda: nn.DataParallel(l.cuda(1)))
- self.assertRaisesRegex(
- RuntimeError, error_msg % (1),
- lambda: nn.DataParallel(l.cuda(), device_ids=[1, 0]))
-
- nn.DataParallel(l.cuda())
- nn.DataParallel(l.cuda(1), device_ids=[1, 0])
+ inp = torch.randn(2, 2)
+ inp_cuda0 = inp.cuda(0)
+ inp_cuda1 = inp.cuda(1)
+
+ error_msg = "module must have its parameters and buffers on device {}"
+
+ @contextlib.contextmanager
+ def dummy_ctx_manager():
+ yield
+
+ def test(inner_m, dp_device, inp, device_ids, should_fail):
+ if device_ids is None:
+ device_ids = list(range(torch.cuda.device_count()))
+
+ if isinstance(device_ids[0], torch.device):
+ expect_device = device_ids[0]
+ else:
+ expect_device = torch.device("cuda:{}".format(device_ids[0]))
+
+ if should_fail:
+ def assert_correct():
+ return self.assertRaisesRegex(RuntimeError, error_msg.format(expect_device))
+ else:
+ assert_correct = dummy_ctx_manager
+
+ # test DataParallel module
+ dpm = nn.DataParallel(inner_m, device_ids)
+ if dp_device is not None:
+ dpm = dpm.to(dp_device)
+
+ with assert_correct():
+ dpm(inp)
+
+ # test functional
+ with assert_correct():
+ nn.parallel.data_parallel(inner_m.to(dp_device), inp, device_ids)
+
+ test(l.to('cpu'), None, inp, None, should_fail=True)
+ test(l.cuda(1), None, inp_cuda0, None, should_fail=True)
+ test(l.cuda(), None, inp_cuda0, [1, 0], should_fail=True)
+
+ test(l.cuda(), None, inp_cuda0, None, should_fail=False)
+ test(l.cpu(), 'cuda', inp_cuda0, None, should_fail=False)
+ test(l.cuda(1), None, inp_cuda1, [1, 0], should_fail=False)
+ test(l.cpu(), 'cuda:1', inp_cuda1, [1, 0], should_fail=False)
s = nn.Sequential(l.cpu())
- self.assertRaisesRegex(
- RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+ test(s, None, inp, None, should_fail=True)
+ test(s, None, inp, [0, 1], should_fail=True)
+ test(s, None, inp, [1, 0], should_fail=True)
- s = nn.Sequential(deepcopy(l), l.cuda())
- self.assertRaisesRegex(
- RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
+ s = nn.Sequential(deepcopy(l).cpu(), l.cuda())
+ test(s, None, inp, None, should_fail=True)
+ test(s, None, inp, [0, 1], should_fail=True)
+ test(s, None, inp, [1, 0], should_fail=True)
s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1))
- self.assertRaisesRegex(
- RuntimeError, error_msg % (0), lambda: nn.DataParallel(s))
- self.assertRaisesRegex(
- RuntimeError, error_msg % (1),
- lambda: nn.DataParallel(s, device_ids=[1, 0]))
+ test(s, None, inp, None, should_fail=True)
+ test(s, None, inp, [0, 1], should_fail=True)
+ test(s, None, inp, [1, 0], should_fail=True)
s = nn.Sequential(l.cuda(), deepcopy(l).cuda())
- nn.DataParallel(s)
-
- s = nn.Sequential(l.cuda(1), deepcopy(l).cuda(1))
- nn.DataParallel(s, device_ids=[1, 0])
+ test(s, None, inp, None, should_fail=False)
+ test(s, None, inp, [0, 1], should_fail=False)
+ test(s, None, inp, [1, 0], should_fail=True)
+ test(s.cpu(), None, inp, [1, 0], should_fail=True)
+ test(s.cuda(1), None, inp, [1, 0], should_fail=False)
@unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
@skipIfRocm
if output_device is None:
output_device = device_ids[0]
- if not all(t.is_cuda and t.device.index == device_ids[0]
- for t in chain(module.parameters(), module.buffers())):
- raise RuntimeError("module must have its parameters and buffers "
- "on device %d (device_ids[0])" % device_ids[0])
-
self.dim = dim
self.module = module
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
+ self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
_check_balance(self.device_ids)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError("module must have its parameters and buffers "
+ "on device {} (device_ids[0]) but found one of "
+ "them on device: {}".format(self.src_device_obj, t.device))
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
if output_device is None:
output_device = device_ids[0]
+ device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
+ output_device = _get_device_index(output_device, True)
+ src_device_obj = torch.device("cuda:{}".format(device_ids[0]))
+
+ for t in chain(module.parameters(), module.buffers()):
+ if t.device != src_device_obj:
+ raise RuntimeError("module must have its parameters and buffers "
+ "on device {} (device_ids[0]) but found one of "
+ "them on device: {}".format(src_device_obj, t.device))
+
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
if len(device_ids) == 1:
return module(*inputs[0], **module_kwargs[0])