[mlir] Initialize CUDA context lazily.
authorChristian Sigg <csigg@google.com>
Wed, 3 Mar 2021 16:35:02 +0000 (17:35 +0100)
committerChristian Sigg <csigg@google.com>
Thu, 4 Mar 2021 12:07:56 +0000 (13:07 +0100)
So we can remove the ignore-warning pragma again.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D97864

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

index 8afb80e..9122e34 100644 (file)
     fprintf(stderr, "'%s' failed with '%s'\n", #expr, name);                   \
   }(expr)
 
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wglobal-constructors"
-// Static reference to CUDA primary context for device ordinal 0.
-static CUcontext Context = [] {
-  CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
-  CUdevice device;
-  CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
-  CUcontext context;
-  CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device));
-  return context;
-}();
-#pragma clang diagnostic pop
-
-// Sets the `Context` for the duration of the instance and restores the previous
-// context on destruction.
+// Make the primary context of device 0 current for the duration of the instance
+// and restore the previous context on destruction.
 class ScopedContext {
 public:
   ScopedContext() {
-    CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous));
-    CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context));
+    // Static reference to CUDA primary context for device ordinal 0.
+    static CUcontext context = [] {
+      CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
+      CUdevice device;
+      CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
+      CUcontext ctx;
+      // Note: this does not affect the current context.
+      CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device));
+      return ctx;
+    }();
+
+    CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
   }
 
-  ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); }
-
-private:
-  CUcontext previous;
+  ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
 };
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {