fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
}(expr)
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wglobal-constructors"
-// Static reference to CUDA primary context for device ordinal 0.
-static CUcontext Context = [] {
- CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
- CUdevice device;
- CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
- CUcontext context;
- CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device));
- return context;
-}();
-#pragma clang diagnostic pop
-
-// Sets the `Context` for the duration of the instance and restores the previous
-// context on destruction.
+// Make the primary context of device 0 current for the duration of the instance
+// and restore the previous context on destruction.
class ScopedContext {
public:
ScopedContext() {
- CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous));
- CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context));
+ // Static reference to CUDA primary context for device ordinal 0.
+ static CUcontext context = [] {
+ CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
+ CUdevice device;
+ CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
+ CUcontext ctx;
+ // Note: this does not affect the current context.
+ CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device));
+ return ctx;
+ }();
+
+ CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
}
- ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); }
-
-private:
- CUcontext previous;
+ ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
};
extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {