Restore current streams on dst device after switching streams (#17439)
authorShen Li <shenli@fb.com>
Mon, 25 Feb 2019 20:02:35 +0000 (12:02 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Feb 2019 20:06:41 +0000 (12:06 -0800)
Summary:
When switching back to `d0` from a stream on a different device `d1`, we need to restore the current streams on both `d0` and `d1`. The current implementation only does that for `d0`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17439

Differential Revision: D14208919

Pulled By: mrshenli

fbshipit-source-id: 89f2565b9977206256efbec42adbd789329ccad8

test/test_cuda.py
torch/cuda/__init__.py

index 79c4f2a..2729bbc 100644 (file)
@@ -1559,6 +1559,9 @@ class TestCuda(TestCase):
         s1 = torch.cuda.Stream(device=1)
         s2 = torch.cuda.Stream(device=0)
 
+        with torch.cuda.device(s1.device):
+            prev_stream_on_cuda1 = torch.cuda.current_stream()
+
         self.assertEqual(torch.cuda.current_stream(), s0)
         self.assertEqual(0, torch.cuda.current_device())
         with torch.cuda.stream(s1):
@@ -1575,6 +1578,9 @@ class TestCuda(TestCase):
             self.assertEqual(torch.cuda.current_stream(), s1)
             self.assertEqual(1, torch.cuda.current_device())
 
+        with torch.cuda.device(s1.device):
+            self.assertEqual(prev_stream_on_cuda1, torch.cuda.current_stream())
+
         self.assertEqual(torch.cuda.current_stream(), s0)
         self.assertEqual(0, torch.cuda.current_device())
 
index 63bc75e..08d8565 100644 (file)
@@ -321,12 +321,21 @@ def stream(stream):
     if stream is None:
         yield
         return
-    prev_stream = current_stream()
+    src_prev_stream = current_stream()
+
+    if src_prev_stream.device != stream.device:
+        # The given stream is on a different device; have to restore the
+        # current_stream on that device on exit as well
+        with device(stream.device):
+            dst_prev_stream = current_stream()
+
     torch._C._cuda_setStream(stream._cdata)
     try:
         yield
     finally:
-        torch._C._cuda_setStream(prev_stream._cdata)
+        if src_prev_stream.device != stream.device:
+            torch._C._cuda_setStream(dst_prev_stream._cdata)
+        torch._C._cuda_setStream(src_prev_stream._cdata)
 
 
 def device_count():