From 292edfb0879cb2ae9377637eadd282dfb6daea33 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Fri, 18 Jan 2019 12:32:20 -0800 Subject: [PATCH] Change current device in stream context manager if necessary (#16128) Summary: Fixes #16019 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16128 Differential Revision: D13721850 Pulled By: mrshenli fbshipit-source-id: 422c6c0b97c1cd46e127e265b532cb8c74a3aac5 --- test/test_cuda.py | 26 ++++++++++++++++++++++++++ torch/csrc/cuda/Module.cpp | 8 +++++++- torch/cuda/__init__.py | 6 +++--- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index c8a0db7..5873216 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1438,6 +1438,32 @@ class TestCuda(TestCase): self.assertTrue(default_stream.query()) @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + @skipIfRocm + def test_stream_context(self): + s0 = torch.cuda.current_stream() + s1 = torch.cuda.Stream(device=1) + s2 = torch.cuda.Stream(device=0) + + self.assertEqual(torch.cuda.current_stream(), s0) + self.assertEqual(0, torch.cuda.current_device()) + with torch.cuda.stream(s1): + self.assertEqual(torch.cuda.current_stream(), s1) + self.assertEqual(1, torch.cuda.current_device()) + with torch.cuda.stream(s2): + self.assertEqual(torch.cuda.current_stream(), s2) + self.assertEqual(0, torch.cuda.current_device()) + with torch.cuda.stream(s0): + self.assertEqual(torch.cuda.current_stream(), s0) + self.assertEqual(0, torch.cuda.current_device()) + self.assertEqual(torch.cuda.current_stream(), s2) + self.assertEqual(0, torch.cuda.current_device()) + self.assertEqual(torch.cuda.current_stream(), s1) + self.assertEqual(1, torch.cuda.current_device()) + + self.assertEqual(torch.cuda.current_stream(), s0) + self.assertEqual(0, torch.cuda.current_device()) + + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_streams_multi_gpu(self): default_stream = torch.cuda.current_stream() self.assertEqual(default_stream.device, 0) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index bbace72..4408612 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -82,7 +82,13 @@ PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj) if (bits == static_cast(-1) && PyErr_Occurred()) { throw python_error(); } - at::cuda::setCurrentCUDAStream(at::cuda::CUDAStream::unpack(bits)); + auto stream = at::cuda::CUDAStream::unpack(bits); + int device; + THCudaCheck(cudaGetDevice(&device)); + if (device != stream.device_index()) { + THCPModule_setDevice(stream.device_index()); + } + at::cuda::setCurrentCUDAStream(stream); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 2a26411..8015b20 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -313,9 +313,9 @@ def stream(stream): stream (Stream): selected stream. This manager is a no-op if it's ``None``. - .. note:: Streams are per-device, and this function changes the "current - stream" only for the currently selected device. It is illegal to select - a stream that belongs to a different device. + .. note:: Streams are per-device. If the selected stream is not on the + current device, this function will also change the current device to + match the stream. """ if stream is None: yield -- 2.7.4