Move all Stream and Event Python implementation to C++ (#15937)
authorShen Li <shenli@fb.com>
Thu, 17 Jan 2019 15:22:42 +0000 (07:22 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 15:29:22 +0000 (07:29 -0800)
Summary:
1. Added `torch/csrc/cuda/Event.h` and `torch/csrc/cuda/Event.cpp` to bind Python Event class to C++ implementation.
2. Move all CUDA runtime invocations from `torch/cuda/streams.py` to C++
3. Added tests to cover Stream and Event APIs. ~(event IPC handle tests is introduced in #15974)~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15937

Differential Revision: D13649001

Pulled By: mrshenli

fbshipit-source-id: 84ca58f35f6ba679a4ba33150ceba678d760d240

16 files changed:
aten/src/ATen/cuda/CUDAEvent.h
aten/src/ATen/test/cuda_stream_test.cpp
c10/cuda/CUDAStream.h
test/test_cuda.py
test/test_multiprocessing.py
tools/autograd/templates/python_variable_methods.cpp
torch/CMakeLists.txt
torch/csrc/Module.cpp
torch/csrc/cuda/Event.cpp [new file with mode: 0644]
torch/csrc/cuda/Event.h [new file with mode: 0644]
torch/csrc/cuda/Stream.cpp
torch/csrc/cuda/Stream.h
torch/csrc/cuda/THCP.h
torch/cuda/__init__.py
torch/cuda/streams.py
torch/multiprocessing/reductions.py

index b12f0b7..1b14685 100644 (file)
@@ -17,10 +17,12 @@ namespace at { namespace cuda {
 /*
 * CUDAEvents are movable not copyable wrappers around CUDA's events.
 *
-* CUDAEvents are constructed lazily when recorded on streams. The events
-* have a device, and this device is acquired from the first recording stream.
-* Later streams that record to the event must share this device, but streams
-* on any device can wait on the event.
+* CUDAEvents are constructed lazily when first recorded unless it is
+* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
+* device is acquired from the first recording stream. However, if reconstructed
+* from a handle, the device should be explicitly specified; or if ipc_handle() is
+* called before the event is ever recorded, it will use the current device.
+* Later streams that record the event must match this device.
 */
 struct AT_CUDA_API CUDAEvent {
   // Constants
@@ -30,12 +32,25 @@ struct AT_CUDA_API CUDAEvent {
   CUDAEvent(unsigned int flags = DEFAULT_FLAGS)
   : flags_{flags} { }
 
+  CUDAEvent(
+      DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
+    #ifndef __HIP_PLATFORM_HCC__
+      device_index_ = device_index;
+      CUDAGuard guard(device_index_);
+
+      AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
+      is_created_ = true;
+    #else
+      AT_ERROR("cuIpcOpenEventHandle with HIP is not supported");
+    #endif
+  }
+
   // Note: event destruction done on creating device to avoid creating a
   // CUDA context on other devices.
   ~CUDAEvent() {
     try {
       if (is_created_) {
-        CUDAGuard device_guard(static_cast<int16_t>(device_index_));
+        CUDAGuard guard(device_index_);
         cudaEventDestroy(event_);
       }
     } catch (...) { /* No throw */ }
@@ -57,13 +72,28 @@ struct AT_CUDA_API CUDAEvent {
     return left.event_ < right.event_;
   }
 
+  at::Device device() const {
+    return at::Device(at::kCUDA, device_index_);
+  }
+
   bool isCreated() const { return is_created_; }
-  int64_t device() const { return device_index_; }
+  DeviceIndex device_index() const {return device_index_;}
   cudaEvent_t event() const { return event_; }
 
   // Note: cudaEventQuery can be safely called from any device
-  bool happened() const {
-    return (was_recorded_ && cudaEventQuery(event_) == cudaSuccess);
+  bool query() const {
+    if (!is_created_) {
+      return true;
+    }
+
+    cudaError_t err = cudaEventQuery(event_);
+    if (err == cudaSuccess) {
+      return true;
+    } else if (err != cudaErrorNotReady) {
+      C10_CUDA_CHECK(err);
+    }
+
+    return false;
   }
 
   void record() { record(getCurrentCUDAStream()); }
@@ -72,18 +102,15 @@ struct AT_CUDA_API CUDAEvent {
     if (!was_recorded_) record(stream);
   }
 
-  // Note: cudaEventRecord must be called on the same device as the stream.
+  // Note: cudaEventRecord must be called on the same device as the event.
   void record(const CUDAStream& stream) {
-    CUDAGuard guard(static_cast<int16_t>(stream.device_index()));
-
-    if (is_created_) {
-      AT_ASSERT(device_index_ == stream.device_index());
-    } else {
-      AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
-      is_created_ = true;
-      device_index_ = stream.device_index();
+    if (!is_created_) {
+      createEvent(stream.device_index());
     }
 
+    AT_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
+      " does not match recording stream's device ", stream.device_index(), ".");
+    CUDAGuard guard(device_index_);
     AT_CUDA_CHECK(cudaEventRecord(event_, stream));
     was_recorded_ = true;
   }
@@ -92,18 +119,57 @@ struct AT_CUDA_API CUDAEvent {
   // The event has no actual GPU resources associated with it.
   void block(const CUDAStream& stream) {
     if (is_created_) {
-      CUDAGuard guard(static_cast<int16_t>(stream.device_index()));
+      CUDAGuard guard(stream.device_index());
       AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
     }
   }
 
+  // Note: cudaEventElapsedTime can be safely called from any device
+  float elapsed_time(const CUDAEvent& other) const {
+    AT_CHECK(is_created_ && other.isCreated(),
+      "Both events must be recorded before calculating elapsed time.");
+    float time_ms = 0;
+    // raise cudaErrorNotReady if either event is recorded but not yet completed
+    AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
+    return time_ms;
+  }
+
+  // Note: cudaEventSynchronize can be safely called from any device
+  void synchronize() const {
+    if (is_created_) {
+      AT_CUDA_CHECK(cudaEventSynchronize(event_));
+    }
+  }
+
+  // Note: cudaIpcGetEventHandle must be called on the same device as the event
+  void ipc_handle(cudaIpcEventHandle_t * handle) {
+    #ifndef __HIP_PLATFORM_HCC__
+      if (!is_created_) {
+        // this CUDAEvent object was initially constructed from flags but event_
+        // is not created yet.
+        createEvent(getCurrentCUDAStream().device_index());
+      }
+      CUDAGuard guard(device_index_);
+      AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
+    #else
+      AT_ERROR("cuIpcGetEventHandle with HIP is not supported");
+    #endif
+  }
+
 private:
   unsigned int flags_ = DEFAULT_FLAGS;
   bool is_created_ = false;
   bool was_recorded_ = false;
-  int64_t device_index_ = -1;
+  DeviceIndex device_index_ = -1;
   cudaEvent_t event_;
 
+  void createEvent(DeviceIndex device_index) {
+    device_index_ = device_index;
+    CUDAGuard guard(device_index_);
+    AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
+    is_created_ = true;
+  }
+
   void moveHelper(CUDAEvent&& other) {
     std::swap(flags_, other.flags_);
     std::swap(is_created_, other.is_created_);
index 22cd5f7..c245cb6 100644 (file)
@@ -217,7 +217,7 @@ TEST(TestStream, CUDAEventSyncTest) {
   const auto stream = at::cuda::getStreamFromPool();
   at::cuda::CUDAEvent event;
 
-  ASSERT_FALSE(event.happened());
+  ASSERT_TRUE(event.query());
 
   event.recordOnce(stream);
 
@@ -228,7 +228,7 @@ TEST(TestStream, CUDAEventSyncTest) {
   event.block(wait_stream1);
 
   cudaStreamSynchronize(wait_stream0);
-  ASSERT_TRUE(event.happened());
+  ASSERT_TRUE(event.query());
 }
 
 // Cross-Device Events
@@ -249,10 +249,10 @@ TEST(TestStream, CrossDeviceTest) {
 
   event0 = std::move(event1);
 
-  ASSERT_EQ_CUDA(event0.device(), 1);
+  ASSERT_EQ_CUDA(event0.device(), at::Device(at::kCUDA, 1));
 
   event0.block(stream0);
 
   cudaStreamSynchronize(stream0);
-  ASSERT_TRUE(event0.happened());
+  ASSERT_TRUE(event0.query());
 }
index 7548e72..7f7f864 100644 (file)
@@ -102,16 +102,32 @@ public:
   StreamId id() const { return stream_.id(); }
 
   bool query() const {
-    DeviceGuard device_guard{stream_.device()};
+    DeviceGuard guard{stream_.device()};
     cudaError_t err = cudaStreamQuery(stream());
 
-    if (err == cudaErrorNotReady) {
-      return false;
-    } else if (err != cudaSuccess) {
+    if (err == cudaSuccess) {
+      return true;
+    } else if (err != cudaErrorNotReady) {
       C10_CUDA_CHECK(err);
     }
 
-    return true;
+    return false;
+  }
+
+  void synchronize() const {
+    DeviceGuard guard{stream_.device()};
+    C10_CUDA_CHECK(cudaStreamSynchronize(stream()));
+  }
+
+  int priority() const {
+    #ifndef __HIP_PLATFORM_HCC__
+      DeviceGuard guard{stream_.device()};
+      int priority = 0;
+      C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
+      return priority;
+    #else
+      AT_ERROR("cuStreamGetPriority with HIP is not supported");
+    #endif
   }
 
   /// Explicit conversion to cudaStream_t.
@@ -137,6 +153,17 @@ public:
     return CUDAStream(Stream::unpack(bits));
   }
 
+  static std::tuple<int, int> priority_range() {
+    #ifndef __HIP_PLATFORM_HCC__
+      int least_priority, greatest_priority;
+      C10_CUDA_CHECK(
+        cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
+      return std::make_tuple(least_priority, greatest_priority);
+    #else
+      AT_ERROR("cuDeviceGetStreamPriorityRange with HIP is not supported");
+    #endif
+  }
+
   // Deleted for now; use CUDAEvent::block instead
   // void synchronize_with(const CUDAEvent& event) const;
 
index c22ddb0..c8a0db7 100644 (file)
@@ -1471,7 +1471,8 @@ class TestCuda(TestCase):
             self.assertTrue(s0.query())
             self.assertFalse(s1.query())
 
-        with torch.cuda.device(d1):
+        # deliberately using a different device
+        with torch.cuda.device(d0):
             s1.synchronize()
 
         self.assertTrue(s0.query())
@@ -1517,6 +1518,20 @@ class TestCuda(TestCase):
         self.assertNotEqual(hash(s0), hash(s3))
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+    @skipIfRocm
+    def test_streams_priority(self):
+        low, high = torch.cuda.Stream.priority_range()
+        s0 = torch.cuda.Stream(device=0, priority=low)
+
+        self.assertEqual(low, s0.priority)
+        self.assertEqual(0, s0.device)
+
+        s1 = torch.cuda.Stream(device=1, priority=high)
+
+        self.assertEqual(high, s1.priority)
+        self.assertEqual(1, s1.device)
+
+    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     def test_tensor_device(self):
         self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 0)
         self.assertEqual(torch.cuda.FloatTensor(1, device=1).get_device(), 1)
@@ -1539,6 +1554,113 @@ class TestCuda(TestCase):
         self.assertTrue(event.query())
         self.assertGreater(start_event.elapsed_time(event), 0)
 
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @skipIfRocm
+    def test_events_wait(self):
+        d0 = torch.device('cuda:0')
+        d1 = torch.device('cuda:1')
+
+        with torch.cuda.device(d0):
+            s0 = torch.cuda.current_stream()
+            torch.cuda._sleep(50000000)  # spin for about 50 ms on device1
+            e0 = torch.cuda.Event()
+            s0.record_event(e0)
+
+        with torch.cuda.device(d1):
+            s1 = torch.cuda.current_stream()
+
+        self.assertFalse(s0.query())
+        self.assertTrue(s1.query())
+
+        s1.wait_event(e0)
+        s1.synchronize()
+
+        self.assertTrue(e0.query())
+        self.assertTrue(s0.query())
+        self.assertTrue(s1.query())
+
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @skipIfRocm
+    def test_events_multi_gpu_query(self):
+        d0 = torch.device('cuda:0')
+        d1 = torch.device('cuda:1')
+
+        with torch.cuda.device(d0):
+            s0 = torch.cuda.current_stream()
+            e0 = s0.record_event()
+
+        with torch.cuda.device(d1):
+            s1 = torch.cuda.current_stream()
+            torch.cuda._sleep(50000000)  # spin for about 50 ms on device1
+            e1 = s1.record_event()
+
+        self.assertTrue(e0.query())
+        self.assertFalse(e1.query())
+
+        with torch.cuda.device(d0):
+            self.assertTrue(e0.query())
+            self.assertFalse(e1.query())
+
+        with torch.cuda.device(d1):
+            self.assertTrue(e0.query())
+            self.assertFalse(e1.query())
+
+        # deliberately using a different device
+        with torch.cuda.device(d0):
+            e1.synchronize()
+
+        self.assertTrue(e0.query())
+        self.assertTrue(e1.query())
+
+        with torch.cuda.device(d0):
+            self.assertTrue(e0.query())
+            self.assertTrue(e1.query())
+
+        with torch.cuda.device(d1):
+            self.assertTrue(e0.query())
+            self.assertTrue(e1.query())
+
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @skipIfRocm
+    def test_events_multi_gpu_elapsed_time(self):
+        d0 = torch.device('cuda:0')
+        d1 = torch.device('cuda:1')
+
+        with torch.cuda.device(d0):
+            s0 = torch.cuda.current_stream()
+            e0 = torch.cuda.Event(enable_timing=True)
+            torch.cuda._sleep(10)  # spin for about 50 ms on device1
+            s0.record_event(e0)
+
+        with torch.cuda.device(d1):
+            s1 = torch.cuda.current_stream()
+            e1 = torch.cuda.Event(enable_timing=True)
+            torch.cuda._sleep(30000000)  # spin for about 50 ms on device1
+            s1.record_event(e1)
+
+        e0.synchronize()
+        e1.synchronize()
+        with torch.cuda.device(d0):
+            with self.assertRaises(RuntimeError):
+                self.assertGreater(e0.elapsed_time(e1), 0)
+
+        with torch.cuda.device(d1):
+            with self.assertRaises(RuntimeError):
+                self.assertGreater(e0.elapsed_time(e1), 0)
+
+        with torch.cuda.device(d0):
+            s0 = torch.cuda.current_stream()
+            e2 = torch.cuda.Event(enable_timing=True)
+            torch.cuda._sleep(30000000)  # spin for about 50 ms on device1
+            s0.record_event(e2)
+            s0.synchronize()
+
+        self.assertGreater(e0.elapsed_time(e2), 0)
+
+        # deliberately calling from a different device
+        with torch.cuda.device(d1):
+            self.assertGreater(e0.elapsed_time(e2), 0)
+
     @skipIfRocm
     def test_record_stream(self):
         cycles_per_ms = get_cycles_per_ms()
index f56cd2c..fadeb43 100644 (file)
@@ -402,6 +402,129 @@ class TestMultiprocessing(TestCase):
             self.assertEqual(list(tensor), [4, 4, 4, 4])
         p.join()
 
+    def _test_event_multiprocess_child(event, p2c, c2p):
+        c2p.put(0)  # notify parent child is ready
+        p2c.get()  # wait for record in parent
+        event.synchronize()
+        c2p.put(1)  # notify parent synchronization is done
+
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_event_multiprocess(self):
+        event = torch.cuda.Event(enable_timing=False, interprocess=True)
+        self.assertTrue(event.query())
+
+        ctx = mp.get_context('spawn')
+        p2c = ctx.SimpleQueue()
+        c2p = ctx.SimpleQueue()
+        p = ctx.Process(
+            target=TestMultiprocessing._test_event_multiprocess_child,
+            args=(event, p2c, c2p))
+        p.start()
+
+        c2p.get()  # wait for until child process is ready
+        torch.cuda._sleep(50000000)  # spin for about 50 ms
+        event.record()
+        p2c.put(0)  # notify child event is recorded
+
+        self.assertFalse(event.query())
+        c2p.get()  # wait for synchronization in child
+        self.assertTrue(event.query())
+        p.join()
+
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_event_handle_multi_gpu(self):
+        d0 = torch.device('cuda:0')
+        d1 = torch.device('cuda:1')
+        with torch.cuda.device(d0):
+            e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
+
+        with torch.cuda.device(d1):
+            # create handle on different device from un-recorded event
+            e0.ipc_handle()
+
+        with torch.cuda.device(d0):
+            e1 = torch.cuda.Event(enable_timing=False, interprocess=True)
+            stream = torch.cuda.Stream()
+            torch.cuda._sleep(50000000)  # spin for about 50 ms
+            e1.record(stream)
+
+        with torch.cuda.device(d1):
+            # create handle on different device from recorded event
+            e1.ipc_handle()
+
+    def _test_event_handle_importer_consumer(handle, p2c, c2p):
+        e1 = torch.cuda.Event.from_ipc_handle(
+            torch.cuda.current_device(), handle)
+        c2p.put(0)  # notify parent child is ready
+        p2c.get()  # wait for record in parent
+        e1.synchronize()
+        c2p.put(1)  # nofity synchronization is done in child
+        p2c.get()  # wait for parent to finish before destructing child event
+
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_event_handle_importer(self):
+        e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
+        self.assertTrue(e0.query())
+
+        ctx = mp.get_context('spawn')
+        p2c = ctx.SimpleQueue()
+        c2p = ctx.SimpleQueue()
+        p = ctx.Process(
+            target=TestMultiprocessing._test_event_handle_importer_consumer,
+            args=(e0.ipc_handle(), p2c, c2p))
+        p.start()
+
+        c2p.get()  # wait for child to become ready
+        torch.cuda._sleep(50000000)  # spin for about 50 ms
+        e0.record()
+        p2c.put(0)  # notify child event is recorded
+
+        self.assertFalse(e0.query())
+        c2p.get()  # wait for synchronization in child
+        self.assertTrue(e0.query())
+        p2c.put(1)  # notify child that parent is done
+        p.join()
+
+    def _test_event_handle_exporter_consumer(handle, p2c, c2p):
+        stream = torch.cuda.Stream()
+        with torch.cuda.stream(stream):
+            e1 = torch.cuda.Event.from_ipc_handle(
+                torch.cuda.current_device(), handle)
+            torch.cuda._sleep(50000000)  # spin for about 50 ms
+            e1.record()
+            c2p.put(0)
+            # wait for parent process finished synchronization before
+            # destructing e1
+            p2c.get()
+
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+    def test_event_handle_exporter(self):
+        e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
+
+        ctx = mp.get_context('spawn')
+        p2c = ctx.SimpleQueue()
+        c2p = ctx.SimpleQueue()
+        p = ctx.Process(
+            target=TestMultiprocessing._test_event_handle_exporter_consumer,
+            args=(e0.ipc_handle(), p2c, c2p))
+        p.start()
+        # wait for event in child process is recorded
+        c2p.get()
+
+        self.assertFalse(e0.query())
+        e0.synchronize()
+        self.assertTrue(e0.query())
+        p2c.put(0)
+        p.join()
+
     def _test_empty_tensor_sharing(self, dtype, device):
         q = mp.Queue()
         empty = torch.tensor([], dtype=dtype, device=device)
index 08ff373..241d187 100644 (file)
@@ -12,6 +12,7 @@
 #include "torch/csrc/jit/tracer.h"
 #ifdef USE_CUDA
 #include "torch/csrc/cuda/Stream.h"
+#include "torch/csrc/cuda/Event.h"
 #endif
 #include "torch/csrc/utils/cuda_lazy_init.h"
 #include "torch/csrc/utils/object_ptr.h"
index 9254816..879c5d4 100644 (file)
@@ -624,6 +624,7 @@ if (BUILD_PYTHON)
       ${TORCH_SRC_DIR}/csrc/cuda/Module.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/Storage.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
+      ${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
@@ -669,6 +670,7 @@ if (BUILD_PYTHON)
       ${TORCH_SRC_DIR}/csrc/cuda/Module.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/Storage.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
+      ${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
       ${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
index 812ed09..37c2d0c 100644 (file)
@@ -461,7 +461,8 @@ bool THCPShortStorage_init(PyObject *module);
 bool THCPCharStorage_init(PyObject *module);
 bool THCPByteStorage_init(PyObject *module);
 
-bool THCPStream_init(PyObject *module);
+void THCPStream_init(PyObject *module);
+void THCPEvent_init(PyObject *module);
 
 #ifdef USE_CUDA
 PyMethodDef* THCPModule_methods();
@@ -607,7 +608,8 @@ PyObject* initModule() {
   ASSERT_TRUE(THCPCharStorage_init(module));
   ASSERT_TRUE(THCPByteStorage_init(module));
 
-  ASSERT_TRUE(THCPStream_init(module));
+  THCPStream_init(module);
+  THCPEvent_init(module);
 #endif
 
   auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) {
diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp
new file mode 100644 (file)
index 0000000..ce64c61
--- /dev/null
@@ -0,0 +1,211 @@
+#include <torch/csrc/cuda/Event.h>
+#include <torch/csrc/cuda/Stream.h>
+
+#include <torch/csrc/THP.h>
+#include <torch/csrc/cuda/Module.h>
+
+#include <c10/cuda/CUDAGuard.h>
+
+#include <structmember.h>
+#include <cuda_runtime_api.h>
+
+PyObject *THCPEventClass = nullptr;
+
+static PyObject * THCPEvent_pynew(
+    PyTypeObject *type, PyObject *args, PyObject *kwargs) {
+  HANDLE_TH_ERRORS
+  unsigned char enable_timing = 0;
+  unsigned char blocking = 0;
+  unsigned char interprocess = 0;
+
+  static char *kwlist[] =
+    {"enable_timing", "blocking", "interprocess", nullptr};
+  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|bbb", kwlist,
+      &enable_timing, &blocking, &interprocess)) {
+    return nullptr;
+  }
+
+  THPObjectPtr ptr(type->tp_alloc(type, 0));
+  if (!ptr) {
+    return nullptr;
+  }
+
+  THCPEvent* self = (THCPEvent *)ptr.get();
+  unsigned int flags =
+    (blocking ? cudaEventBlockingSync : cudaEventDefault) |
+    (enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
+    (interprocess ? cudaEventInterprocess : cudaEventDefault);
+
+  new (&self->cuda_event) at::cuda::CUDAEvent(flags);
+
+  return (PyObject *)ptr.release();
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_from_ipc_handle(
+    PyTypeObject *type, PyObject *args) {
+  HANDLE_TH_ERRORS
+  long long device_index = -1;
+  const char *handle_bytes = nullptr;
+  int handle_size = 0;
+
+  // cannot use bool 'p' and bytearray 'Y' as they are not available in Python 2
+  if (!PyArg_ParseTuple(
+      args, "Ls#", &device_index, &handle_bytes, &handle_size)) {
+    return nullptr;
+  }
+
+  AT_CHECK(handle_size == sizeof(cudaIpcEventHandle_t),
+    "cudaIpcEventHandle_t expects byte-like object of size ",
+    sizeof(cudaIpcEventHandle_t), ", but got ", handle_size);
+  AT_CHECK(device_index >= 0, "Reconstructing event from handle requires "
+    "a non-negtive device index, but got ", device_index)
+
+  // no need to release the handle byte array as it is automatically managed
+  // by the corresponding THCPEvent python object.
+  // see https://docs.python.org/3/c-api/arg.html#strings-and-buffers
+
+  THPObjectPtr ptr(type->tp_alloc(type, 0));
+  if (!ptr) {
+    return nullptr;
+  }
+  THCPEvent* self = (THCPEvent *)ptr.get();
+
+  cudaIpcEventHandle_t handle;
+  std::memcpy(&handle, handle_bytes, handle_size);
+  new (&self->cuda_event) at::cuda::CUDAEvent(device_index, &handle);
+
+  return (PyObject *)ptr.release();
+  END_HANDLE_TH_ERRORS
+}
+
+static void THCPEvent_dealloc(THCPEvent *self) {
+  self->cuda_event.~CUDAEvent();
+  Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+static PyObject * THCPEvent_get_cuda_event(THCPEvent *self) {
+  HANDLE_TH_ERRORS
+  return PyLong_FromVoidPtr(self->cuda_event.event());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_get_device(THCPEvent *self) {
+  HANDLE_TH_ERRORS
+  return THPUtils_packInt64(self->cuda_event.device_index());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_record(THCPEvent *self, THCPStream *stream) {
+  HANDLE_TH_ERRORS
+  self->cuda_event.record(stream->cuda_stream);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_wait(THCPEvent *self, THCPStream *stream) {
+  HANDLE_TH_ERRORS
+  self->cuda_event.block(stream->cuda_stream);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_query(THCPEvent *self) {
+  HANDLE_TH_ERRORS
+  return PyBool_FromLong(self->cuda_event.query());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_elapsed_time(THCPEvent *self, THCPEvent *other) {
+  HANDLE_TH_ERRORS
+  return PyFloat_FromDouble(self->cuda_event.elapsed_time(other->cuda_event));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_synchronize(THCPEvent *self) {
+  HANDLE_TH_ERRORS
+  self->cuda_event.synchronize();
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_ipc_handle(THCPEvent *self) {
+  HANDLE_TH_ERRORS
+  cudaIpcEventHandle_t handle;
+  self->cuda_event.ipc_handle(&handle);
+  return PyBytes_FromStringAndSize((const char *)&handle, sizeof(handle));
+  END_HANDLE_TH_ERRORS
+}
+
+static struct PyGetSetDef THCPEvent_properties[] = {
+  {"device", (getter)THCPEvent_get_device, nullptr, nullptr, nullptr},
+  {"cuda_event", (getter)THCPEvent_get_cuda_event, nullptr, nullptr, nullptr},
+  {nullptr}
+};
+
+static PyMethodDef THCPEvent_methods[] = {
+  {(char*)"from_ipc_handle", (PyCFunction)THCPEvent_from_ipc_handle,
+    METH_CLASS | METH_VARARGS, nullptr},
+  {(char*)"record", (PyCFunction)THCPEvent_record, METH_O, nullptr},
+  {(char*)"wait", (PyCFunction)THCPEvent_wait, METH_O, nullptr},
+  {(char*)"query", (PyCFunction)THCPEvent_query, METH_NOARGS, nullptr},
+  {(char*)"elapsed_time", (PyCFunction)THCPEvent_elapsed_time, METH_O, nullptr},
+  {(char*)"synchronize", (PyCFunction)THCPEvent_synchronize,
+    METH_NOARGS, nullptr},
+  {(char*)"ipc_handle", (PyCFunction)THCPEvent_ipc_handle,
+    METH_NOARGS, nullptr},
+  {nullptr}
+};
+
+PyTypeObject THCPEventType = {
+  PyVarObject_HEAD_INIT(nullptr, 0)
+  "torch._C._CudaEventBase",             /* tp_name */
+  sizeof(THCPEvent),                     /* tp_basicsize */
+  0,                                     /* tp_itemsize */
+  (destructor)THCPEvent_dealloc,         /* tp_dealloc */
+  0,                                     /* tp_print */
+  0,                                     /* tp_getattr */
+  0,                                     /* tp_setattr */
+  0,                                     /* tp_reserved */
+  0,                                     /* tp_repr */
+  0,                                     /* tp_as_number */
+  0,                                     /* tp_as_sequence */
+  0,                                     /* tp_as_mapping */
+  0,                                     /* tp_hash  */
+  0,                                     /* tp_call */
+  0,                                     /* tp_str */
+  0,                                     /* tp_getattro */
+  0,                                     /* tp_setattro */
+  0,                                     /* tp_as_buffer */
+  Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
+  nullptr,                                  /* tp_doc */
+  0,                                     /* tp_traverse */
+  0,                                     /* tp_clear */
+  0,                                     /* tp_richcompare */
+  0,                                     /* tp_weaklistoffset */
+  0,                                     /* tp_iter */
+  0,                                     /* tp_iternext */
+  THCPEvent_methods,                     /* tp_methods */
+  0,                                     /* tp_members */
+  THCPEvent_properties,                  /* tp_getset */
+  0,                                     /* tp_base */
+  0,                                     /* tp_dict */
+  0,                                     /* tp_descr_get */
+  0,                                     /* tp_descr_set */
+  0,                                     /* tp_dictoffset */
+  0,                                     /* tp_init */
+  0,                                     /* tp_alloc */
+  THCPEvent_pynew,                       /* tp_new */
+};
+
+void THCPEvent_init(PyObject *module) {
+  THCPEventClass = (PyObject*)&THCPEventType;
+  if (PyType_Ready(&THCPEventType) < 0) {
+    throw python_error();
+  }
+  Py_INCREF(&THCPEventType);
+  if (PyModule_AddObject(
+      module, "_CudaEventBase", (PyObject *)&THCPEventType) < 0) {
+    throw python_error();
+  }
+}
diff --git a/torch/csrc/cuda/Event.h b/torch/csrc/cuda/Event.h
new file mode 100644 (file)
index 0000000..214f3cd
--- /dev/null
@@ -0,0 +1,20 @@
+#ifndef THCP_EVENT_INC
+#define THCP_EVENT_INC
+
+#include <ATen/cuda/CUDAEvent.h>
+#include <torch/csrc/python_headers.h>
+#include <THC/THC.h>
+
+struct THCPEvent {
+  PyObject_HEAD
+  at::cuda::CUDAEvent cuda_event;
+};
+extern PyObject *THCPEventClass;
+
+void THCPEvent_init(PyObject *module);
+
+inline bool THCPEvent_Check(PyObject* obj) {
+  return THCPEventClass && PyObject_IsInstance(obj, THCPEventClass);
+}
+
+#endif // THCP_EVENT_INC
index 9ad9c77..3a0d4a4 100644 (file)
@@ -10,8 +10,8 @@
 
 PyObject *THCPStreamClass = nullptr;
 
-static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
-{
+static PyObject * THCPStream_pynew(
+  PyTypeObject *type, PyObject *args, PyObject *kwargs) {
   HANDLE_TH_ERRORS
 
   int current_device;
@@ -21,7 +21,8 @@ static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject
   uint64_t cdata = 0;
 
   static char *kwlist[] = {"priority", "_cdata", nullptr};
-  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iK", kwlist, &priority, &cdata)) {
+  if (!PyArg_ParseTupleAndKeywords(
+      args, kwargs, "|iK", kwlist, &priority, &cdata)) {
     return nullptr;
   }
 
@@ -33,7 +34,8 @@ static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject
   at::cuda::CUDAStream stream =
     cdata ?
     at::cuda::CUDAStream::unpack(cdata) :
-    at::cuda::getStreamFromPool(/* isHighPriority */ priority < 0 ? true : false);
+    at::cuda::getStreamFromPool(
+      /* isHighPriority */ priority < 0 ? true : false);
 
   THCPStream* self = (THCPStream *)ptr.get();
   self->cdata = stream.pack();
@@ -48,24 +50,46 @@ static void THCPStream_dealloc(THCPStream *self) {
   Py_TYPE(self)->tp_free((PyObject*)self);
 }
 
-static PyObject * THPVariable_get_device(THCPStream *self) {
+static PyObject * THCPStream_get_device(THCPStream *self) {
   HANDLE_TH_ERRORS
   return THPUtils_packInt64(self->cuda_stream.device_index());
   END_HANDLE_TH_ERRORS
 }
 
-static PyObject * THPVariable_get_cuda_stream(THCPStream *self) {
+static PyObject * THCPStream_get_cuda_stream(THCPStream *self) {
   HANDLE_TH_ERRORS
   return PyLong_FromVoidPtr(self->cuda_stream.stream());
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject * THCPStream_get_priority(THCPStream *self) {
+  HANDLE_TH_ERRORS
+  return PyLong_FromLong(self->cuda_stream.priority());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPStream_priority_range() {
+  HANDLE_TH_ERRORS
+  int least_priority, greatest_priority;
+  std::tie(least_priority, greatest_priority) =
+    at::cuda::CUDAStream::priority_range();
+  return Py_BuildValue("(ii)", least_priority, greatest_priority);
+  END_HANDLE_TH_ERRORS
+}
+
 static PyObject * THCPStream_query(THCPStream *self) {
   HANDLE_TH_ERRORS
   return PyBool_FromLong(self->cuda_stream.query());
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject * THCPStream_synchronize(THCPStream *self) {
+  HANDLE_TH_ERRORS
+  self->cuda_stream.synchronize();
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 static PyObject * THCPStream_eq(THCPStream *self, THCPStream *other) {
   HANDLE_TH_ERRORS
   return PyBool_FromLong(self->cuda_stream == other->cuda_stream);
@@ -73,26 +97,32 @@ static PyObject * THCPStream_eq(THCPStream *self, THCPStream *other) {
 }
 
 static struct PyMemberDef THCPStream_members[] = {
-  {(char*)"_cdata", T_ULONGLONG, offsetof(THCPStream, cdata), READONLY, nullptr},
+  {(char*)"_cdata",
+    T_ULONGLONG, offsetof(THCPStream, cdata), READONLY, nullptr},
   {nullptr}
 };
 
-static struct PyGetSetDef THPVariable_properties[] = {
-  {"device", (getter)THPVariable_get_device, nullptr, nullptr, nullptr},
+static struct PyGetSetDef THCPStream_properties[] = {
+  {"device", (getter)THCPStream_get_device, nullptr, nullptr, nullptr},
   {"cuda_stream",
-    (getter)THPVariable_get_cuda_stream, nullptr, nullptr, nullptr},
+    (getter)THCPStream_get_cuda_stream, nullptr, nullptr, nullptr},
+  {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr},
   {nullptr}
 };
 
 static PyMethodDef THCPStream_methods[] = {
   {(char*)"query", (PyCFunction)THCPStream_query, METH_NOARGS, nullptr},
+  {(char*)"synchronize",
+    (PyCFunction)THCPStream_synchronize, METH_NOARGS, nullptr},
+  {(char*)"priority_range",
+    (PyCFunction)THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr},
   {(char*)"__eq__", (PyCFunction)THCPStream_eq, METH_O, nullptr},
   {nullptr}
 };
 
 PyTypeObject THCPStreamType = {
   PyVarObject_HEAD_INIT(nullptr, 0)
-  "torch._C._CudaStreamBase",             /* tp_name */
+  "torch._C._CudaStreamBase",            /* tp_name */
   sizeof(THCPStream),                    /* tp_basicsize */
   0,                                     /* tp_itemsize */
   (destructor)THCPStream_dealloc,        /* tp_dealloc */
@@ -120,7 +150,7 @@ PyTypeObject THCPStreamType = {
   0,                                     /* tp_iternext */
   THCPStream_methods,                    /* tp_methods */
   THCPStream_members,                    /* tp_members */
-  THPVariable_properties,                /* tp_getset */
+  THCPStream_properties,                /* tp_getset */
   0,                                     /* tp_base */
   0,                                     /* tp_dict */
   0,                                     /* tp_descr_get */
@@ -132,12 +162,15 @@ PyTypeObject THCPStreamType = {
 };
 
 
-bool THCPStream_init(PyObject *module)
+void THCPStream_init(PyObject *module)
 {
   THCPStreamClass = (PyObject*)&THCPStreamType;
-  if (PyType_Ready(&THCPStreamType) < 0)
-    return false;
+  if (PyType_Ready(&THCPStreamType) < 0) {
+    throw python_error();
+  }
   Py_INCREF(&THCPStreamType);
-  PyModule_AddObject(module, "_CudaStreamBase", (PyObject *)&THCPStreamType);
-  return true;
+  if (PyModule_AddObject(
+      module, "_CudaStreamBase", (PyObject *)&THCPStreamType) < 0) {
+    throw python_error();
+  }
 }
index f0ed77e..c98d135 100644 (file)
@@ -12,7 +12,7 @@ struct THCPStream {
 };
 extern PyObject *THCPStreamClass;
 
-bool THCPStream_init(PyObject *module);
+void THCPStream_init(PyObject *module);
 
 inline bool THCPStream_Check(PyObject* obj) {
   return THCPStreamClass && PyObject_IsInstance(obj, THCPStreamClass);
index c81ca05..eb5f7ae 100644 (file)
@@ -12,6 +12,7 @@
 #include <torch/csrc/cuda/Module.h>
 #include <torch/csrc/cuda/Storage.h>
 #include <torch/csrc/cuda/Stream.h>
+#include <torch/csrc/cuda/Event.h>
 #ifdef _THP_CORE
 #include <torch/csrc/cuda/utils.h>
 #endif
index 6534446..2a26411 100644 (file)
@@ -536,6 +536,7 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
         torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
 
     torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
+    torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
 
 
 @staticmethod
index 64fe7e8..1736e3e 100644 (file)
@@ -1,8 +1,5 @@
 import ctypes
 import torch
-from . import cudart, check_error, cudaStatus
-from ._utils import _get_device_index
-from torch._C import _add_docstr
 
 
 class Stream(torch._C._CudaStreamBase):
@@ -39,7 +36,7 @@ class Stream(torch._C._CudaStreamBase):
         .. _CUDA documentation:
            http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
         """
-        check_error(cudart().cudaStreamWaitEvent(self, event, ctypes.c_int(0)))
+        event.wait(self)
 
     def wait_stream(self, stream):
         r"""Synchronizes with another stream.
@@ -67,7 +64,7 @@ class Stream(torch._C._CudaStreamBase):
         """
         if event is None:
             event = Event()
-        check_error(cudart().cudaEventRecord(event, self))
+        event.record(self)
         return event
 
     def query(self):
@@ -86,21 +83,7 @@ class Stream(torch._C._CudaStreamBase):
         .. _CUDA documentation:
            http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
         """
-        check_error(cudart().cudaStreamSynchronize(self))
-
-    @staticmethod
-    def priority_range():
-        least_priority = ctypes.c_int()
-        greatest_priority = ctypes.c_int()
-        check_error(cudart().cudaDeviceGetStreamPriorityRange(
-            ctypes.byref(least_priority), ctypes.byref(greatest_priority)))
-        return (least_priority.value, greatest_priority.value)
-
-    @property
-    def priority(self):
-        priority = ctypes.c_int()
-        check_error(cudart().cudaStreamGetPriority(self, ctypes.byref(priority)))
-        return priority.value
+        super(Stream, self).synchronize()
 
     @property
     def _as_parameter_(self):
@@ -119,90 +102,94 @@ class Stream(torch._C._CudaStreamBase):
                 .format(self.device, self.cuda_stream))
 
 
-class EventHandle(ctypes.Structure):
-    IPC_HANDLE_SIZE = 64
-    _fields_ = [('reserved', ctypes.c_char * IPC_HANDLE_SIZE)]
+class Event(torch._C._CudaEventBase):
+    r"""Wrapper around a CUDA event.
 
+    CUDA events are synchronization markers that can be used to monitor the
+    device's progress, to accurately measure timing, and to synchronize CUDA
+    streams.
 
-class Event(object):
-    r"""Wrapper around CUDA event.
+    The underlying CUDA events are lazily initialized when the event is first
+    recorded or exported to another process. After creation, only streams on the
+    same device may record the event. However, streams on any device can wait on
+    the event.
 
     Arguments:
-        enable_timing (bool): indicates if the event should measure time
+        enable_timing (bool, optional): indicates if the event should measure time
             (default: ``False``)
-        blocking (bool): if ``True``, :meth:`wait` will be blocking (default: ``False``)
+        blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
         interprocess (bool): if ``True``, the event can be shared between processes
             (default: ``False``)
+
+       .. _CUDA documentation:
+       https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
     """
 
-    DEFAULT = 0x0
-    BLOCKING_SYNC = 0x1
-    DISABLE_TIMING = 0x2
-    INTERPROCESS = 0x4
-
-    def __init__(self, enable_timing=False, blocking=False, interprocess=False,
-                 _handle=None):
-        flags = Event.DEFAULT
-        if not enable_timing:
-            flags |= Event.DISABLE_TIMING
-        if blocking:
-            flags |= Event.BLOCKING_SYNC
-        if interprocess:
-            flags |= Event.INTERPROCESS
-
-        ptr = ctypes.c_void_p()
-        self._cudart = cudart()
-        if _handle:
-            check_error(self._cudart.cudaIpcOpenEventHandle(ctypes.byref(ptr), _handle))
-        else:
-            check_error(self._cudart.cudaEventCreateWithFlags(ctypes.byref(ptr), ctypes.c_uint(flags)))
-        self._as_parameter_ = ptr
-
-    def __del__(self):
-        if hasattr(self, '_as_parameter_'):
-            check_error(self._cudart.cudaEventDestroy(self._as_parameter_))
-            del self._as_parameter_
+    def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
+        return super(Event, cls).__new__(
+            cls,
+            enable_timing=enable_timing, blocking=blocking, interprocess=interprocess)
+
+    @classmethod
+    def from_ipc_handle(cls, device, handle):
+        r"""Reconstruct an event from an IPC handle on the given device."""
+        return super(Event, cls).from_ipc_handle(device, handle)
 
     def record(self, stream=None):
-        r"""Records the event in a given stream."""
+        r"""Records the event in a given stream.
+
+        Uses ``torch.cuda.current_stream()`` if no stream is specified. The
+        stream's device must match the event's device."""
         if stream is None:
             stream = torch.cuda.current_stream()
-        stream.record_event(self)
+        super(Event, self).record(stream)
 
     def wait(self, stream=None):
-        r"""Makes a given stream wait for the event."""
+        r"""Makes all future work submitted to the given stream wait for this
+        event.
+
+        Use ``torch.cuda.current_stream()`` if no stream is specified."""
         if stream is None:
             stream = torch.cuda.current_stream()
-        stream.wait_event(self)
+        super(Event, self).wait(stream)
 
     def query(self):
-        r"""Checks if the event has been recorded.
+        r"""Checks if all work currently captured by event has completed.
 
         Returns:
-            A boolean indicating if the event has been recorded.
+            A boolean indicating if all work currently captured by event has
+            completed.
         """
-        res = cudart().cudaEventQuery(self)
-        if res == cudaStatus.ERROR_NOT_READY:
-            return False
-        check_error(res)
-        return True
+        return super(Event, self).query()
 
     def elapsed_time(self, end_event):
-        r"""Returns the time elapsed in milliseconds before the event was recorded."""
-        time_ms = ctypes.c_float()
-        check_error(cudart().cudaEventElapsedTime(
-            ctypes.byref(time_ms), self, end_event))
-        return time_ms.value
+        r"""Returns the time elapsed in milliseconds after the event was
+        recorded and before the end_event was recorded.
+        """
+        return super(Event, self).elapsed_time(end_event)
 
     def synchronize(self):
-        r"""Synchronizes with the event."""
-        check_error(cudart().cudaEventSynchronize(self))
+        r"""Waits for the event to complete.
+
+        Waits until the completion of all work currently captured in this event.
+        This prevents the CPU thread from proceeding until the event completes.
+
+         .. note:: This is a wrapper around ``cudaEventSynchronize()``: see `CUDA
+           documentation`_ for more info.
+
+        .. _CUDA documentation:
+           https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
+        """
+        super(Event, self).synchronize()
 
     def ipc_handle(self):
-        r"""Returns an IPC handle of this event."""
-        handle = EventHandle()
-        check_error(cudart().cudaIpcGetEventHandle(ctypes.byref(handle), self))
-        return handle
+        r"""Returns an IPC handle of this event. If not recorded yet, the event
+        will use the current device. """
+        return super(Event, self).ipc_handle()
+
+    @property
+    def _as_parameter_(self):
+        return ctypes.c_void_p(self.cuda_event)
 
     def __repr__(self):
         return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)
index b55e4ae..03d51ce 100644 (file)
@@ -67,12 +67,13 @@ class SharedCache(dict):
 shared_cache = SharedCache()
 
 
-def rebuild_event(handle):
-    return torch.cuda.Event(_handle=handle)
+def rebuild_event(device, handle):
+    return torch.cuda.Event.from_ipc_handle(device, handle)
 
 
 def reduce_event(event):
-    return (rebuild_event, (event.ipc_handle(),))
+    handle = event.ipc_handle()
+    return (rebuild_event, (event.device, handle))
 
 
 def rebuild_tensor(cls, storage, metadata):