/*
* CUDAEvents are movable not copyable wrappers around CUDA's events.
*
-* CUDAEvents are constructed lazily when recorded on streams. The events
-* have a device, and this device is acquired from the first recording stream.
-* Later streams that record to the event must share this device, but streams
-* on any device can wait on the event.
+* CUDAEvents are constructed lazily when first recorded unless it is
+* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
+* device is acquired from the first recording stream. However, if reconstructed
+* from a handle, the device should be explicitly specified; or if ipc_handle() is
+* called before the event is ever recorded, it will use the current device.
+* Later streams that record the event must match this device.
*/
struct AT_CUDA_API CUDAEvent {
// Constants
CUDAEvent(unsigned int flags = DEFAULT_FLAGS)
: flags_{flags} { }
+ CUDAEvent(
+ DeviceIndex device_index, const cudaIpcEventHandle_t* handle) {
+ #ifndef __HIP_PLATFORM_HCC__
+ device_index_ = device_index;
+ CUDAGuard guard(device_index_);
+
+ AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
+ is_created_ = true;
+ #else
+ AT_ERROR("cuIpcOpenEventHandle with HIP is not supported");
+ #endif
+ }
+
// Note: event destruction done on creating device to avoid creating a
// CUDA context on other devices.
~CUDAEvent() {
try {
if (is_created_) {
- CUDAGuard device_guard(static_cast<int16_t>(device_index_));
+ CUDAGuard guard(device_index_);
cudaEventDestroy(event_);
}
} catch (...) { /* No throw */ }
return left.event_ < right.event_;
}
+ at::Device device() const {
+ return at::Device(at::kCUDA, device_index_);
+ }
+
bool isCreated() const { return is_created_; }
- int64_t device() const { return device_index_; }
+ DeviceIndex device_index() const {return device_index_;}
cudaEvent_t event() const { return event_; }
// Note: cudaEventQuery can be safely called from any device
- bool happened() const {
- return (was_recorded_ && cudaEventQuery(event_) == cudaSuccess);
+ bool query() const {
+ if (!is_created_) {
+ return true;
+ }
+
+ cudaError_t err = cudaEventQuery(event_);
+ if (err == cudaSuccess) {
+ return true;
+ } else if (err != cudaErrorNotReady) {
+ C10_CUDA_CHECK(err);
+ }
+
+ return false;
}
void record() { record(getCurrentCUDAStream()); }
if (!was_recorded_) record(stream);
}
- // Note: cudaEventRecord must be called on the same device as the stream.
+ // Note: cudaEventRecord must be called on the same device as the event.
void record(const CUDAStream& stream) {
- CUDAGuard guard(static_cast<int16_t>(stream.device_index()));
-
- if (is_created_) {
- AT_ASSERT(device_index_ == stream.device_index());
- } else {
- AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
- is_created_ = true;
- device_index_ = stream.device_index();
+ if (!is_created_) {
+ createEvent(stream.device_index());
}
+ AT_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_,
+ " does not match recording stream's device ", stream.device_index(), ".");
+ CUDAGuard guard(device_index_);
AT_CUDA_CHECK(cudaEventRecord(event_, stream));
was_recorded_ = true;
}
// The event has no actual GPU resources associated with it.
void block(const CUDAStream& stream) {
if (is_created_) {
- CUDAGuard guard(static_cast<int16_t>(stream.device_index()));
+ CUDAGuard guard(stream.device_index());
AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
}
}
+ // Note: cudaEventElapsedTime can be safely called from any device
+ float elapsed_time(const CUDAEvent& other) const {
+ AT_CHECK(is_created_ && other.isCreated(),
+ "Both events must be recorded before calculating elapsed time.");
+ float time_ms = 0;
+ // raise cudaErrorNotReady if either event is recorded but not yet completed
+ AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
+ return time_ms;
+ }
+
+ // Note: cudaEventSynchronize can be safely called from any device
+ void synchronize() const {
+ if (is_created_) {
+ AT_CUDA_CHECK(cudaEventSynchronize(event_));
+ }
+ }
+
+ // Note: cudaIpcGetEventHandle must be called on the same device as the event
+ void ipc_handle(cudaIpcEventHandle_t * handle) {
+ #ifndef __HIP_PLATFORM_HCC__
+ if (!is_created_) {
+ // this CUDAEvent object was initially constructed from flags but event_
+ // is not created yet.
+ createEvent(getCurrentCUDAStream().device_index());
+ }
+ CUDAGuard guard(device_index_);
+ AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
+ #else
+ AT_ERROR("cuIpcGetEventHandle with HIP is not supported");
+ #endif
+ }
+
private:
unsigned int flags_ = DEFAULT_FLAGS;
bool is_created_ = false;
bool was_recorded_ = false;
- int64_t device_index_ = -1;
+ DeviceIndex device_index_ = -1;
cudaEvent_t event_;
+ void createEvent(DeviceIndex device_index) {
+ device_index_ = device_index;
+ CUDAGuard guard(device_index_);
+ AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
+ is_created_ = true;
+ }
+
void moveHelper(CUDAEvent&& other) {
std::swap(flags_, other.flags_);
std::swap(is_created_, other.is_created_);
const auto stream = at::cuda::getStreamFromPool();
at::cuda::CUDAEvent event;
- ASSERT_FALSE(event.happened());
+ ASSERT_TRUE(event.query());
event.recordOnce(stream);
event.block(wait_stream1);
cudaStreamSynchronize(wait_stream0);
- ASSERT_TRUE(event.happened());
+ ASSERT_TRUE(event.query());
}
// Cross-Device Events
event0 = std::move(event1);
- ASSERT_EQ_CUDA(event0.device(), 1);
+ ASSERT_EQ_CUDA(event0.device(), at::Device(at::kCUDA, 1));
event0.block(stream0);
cudaStreamSynchronize(stream0);
- ASSERT_TRUE(event0.happened());
+ ASSERT_TRUE(event0.query());
}
StreamId id() const { return stream_.id(); }
bool query() const {
- DeviceGuard device_guard{stream_.device()};
+ DeviceGuard guard{stream_.device()};
cudaError_t err = cudaStreamQuery(stream());
- if (err == cudaErrorNotReady) {
- return false;
- } else if (err != cudaSuccess) {
+ if (err == cudaSuccess) {
+ return true;
+ } else if (err != cudaErrorNotReady) {
C10_CUDA_CHECK(err);
}
- return true;
+ return false;
+ }
+
+ void synchronize() const {
+ DeviceGuard guard{stream_.device()};
+ C10_CUDA_CHECK(cudaStreamSynchronize(stream()));
+ }
+
+ int priority() const {
+ #ifndef __HIP_PLATFORM_HCC__
+ DeviceGuard guard{stream_.device()};
+ int priority = 0;
+ C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
+ return priority;
+ #else
+ AT_ERROR("cuStreamGetPriority with HIP is not supported");
+ #endif
}
/// Explicit conversion to cudaStream_t.
return CUDAStream(Stream::unpack(bits));
}
+ static std::tuple<int, int> priority_range() {
+ #ifndef __HIP_PLATFORM_HCC__
+ int least_priority, greatest_priority;
+ C10_CUDA_CHECK(
+ cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
+ return std::make_tuple(least_priority, greatest_priority);
+ #else
+ AT_ERROR("cuDeviceGetStreamPriorityRange with HIP is not supported");
+ #endif
+ }
+
// Deleted for now; use CUDAEvent::block instead
// void synchronize_with(const CUDAEvent& event) const;
self.assertTrue(s0.query())
self.assertFalse(s1.query())
- with torch.cuda.device(d1):
+ # deliberately using a different device
+ with torch.cuda.device(d0):
s1.synchronize()
self.assertTrue(s0.query())
self.assertNotEqual(hash(s0), hash(s3))
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
+ @skipIfRocm
+ def test_streams_priority(self):
+ low, high = torch.cuda.Stream.priority_range()
+ s0 = torch.cuda.Stream(device=0, priority=low)
+
+ self.assertEqual(low, s0.priority)
+ self.assertEqual(0, s0.device)
+
+ s1 = torch.cuda.Stream(device=1, priority=high)
+
+ self.assertEqual(high, s1.priority)
+ self.assertEqual(1, s1.device)
+
+ @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_tensor_device(self):
self.assertEqual(torch.cuda.FloatTensor(1).get_device(), 0)
self.assertEqual(torch.cuda.FloatTensor(1, device=1).get_device(), 1)
self.assertTrue(event.query())
self.assertGreater(start_event.elapsed_time(event), 0)
+ @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+ @skipIfRocm
+ def test_events_wait(self):
+ d0 = torch.device('cuda:0')
+ d1 = torch.device('cuda:1')
+
+ with torch.cuda.device(d0):
+ s0 = torch.cuda.current_stream()
+ torch.cuda._sleep(50000000) # spin for about 50 ms on device1
+ e0 = torch.cuda.Event()
+ s0.record_event(e0)
+
+ with torch.cuda.device(d1):
+ s1 = torch.cuda.current_stream()
+
+ self.assertFalse(s0.query())
+ self.assertTrue(s1.query())
+
+ s1.wait_event(e0)
+ s1.synchronize()
+
+ self.assertTrue(e0.query())
+ self.assertTrue(s0.query())
+ self.assertTrue(s1.query())
+
+ @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+ @skipIfRocm
+ def test_events_multi_gpu_query(self):
+ d0 = torch.device('cuda:0')
+ d1 = torch.device('cuda:1')
+
+ with torch.cuda.device(d0):
+ s0 = torch.cuda.current_stream()
+ e0 = s0.record_event()
+
+ with torch.cuda.device(d1):
+ s1 = torch.cuda.current_stream()
+ torch.cuda._sleep(50000000) # spin for about 50 ms on device1
+ e1 = s1.record_event()
+
+ self.assertTrue(e0.query())
+ self.assertFalse(e1.query())
+
+ with torch.cuda.device(d0):
+ self.assertTrue(e0.query())
+ self.assertFalse(e1.query())
+
+ with torch.cuda.device(d1):
+ self.assertTrue(e0.query())
+ self.assertFalse(e1.query())
+
+ # deliberately using a different device
+ with torch.cuda.device(d0):
+ e1.synchronize()
+
+ self.assertTrue(e0.query())
+ self.assertTrue(e1.query())
+
+ with torch.cuda.device(d0):
+ self.assertTrue(e0.query())
+ self.assertTrue(e1.query())
+
+ with torch.cuda.device(d1):
+ self.assertTrue(e0.query())
+ self.assertTrue(e1.query())
+
+ @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+ @skipIfRocm
+ def test_events_multi_gpu_elapsed_time(self):
+ d0 = torch.device('cuda:0')
+ d1 = torch.device('cuda:1')
+
+ with torch.cuda.device(d0):
+ s0 = torch.cuda.current_stream()
+ e0 = torch.cuda.Event(enable_timing=True)
+ torch.cuda._sleep(10) # spin for about 50 ms on device1
+ s0.record_event(e0)
+
+ with torch.cuda.device(d1):
+ s1 = torch.cuda.current_stream()
+ e1 = torch.cuda.Event(enable_timing=True)
+ torch.cuda._sleep(30000000) # spin for about 50 ms on device1
+ s1.record_event(e1)
+
+ e0.synchronize()
+ e1.synchronize()
+ with torch.cuda.device(d0):
+ with self.assertRaises(RuntimeError):
+ self.assertGreater(e0.elapsed_time(e1), 0)
+
+ with torch.cuda.device(d1):
+ with self.assertRaises(RuntimeError):
+ self.assertGreater(e0.elapsed_time(e1), 0)
+
+ with torch.cuda.device(d0):
+ s0 = torch.cuda.current_stream()
+ e2 = torch.cuda.Event(enable_timing=True)
+ torch.cuda._sleep(30000000) # spin for about 50 ms on device1
+ s0.record_event(e2)
+ s0.synchronize()
+
+ self.assertGreater(e0.elapsed_time(e2), 0)
+
+ # deliberately calling from a different device
+ with torch.cuda.device(d1):
+ self.assertGreater(e0.elapsed_time(e2), 0)
+
@skipIfRocm
def test_record_stream(self):
cycles_per_ms = get_cycles_per_ms()
self.assertEqual(list(tensor), [4, 4, 4, 4])
p.join()
+ def _test_event_multiprocess_child(event, p2c, c2p):
+ c2p.put(0) # notify parent child is ready
+ p2c.get() # wait for record in parent
+ event.synchronize()
+ c2p.put(1) # notify parent synchronization is done
+
+ @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+ don't support multiprocessing with spawn start method")
+ @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+ def test_event_multiprocess(self):
+ event = torch.cuda.Event(enable_timing=False, interprocess=True)
+ self.assertTrue(event.query())
+
+ ctx = mp.get_context('spawn')
+ p2c = ctx.SimpleQueue()
+ c2p = ctx.SimpleQueue()
+ p = ctx.Process(
+ target=TestMultiprocessing._test_event_multiprocess_child,
+ args=(event, p2c, c2p))
+ p.start()
+
+ c2p.get() # wait for until child process is ready
+ torch.cuda._sleep(50000000) # spin for about 50 ms
+ event.record()
+ p2c.put(0) # notify child event is recorded
+
+ self.assertFalse(event.query())
+ c2p.get() # wait for synchronization in child
+ self.assertTrue(event.query())
+ p.join()
+
+ @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+ don't support multiprocessing with spawn start method")
+ @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+ def test_event_handle_multi_gpu(self):
+ d0 = torch.device('cuda:0')
+ d1 = torch.device('cuda:1')
+ with torch.cuda.device(d0):
+ e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
+
+ with torch.cuda.device(d1):
+ # create handle on different device from un-recorded event
+ e0.ipc_handle()
+
+ with torch.cuda.device(d0):
+ e1 = torch.cuda.Event(enable_timing=False, interprocess=True)
+ stream = torch.cuda.Stream()
+ torch.cuda._sleep(50000000) # spin for about 50 ms
+ e1.record(stream)
+
+ with torch.cuda.device(d1):
+ # create handle on different device from recorded event
+ e1.ipc_handle()
+
+ def _test_event_handle_importer_consumer(handle, p2c, c2p):
+ e1 = torch.cuda.Event.from_ipc_handle(
+ torch.cuda.current_device(), handle)
+ c2p.put(0) # notify parent child is ready
+ p2c.get() # wait for record in parent
+ e1.synchronize()
+ c2p.put(1) # nofity synchronization is done in child
+ p2c.get() # wait for parent to finish before destructing child event
+
+ @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+ don't support multiprocessing with spawn start method")
+ @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+ def test_event_handle_importer(self):
+ e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
+ self.assertTrue(e0.query())
+
+ ctx = mp.get_context('spawn')
+ p2c = ctx.SimpleQueue()
+ c2p = ctx.SimpleQueue()
+ p = ctx.Process(
+ target=TestMultiprocessing._test_event_handle_importer_consumer,
+ args=(e0.ipc_handle(), p2c, c2p))
+ p.start()
+
+ c2p.get() # wait for child to become ready
+ torch.cuda._sleep(50000000) # spin for about 50 ms
+ e0.record()
+ p2c.put(0) # notify child event is recorded
+
+ self.assertFalse(e0.query())
+ c2p.get() # wait for synchronization in child
+ self.assertTrue(e0.query())
+ p2c.put(1) # notify child that parent is done
+ p.join()
+
+ def _test_event_handle_exporter_consumer(handle, p2c, c2p):
+ stream = torch.cuda.Stream()
+ with torch.cuda.stream(stream):
+ e1 = torch.cuda.Event.from_ipc_handle(
+ torch.cuda.current_device(), handle)
+ torch.cuda._sleep(50000000) # spin for about 50 ms
+ e1.record()
+ c2p.put(0)
+ # wait for parent process finished synchronization before
+ # destructing e1
+ p2c.get()
+
+ @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+ don't support multiprocessing with spawn start method")
+ @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
+ def test_event_handle_exporter(self):
+ e0 = torch.cuda.Event(enable_timing=False, interprocess=True)
+
+ ctx = mp.get_context('spawn')
+ p2c = ctx.SimpleQueue()
+ c2p = ctx.SimpleQueue()
+ p = ctx.Process(
+ target=TestMultiprocessing._test_event_handle_exporter_consumer,
+ args=(e0.ipc_handle(), p2c, c2p))
+ p.start()
+ # wait for event in child process is recorded
+ c2p.get()
+
+ self.assertFalse(e0.query())
+ e0.synchronize()
+ self.assertTrue(e0.query())
+ p2c.put(0)
+ p.join()
+
def _test_empty_tensor_sharing(self, dtype, device):
q = mp.Queue()
empty = torch.tensor([], dtype=dtype, device=device)
#include "torch/csrc/jit/tracer.h"
#ifdef USE_CUDA
#include "torch/csrc/cuda/Stream.h"
+#include "torch/csrc/cuda/Event.h"
#endif
#include "torch/csrc/utils/cuda_lazy_init.h"
#include "torch/csrc/utils/object_ptr.h"
${TORCH_SRC_DIR}/csrc/cuda/Module.cpp
${TORCH_SRC_DIR}/csrc/cuda/Storage.cpp
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
+ ${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/Module.cpp
${TORCH_SRC_DIR}/csrc/cuda/Storage.cpp
${TORCH_SRC_DIR}/csrc/cuda/Stream.cpp
+ ${TORCH_SRC_DIR}/csrc/cuda/Event.cpp
${TORCH_SRC_DIR}/csrc/cuda/utils.cpp
${TORCH_SRC_DIR}/csrc/cuda/comm.cpp
${TORCH_SRC_DIR}/csrc/cuda/python_comm.cpp
bool THCPCharStorage_init(PyObject *module);
bool THCPByteStorage_init(PyObject *module);
-bool THCPStream_init(PyObject *module);
+void THCPStream_init(PyObject *module);
+void THCPEvent_init(PyObject *module);
#ifdef USE_CUDA
PyMethodDef* THCPModule_methods();
ASSERT_TRUE(THCPCharStorage_init(module));
ASSERT_TRUE(THCPByteStorage_init(module));
- ASSERT_TRUE(THCPStream_init(module));
+ THCPStream_init(module);
+ THCPEvent_init(module);
#endif
auto set_module_attr = [&](const char* name, PyObject* v, bool incref = true) {
--- /dev/null
+#include <torch/csrc/cuda/Event.h>
+#include <torch/csrc/cuda/Stream.h>
+
+#include <torch/csrc/THP.h>
+#include <torch/csrc/cuda/Module.h>
+
+#include <c10/cuda/CUDAGuard.h>
+
+#include <structmember.h>
+#include <cuda_runtime_api.h>
+
+PyObject *THCPEventClass = nullptr;
+
+static PyObject * THCPEvent_pynew(
+ PyTypeObject *type, PyObject *args, PyObject *kwargs) {
+ HANDLE_TH_ERRORS
+ unsigned char enable_timing = 0;
+ unsigned char blocking = 0;
+ unsigned char interprocess = 0;
+
+ static char *kwlist[] =
+ {"enable_timing", "blocking", "interprocess", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|bbb", kwlist,
+ &enable_timing, &blocking, &interprocess)) {
+ return nullptr;
+ }
+
+ THPObjectPtr ptr(type->tp_alloc(type, 0));
+ if (!ptr) {
+ return nullptr;
+ }
+
+ THCPEvent* self = (THCPEvent *)ptr.get();
+ unsigned int flags =
+ (blocking ? cudaEventBlockingSync : cudaEventDefault) |
+ (enable_timing ? cudaEventDefault : cudaEventDisableTiming) |
+ (interprocess ? cudaEventInterprocess : cudaEventDefault);
+
+ new (&self->cuda_event) at::cuda::CUDAEvent(flags);
+
+ return (PyObject *)ptr.release();
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_from_ipc_handle(
+ PyTypeObject *type, PyObject *args) {
+ HANDLE_TH_ERRORS
+ long long device_index = -1;
+ const char *handle_bytes = nullptr;
+ int handle_size = 0;
+
+ // cannot use bool 'p' and bytearray 'Y' as they are not available in Python 2
+ if (!PyArg_ParseTuple(
+ args, "Ls#", &device_index, &handle_bytes, &handle_size)) {
+ return nullptr;
+ }
+
+ AT_CHECK(handle_size == sizeof(cudaIpcEventHandle_t),
+ "cudaIpcEventHandle_t expects byte-like object of size ",
+ sizeof(cudaIpcEventHandle_t), ", but got ", handle_size);
+ AT_CHECK(device_index >= 0, "Reconstructing event from handle requires "
+ "a non-negtive device index, but got ", device_index)
+
+ // no need to release the handle byte array as it is automatically managed
+ // by the corresponding THCPEvent python object.
+ // see https://docs.python.org/3/c-api/arg.html#strings-and-buffers
+
+ THPObjectPtr ptr(type->tp_alloc(type, 0));
+ if (!ptr) {
+ return nullptr;
+ }
+ THCPEvent* self = (THCPEvent *)ptr.get();
+
+ cudaIpcEventHandle_t handle;
+ std::memcpy(&handle, handle_bytes, handle_size);
+ new (&self->cuda_event) at::cuda::CUDAEvent(device_index, &handle);
+
+ return (PyObject *)ptr.release();
+ END_HANDLE_TH_ERRORS
+}
+
+static void THCPEvent_dealloc(THCPEvent *self) {
+ self->cuda_event.~CUDAEvent();
+ Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+static PyObject * THCPEvent_get_cuda_event(THCPEvent *self) {
+ HANDLE_TH_ERRORS
+ return PyLong_FromVoidPtr(self->cuda_event.event());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_get_device(THCPEvent *self) {
+ HANDLE_TH_ERRORS
+ return THPUtils_packInt64(self->cuda_event.device_index());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_record(THCPEvent *self, THCPStream *stream) {
+ HANDLE_TH_ERRORS
+ self->cuda_event.record(stream->cuda_stream);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_wait(THCPEvent *self, THCPStream *stream) {
+ HANDLE_TH_ERRORS
+ self->cuda_event.block(stream->cuda_stream);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_query(THCPEvent *self) {
+ HANDLE_TH_ERRORS
+ return PyBool_FromLong(self->cuda_event.query());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_elapsed_time(THCPEvent *self, THCPEvent *other) {
+ HANDLE_TH_ERRORS
+ return PyFloat_FromDouble(self->cuda_event.elapsed_time(other->cuda_event));
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_synchronize(THCPEvent *self) {
+ HANDLE_TH_ERRORS
+ self->cuda_event.synchronize();
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPEvent_ipc_handle(THCPEvent *self) {
+ HANDLE_TH_ERRORS
+ cudaIpcEventHandle_t handle;
+ self->cuda_event.ipc_handle(&handle);
+ return PyBytes_FromStringAndSize((const char *)&handle, sizeof(handle));
+ END_HANDLE_TH_ERRORS
+}
+
+static struct PyGetSetDef THCPEvent_properties[] = {
+ {"device", (getter)THCPEvent_get_device, nullptr, nullptr, nullptr},
+ {"cuda_event", (getter)THCPEvent_get_cuda_event, nullptr, nullptr, nullptr},
+ {nullptr}
+};
+
+static PyMethodDef THCPEvent_methods[] = {
+ {(char*)"from_ipc_handle", (PyCFunction)THCPEvent_from_ipc_handle,
+ METH_CLASS | METH_VARARGS, nullptr},
+ {(char*)"record", (PyCFunction)THCPEvent_record, METH_O, nullptr},
+ {(char*)"wait", (PyCFunction)THCPEvent_wait, METH_O, nullptr},
+ {(char*)"query", (PyCFunction)THCPEvent_query, METH_NOARGS, nullptr},
+ {(char*)"elapsed_time", (PyCFunction)THCPEvent_elapsed_time, METH_O, nullptr},
+ {(char*)"synchronize", (PyCFunction)THCPEvent_synchronize,
+ METH_NOARGS, nullptr},
+ {(char*)"ipc_handle", (PyCFunction)THCPEvent_ipc_handle,
+ METH_NOARGS, nullptr},
+ {nullptr}
+};
+
+PyTypeObject THCPEventType = {
+ PyVarObject_HEAD_INIT(nullptr, 0)
+ "torch._C._CudaEventBase", /* tp_name */
+ sizeof(THCPEvent), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ (destructor)THCPEvent_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_reserved */
+ 0, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ 0, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
+ nullptr, /* tp_doc */
+ 0, /* tp_traverse */
+ 0, /* tp_clear */
+ 0, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ 0, /* tp_iter */
+ 0, /* tp_iternext */
+ THCPEvent_methods, /* tp_methods */
+ 0, /* tp_members */
+ THCPEvent_properties, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ 0, /* tp_dictoffset */
+ 0, /* tp_init */
+ 0, /* tp_alloc */
+ THCPEvent_pynew, /* tp_new */
+};
+
+void THCPEvent_init(PyObject *module) {
+ THCPEventClass = (PyObject*)&THCPEventType;
+ if (PyType_Ready(&THCPEventType) < 0) {
+ throw python_error();
+ }
+ Py_INCREF(&THCPEventType);
+ if (PyModule_AddObject(
+ module, "_CudaEventBase", (PyObject *)&THCPEventType) < 0) {
+ throw python_error();
+ }
+}
--- /dev/null
+#ifndef THCP_EVENT_INC
+#define THCP_EVENT_INC
+
+#include <ATen/cuda/CUDAEvent.h>
+#include <torch/csrc/python_headers.h>
+#include <THC/THC.h>
+
+struct THCPEvent {
+ PyObject_HEAD
+ at::cuda::CUDAEvent cuda_event;
+};
+extern PyObject *THCPEventClass;
+
+void THCPEvent_init(PyObject *module);
+
+inline bool THCPEvent_Check(PyObject* obj) {
+ return THCPEventClass && PyObject_IsInstance(obj, THCPEventClass);
+}
+
+#endif // THCP_EVENT_INC
PyObject *THCPStreamClass = nullptr;
-static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
-{
+static PyObject * THCPStream_pynew(
+ PyTypeObject *type, PyObject *args, PyObject *kwargs) {
HANDLE_TH_ERRORS
int current_device;
uint64_t cdata = 0;
static char *kwlist[] = {"priority", "_cdata", nullptr};
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|iK", kwlist, &priority, &cdata)) {
+ if (!PyArg_ParseTupleAndKeywords(
+ args, kwargs, "|iK", kwlist, &priority, &cdata)) {
return nullptr;
}
at::cuda::CUDAStream stream =
cdata ?
at::cuda::CUDAStream::unpack(cdata) :
- at::cuda::getStreamFromPool(/* isHighPriority */ priority < 0 ? true : false);
+ at::cuda::getStreamFromPool(
+ /* isHighPriority */ priority < 0 ? true : false);
THCPStream* self = (THCPStream *)ptr.get();
self->cdata = stream.pack();
Py_TYPE(self)->tp_free((PyObject*)self);
}
-static PyObject * THPVariable_get_device(THCPStream *self) {
+static PyObject * THCPStream_get_device(THCPStream *self) {
HANDLE_TH_ERRORS
return THPUtils_packInt64(self->cuda_stream.device_index());
END_HANDLE_TH_ERRORS
}
-static PyObject * THPVariable_get_cuda_stream(THCPStream *self) {
+static PyObject * THCPStream_get_cuda_stream(THCPStream *self) {
HANDLE_TH_ERRORS
return PyLong_FromVoidPtr(self->cuda_stream.stream());
END_HANDLE_TH_ERRORS
}
+static PyObject * THCPStream_get_priority(THCPStream *self) {
+ HANDLE_TH_ERRORS
+ return PyLong_FromLong(self->cuda_stream.priority());
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THCPStream_priority_range() {
+ HANDLE_TH_ERRORS
+ int least_priority, greatest_priority;
+ std::tie(least_priority, greatest_priority) =
+ at::cuda::CUDAStream::priority_range();
+ return Py_BuildValue("(ii)", least_priority, greatest_priority);
+ END_HANDLE_TH_ERRORS
+}
+
static PyObject * THCPStream_query(THCPStream *self) {
HANDLE_TH_ERRORS
return PyBool_FromLong(self->cuda_stream.query());
END_HANDLE_TH_ERRORS
}
+static PyObject * THCPStream_synchronize(THCPStream *self) {
+ HANDLE_TH_ERRORS
+ self->cuda_stream.synchronize();
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
static PyObject * THCPStream_eq(THCPStream *self, THCPStream *other) {
HANDLE_TH_ERRORS
return PyBool_FromLong(self->cuda_stream == other->cuda_stream);
}
static struct PyMemberDef THCPStream_members[] = {
- {(char*)"_cdata", T_ULONGLONG, offsetof(THCPStream, cdata), READONLY, nullptr},
+ {(char*)"_cdata",
+ T_ULONGLONG, offsetof(THCPStream, cdata), READONLY, nullptr},
{nullptr}
};
-static struct PyGetSetDef THPVariable_properties[] = {
- {"device", (getter)THPVariable_get_device, nullptr, nullptr, nullptr},
+static struct PyGetSetDef THCPStream_properties[] = {
+ {"device", (getter)THCPStream_get_device, nullptr, nullptr, nullptr},
{"cuda_stream",
- (getter)THPVariable_get_cuda_stream, nullptr, nullptr, nullptr},
+ (getter)THCPStream_get_cuda_stream, nullptr, nullptr, nullptr},
+ {"priority", (getter)THCPStream_get_priority, nullptr, nullptr, nullptr},
{nullptr}
};
static PyMethodDef THCPStream_methods[] = {
{(char*)"query", (PyCFunction)THCPStream_query, METH_NOARGS, nullptr},
+ {(char*)"synchronize",
+ (PyCFunction)THCPStream_synchronize, METH_NOARGS, nullptr},
+ {(char*)"priority_range",
+ (PyCFunction)THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr},
{(char*)"__eq__", (PyCFunction)THCPStream_eq, METH_O, nullptr},
{nullptr}
};
PyTypeObject THCPStreamType = {
PyVarObject_HEAD_INIT(nullptr, 0)
- "torch._C._CudaStreamBase", /* tp_name */
+ "torch._C._CudaStreamBase", /* tp_name */
sizeof(THCPStream), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THCPStream_dealloc, /* tp_dealloc */
0, /* tp_iternext */
THCPStream_methods, /* tp_methods */
THCPStream_members, /* tp_members */
- THPVariable_properties, /* tp_getset */
+ THCPStream_properties, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
};
-bool THCPStream_init(PyObject *module)
+void THCPStream_init(PyObject *module)
{
THCPStreamClass = (PyObject*)&THCPStreamType;
- if (PyType_Ready(&THCPStreamType) < 0)
- return false;
+ if (PyType_Ready(&THCPStreamType) < 0) {
+ throw python_error();
+ }
Py_INCREF(&THCPStreamType);
- PyModule_AddObject(module, "_CudaStreamBase", (PyObject *)&THCPStreamType);
- return true;
+ if (PyModule_AddObject(
+ module, "_CudaStreamBase", (PyObject *)&THCPStreamType) < 0) {
+ throw python_error();
+ }
}
};
extern PyObject *THCPStreamClass;
-bool THCPStream_init(PyObject *module);
+void THCPStream_init(PyObject *module);
inline bool THCPStream_Check(PyObject* obj) {
return THCPStreamClass && PyObject_IsInstance(obj, THCPStreamClass);
#include <torch/csrc/cuda/Module.h>
#include <torch/csrc/cuda/Storage.h>
#include <torch/csrc/cuda/Stream.h>
+#include <torch/csrc/cuda/Event.h>
#ifdef _THP_CORE
#include <torch/csrc/cuda/utils.h>
#endif
torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
+ torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
@staticmethod
import ctypes
import torch
-from . import cudart, check_error, cudaStatus
-from ._utils import _get_device_index
-from torch._C import _add_docstr
class Stream(torch._C._CudaStreamBase):
.. _CUDA documentation:
http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
"""
- check_error(cudart().cudaStreamWaitEvent(self, event, ctypes.c_int(0)))
+ event.wait(self)
def wait_stream(self, stream):
r"""Synchronizes with another stream.
"""
if event is None:
event = Event()
- check_error(cudart().cudaEventRecord(event, self))
+ event.record(self)
return event
def query(self):
.. _CUDA documentation:
http://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
"""
- check_error(cudart().cudaStreamSynchronize(self))
-
- @staticmethod
- def priority_range():
- least_priority = ctypes.c_int()
- greatest_priority = ctypes.c_int()
- check_error(cudart().cudaDeviceGetStreamPriorityRange(
- ctypes.byref(least_priority), ctypes.byref(greatest_priority)))
- return (least_priority.value, greatest_priority.value)
-
- @property
- def priority(self):
- priority = ctypes.c_int()
- check_error(cudart().cudaStreamGetPriority(self, ctypes.byref(priority)))
- return priority.value
+ super(Stream, self).synchronize()
@property
def _as_parameter_(self):
.format(self.device, self.cuda_stream))
-class EventHandle(ctypes.Structure):
- IPC_HANDLE_SIZE = 64
- _fields_ = [('reserved', ctypes.c_char * IPC_HANDLE_SIZE)]
+class Event(torch._C._CudaEventBase):
+ r"""Wrapper around a CUDA event.
+ CUDA events are synchronization markers that can be used to monitor the
+ device's progress, to accurately measure timing, and to synchronize CUDA
+ streams.
-class Event(object):
- r"""Wrapper around CUDA event.
+ The underlying CUDA events are lazily initialized when the event is first
+ recorded or exported to another process. After creation, only streams on the
+ same device may record the event. However, streams on any device can wait on
+ the event.
Arguments:
- enable_timing (bool): indicates if the event should measure time
+ enable_timing (bool, optional): indicates if the event should measure time
(default: ``False``)
- blocking (bool): if ``True``, :meth:`wait` will be blocking (default: ``False``)
+ blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
interprocess (bool): if ``True``, the event can be shared between processes
(default: ``False``)
+
+ .. _CUDA documentation:
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
"""
- DEFAULT = 0x0
- BLOCKING_SYNC = 0x1
- DISABLE_TIMING = 0x2
- INTERPROCESS = 0x4
-
- def __init__(self, enable_timing=False, blocking=False, interprocess=False,
- _handle=None):
- flags = Event.DEFAULT
- if not enable_timing:
- flags |= Event.DISABLE_TIMING
- if blocking:
- flags |= Event.BLOCKING_SYNC
- if interprocess:
- flags |= Event.INTERPROCESS
-
- ptr = ctypes.c_void_p()
- self._cudart = cudart()
- if _handle:
- check_error(self._cudart.cudaIpcOpenEventHandle(ctypes.byref(ptr), _handle))
- else:
- check_error(self._cudart.cudaEventCreateWithFlags(ctypes.byref(ptr), ctypes.c_uint(flags)))
- self._as_parameter_ = ptr
-
- def __del__(self):
- if hasattr(self, '_as_parameter_'):
- check_error(self._cudart.cudaEventDestroy(self._as_parameter_))
- del self._as_parameter_
+ def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
+ return super(Event, cls).__new__(
+ cls,
+ enable_timing=enable_timing, blocking=blocking, interprocess=interprocess)
+
+ @classmethod
+ def from_ipc_handle(cls, device, handle):
+ r"""Reconstruct an event from an IPC handle on the given device."""
+ return super(Event, cls).from_ipc_handle(device, handle)
def record(self, stream=None):
- r"""Records the event in a given stream."""
+ r"""Records the event in a given stream.
+
+ Uses ``torch.cuda.current_stream()`` if no stream is specified. The
+ stream's device must match the event's device."""
if stream is None:
stream = torch.cuda.current_stream()
- stream.record_event(self)
+ super(Event, self).record(stream)
def wait(self, stream=None):
- r"""Makes a given stream wait for the event."""
+ r"""Makes all future work submitted to the given stream wait for this
+ event.
+
+ Use ``torch.cuda.current_stream()`` if no stream is specified."""
if stream is None:
stream = torch.cuda.current_stream()
- stream.wait_event(self)
+ super(Event, self).wait(stream)
def query(self):
- r"""Checks if the event has been recorded.
+ r"""Checks if all work currently captured by event has completed.
Returns:
- A boolean indicating if the event has been recorded.
+ A boolean indicating if all work currently captured by event has
+ completed.
"""
- res = cudart().cudaEventQuery(self)
- if res == cudaStatus.ERROR_NOT_READY:
- return False
- check_error(res)
- return True
+ return super(Event, self).query()
def elapsed_time(self, end_event):
- r"""Returns the time elapsed in milliseconds before the event was recorded."""
- time_ms = ctypes.c_float()
- check_error(cudart().cudaEventElapsedTime(
- ctypes.byref(time_ms), self, end_event))
- return time_ms.value
+ r"""Returns the time elapsed in milliseconds after the event was
+ recorded and before the end_event was recorded.
+ """
+ return super(Event, self).elapsed_time(end_event)
def synchronize(self):
- r"""Synchronizes with the event."""
- check_error(cudart().cudaEventSynchronize(self))
+ r"""Waits for the event to complete.
+
+ Waits until the completion of all work currently captured in this event.
+ This prevents the CPU thread from proceeding until the event completes.
+
+ .. note:: This is a wrapper around ``cudaEventSynchronize()``: see `CUDA
+ documentation`_ for more info.
+
+ .. _CUDA documentation:
+ https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
+ """
+ super(Event, self).synchronize()
def ipc_handle(self):
- r"""Returns an IPC handle of this event."""
- handle = EventHandle()
- check_error(cudart().cudaIpcGetEventHandle(ctypes.byref(handle), self))
- return handle
+ r"""Returns an IPC handle of this event. If not recorded yet, the event
+ will use the current device. """
+ return super(Event, self).ipc_handle()
+
+ @property
+ def _as_parameter_(self):
+ return ctypes.c_void_p(self.cuda_event)
def __repr__(self):
return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)
shared_cache = SharedCache()
-def rebuild_event(handle):
- return torch.cuda.Event(_handle=handle)
+def rebuild_event(device, handle):
+ return torch.cuda.Event.from_ipc_handle(device, handle)
def reduce_event(event):
- return (rebuild_event, (event.ipc_handle(),))
+ handle = event.ipc_handle()
+ return (rebuild_event, (event.device, handle))
def rebuild_tensor(cls, storage, metadata):