#include "cuda.h"
+#ifdef _WIN32
+#define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport)
+#else
+#define MLIR_CUDA_WRAPPERS_EXPORT
+#endif // _WIN32
+
#define CUDA_REPORT_IF_ERROR(expr) \
[](CUresult result) { \
if (!result) \
CUcontext previous;
};
-extern "C" CUmodule mgpuModuleLoad(void *data) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {
ScopedContext scopedContext;
CUmodule module = nullptr;
CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
return module;
}
-extern "C" void mgpuModuleUnload(CUmodule module) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) {
CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
}
-extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction
+mgpuModuleGetFunction(CUmodule module, const char *name) {
CUfunction function = nullptr;
CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
return function;
// The wrapper uses intptr_t instead of CUDA's unsigned int to match
// the type of MLIR's index type. This avoids the need for casts in the
// generated MLIR code.
-extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
- intptr_t gridY, intptr_t gridZ,
- intptr_t blockX, intptr_t blockY,
- intptr_t blockZ, int32_t smem, CUstream stream,
- void **params, void **extra) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
+ intptr_t gridZ, intptr_t blockX, intptr_t blockY,
+ intptr_t blockZ, int32_t smem, CUstream stream, void **params,
+ void **extra) {
ScopedContext scopedContext;
CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
blockY, blockZ, smem, stream, params,
extra));
}
-extern "C" CUstream mgpuStreamCreate() {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() {
ScopedContext scopedContext;
CUstream stream = nullptr;
CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
return stream;
}
-extern "C" void mgpuStreamDestroy(CUstream stream) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) {
CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
}
-extern "C" void mgpuStreamSynchronize(CUstream stream) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuStreamSynchronize(CUstream stream) {
CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
}
-extern "C" void mgpuStreamWaitEvent(CUstream stream, CUevent event) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream,
+ CUevent event) {
CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0));
}
-extern "C" CUevent mgpuEventCreate() {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() {
ScopedContext scopedContext;
CUevent event = nullptr;
CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
return event;
}
-extern "C" void mgpuEventDestroy(CUevent event) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) {
CUDA_REPORT_IF_ERROR(cuEventDestroy(event));
}
-extern "C" void mgpuEventSynchronize(CUevent event) {
+extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventSynchronize(CUevent event) {
CUDA_REPORT_IF_ERROR(cuEventSynchronize(event));
}
-extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
+extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventRecord(CUevent event,
+ CUstream stream) {
CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
}
// Allows to register byte array with the CUDA runtime. Helpful until we have
// transfer functions implemented.
-extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
ScopedContext scopedContext;
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
}
// Allows to register a MemRef with the CUDA runtime. Helpful until we have
// transfer functions implemented.
-extern "C" void
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
int64_t elementSizeBytes) {