[IE CLDNN] dp4a check that works both with old and new drivers (#1766)
authorMikhail Letavin <mikhail.letavin@intel.com>
Fri, 14 Aug 2020 11:50:33 +0000 (14:50 +0300)
committerGitHub <noreply@github.com>
Fri, 14 Aug 2020 11:50:33 +0000 (14:50 +0300)
inference-engine/thirdparty/clDNN/src/gpu/device_info.cpp

index 8787788..1fc851d 100644 (file)
 #include <iostream>
 #include <utility>
 
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#include <SetupAPI.h>
+#include <devguid.h>
+#include <cstring>
+#else
+#include <unistd.h>
+#include <limits.h>
+#include <link.h>
+#include <dlfcn.h>
+#endif
 
 namespace cldnn {
 namespace gpu {
 
 namespace {
+int driver_dev_id()
+{
+    const std::vector<int> unused_ids = {
+        0x4905, 0x4906, 0x4907, 0x4908
+    };
+    std::vector<int> result;
+
+#ifdef _WIN32
+    {
+        HDEVINFO device_info_set = SetupDiGetClassDevsA(&GUID_DEVCLASS_DISPLAY, NULL, NULL, DIGCF_PRESENT);
+        if (device_info_set == INVALID_HANDLE_VALUE)
+            return 0;
+
+        SP_DEVINFO_DATA devinfo_data;
+        std::memset(&devinfo_data, 0, sizeof(devinfo_data));
+        devinfo_data.cbSize = sizeof(devinfo_data);
+
+        for (DWORD dev_idx = 0; SetupDiEnumDeviceInfo(device_info_set, dev_idx, &devinfo_data); dev_idx++) {
+            const size_t buf_size = 512;
+            char buf[buf_size];
+            if (!SetupDiGetDeviceInstanceIdA(device_info_set, &devinfo_data, buf, buf_size, NULL)) {
+                continue;
+            }
+
+            char* vendor_pos = std::strstr(buf, "VEN_");
+            if (vendor_pos != NULL && std::stoi(vendor_pos + 4, NULL, 16) == 0x8086) {
+                char* device_pos = strstr(vendor_pos, "DEV_");
+                if (device_pos != NULL) {
+                    result.push_back(std::stoi(device_pos + 4, NULL, 16));
+                }
+            }
+        }
+
+        if (device_info_set) {
+            SetupDiDestroyDeviceInfoList(device_info_set);
+        }
+    }
+#elif defined(__linux__)
+    {
+        std::string dev_base{ "/sys/devices/pci0000:00/0000:00:02.0/" };
+        std::ifstream ifs(dev_base + "vendor");
+        if (ifs.good())
+        {
+            int ven_id;
+            ifs >> std::hex >> ven_id;
+            ifs.close();
+            if (ven_id == 0x8086)
+            {
+                ifs.open(dev_base + "device");
+                if (ifs.good())
+                {
+                    int res = 0;
+                    ifs >> std::hex >> res;
+                    result.push_back(res);
+                }
+            }
+        }
+    }
+#endif
+
+    auto id_itr = result.begin();
+    while (id_itr != result.end()) {
+        if (std::find(unused_ids.begin(), unused_ids.end(), *id_itr) != unused_ids.end())
+            result.erase(id_itr);
+        else
+            id_itr++;
+    }
+
+    if (result.empty())
+        return 0;
+    else
+        return result.back();
+}
+
+bool get_imad_support(const cl::Device& device) {
+    std::string dev_name = device.getInfo<CL_DEVICE_NAME>();
+
+    if (dev_name.find("Gen12") != std::string::npos ||
+        dev_name.find("Xe") != std::string::npos)
+        return true;
+
+    auto flag = device.getInfo<CL_DEVICE_HOST_UNIFIED_MEMORY>();
+    if (flag != 0) {
+        const std::vector<int> imad_ids = {
+            0x9A40, 0x9A49, 0x9A59, 0x9AD9,
+            0x9A60, 0x9A68, 0x9A70, 0x9A78,
+            0x9A7F, 0x9AF8, 0x9AC0, 0x9AC9
+        };
+        int dev_id = driver_dev_id();
+        if (dev_id == 0)
+            return false;
+
+        if (std::find(imad_ids.begin(), imad_ids.end(), dev_id) != imad_ids.end())
+            return true;
+    } else {
+        return true;
+    }
+
+    return false;
+}
 
 bool is_local_block_io_supported(const cl::Device& device) {
     try {
@@ -105,7 +217,7 @@ device_info_internal::device_info_internal(const cl::Device& device) {
 
     supports_subgroups_short = extensions.find("cl_intel_subgroups_short") != std::string::npos;
 
-    supports_imad = dev_name.find("Gen12") != std::string::npos;
+    supports_imad = get_imad_support(device);
     supports_immad = false;
 
     dev_type = static_cast<uint32_t>(device.getInfo<CL_DEVICE_TYPE>());