From: Shen Li Date: Fri, 15 Feb 2019 19:09:12 +0000 (-0800) Subject: Enforce module device at DataParallel construction time (#17129) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1250 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=472cfc0f2cc617d70ef28bdf3a623e1ff2e759d5;p=platform%2Fupstream%2Fpytorch.git Enforce module device at DataParallel construction time (#17129) Summary: closes #17065 CC douwekiela Pull Request resolved: https://github.com/pytorch/pytorch/pull/17129 Differential Revision: D14093353 Pulled By: mrshenli fbshipit-source-id: 9a5a10f16e392337a7f7073223541cf69b402f82 --- diff --git a/test/test_nn.py b/test/test_nn.py index 5b23b86..0d4ff86 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3351,6 +3351,43 @@ class TestNN(NNTestCase): out = dp.data_parallel(l, i, (0, 1)) self.assertEqual(out, l(i)) + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + @skipIfRocm + def test_data_parallel_model_device(self): + l = nn.Linear(2, 2) + error_msg = "module must have its parameters and buffers on device %d " + self.assertRaisesRegex( + RuntimeError, error_msg % (0), lambda: nn.DataParallel(l)) + self.assertRaisesRegex( + RuntimeError, error_msg % (0), lambda: nn.DataParallel(l.cuda(1))) + self.assertRaisesRegex( + RuntimeError, error_msg % (1), + lambda: nn.DataParallel(l.cuda(), device_ids=[1, 0])) + + nn.DataParallel(l.cuda()) + nn.DataParallel(l.cuda(1), device_ids=[1, 0]) + + s = nn.Sequential(l.cpu()) + self.assertRaisesRegex( + RuntimeError, error_msg % (0), lambda: nn.DataParallel(s)) + + s = nn.Sequential(deepcopy(l), l.cuda()) + self.assertRaisesRegex( + RuntimeError, error_msg % (0), lambda: nn.DataParallel(s)) + + s = nn.Sequential(l.cuda(), deepcopy(l).cuda(1)) + self.assertRaisesRegex( + RuntimeError, error_msg % (0), lambda: nn.DataParallel(s)) + self.assertRaisesRegex( + RuntimeError, error_msg % (1), + lambda: nn.DataParallel(s, device_ids=[1, 0])) + + s = nn.Sequential(l.cuda(), deepcopy(l).cuda()) + nn.DataParallel(s) + + s = nn.Sequential(l.cuda(1), deepcopy(l).cuda(1)) + nn.DataParallel(s, device_ids=[1, 0]) + @unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported") @skipIfRocm def test_data_parallel_model_no_refcycles(self): diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 9e04dbe..c543fc2 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -1,6 +1,7 @@ import operator import torch import warnings +from itertools import chain from ..modules import Module from .scatter_gather import scatter_kwargs, gather from .replicate import replicate @@ -105,7 +106,7 @@ class DataParallel(Module): Example:: >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) - >>> output = net(input_var) + >>> output = net(input_var) # input_var can be on any device, including CPU """ # TODO: update notes/cuda.rst when this class handles 8+ GPUs well @@ -123,6 +124,11 @@ class DataParallel(Module): if output_device is None: output_device = device_ids[0] + if not all(t.is_cuda and t.device.index == device_ids[0] + for t in chain(module.parameters(), module.buffers())): + raise RuntimeError("module must have its parameters and buffers " + "on device %d (device_ids[0])" % device_ids[0]) + self.dim = dim self.module = module self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))