Change current device in stream context manager if necessary (#16128)
authorShen Li <shenli@fb.com>
Fri, 18 Jan 2019 20:32:20 +0000 (12:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 18 Jan 2019 20:39:51 +0000 (12:39 -0800)
Summary:
Fixes #16019
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16128

Differential Revision: D13721850

Pulled By: mrshenli

fbshipit-source-id: 422c6c0b97c1cd46e127e265b532cb8c74a3aac5

test/test_cuda.py
torch/csrc/cuda/Module.cpp
torch/cuda/__init__.py

index c8a0db7..5873216 100644 (file)
@@ -1438,6 +1438,32 @@ class TestCuda(TestCase):
         self.assertTrue(default_stream.query())
 
     @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @skipIfRocm
+    def test_stream_context(self):
+        s0 = torch.cuda.current_stream()
+        s1 = torch.cuda.Stream(device=1)
+        s2 = torch.cuda.Stream(device=0)
+
+        self.assertEqual(torch.cuda.current_stream(), s0)
+        self.assertEqual(0, torch.cuda.current_device())
+        with torch.cuda.stream(s1):
+            self.assertEqual(torch.cuda.current_stream(), s1)
+            self.assertEqual(1, torch.cuda.current_device())
+            with torch.cuda.stream(s2):
+                self.assertEqual(torch.cuda.current_stream(), s2)
+                self.assertEqual(0, torch.cuda.current_device())
+                with torch.cuda.stream(s0):
+                    self.assertEqual(torch.cuda.current_stream(), s0)
+                    self.assertEqual(0, torch.cuda.current_device())
+                self.assertEqual(torch.cuda.current_stream(), s2)
+                self.assertEqual(0, torch.cuda.current_device())
+            self.assertEqual(torch.cuda.current_stream(), s1)
+            self.assertEqual(1, torch.cuda.current_device())
+
+        self.assertEqual(torch.cuda.current_stream(), s0)
+        self.assertEqual(0, torch.cuda.current_device())
+
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
     def test_streams_multi_gpu(self):
         default_stream = torch.cuda.current_stream()
         self.assertEqual(default_stream.device, 0)
index bbace72..4408612 100644 (file)
@@ -82,7 +82,13 @@ PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj)
   if (bits == static_cast<uint64_t>(-1) && PyErr_Occurred()) {
     throw python_error();
   }
-  at::cuda::setCurrentCUDAStream(at::cuda::CUDAStream::unpack(bits));
+  auto stream = at::cuda::CUDAStream::unpack(bits);
+  int device;
+  THCudaCheck(cudaGetDevice(&device));
+  if (device != stream.device_index()) {
+    THCPModule_setDevice(stream.device_index());
+  }
+  at::cuda::setCurrentCUDAStream(stream);
   Py_RETURN_NONE;
   END_HANDLE_TH_ERRORS
 }
index 2a26411..8015b20 100644 (file)
@@ -313,9 +313,9 @@ def stream(stream):
         stream (Stream): selected stream. This manager is a no-op if it's
             ``None``.
 
-    .. note:: Streams are per-device, and this function changes the "current
-       stream" only for the currently selected device.  It is illegal to select
-       a stream that belongs to a different device.
+    .. note:: Streams are per-device. If the selected stream is not on the
+        current device, this function will also change the current device to
+        match the stream.
     """
     if stream is None:
         yield