cuda: Fix runtime compiler library loading on Windows
authorSeungha Yang <seungha@centricular.com>
Wed, 7 Dec 2022 21:05:25 +0000 (06:05 +0900)
committerSeungha Yang <seungha@centricular.com>
Fri, 9 Dec 2022 10:24:08 +0000 (19:24 +0900)
The cuda is a part of GPU driver but runtime compiler is a part of
cuda toolkit, which means the version number can be different.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gstreamer/-/merge_requests/3545>

subprojects/gst-plugins-bad/gst-libs/gst/cuda/gstcudanvrtc.c

index d9b97ca..5f857d4 100644 (file)
@@ -38,7 +38,7 @@ GST_DEBUG_CATEGORY_STATIC (gst_cuda_nvrtc_debug);
 
 #define LOAD_SYMBOL(name,func) G_STMT_START { \
   if (!g_module_symbol (module, G_STRINGIFY (name), (gpointer *) &vtable->func)) { \
-    GST_ERROR ("Failed to load '%s' from %s, %s", G_STRINGIFY (name), fname, g_module_error()); \
+    GST_ERROR ("Failed to load '%s', %s", G_STRINGIFY (name), g_module_error()); \
     goto error; \
   } \
 } G_STMT_END;
@@ -64,57 +64,88 @@ typedef struct _GstCudaNvrtcVTable
 
 static GstCudaNvrtcVTable gst_cuda_nvrtc_vtable = { 0, };
 
+#ifdef G_OS_WIN32
+static GModule *
+gst_cuda_nvrtc_load_library_once_win32 (void)
+{
+  gchar *dll_name = NULL;
+  GModule *module = NULL;
+  gint cuda_version;
+  gint cuda_major_version;
+  gint cuda_minor_version;
+  gint major, minor;
+  CUresult rst;
+
+  rst = CuDriverGetVersion (&cuda_version);
+  if (rst != CUDA_SUCCESS) {
+    GST_WARNING ("Couldn't get driver version, 0x%x", (guint) rst);
+    return NULL;
+  }
+
+  cuda_major_version = cuda_version / 1000;
+  cuda_minor_version = (cuda_version % 1000) / 10;
+
+  GST_INFO ("CUDA version %d / %d", cuda_major_version, cuda_minor_version);
+
+  /* First path for searching nvrtc library using system CUDA version */
+  for (minor = cuda_minor_version; minor >= 0; minor--) {
+    g_clear_pointer (&dll_name, g_free);
+    dll_name = g_strdup_printf (NVRTC_LIBNAME, cuda_major_version, minor);
+    module = g_module_open (dll_name, G_MODULE_BIND_LAZY);
+    if (module) {
+      GST_INFO ("%s is available", dll_name);
+      g_free (dll_name);
+      return module;
+    }
+
+    GST_DEBUG ("Couldn't open library %s", dll_name);
+  }
+
+  /* CUDA is a part for driever installation, but nvrtc library is a part of
+   * CUDA-toolkit. So CUDA-toolkit version may be lower than
+   * CUDA version. Do search the dll again */
+  for (major = cuda_major_version; major >= 9; major--) {
+    for (minor = 5; minor >= 0; minor--) {
+      g_clear_pointer (&dll_name, g_free);
+      dll_name = g_strdup_printf (NVRTC_LIBNAME, major, minor);
+      module = g_module_open (dll_name, G_MODULE_BIND_LAZY);
+      if (module) {
+        GST_INFO ("%s is available", dll_name);
+        g_free (dll_name);
+        return module;
+      }
+
+      GST_DEBUG ("Couldn't open library %s", dll_name);
+    }
+  }
+
+  g_free (dll_name);
+
+  return NULL;
+}
+#endif
+
 static gboolean
 gst_cuda_nvrtc_load_library_once (void)
 {
   GModule *module = NULL;
-  gchar *filename = NULL;
   const gchar *filename_env;
-  const gchar *fname;
-  gint cuda_version;
   GstCudaNvrtcVTable *vtable;
 
-  CuDriverGetVersion (&cuda_version);
-
-  fname = filename_env = g_getenv ("GST_CUDA_NVRTC_LIBNAME");
+  filename_env = g_getenv ("GST_CUDA_NVRTC_LIBNAME");
   if (filename_env)
     module = g_module_open (filename_env, G_MODULE_BIND_LAZY);
 
   if (!module) {
 #ifndef G_OS_WIN32
-    filename = g_strdup (NVRTC_LIBNAME);
-    fname = filename;
-    module = g_module_open (filename, G_MODULE_BIND_LAZY);
+    module = g_module_open (NVRTC_LIBNAME, G_MODULE_BIND_LAZY);
 #else
-    /* XXX: On Windows, minor version of nvrtc library might not be exactly
-     * same as CUDA library */
-    {
-      gint cuda_major_version = cuda_version / 1000;
-      gint cuda_minor_version = (cuda_version % 1000) / 10;
-      gint minor_version;
-
-      for (minor_version = cuda_minor_version; minor_version >= 0;
-          minor_version--) {
-        g_free (filename);
-        filename = g_strdup_printf (NVRTC_LIBNAME, cuda_major_version,
-            minor_version);
-        fname = filename;
-
-        module = g_module_open (filename, G_MODULE_BIND_LAZY);
-        if (module) {
-          GST_INFO ("%s is available", filename);
-          break;
-        }
-
-        GST_DEBUG ("Couldn't open library %s", filename);
-      }
-    }
+    module = gst_cuda_nvrtc_load_library_once_win32 ();
 #endif
   }
 
   if (module == NULL) {
-    GST_WARNING ("Could not open library %s, %s", filename, g_module_error ());
-    g_free (filename);
+    GST_WARNING ("Could not open nvrtc library %s", g_module_error ());
     return FALSE;
   }
 
@@ -129,13 +160,11 @@ gst_cuda_nvrtc_load_library_once (void)
   LOAD_SYMBOL (nvrtcGetProgramLogSize, NvrtcGetProgramLogSize);
 
   vtable->loaded = TRUE;
-  g_free (filename);
 
   return TRUE;
 
 error:
   g_module_close (module);
-  g_free (filename);
 
   return FALSE;
 }