self.assertTrue(default_stream.query())
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+ @skipIfRocm
+ def test_stream_context(self):
+ s0 = torch.cuda.current_stream()
+ s1 = torch.cuda.Stream(device=1)
+ s2 = torch.cuda.Stream(device=0)
+
+ self.assertEqual(torch.cuda.current_stream(), s0)
+ self.assertEqual(0, torch.cuda.current_device())
+ with torch.cuda.stream(s1):
+ self.assertEqual(torch.cuda.current_stream(), s1)
+ self.assertEqual(1, torch.cuda.current_device())
+ with torch.cuda.stream(s2):
+ self.assertEqual(torch.cuda.current_stream(), s2)
+ self.assertEqual(0, torch.cuda.current_device())
+ with torch.cuda.stream(s0):
+ self.assertEqual(torch.cuda.current_stream(), s0)
+ self.assertEqual(0, torch.cuda.current_device())
+ self.assertEqual(torch.cuda.current_stream(), s2)
+ self.assertEqual(0, torch.cuda.current_device())
+ self.assertEqual(torch.cuda.current_stream(), s1)
+ self.assertEqual(1, torch.cuda.current_device())
+
+ self.assertEqual(torch.cuda.current_stream(), s0)
+ self.assertEqual(0, torch.cuda.current_device())
+
+ @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_streams_multi_gpu(self):
default_stream = torch.cuda.current_stream()
self.assertEqual(default_stream.device, 0)
if (bits == static_cast<uint64_t>(-1) && PyErr_Occurred()) {
throw python_error();
}
- at::cuda::setCurrentCUDAStream(at::cuda::CUDAStream::unpack(bits));
+ auto stream = at::cuda::CUDAStream::unpack(bits);
+ int device;
+ THCudaCheck(cudaGetDevice(&device));
+ if (device != stream.device_index()) {
+ THCPModule_setDevice(stream.device_index());
+ }
+ at::cuda::setCurrentCUDAStream(stream);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
- .. note:: Streams are per-device, and this function changes the "current
- stream" only for the currently selected device. It is illegal to select
- a stream that belongs to a different device.
+ .. note:: Streams are per-device. If the selected stream is not on the
+ current device, this function will also change the current device to
+ match the stream.
"""
if stream is None:
yield