[RUNTIME][REFACTOR] Use new to avoid exit-time de-allocation order problem in DeviceA...
authorTianqi Chen <tqchen@users.noreply.github.com>
Tue, 18 Aug 2020 23:08:03 +0000 (16:08 -0700)
committerGitHub <noreply@github.com>
Tue, 18 Aug 2020 23:08:03 +0000 (16:08 -0700)
23 files changed:
src/runtime/c_runtime_api.cc
src/runtime/cpu_device_api.cc
src/runtime/cuda/cuda_device_api.cc
src/runtime/hexagon/hexagon_device_api.cc
src/runtime/metal/metal_common.h
src/runtime/metal/metal_device_api.mm
src/runtime/metal/metal_module.mm
src/runtime/micro/micro_device_api.cc
src/runtime/opencl/aocl/aocl_common.h
src/runtime/opencl/aocl/aocl_device_api.cc
src/runtime/opencl/aocl/aocl_module.cc
src/runtime/opencl/opencl_common.h
src/runtime/opencl/opencl_device_api.cc
src/runtime/opencl/opencl_module.cc
src/runtime/opencl/sdaccel/sdaccel_common.h
src/runtime/opencl/sdaccel/sdaccel_device_api.cc
src/runtime/opencl/sdaccel/sdaccel_module.cc
src/runtime/rocm/rocm_device_api.cc
src/runtime/vulkan/vulkan.cc
src/runtime/workspace_pool.cc
src/runtime/workspace_pool.h
vta/runtime/device_api.cc
web/emcc/webgpu_runtime.cc

index 0164b1b..1c860b8 100644 (file)
@@ -105,8 +105,8 @@ class DeviceAPIManager {
   DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
   // Global static variable.
   static DeviceAPIManager* Global() {
-    static DeviceAPIManager inst;
-    return &inst;
+    static DeviceAPIManager* inst = new DeviceAPIManager();
+    return inst;
   }
   // Get or initialize API.
   DeviceAPI* GetAPI(int type, bool allow_missing) {
index c70a4f2..5474b75 100644 (file)
@@ -80,8 +80,10 @@ class CPUDeviceAPI final : public DeviceAPI {
   void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
 
-  static const std::shared_ptr<CPUDeviceAPI>& Global() {
-    static std::shared_ptr<CPUDeviceAPI> inst = std::make_shared<CPUDeviceAPI>();
+  static CPUDeviceAPI* Global() {
+    // NOTE: explicitly use new to avoid exit-time destruction of global state
+    // Global state will be recycled by OS as the process exits.
+    static auto* inst = new CPUDeviceAPI();
     return inst;
   }
 };
@@ -99,7 +101,7 @@ void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
 }
 
 TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = CPUDeviceAPI::Global().get();
+  DeviceAPI* ptr = CPUDeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 }  // namespace runtime
index 14444c9..b69ecf2 100644 (file)
@@ -207,8 +207,10 @@ class CUDADeviceAPI final : public DeviceAPI {
     CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
   }
 
-  static const std::shared_ptr<CUDADeviceAPI>& Global() {
-    static std::shared_ptr<CUDADeviceAPI> inst = std::make_shared<CUDADeviceAPI>();
+  static CUDADeviceAPI* Global() {
+    // NOTE: explicitly use new to avoid exit-time destruction of global state
+    // Global state will be recycled by OS as the process exits.
+    static auto* inst = new CUDADeviceAPI();
     return inst;
   }
 
@@ -230,12 +232,12 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
 CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }
 
 TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = CUDADeviceAPI::Global().get();
+  DeviceAPI* ptr = CUDADeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 
 TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = CUDADeviceAPI::Global().get();
+  DeviceAPI* ptr = CUDADeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 
index fd6f323..a890157 100644 (file)
@@ -42,8 +42,10 @@ class HexagonDeviceAPI : public DeviceAPI {
   void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final;
   void FreeWorkspace(TVMContext ctx, void* ptr) final;
 
-  static const std::shared_ptr<HexagonDeviceAPI>& Global() {
-    static std::shared_ptr<HexagonDeviceAPI> inst = std::make_shared<HexagonDeviceAPI>();
+  static HexagonDeviceAPI* Global() {
+    // NOTE: explicitly use new to avoid destruction of global state
+    // Global state will be recycled by OS as the process exits.
+    static HexagonDeviceAPI* inst = new HexagonDeviceAPI();
     return inst;
   }
 };
@@ -121,7 +123,7 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
 }
 
 TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = HexagonDeviceAPI::Global().get();
+  DeviceAPI* ptr = HexagonDeviceAPI::Global();
   *rv = ptr;
 });
 }  // namespace runtime
index ca369d4..634ee30 100644 (file)
@@ -91,7 +91,7 @@ class MetalWorkspace final : public DeviceAPI {
   void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
   void FreeWorkspace(TVMContext ctx, void* data) final;
   // get the global workspace
-  static const std::shared_ptr<MetalWorkspace>& Global();
+  static MetalWorkspace* Global();
 };
 
 /*! \brief Thread local workspace */
index f2a2930..fddeadf 100644 (file)
@@ -28,8 +28,10 @@ namespace tvm {
 namespace runtime {
 namespace metal {
 
-const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
-  static std::shared_ptr<MetalWorkspace> inst = std::make_shared<MetalWorkspace>();
+MetalWorkspace* MetalWorkspace::Global() {
+  // NOTE: explicitly use new to avoid exit-time destruction of global state
+  // Global state will be recycled by OS as the process exits.
+  static MetalWorkspace* inst = new MetalWorkspace();
   return inst;
 }
 
@@ -273,7 +275,7 @@ typedef dmlc::ThreadLocalStore<MetalThreadEntry> MetalThreadStore;
 MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); }
 
 TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = MetalWorkspace::Global().get();
+  DeviceAPI* ptr = MetalWorkspace::Global();
   *rv = static_cast<void*>(ptr);
 });
 
index 9bdebf3..8d10ff2 100644 (file)
@@ -73,7 +73,7 @@ class MetalModuleNode final : public runtime::ModuleNode {
   }
   // get a from primary context in device_id
   id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
-    metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
+    metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
     CHECK_LT(device_id, w->devices.size());
     // start lock scope.
     std::lock_guard<std::mutex> lock(mutex_);
@@ -168,7 +168,7 @@ class MetalWrappedFunc {
   void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
             size_t num_buffer_args, size_t num_pack_args,
             const std::vector<std::string>& thread_axis_tags) {
-    w_ = metal::MetalWorkspace::Global().get();
+    w_ = metal::MetalWorkspace::Global();
     m_ = m;
     sptr_ = sptr;
     func_name_ = func_name;
index 6848078..3812ec0 100644 (file)
@@ -140,8 +140,8 @@ class MicroDeviceAPI final : public DeviceAPI {
    * \brief obtain a global singleton of MicroDeviceAPI
    * \return global shared pointer to MicroDeviceAPI
    */
-  static const std::shared_ptr<MicroDeviceAPI>& Global() {
-    static std::shared_ptr<MicroDeviceAPI> inst = std::make_shared<MicroDeviceAPI>();
+  static MicroDeviceAPI* Global() {
+    static MicroDeviceAPI* inst = new MicroDeviceAPI();
     return inst;
   }
 
@@ -155,7 +155,7 @@ class MicroDeviceAPI final : public DeviceAPI {
 
 // register device that can be obtained from Python frontend
 TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = MicroDeviceAPI::Global().get();
+  DeviceAPI* ptr = MicroDeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 }  // namespace runtime
index 1b98d4b..ae1a4a8 100644 (file)
@@ -42,7 +42,7 @@ class AOCLWorkspace final : public OpenCLWorkspace {
   bool IsOpenCLDevice(TVMContext ctx) final;
   OpenCLThreadEntry* GetThreadEntry() final;
   // get the global workspace
-  static const std::shared_ptr<OpenCLWorkspace>& Global();
+  static OpenCLWorkspace* Global();
 };
 
 /*! \brief Thread local workspace for AOCL */
index 07057ff..5432507 100644 (file)
@@ -31,8 +31,8 @@ namespace cl {
 
 OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); }
 
-const std::shared_ptr<OpenCLWorkspace>& AOCLWorkspace::Global() {
-  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<AOCLWorkspace>();
+OpenCLWorkspace* AOCLWorkspace::Global() {
+  static OpenCLWorkspace* inst = new AOCLWorkspace();
   return inst;
 }
 
@@ -49,7 +49,7 @@ typedef dmlc::ThreadLocalStore<AOCLThreadEntry> AOCLThreadStore;
 AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); }
 
 TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = AOCLWorkspace::Global().get();
+  DeviceAPI* ptr = AOCLWorkspace::Global();
   *rv = static_cast<void*>(ptr);
 });
 
index 747188c..cb86533 100644 (file)
@@ -39,12 +39,10 @@ class AOCLModuleNode : public OpenCLModuleNode {
   explicit AOCLModuleNode(std::string data, std::string fmt,
                           std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
       : OpenCLModuleNode(data, fmt, fmap, source) {}
-  const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
+  cl::OpenCLWorkspace* GetGlobalWorkspace() final;
 };
 
-const std::shared_ptr<cl::OpenCLWorkspace>& AOCLModuleNode::GetGlobalWorkspace() {
-  return cl::AOCLWorkspace::Global();
-}
+cl::OpenCLWorkspace* AOCLModuleNode::GetGlobalWorkspace() { return cl::AOCLWorkspace::Global(); }
 
 Module AOCLModuleCreate(std::string data, std::string fmt,
                         std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
index a892bff..aab0c27 100644 (file)
@@ -245,7 +245,7 @@ class OpenCLWorkspace : public DeviceAPI {
   virtual OpenCLThreadEntry* GetThreadEntry();
 
   // get the global workspace
-  static const std::shared_ptr<OpenCLWorkspace>& Global();
+  static OpenCLWorkspace* Global();
 };
 
 /*! \brief Thread local workspace */
@@ -265,8 +265,7 @@ class OpenCLThreadEntry {
   /*! \brief workspace pool */
   WorkspacePool pool;
   // constructor
-  OpenCLThreadEntry(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
-      : pool(device_type, device) {
+  OpenCLThreadEntry(DLDeviceType device_type, DeviceAPI* device) : pool(device_type, device) {
     context.device_id = 0;
     context.device_type = device_type;
   }
@@ -298,7 +297,7 @@ class OpenCLModuleNode : public ModuleNode {
   /*!
    * \brief Get the global workspace
    */
-  virtual const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace();
+  virtual cl::OpenCLWorkspace* GetGlobalWorkspace();
 
   const char* type_key() const final { return workspace_->type_key.c_str(); }
 
@@ -315,7 +314,7 @@ class OpenCLModuleNode : public ModuleNode {
  private:
   // The workspace, need to keep reference to use it in destructor.
   // In case of static destruction order problem.
-  std::shared_ptr<cl::OpenCLWorkspace> workspace_;
+  cl::OpenCLWorkspace* workspace_;
   // the binary data
   std::string data_;
   // The format
index 5753c1d..83944cd 100644 (file)
@@ -31,8 +31,8 @@ namespace cl {
 
 OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); }
 
-const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
-  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
+OpenCLWorkspace* OpenCLWorkspace::Global() {
+  static OpenCLWorkspace* inst = new OpenCLWorkspace();
   return inst;
 }
 
@@ -276,7 +276,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
 }
 
 TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = OpenCLWorkspace::Global().get();
+  DeviceAPI* ptr = OpenCLWorkspace::Global();
   *rv = static_cast<void*>(ptr);
 });
 
index 95d0481..590a446 100644 (file)
@@ -40,7 +40,7 @@ class OpenCLWrappedFunc {
   void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
             std::string func_name, std::vector<size_t> arg_size,
             const std::vector<std::string>& thread_axis_tags) {
-    w_ = m->GetGlobalWorkspace().get();
+    w_ = m->GetGlobalWorkspace();
     m_ = m;
     sptr_ = sptr;
     entry_ = entry;
@@ -110,7 +110,7 @@ OpenCLModuleNode::~OpenCLModuleNode() {
   }
 }
 
-const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace() {
+cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() {
   return cl::OpenCLWorkspace::Global();
 }
 
index 803cbe6..feeab0b 100644 (file)
@@ -42,7 +42,7 @@ class SDAccelWorkspace final : public OpenCLWorkspace {
   bool IsOpenCLDevice(TVMContext ctx) final;
   OpenCLThreadEntry* GetThreadEntry() final;
   // get the global workspace
-  static const std::shared_ptr<OpenCLWorkspace>& Global();
+  static OpenCLWorkspace* Global();
 };
 
 /*! \brief Thread local workspace for SDAccel*/
index 6bac0c9..ebe387b 100644 (file)
@@ -31,8 +31,8 @@ namespace cl {
 
 OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); }
 
-const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
-  static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<SDAccelWorkspace>();
+OpenCLWorkspace* SDAccelWorkspace::Global() {
+  static OpenCLWorkspace* inst = new SDAccelWorkspace();
   return inst;
 }
 
@@ -47,7 +47,7 @@ typedef dmlc::ThreadLocalStore<SDAccelThreadEntry> SDAccelThreadStore;
 SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); }
 
 TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = SDAccelWorkspace::Global().get();
+  DeviceAPI* ptr = SDAccelWorkspace::Global();
   *rv = static_cast<void*>(ptr);
 });
 
index b4edca3..36dabd1 100644 (file)
@@ -39,10 +39,10 @@ class SDAccelModuleNode : public OpenCLModuleNode {
   explicit SDAccelModuleNode(std::string data, std::string fmt,
                              std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
       : OpenCLModuleNode(data, fmt, fmap, source) {}
-  const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
+  cl::OpenCLWorkspace* GetGlobalWorkspace() final;
 };
 
-const std::shared_ptr<cl::OpenCLWorkspace>& SDAccelModuleNode::GetGlobalWorkspace() {
+cl::OpenCLWorkspace* SDAccelModuleNode::GetGlobalWorkspace() {
   return cl::SDAccelWorkspace::Global();
 }
 
index e1a14c7..7f5bc99 100644 (file)
@@ -174,8 +174,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
     ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
   }
 
-  static const std::shared_ptr<ROCMDeviceAPI>& Global() {
-    static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
+  static ROCMDeviceAPI* Global() {
+    static ROCMDeviceAPI* inst = new ROCMDeviceAPI();
     return inst;
   }
 
@@ -197,7 +197,7 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
 ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }
 
 TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
+  DeviceAPI* ptr = ROCMDeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 }  // namespace runtime
index 9e730b7..5686725 100644 (file)
@@ -340,8 +340,8 @@ class VulkanDeviceAPI final : public DeviceAPI {
     VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
   }
 
-  static const std::shared_ptr<VulkanDeviceAPI>& Global() {
-    static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
+  static VulkanDeviceAPI* Global() {
+    static VulkanDeviceAPI* inst = new VulkanDeviceAPI();
     return inst;
   }
 
@@ -1159,7 +1159,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModul
 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
 
 TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
+  DeviceAPI* ptr = VulkanDeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });
 
index 8ee905e..49a4c96 100644 (file)
@@ -134,7 +134,7 @@ class WorkspacePool::Pool {
   std::vector<Entry> allocated_;
 };
 
-WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
+WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device)
     : device_type_(device_type), device_(device) {}
 
 WorkspacePool::~WorkspacePool() {
@@ -143,7 +143,7 @@ WorkspacePool::~WorkspacePool() {
       TVMContext ctx;
       ctx.device_type = device_type_;
       ctx.device_id = static_cast<int>(i);
-      array_[i]->Release(ctx, device_.get());
+      array_[i]->Release(ctx, device_);
       delete array_[i];
     }
   }
@@ -156,7 +156,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) {
   if (array_[ctx.device_id] == nullptr) {
     array_[ctx.device_id] = new Pool();
   }
-  return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);
+  return array_[ctx.device_id]->Alloc(ctx, device_, size);
 }
 
 void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) {
index 288da7d..887afc5 100644 (file)
@@ -47,9 +47,9 @@ class TVM_DLL WorkspacePool {
   /*!
    * \brief Create pool with specific device type and device.
    * \param device_type The device type.
-   * \param device The device API.
+   * \param device_api The device API.
    */
-  WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device);
+  WorkspacePool(DLDeviceType device_type, DeviceAPI* device_api);
   /*! \brief destructor */
   ~WorkspacePool();
   /*!
@@ -73,7 +73,7 @@ class TVM_DLL WorkspacePool {
   /*! \brief device type this pool support */
   DLDeviceType device_type_;
   /*! \brief The device API */
-  std::shared_ptr<DeviceAPI> device_;
+  DeviceAPI* device_;
 };
 
 }  // namespace runtime
index 298403c..0fea7ba 100644 (file)
@@ -66,8 +66,8 @@ class VTADeviceAPI final : public DeviceAPI {
 
   void FreeWorkspace(TVMContext ctx, void* data) final;
 
-  static const std::shared_ptr<VTADeviceAPI>& Global() {
-    static std::shared_ptr<VTADeviceAPI> inst = std::make_shared<VTADeviceAPI>();
+  static VTADeviceAPI* Global() {
+    static VTADeviceAPI* inst = new VTADeviceAPI();
     return inst;
   }
 };
@@ -88,7 +88,7 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
 static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ =
     ::tvm::runtime::Registry::Register("device_api.ext_dev", true)
         .set_body([](TVMArgs args, TVMRetValue* rv) {
-          DeviceAPI* ptr = VTADeviceAPI::Global().get();
+          DeviceAPI* ptr = VTADeviceAPI::Global();
           *rv = static_cast<void*>(ptr);
         });
 }  // namespace runtime
index 7f0b0d9..54601e3 100644 (file)
@@ -132,8 +132,8 @@ class WebGPUDeviceAPI : public DeviceAPI {
     WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
   }
 
-  static const std::shared_ptr<WebGPUDeviceAPI>& Global() {
-    static std::shared_ptr<WebGPUDeviceAPI> inst = std::make_shared<WebGPUDeviceAPI>();
+  static WebGPUDeviceAPI* Global() {
+    static WebGPUDeviceAPI* inst = new WebGPUDeviceAPI();
     return inst;
   }
 
@@ -222,7 +222,7 @@ Module WebGPUModuleLoadBinary(void* strm) {
 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary);
 
 TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) {
-  DeviceAPI* ptr = WebGPUDeviceAPI::Global().get();
+  DeviceAPI* ptr = WebGPUDeviceAPI::Global();
   *rv = static_cast<void*>(ptr);
 });