[RUNTIME][VULKAN] Seg fault in WorkspacePool's destructor (#5632) (#5636)
authorYi Wang <samwyi@yahoo.com>
Thu, 21 May 2020 19:07:54 +0000 (12:07 -0700)
committerGitHub <noreply@github.com>
Thu, 21 May 2020 19:07:54 +0000 (12:07 -0700)
* [RUNTIME][VULKAN] Seg fault in WorkspacePool's destructor (#5632)
* fixed this issue by changing WorkspacePool's destruction order

* make line < 100 charactors long

src/runtime/vulkan/vulkan.cc

index 207a86a..ef4b9b0 100644 (file)
@@ -56,6 +56,8 @@ class VulkanThreadEntry {
     // the instance and device get destroyed.
     // The destruction need to be manually called
     // to ensure the destruction order.
+
+    pool.reset();
     streams_.clear();
     for (const auto& kv : staging_buffers_) {
       if (!kv.second) {
@@ -75,7 +77,7 @@ class VulkanThreadEntry {
   }
 
   TVMContext ctx;
-  WorkspacePool pool;
+  std::unique_ptr<WorkspacePool> pool;
   VulkanStream* Stream(size_t device_id);
   VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
 
@@ -331,11 +333,11 @@ class VulkanDeviceAPI final : public DeviceAPI {
   }
 
   void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
-    return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
+    return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(ctx, size);
   }
 
   void FreeWorkspace(TVMContext ctx, void* data) final {
-    VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
+    VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
   }
 
   static const std::shared_ptr<VulkanDeviceAPI>& Global() {
@@ -999,7 +1001,8 @@ VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size
 }
 
 VulkanThreadEntry::VulkanThreadEntry()
-    : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
+    : pool(std::make_unique<WorkspacePool>(static_cast<DLDeviceType>(kDLVulkan),
+                                           VulkanDeviceAPI::Global())) {
   ctx.device_id = 0;
   ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
 }