#include <c10/cuda/CUDAStream.h>
-#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAFunctions.h>
+#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
-#include <mutex>
+#include <array>
#include <atomic>
#include <cstdint>
-#include <deque>
+#include <mutex>
#include <vector>
-#include <array>
+
+#include <iostream>
namespace c10 {
namespace cuda {
namespace {
-// Internal implementation is entirely hidden
-struct CUDAStreamInternals {
- CUDAStreamInternals() = default;
+// Internal implementation that leaks the stream. It's not intended to be used
+// outside of this file.
+struct LeakyStreamInternals {
+ LeakyStreamInternals() = default;
+ C10_DISABLE_COPY_AND_ASSIGN(LeakyStreamInternals);
+
+ ~LeakyStreamInternals() {
+ // NB: this code is invoked only in the destruction of global variables
+ // (since we never shrink the corresponding vectors). At this point the CUDA
+ // runtime might be already destroyed and invoking cudaStreamDestroy leads
+ // to a crash. It's likely an issue in CUDA, but to be safe - let's just
+ // "forget" the destruction.
- ~CUDAStreamInternals() {
- if (stream) cudaStreamDestroy(stream);
+ // if (stream) cudaStreamDestroy(stream);
}
DeviceIndex device_index = -1;
// Note: stream priority is not supported by HIP
// Note: lower numbers are higher priorities, zero is default priority
#ifndef __HIP_PLATFORM_HCC__
- static int kHighPriority = -1;
- static int kLowPriority = 0;
+static int kHighPriority = -1;
+static int kLowPriority = 0;
#endif // __HIP_PLATFORM_HCC__
// Default streams
static std::once_flag init_flag;
-static std::vector<CUDAStreamInternals> default_streams;
+static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS];
// Non-default streams
// Note: the number of CUDA devices is determined at run time,
// the low and high priority counters track, for each device, the next stream
// in the pool to be returned when a stream is requested (round-robin fashion
// , see the note in CUDAStream.h).
-static std::deque<std::once_flag> device_flags;
-static std::deque<std::atomic<uint32_t>> low_priority_counters;
-static std::deque<std::atomic<uint32_t>> high_priority_counters;
-static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> low_priority_streams;
-static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> high_priority_streams;
+//
+// unique_ptr<T[]> is used instead of vector<T> because T might be non-moveable
+// and non-copyable.
+static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
+static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
+static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
+static std::array<LeakyStreamInternals, kStreamsPerPool>
+ low_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
+static std::array<LeakyStreamInternals, kStreamsPerPool>
+ high_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
// Note [StreamId assignment]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
enum class StreamIdType : uint8_t {
DEFAULT = 0x0,
- LOW = 0x1,
- HIGH = 0x2,
+ LOW = 0x1,
+ HIGH = 0x2,
};
std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
}
StreamId makeStreamId(StreamIdType st, size_t si) {
- return (static_cast<StreamId>(st) << kStreamsPerPoolBits) | static_cast<StreamId>(si);
+ return (static_cast<StreamId>(st) << kStreamsPerPoolBits) |
+ static_cast<StreamId>(si);
}
template <typename T, typename A>
static bool pointer_within(const T* ptr, const A& arr) {
- return std::greater_equal<const T*>()(ptr, arr.data()) && std::less<const T*>()(ptr, arr.data() + arr.size());
+ return std::greater_equal<const T*>()(ptr, arr.data()) &&
+ std::less<const T*>()(ptr, arr.data() + arr.size());
}
-static StreamId CUDAStream_getStreamId(const CUDAStreamInternals* ptr) {
+static StreamId CUDAStream_getStreamId(const LeakyStreamInternals* ptr) {
// Hypothetically, we could store the stream ID in the stream. But that
// introduces a degree of freedom which could lead to bugs (where we
// misnumber streams in the pool, or overwrite the number). Better
// NB: Because ptr may not necessarily lie within the array, we must use
// std::less and similar templates to avoid UB that arises when
// doing an operator< comparison.
- if (pointer_within<CUDAStreamInternals>(ptr, low_priority_streams[device_index])) {
- return makeStreamId(StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
+ if (pointer_within<LeakyStreamInternals>(
+ ptr, low_priority_streams[device_index])) {
+ return makeStreamId(
+ StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
}
// Check if it's a high priority stream
- if (pointer_within<CUDAStreamInternals>(ptr, high_priority_streams[device_index])) {
- return makeStreamId(StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
+ if (pointer_within<LeakyStreamInternals>(
+ ptr, high_priority_streams[device_index])) {
+ return makeStreamId(
+ StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
}
- AT_ASSERTM(0, "Could not compute stream ID for ", ptr, " on device ", device_index,
- " (something has gone horribly wrong!)");
+ AT_ASSERTM(
+ 0,
+ "Could not compute stream ID for ",
+ ptr,
+ " on device ",
+ device_index,
+ " (something has gone horribly wrong!)");
}
// Thread-local current streams
-static thread_local CUDAStreamInternals** current_streams = nullptr;
+static thread_local LeakyStreamInternals** current_streams = nullptr;
// Populates global values and creates a default stream for each device.
// Note: the default stream on each device is signified by a nullptr,
// Warning: this function must only be called once!
static void initGlobalStreamState() {
num_gpus = device_count();
-
- // Resizes deques and vectors
- default_streams.resize(num_gpus);
- device_flags.resize(num_gpus);
- low_priority_counters.resize(num_gpus);
- high_priority_counters.resize(num_gpus);
- low_priority_streams.resize(num_gpus);
- high_priority_streams.resize(num_gpus);
+ // Check if the number of GPUs matches the expected compile-time max number
+ // of GPUs.
+ AT_ASSERTM(
+ num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
+ "Number of CUDA devices on the machine is larger than the compiled "
+ "max number of gpus expected (",
+ C10_COMPILE_TIME_MAX_GPUS,
+ "). Increase that and recompile.");
// Initializes default streams
for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
lowpri_stream.device_index = device_index;
hipri_stream.device_index = device_index;
- #ifndef __HIP_PLATFORM_HCC__
- C10_CUDA_CHECK(cudaStreamCreateWithPriority(
- &lowpri_stream.stream
- , kDefaultFlags
- , kLowPriority));
- C10_CUDA_CHECK(cudaStreamCreateWithPriority(
- &hipri_stream.stream
- , kDefaultFlags
- , kHighPriority));
- #else
- C10_CUDA_CHECK(cudaStreamCreateWithFlags(
- &lowpri_stream.stream
- , kDefaultFlags));
- C10_CUDA_CHECK(cudaStreamCreateWithFlags(
- &hipri_stream.stream
- , kDefaultFlags));
- #endif // __HIP_PLATFORM_HCC__
+#ifndef __HIP_PLATFORM_HCC__
+ C10_CUDA_CHECK(cudaStreamCreateWithPriority(
+ &lowpri_stream.stream, kDefaultFlags, kLowPriority));
+ C10_CUDA_CHECK(cudaStreamCreateWithPriority(
+ &hipri_stream.stream, kDefaultFlags, kHighPriority));
+#else
+ C10_CUDA_CHECK(
+ cudaStreamCreateWithFlags(&lowpri_stream.stream, kDefaultFlags));
+ C10_CUDA_CHECK(
+ cudaStreamCreateWithFlags(&hipri_stream.stream, kDefaultFlags));
+#endif // __HIP_PLATFORM_HCC__
}
}
// Inits default streams (once, globally)
std::call_once(init_flag, initGlobalStreamState);
- if (current_streams) return;
+ if (current_streams) {
+ return;
+ }
// Inits current streams (thread local) to default streams
- current_streams = (CUDAStreamInternals**) malloc(num_gpus * sizeof(CUDAStreamInternals*));
+ current_streams =
+ (LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*));
for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
current_streams[i] = &default_streams[i];
}
// Helper to determine the index of the stream to return
// Note: Streams are returned round-robin (see note in CUDAStream.h)
-static uint32_t get_idx(std::atomic<uint32_t> &counter) {
+static uint32_t get_idx(std::atomic<uint32_t>& counter) {
auto raw_idx = counter++;
return raw_idx % kStreamsPerPool;
}
// See Note [StreamId assignment]
-CUDAStreamInternals* CUDAStream_internals(CUDAStream s) {
+LeakyStreamInternals* CUDAStream_internals(CUDAStream s) {
c10::DeviceIndex device_index = s.device_index();
StreamIdType st = streamIdType(s.unwrap().id());
size_t si = streamIdIndex(s.unwrap().id());
switch (st) {
case StreamIdType::DEFAULT:
- AT_ASSERTM(si == 0, "Unrecognized stream ", s.unwrap(),
- " (I think this should be the default stream, but I got a non-zero index ", si, ").",
- " Did you manufacture the StreamId yourself? Don't do that; use the",
- " official API like c10::cuda::getStreamFromPool() to get a new stream.");
+ AT_ASSERTM(
+ si == 0,
+ "Unrecognized stream ",
+ s.unwrap(),
+ " (I think this should be the default stream, but I got a non-zero index ",
+ si,
+ ").",
+ " Did you manufacture the StreamId yourself? Don't do that; use the",
+ " official API like c10::cuda::getStreamFromPool() to get a new stream.");
return &default_streams[device_index];
case StreamIdType::LOW:
return &low_priority_streams[device_index][si];
case StreamIdType::HIGH:
return &high_priority_streams[device_index][si];
default:
- AT_ASSERTM(0, "Unrecognized stream ", s.unwrap(), " (I didn't recognize the stream type, ", st, ")");
+ AT_ASSERTM(
+ 0,
+ "Unrecognized stream ",
+ s.unwrap(),
+ " (I didn't recognize the stream type, ",
+ st,
+ ")");
}
}
-CUDAStream CUDAStream_fromInternals(const CUDAStreamInternals* ptr) {
- return CUDAStream(CUDAStream::UNCHECKED,
- Stream(Stream::UNSAFE,
- c10::Device(DeviceType::CUDA, ptr->device_index),
- CUDAStream_getStreamId(ptr)));
+CUDAStream CUDAStream_fromInternals(const LeakyStreamInternals* ptr) {
+ return CUDAStream(
+ CUDAStream::UNCHECKED,
+ Stream(
+ Stream::UNSAFE,
+ c10::Device(DeviceType::CUDA, ptr->device_index),
+ CUDAStream_getStreamId(ptr)));
}
} // anonymous namespace
// Note: when called the first time on a device, this will create the
// stream pools for that device.
CUDAStream getStreamFromPool(
- const bool isHighPriority
-, DeviceIndex device_index) {
+ const bool isHighPriority,
+ DeviceIndex device_index) {
initCUDAStreamsOnce();
- if (device_index == -1) device_index = current_device();
+ if (device_index == -1)
+ device_index = current_device();
check_gpu(device_index);
// Initializes the stream pools (once)
- std::call_once(device_flags[device_index], initDeviceStreamState, device_index);
+ std::call_once(
+ device_flags[device_index], initDeviceStreamState, device_index);
if (isHighPriority) {
const auto idx = get_idx(high_priority_counters[device_index]);
CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
initCUDAStreamsOnce();
- if (device_index == -1) device_index = current_device();
+ if (device_index == -1) {
+ device_index = current_device();
+ }
check_gpu(device_index);
return CUDAStream_fromInternals(&default_streams[device_index]);
}
CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
initCUDAStreamsOnce();
- if (device_index == -1) device_index = current_device();
+ if (device_index == -1) {
+ device_index = current_device();
+ }
check_gpu(device_index);
return CUDAStream_fromInternals(current_streams[device_index]);
}
}
} // namespace cuda
-} // namespace at
+} // namespace c10