_get_device_index supports parsing device strings
authorSsnL <tongzhou.wang.1994@gmail.com>
Mon, 10 Dec 2018 05:10:39 +0000 (21:10 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 10 Dec 2018 05:12:46 +0000 (21:12 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14929

Reviewed By: weiyangfb

Differential Revision: D13394498

Pulled By: soumith

fbshipit-source-id: 948c6118abdf6c1e1a8a17709333954cafb2345e

test/test_cuda.py
torch/cuda/_utils.py

index 4867ed6..8226479 100644 (file)
@@ -787,7 +787,7 @@ class TestCuda(TestCase):
 
         # interlace
         torch.cuda.empty_cache()
-        gen0 = self._test_memory_stats_generator(self, device=0, N=35)
+        gen0 = self._test_memory_stats_generator(self, device='cuda:0', N=35)
         gen1 = self._test_memory_stats_generator(self, device=torch.device('cuda:1'), N=35)
         end0 = end1 = False
         while not (end0 and end1):
index 844bf37..278664c 100644 (file)
@@ -1,4 +1,5 @@
 import torch
+import torch._six
 
 
 def _get_device_index(device, optional=False):
@@ -15,6 +16,8 @@ def _get_device_index(device, optional=False):
     If :attr:`device` is ``None``, this will return the current default CUDA
     device if :attr:`optional` is ``True``.
     """
+    if isinstance(device, torch._six.string_classes):
+        device = torch.device(device)
     if isinstance(device, torch.device):
         dev_type = device.type
         if device.type != 'cuda':