Vulkan2 Runtime API (#3849)
authorAndrew Tulloch <andrew@tullo.ch>
Fri, 13 Sep 2019 20:40:43 +0000 (13:40 -0700)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 13 Sep 2019 20:40:43 +0000 (13:40 -0700)
13 files changed:
apps/android_rpc/app/src/main/jni/tvm_runtime.h
cmake/modules/Vulkan.cmake
src/codegen/spirv/build_vulkan.cc
src/codegen/spirv/ir_builder.cc
src/runtime/vulkan/README.md [new file with mode: 0644]
src/runtime/vulkan/vulkan.cc [new file with mode: 0644]
src/runtime/vulkan/vulkan_common.h
src/runtime/vulkan/vulkan_device_api.cc [deleted file]
src/runtime/vulkan/vulkan_module.cc [deleted file]
src/runtime/vulkan/vulkan_module.h
src/runtime/vulkan/vulkan_shader.h [new file with mode: 0644]
src/runtime/vulkan/vulkan_stream.h [new file with mode: 0644]
tests/python/unittest/test_codegen_vulkan.py

index aadc4d1..bcbda61 100644 (file)
@@ -62,8 +62,7 @@
 #endif
 
 #ifdef TVM_VULKAN_RUNTIME
-#include "../src/runtime/vulkan/vulkan_device_api.cc"
-#include "../src/runtime/vulkan/vulkan_module.cc"
+#include "../src/runtime/vulkan/vulkan.cc"
 #endif
 
 #ifdef USE_SORT
index 346d135..127195a 100644 (file)
 # Be compatible with older version of CMake
 find_vulkan(${USE_VULKAN})
 
+# Extra Vulkan runtime options, exposed for advanced users.
+tvm_option(USE_VULKAN_IMMEDIATE_MODE "Use Vulkan Immediate mode
+(KHR_push_descriptor extension)" ON IF USE_VULKAN)
+tvm_option(USE_VULKAN_DEDICATED_ALLOCATION "Use Vulkan dedicated allocations" ON
+IF USE_VULKAN)
+tvm_option(USE_VULKAN_VALIDATION "Enable Vulkan API validation layers" OFF
+  IF USE_VULKAN)
+
 if(Vulkan_FOUND)
   # always set the includedir
   # avoid global retrigger of cmake
@@ -28,12 +36,24 @@ if(USE_VULKAN)
   if(NOT Vulkan_FOUND)
     message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN})
   endif()
-  message(STATUS "Build with VULKAN support")
-  file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
+  message(STATUS "Build with Vulkan support")
+  file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/vulkan.cc)
   file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
   list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
   list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
-
   list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
+
+  if(USE_VULKAN_IMMEDIATE_MODE)
+    message(STATUS "Build with Vulkan immediate mode")
+    add_definitions(-DUSE_VULKAN_IMMEDIATE_MODE=1)
+  endif()
+  if(USE_VULKAN_DEDICATED_ALLOCATION)
+    message(STATUS "Build with Vulkan dedicated allocation")
+    add_definitions(-DUSE_VULKAN_DEDICATED_ALLOCATION=1)
+  endif()
+  if(USE_VULKAN_VALIDATION)
+    message(STATUS "Build with Vulkan API validation")
+    add_definitions(-DUSE_VULKAN_VALIDATION=1)
+  endif()
 endif(USE_VULKAN)
index 18ffad1..ca0bb65 100644 (file)
@@ -29,6 +29,8 @@
 
 #include "codegen_spirv.h"
 #include "../build_common.h"
+
+#include "../../runtime/vulkan/vulkan_shader.h"
 #include "../../runtime/vulkan/vulkan_module.h"
 
 namespace tvm {
index 6afd311..b7bc42c 100644 (file)
@@ -33,7 +33,10 @@ namespace spirv {
 void IRBuilder::InitHeader() {
   CHECK_EQ(header_.size(), 0U);
   header_.push_back(spv::MagicNumber);
-  header_.push_back(spv::Version);
+  // Use SPIR-V v1.0. This needs to be kept in sync (or at least behind)
+  // `VkApplicationInfo.apiVersion` in `vulkan.cc` to ensure Vulkan API
+  // validation passes.
+  header_.push_back(0x10000);
   // generator: set to 0, unknown
   header_.push_back(0U);
   // Bound: set during Finalize
diff --git a/src/runtime/vulkan/README.md b/src/runtime/vulkan/README.md
new file mode 100644 (file)
index 0000000..93bdb30
--- /dev/null
@@ -0,0 +1,48 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+
+## Components
+
+### VulkanDeviceAPI
+
+Implements the TVM DeviceAPI interface. Owns the core Vulkan datastructures. Is
+responsible for initializing the Vulkan instance and devices, querying for
+possible extensions.
+
+### VulkanThreadEntry
+
+Thread-local state for the Vulkan runtime. Maintains a staging buffer (for
+copies), and a VulkanStream per device.
+
+### VulkanWrappedFunc
+
+Responsible for launching computation kernels. Responsible for obtaining a
+VulkanPipeline instance (from the VulkanModuleNode), and launches the kernel
+(via immediate or deferred mode) on the active VulkanStream instance.
+
+## Stream execution in the Vulkan programming model.
+
+The natural model for TVM DeviceAPI implementation and runtime follows the CUDA
+API model. That is, we launch "kernels" onto a (implicit or explicit) "stream"
+(which execute asynchronously with respect to the host, but ordered with respect
+to the stream), and explicitly synchronize the stream with respect to the host.
+We simulate this behaviour in the Vulkan model by maintaining a thread-local
+`vkCommandBuffer` instance, and queueing up (or eagerly executing, depending on
+the availability of the `VK_KHR_push_descriptor` extension). When we synchronize
+the stream, we end the command buffer recording, submit it to the device queue,
+and wait on the corresponding fence.
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
new file mode 100644 (file)
index 0000000..e3b2ac8
--- /dev/null
@@ -0,0 +1,1150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vulkan/vulkan.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include <array>
+#include <cstring>
+
+
+#include "../file_util.h"
+#include "../pack_args.h"
+#include "../thread_storage_scope.h"
+#include "../workspace_pool.h"
+
+#include "vulkan_common.h"
+#include "vulkan_module.h"
+#include "vulkan_shader.h"
+#include "vulkan_stream.h"
+
+namespace tvm {
+namespace runtime {
+namespace vulkan {
+
+/*! \brief Maximum number of GPU supported in VulkanModule. */
+static constexpr const int kVulkanMaxNumDevice = 8;
+
+/*! \brief TVM Vulkan binary pack magic number */
+static constexpr const int kVulkanModuleMagic = 0x02700027;
+
+class VulkanThreadEntry {
+ public:
+  VulkanThreadEntry();
+  static VulkanThreadEntry* ThreadLocal();
+
+  ~VulkanThreadEntry() {
+    // Because the thread entry refers to Device API
+    // The command buffer always will be destroyed before
+    // the instance and device get destroyed.
+    // The destruction need to be manually called
+    // to ensure the destruction order.
+    streams_.clear();
+    for (const auto& kv : staging_buffers_) {
+      if (!kv.second) {
+        continue;
+      }
+      auto& buf = *(kv.second);
+      if (buf.host_addr != nullptr) {
+        vkUnmapMemory(buf.device, buf.memory);
+      }
+      if (buf.memory != VK_NULL_HANDLE) {
+        vkFreeMemory(buf.device, buf.memory, nullptr);
+      }
+      if (buf.buffer != VK_NULL_HANDLE) {
+        vkDestroyBuffer(buf.device, buf.buffer, nullptr);
+      }
+    }
+  }
+
+  TVMContext ctx;
+  WorkspacePool pool;
+  VulkanStream* Stream(size_t device_id);
+  VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
+
+ private:
+  std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
+  std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
+};
+
+struct VulkanBuffer {
+  VkBuffer buffer{VK_NULL_HANDLE};
+  VkDeviceMemory memory{VK_NULL_HANDLE};
+};
+
+struct VulkanPipeline {
+  VulkanContext* vctx_{nullptr};
+  VkShaderModule shader{VK_NULL_HANDLE};
+  VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
+  VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
+  VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
+  VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
+  VkPipeline pipeline{VK_NULL_HANDLE};
+  VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
+};
+
+typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
+
+class VulkanDeviceAPI final : public DeviceAPI {
+ public:
+  VulkanDeviceAPI();
+  ~VulkanDeviceAPI() {
+    for (auto& vctx : context_) {
+      vkDestroyDevice(vctx.device, nullptr);
+    }
+    if (instance_) {
+      vkDestroyInstance(instance_, nullptr);
+    }
+  }
+  void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
+  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
+  void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
+    const auto& vctx = context(ctx.device_id);
+    VkBufferCreateInfo info;
+    info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+    info.pNext = nullptr;
+    info.flags = 0;
+    info.size = nbytes;
+    info.queueFamilyIndexCount = 1;
+    info.pQueueFamilyIndices = &(vctx.queue_family_index);
+    info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+    info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+                 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
+    // create buffer
+    VkBuffer buffer;
+    VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
+    // bind to memory
+    VkBufferMemoryRequirementsInfo2KHR req_info2;
+    req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
+    req_info2.pNext = 0;
+    req_info2.buffer = buffer;
+
+    VkMemoryRequirements2KHR req2;
+    req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
+    req2.pNext = 0;
+
+    VkMemoryDedicatedRequirementsKHR dedicated_req;
+    dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
+    dedicated_req.pNext = 0;
+    req2.pNext = &dedicated_req;
+
+    bool dedicated_allocation = false;
+    if (vctx.get_buffer_memory_requirements_2_functions) {
+      vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
+          vctx.device, &req_info2, &req2);
+      dedicated_allocation =
+          dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
+    }
+
+    VkDeviceMemory memory;
+    if (!dedicated_allocation) {
+      VkMemoryAllocateInfo minfo;
+      minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+      minfo.pNext = nullptr;
+      minfo.allocationSize = nbytes;
+      minfo.memoryTypeIndex = vctx.compute_mtype_index;
+      VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+    } else {
+      VkMemoryAllocateInfo minfo;
+      minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+      minfo.pNext = nullptr;
+      minfo.allocationSize = req2.memoryRequirements.size;
+      minfo.memoryTypeIndex = vctx.compute_mtype_index;
+
+      VkMemoryDedicatedAllocateInfoKHR mdinfo;
+      mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
+      mdinfo.pNext = 0;
+      mdinfo.image = 0;
+      mdinfo.buffer = buffer;
+      minfo.pNext = &mdinfo;
+      VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+    }
+    VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
+    VulkanBuffer* pbuf = new VulkanBuffer();
+    pbuf->memory = memory;
+    pbuf->buffer = buffer;
+    return pbuf;
+  }
+
+  void FreeDataSpace(TVMContext ctx, void* ptr) final {
+    const auto& vctx = context(ctx.device_id);
+    auto* pbuf = static_cast<VulkanBuffer*>(ptr);
+    vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
+    vkFreeMemory(vctx.device, pbuf->memory, nullptr);
+    delete pbuf;
+  }
+
+  void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+                      TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
+                      TVMStreamHandle stream) final {
+    CHECK(stream == nullptr);
+    TVMContext ctx = ctx_from;
+    if (ctx_from.device_type == kDLCPU) {
+      ctx = ctx_to;
+    }
+
+    int from_dev_type = static_cast<int>(ctx_from.device_type);
+    int to_dev_type = static_cast<int>(ctx_to.device_type);
+    if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
+      VulkanThreadEntry::ThreadLocal()
+          ->Stream(ctx_from.device_id)
+          ->Launch([=](VulkanStreamState* state) {
+            // 1: copy
+            const auto* from_buf = static_cast<const VulkanBuffer*>(from);
+            auto* to_buf = static_cast<VulkanBuffer*>(to);
+            VkBufferCopy copy_info;
+            copy_info.srcOffset = from_offset;
+            copy_info.dstOffset = to_offset;
+            copy_info.size = size;
+            vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, &copy_info);
+            // 2: barrier(transfer-> compute|transfer)
+            CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
+            VkMemoryBarrier barrier_info;
+            barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+            barrier_info.pNext = nullptr;
+            barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
+            barrier_info.dstAccessMask =
+                (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
+                 VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
+            vkCmdPipelineBarrier(
+                state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
+                VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
+                &barrier_info, 0, nullptr, 0, nullptr);
+          });
+
+    } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
+      const auto* from_buf = static_cast<const VulkanBuffer*>(from);
+      const auto& vctx = context(ctx_from.device_id);
+      auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
+      VulkanThreadEntry::ThreadLocal()
+          ->Stream(ctx_from.device_id)
+          ->Launch([&](VulkanStreamState* state) {
+            VkBufferCopy copy_info;
+            copy_info.srcOffset = from_offset;
+            copy_info.dstOffset = 0;
+            copy_info.size = size;
+            vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, &copy_info);
+          });
+      VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
+      if (!vctx.coherent_staging) {
+        VkMappedMemoryRange mrange;
+        mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
+        mrange.pNext = nullptr;
+        mrange.memory = temp->memory;
+        mrange.offset = 0;
+        mrange.size = VK_WHOLE_SIZE;  // size;
+        VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
+      }
+      memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
+    } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
+      const auto& vctx = context(ctx_to.device_id);
+      const auto* to_buf = static_cast<const VulkanBuffer*>(to);
+      VulkanStagingBuffer* temp =
+          VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
+      memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
+      // host side flush if access is not coherent.
+      // so writes from CPU is visible to GPU
+      if (!vctx.coherent_staging) {
+        VkMappedMemoryRange mrange;
+        mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
+        mrange.pNext = nullptr;
+        mrange.memory = temp->memory;
+        mrange.offset = 0;
+        mrange.size = VK_WHOLE_SIZE;  // size;
+        VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
+      }
+
+      VulkanThreadEntry::ThreadLocal()
+          ->Stream(ctx_from.device_id)
+          ->Launch([&](VulkanStreamState* state) {
+            // 0: barrier(host->transfer)
+            VkMemoryBarrier barrier_info;
+            barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+            barrier_info.pNext = nullptr;
+            barrier_info.srcAccessMask = 0;
+            barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
+            vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
+                                 VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
+                                 nullptr);
+            // 1: copy
+            VkBufferCopy copy_info;
+            copy_info.srcOffset = 0;
+            copy_info.dstOffset = to_offset;
+            copy_info.size = size;
+            vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, &copy_info);
+          });
+      // TODO(tulloch): should we instead make the staging buffer a property of the
+      // Stream? This would allow us to elide synchronizations here.
+      VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
+    } else {
+      LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
+                 << ", from=" << from_dev_type << ", to=" << to_dev_type;
+    }
+  }
+
+  // Always use the default stream
+  TVMStreamHandle CreateStream(TVMContext ctx) {
+    LOG(FATAL) << "Not implemented";
+    return nullptr;
+  }
+
+  void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
+    LOG(FATAL) << "Not implemented";
+    return;
+  }
+
+  void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
+    LOG(FATAL) << "Not implemented";
+    return;
+  }
+
+  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
+    CHECK(stream == nullptr);
+    VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
+  }
+
+  void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
+    LOG(FATAL) << "Not implemented";
+    return;
+  }
+
+  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
+    return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
+  }
+
+  void FreeWorkspace(TVMContext ctx, void* data) final {
+    VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
+  }
+
+  static const std::shared_ptr<VulkanDeviceAPI>& Global() {
+    static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
+    return inst;
+  }
+
+  const VulkanContext& context(size_t device_id) const {
+    CHECK_LT(device_id, context_.size());
+    return context_[device_id];
+  }
+
+ private:
+  VkInstance instance_{nullptr};
+  // The physical devices, have 1 to 1 mapping to devices
+  std::vector<VulkanContext> context_;
+};
+
+void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
+  size_t index = static_cast<size_t>(ctx.device_id);
+  if (kind == kExist) {
+    *rv = static_cast<int>(index < context_.size());
+    return;
+  }
+  CHECK_LT(index, context_.size()) << "Invalid device id " << index;
+  const auto& vctx = context(index);
+  switch (kind) {
+    case kMaxThreadsPerBlock: {
+      VkPhysicalDeviceProperties phy_prop;
+      vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
+      int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
+      *rv = value;
+      break;
+    }
+    case kMaxSharedMemoryPerBlock: {
+      VkPhysicalDeviceProperties phy_prop;
+      vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
+      int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
+      *rv = value;
+      break;
+    }
+    case kWarpSize: {
+      *rv = 1;
+      break;
+    }
+    case kComputeVersion: {
+      VkPhysicalDeviceProperties phy_prop;
+      vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
+      int64_t value = phy_prop.apiVersion;
+      std::ostringstream os;
+      os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
+         << VK_VERSION_PATCH(value);
+      *rv = os.str();
+      break;
+    }
+    case kDeviceName:
+      return;
+    case kMaxClockRate:
+      return;
+    case kMultiProcessorCount:
+      return;
+    case kExist:
+      break;
+    case kMaxThreadDimensions:
+      break;
+  }
+}
+
+VulkanDeviceAPI::VulkanDeviceAPI() {
+  VkApplicationInfo app_info;
+  app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+  app_info.pNext = nullptr;
+  app_info.pApplicationName = "TVM";
+  app_info.applicationVersion = 0;
+  app_info.pEngineName = "";
+  app_info.engineVersion = 0;
+  app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
+
+  VkInstanceCreateInfo inst_info;
+  inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+  inst_info.pNext = nullptr;
+  inst_info.flags = 0;
+
+  const auto layers = []() -> std::vector<const char*> {
+    uint32_t inst_layer_prop_count;
+    VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
+    std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
+    VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
+    std::vector<const char*> l;
+    for (const auto& lp : inst_layer_prop) {
+      // TODO(tulloch): add CMAKE options.
+      (void)lp;  // suppress unused variable warning.
+#ifdef USE_VULKAN_VALIDATION
+      if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
+        l.push_back("VK_LAYER_LUNARG_standard_validation");
+      }
+      if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
+        l.push_back("VK_LAYER_LUNARG_parameter_validation");
+      }
+      if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
+        l.push_back("VK_LAYER_KHRONOS_validation");
+      }
+#endif
+    }
+    return l;
+  }();
+
+  const auto instance_extensions = []() -> std::vector<const char*> {
+    uint32_t inst_extension_prop_count;
+    VULKAN_CALL(
+        vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
+    std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
+    VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
+                                                       inst_extension_prop.data()));
+    std::vector<const char*> extensions;
+    for (const auto& ip : inst_extension_prop) {
+      if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
+        extensions.push_back("VK_KHR_get_physical_device_properties2");
+      }
+    }
+    return extensions;
+  }();
+
+  inst_info.pApplicationInfo = &app_info;
+  inst_info.enabledLayerCount = layers.size();
+  inst_info.ppEnabledLayerNames = layers.data();
+  inst_info.enabledExtensionCount = instance_extensions.size();
+  inst_info.ppEnabledExtensionNames = instance_extensions.data();
+
+  VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
+
+  uint32_t phy_dev_count = 0;
+  VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
+  std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
+  VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
+  for (VkPhysicalDevice phy_dev : all_phy_devs) {
+    uint32_t queue_prop_count = 0;
+    vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
+    std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
+    vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
+                                             dmlc::BeginPtr(queue_props));
+    uint32_t queue_family_index = 0;
+    std::vector<VkDeviceQueueCreateInfo> queue_create_info;
+    float priority = 1.0f;
+    for (uint32_t i = 0; i < queue_props.size(); i++) {
+      // find queues that support compute
+      if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
+        VkDeviceQueueCreateInfo info;
+        info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+        info.pNext = nullptr;
+        info.flags = 0;
+        info.queueFamilyIndex = i;
+        info.queueCount = 1;
+        info.pQueuePriorities = &priority;
+
+        queue_create_info.push_back(info);
+        // only use the first available queue for now
+        if (queue_create_info.size() == 0) {
+          queue_family_index = i;
+        }
+      }
+    }
+    if (queue_create_info.size() == 0) continue;
+
+    VulkanContext ctx;
+    // setup context
+    ctx.phy_device = phy_dev;
+    vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
+
+    const auto extensions = [&]() {
+      uint32_t device_extension_prop_count;
+      VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
+                                                       &device_extension_prop_count, nullptr));
+      std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
+      VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
+          ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
+      std::vector<const char*> extensions;
+      for (const auto& dp : device_extension_prop) {
+        if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
+          extensions.push_back("VK_KHR_push_descriptor");
+        }
+        if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
+            dp.specVersion > 0) {
+          extensions.push_back("VK_KHR_descriptor_update_template");
+        }
+        if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
+            dp.specVersion > 0) {
+          extensions.push_back("VK_KHR_get_memory_requirements2");
+        }
+        if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
+            dp.specVersion > 0) {
+          extensions.push_back("VK_KHR_dedicated_allocation");
+        }
+      }
+      return extensions;
+    }();
+    VkDeviceCreateInfo device_create_info;
+    device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
+    device_create_info.pNext = nullptr;
+    device_create_info.flags = 0;
+    device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
+    device_create_info.pQueueCreateInfos = queue_create_info.data();
+    device_create_info.enabledLayerCount = 0;
+    device_create_info.ppEnabledLayerNames = nullptr;
+    device_create_info.enabledExtensionCount = extensions.size();
+    device_create_info.ppEnabledExtensionNames = extensions.data();
+    device_create_info.pEnabledFeatures = nullptr;
+    VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
+    ctx.queue_mutex.reset(new std::mutex());
+    vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
+    ctx.queue_family_index = queue_family_index;
+    // Find suitable memory type for staging and compute
+    // Find suitable compute index.
+    VkBuffer buffer;
+    VkMemoryRequirements req_staging, req_compute;
+    VkBufferCreateInfo info;
+    info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+    info.pNext = nullptr;
+    info.flags = 0;
+    info.size = 1024;
+    info.queueFamilyIndexCount = 1;
+    info.pQueueFamilyIndices = &(ctx.queue_family_index);
+    info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+
+    // get staging requirement
+    info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
+    VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
+    vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
+    vkDestroyBuffer(ctx.device, buffer, nullptr);
+    // get compute requirement
+    info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+                 VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
+    VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
+    vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
+    vkDestroyBuffer(ctx.device, buffer, nullptr);
+
+    // Query phyiscal device property
+    // find a memory that is host visible, no need to be consistent
+    int win_rank = -1;
+    VkPhysicalDeviceMemoryProperties prop;
+    vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
+
+    for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
+      VkMemoryType ty = prop.memoryTypes[k];
+      size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+      // host visible
+      if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
+      // match copy requirment
+      if (!(req_staging.memoryTypeBits & (1 << k))) continue;
+      if (heap_size < 1024) continue;
+      int rank = 0;
+      rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
+      if (rank > win_rank) {
+        win_rank = rank;
+        ctx.staging_mtype_index = k;
+        ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
+      }
+    }
+    CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
+
+    win_rank = -1;
+    for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
+      VkMemoryType ty = prop.memoryTypes[k];
+      size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+      // host visible
+      if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
+      // match copy requirment
+      if (!(req_staging.memoryTypeBits & (1 << k))) continue;
+      if (heap_size < 1024) continue;
+      int rank = 0;
+      // prefer not host visible
+      rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
+      if (rank > win_rank) {
+        win_rank = rank;
+        ctx.compute_mtype_index = k;
+      }
+    }
+    CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
+    auto has_extension = [&extensions](const char* query) {
+      return std::any_of(extensions.begin(), extensions.end(),
+                         [&](const char* extension) { return std::strcmp(query, extension) == 0; });
+    };
+
+#ifdef USE_VULKAN_IMMEDIATE_MODE
+    if (has_extension("VK_KHR_push_descriptor") &&
+        has_extension("VK_KHR_descriptor_update_template")) {
+      ctx.descriptor_template_khr_functions =
+          std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
+              new VulkanDescriptorTemplateKHRFunctions());
+      ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
+          CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
+              ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
+      ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
+          CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
+              ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
+      ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
+          CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
+              ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
+      ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
+          CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
+              ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
+    }
+#endif
+
+#ifdef USE_VULKAN_DEDICATED_ALLOCATION
+    if (has_extension("VK_KHR_get_memory_requirements2") &&
+        has_extension("VK_KHR_dedicated_allocation")) {
+      ctx.get_buffer_memory_requirements_2_functions =
+          std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
+              new VulkanGetBufferMemoryRequirements2Functions());
+      ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
+          CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
+              ctx.device, "vkGetBufferMemoryRequirements2KHR"));
+    }
+#endif
+    context_.push_back(std::move(ctx));
+  }
+
+  LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
+  for (size_t i = 0; i < context_.size(); ++i) {
+    LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
+              << "\' phy_dev_id=" << context_[i].phy_device
+              << " use_immediate=" << context_[i].UseImmediate();
+  }
+}  // namespace vulkan
+class VulkanModuleNode;
+
+// a wrapped function class to get packed func.
+class VulkanWrappedFunc {
+ public:
+  void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name,
+            size_t num_buffer_args, size_t num_pack_args,
+            const std::vector<std::string>& thread_axis_tags) {
+    m_ = m;
+    sptr_ = sptr;
+    func_name_ = func_name;
+    num_buffer_args_ = num_buffer_args;
+    num_pack_args_ = num_pack_args;
+    thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
+  }
+
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
+
+ private:
+  // internal module
+  VulkanModuleNode* m_;
+  // the resource holder
+  std::shared_ptr<ModuleNode> sptr_;
+  // v The name of the function.
+  std::string func_name_;
+  // Number of buffer arguments
+  size_t num_buffer_args_;
+  // number of packed arguments.
+  size_t num_pack_args_;
+  // Device state cache per device.
+  // mark as mutable, to enable lazy initialization
+  // thread axis configuration
+  ThreadAxisConfig thread_axis_cfg_;
+
+  mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
+};
+
+// Multi-device enabled module.
+class VulkanModuleNode final : public runtime::ModuleNode {
+ public:
+  explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
+                             std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
+      : smap_(smap), fmap_(fmap), source_(source) {}
+
+  const char* type_key() const final { return "vulkan"; }
+
+  PackedFunc GetFunction(const std::string& name,
+                         const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+    CHECK_EQ(sptr_to_self.get(), this);
+    CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
+    auto it = fmap_.find(name);
+    if (it == fmap_.end()) return PackedFunc();
+    const FunctionInfo& info = it->second;
+    VulkanWrappedFunc f;
+    size_t num_buffer_args = NumBufferArgs(info.arg_types);
+    f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
+           info.thread_axis_tags);
+    return PackFuncNonBufferArg(std::move(f), info.arg_types);
+  }
+
+  ~VulkanModuleNode() {
+    // cleanup vulkan related caches.
+    for (int device_id = 0; device_id < ecache_.size(); ++device_id) {
+      for (auto& kv : ecache_[device_id]) {
+        auto& pe = kv.second;
+        CHECK(pe);
+        const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+
+        if (pe->descriptor_update_template != VK_NULL_HANDLE) {
+          vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
+              vctx.device, pe->descriptor_update_template, nullptr);
+        }
+        vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
+        vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
+        vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
+        vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
+        vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
+      }
+    }
+  }
+
+  std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
+                                               size_t num_pack_args) {
+    const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+    std::lock_guard<std::mutex> lock(mutex_);
+    const auto& cp = ecache_[device_id][func_name];
+    if (cp) {
+      return cp;
+    }
+    // Create new pipeline
+    auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
+    {
+      // create shader
+      auto sit = smap_.find(func_name);
+      CHECK(sit != smap_.end());
+      const std::vector<uint32_t>& data = sit->second.data;
+      VkShaderModuleCreateInfo shader_cinfo;
+      shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
+      shader_cinfo.pNext = nullptr;
+      shader_cinfo.flags = 0;
+      shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
+      shader_cinfo.pCode = data.data();
+      VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
+    }
+    std::vector<VkDescriptorSetLayoutBinding> arg_binding;
+    std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
+    uint32_t num_pod = 0, num_buffer = 0;
+    {
+      auto fit = fmap_.find(func_name);
+      CHECK(fit != fmap_.end());
+      for (TVMType arg_type : fit->second.arg_types) {
+        if (arg_type.code == kHandle) {
+          {
+            VkDescriptorSetLayoutBinding bd;
+            bd.binding = num_buffer;
+            bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+            bd.descriptorCount = 1;
+            bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+            bd.pImmutableSamplers = nullptr;
+            arg_binding.push_back(bd);
+          }
+          {
+            VkDescriptorUpdateTemplateEntryKHR tpl;
+            tpl.dstBinding = num_buffer;
+            tpl.dstArrayElement = 0;
+            tpl.descriptorCount = 1;
+            tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+            tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
+            tpl.stride = sizeof(VkDescriptorBufferInfo);
+            arg_template.push_back(tpl);
+          }
+          ++num_buffer;
+        } else {
+          ++num_pod;
+        }
+      }
+    }
+
+    {
+      VkDescriptorSetLayoutCreateInfo descrip_cinfo;
+      descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
+      descrip_cinfo.pNext = nullptr;
+      descrip_cinfo.flags = 0;
+      if (vctx.UseImmediate()) {
+        descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
+      }
+      descrip_cinfo.bindingCount = arg_binding.size();
+      descrip_cinfo.pBindings = arg_binding.data();
+      VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
+                                              &(pe->descriptor_set_layout)));
+    }
+
+    {
+      VkDescriptorPoolSize pool_size;
+      pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+      pool_size.descriptorCount = arg_binding.size();
+      VkDescriptorPoolCreateInfo descrip_pool_cinfo;
+      descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
+      descrip_pool_cinfo.pNext = nullptr;
+      descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
+      descrip_pool_cinfo.maxSets = 1;
+      descrip_pool_cinfo.poolSizeCount = 1;
+      descrip_pool_cinfo.pPoolSizes = &pool_size;
+      VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
+                                         &(pe->descriptor_pool)));
+    }
+
+    if (!vctx.UseImmediate()) {
+      VkDescriptorSetAllocateInfo alloc_info;
+      alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
+      alloc_info.pNext = nullptr;
+      alloc_info.descriptorPool = pe->descriptor_pool;
+      alloc_info.descriptorSetCount = 1;
+      alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
+      VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
+    }
+
+    VkPushConstantRange crange;
+    crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+    crange.offset = 0;
+    crange.size = sizeof(ArgUnion) * num_pack_args;
+
+    VkPipelineLayoutCreateInfo playout_cinfo;
+    playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
+    playout_cinfo.pNext = nullptr;
+    playout_cinfo.flags = 0;
+    playout_cinfo.setLayoutCount = 1;
+    playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
+
+    if (num_pack_args != 0) {
+      playout_cinfo.pushConstantRangeCount = 1;
+      playout_cinfo.pPushConstantRanges = &crange;
+      CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
+    } else {
+      playout_cinfo.pushConstantRangeCount = 0;
+      playout_cinfo.pPushConstantRanges = nullptr;
+    }
+
+    VULKAN_CALL(
+        vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
+
+    VkComputePipelineCreateInfo pipeline_cinfo;
+    pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+    pipeline_cinfo.pNext = nullptr;
+    pipeline_cinfo.flags = 0;
+    pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
+    pipeline_cinfo.stage.pNext = nullptr;
+    pipeline_cinfo.stage.flags = 0;
+    pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
+    pipeline_cinfo.stage.module = pe->shader;
+    pipeline_cinfo.stage.pName = func_name.c_str();
+    pipeline_cinfo.stage.pSpecializationInfo = nullptr;
+    pipeline_cinfo.layout = pe->pipeline_layout;
+    pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
+    pipeline_cinfo.basePipelineIndex = 0;
+    VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
+                                         &(pe->pipeline)));
+
+    if (vctx.UseImmediate()) {
+      VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
+      descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
+      descrip_template_cinfo.pNext = 0;
+      descrip_template_cinfo.flags = 0;
+      descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
+      descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
+      descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
+      descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
+      descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
+      descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
+      descrip_template_cinfo.set = 0;
+      VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
+          vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
+    }
+    ecache_[device_id][func_name] = pe;
+    return pe;
+  }
+
+  void SaveToFile(const std::string& file_name, const std::string& format) final {
+    std::string fmt = GetFileFormat(file_name, format);
+    CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
+    std::string meta_file = GetMetaFilePath(file_name);
+    SaveMetaDataToFile(meta_file, fmap_);
+    std::string data_bin;
+    dmlc::MemoryStringStream fs(&data_bin);
+    dmlc::Stream* stream = &fs;
+    uint32_t magic = kVulkanModuleMagic;
+    stream->Write(magic);
+    stream->Write(smap_);
+    SaveBinaryToFile(file_name, data_bin);
+  }
+
+  void SaveToBinary(dmlc::Stream* stream) final {
+    stream->Write(fmt_);
+    stream->Write(fmap_);
+    stream->Write(smap_);
+  }
+  std::string GetSource(const std::string& format) final {
+    // can only return source code.
+    return source_;
+  }
+
+ private:
+  // the binary data
+  std::vector<uint32_t> data_;
+  // function information table.
+  std::unordered_map<std::string, VulkanShader> smap_;
+  // function information table.
+  std::unordered_map<std::string, FunctionInfo> fmap_;
+  // The format
+  std::string fmt_{"vulkan"};
+  // The source
+  std::string source_;
+
+  // Guards accesses to `ecache_`
+  std::mutex mutex_;
+  std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
+      ecache_;
+};
+
+Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
+                          std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
+  std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source);
+  return Module(n);
+}
+
+VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
+
+VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
+  if (!staging_buffers_[device_id]) {
+    staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
+  }
+  auto& buf = *(staging_buffers_[device_id]);
+  if (buf.device != nullptr && buf.size < size) {
+    // free previous buffer
+    if (buf.host_addr != nullptr) {
+      vkUnmapMemory(buf.device, buf.memory);
+    }
+    if (buf.memory != VK_NULL_HANDLE) {
+      vkFreeMemory(buf.device, buf.memory, nullptr);
+    }
+    if (buf.buffer != VK_NULL_HANDLE) {
+      vkDestroyBuffer(buf.device, buf.buffer, nullptr);
+    }
+    buf.host_addr = nullptr;
+    buf.memory = VK_NULL_HANDLE;
+    buf.buffer = VK_NULL_HANDLE;
+  }
+  const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+
+  if (buf.device == nullptr) {
+    buf.device = vctx.device;
+  }
+  if (buf.memory == VK_NULL_HANDLE) {
+    // allocate the stagging buffer memory if necessary
+    VkBufferCreateInfo info;
+    info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+    info.pNext = nullptr;
+    info.flags = 0;
+    info.size = size;
+    info.queueFamilyIndexCount = 1;
+    info.pQueueFamilyIndices = &(vctx.queue_family_index);
+    info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
+    info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+    VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
+    VkMemoryAllocateInfo minfo;
+    minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+    minfo.pNext = nullptr;
+    minfo.allocationSize = size;
+    minfo.memoryTypeIndex = vctx.staging_mtype_index;
+    VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
+    VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
+    VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
+    buf.size = size;
+  }
+  memset(buf.host_addr, 0, size);
+  return &buf;
+}
+
+VulkanThreadEntry::VulkanThreadEntry()
+    : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
+  ctx.device_id = 0;
+  ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
+}
+
+VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
+  if (!streams_[device_id]) {
+    streams_[device_id] = std::unique_ptr<VulkanStream>(
+        new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
+  }
+  return streams_[device_id].get();
+}
+
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
+                                    const ArgUnion* pack_args) const {
+  int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
+  CHECK_LT(device_id, kVulkanMaxNumDevice);
+  const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+  if (!scache_[device_id]) {
+    scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
+  }
+  const auto& pipeline = scache_[device_id];
+  ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+  std::vector<VkDescriptorBufferInfo> descriptor_buffers;
+  descriptor_buffers.resize(num_buffer_args_);
+  for (int i = 0; i < num_buffer_args_; ++i) {
+    void* buf = args[static_cast<int>(i)];
+    VkDescriptorBufferInfo binfo;
+    binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
+    binfo.offset = 0;
+    binfo.range = VK_WHOLE_SIZE;
+    descriptor_buffers[i] = binfo;
+  }
+  if (vctx.UseImmediate()) {
+    // Can safely capture by reference as this lambda is immediately executed on the calling thread.
+    VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
+      vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
+      CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
+      vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
+          state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
+          descriptor_buffers.data());
+      if (num_pack_args_ != 0) {
+        vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
+                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
+                           pack_args);
+      }
+      vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
+      VkMemoryBarrier barrier_info;
+      barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+      barrier_info.pNext = nullptr;
+      barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
+      barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
+                                    VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
+      vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
+                           VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
+                           1, &barrier_info, 0, nullptr, 0, nullptr);
+    });
+    return;
+  }
+
+  // Otherwise, the more expensive deferred path.
+  std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
+  const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
+    std::vector<VkWriteDescriptorSet> write_descriptor_sets;
+    write_descriptor_sets.resize(descriptor_buffers.size());
+    for (int i = 0; i < write_descriptor_sets.size(); i++) {
+      write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+      write_descriptor_sets[i].pNext = 0;
+      write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
+      write_descriptor_sets[i].dstBinding = i;
+      write_descriptor_sets[i].dstArrayElement = 0;
+      write_descriptor_sets[i].descriptorCount = 1;
+      write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+      write_descriptor_sets[i].pImageInfo = 0;
+      write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
+      write_descriptor_sets[i].pTexelBufferView = 0;
+    }
+    vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
+                           0, 0);
+  };
+  const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
+    vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
+    vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
+                            pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
+                            nullptr);
+    if (pack_args_storage.size() != 0) {
+      vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
+                         0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
+    }
+    vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
+    VkMemoryBarrier barrier_info;
+    barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+    barrier_info.pNext = nullptr;
+    barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
+    barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
+                                  VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
+    vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
+                         VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
+                         1, &barrier_info, 0, nullptr, 0, nullptr);
+  };
+  VulkanStreamToken deferred_token;
+  deferred_token.descriptor_set_ = pipeline->descriptor_set;
+  deferred_token.buffers_.resize(descriptor_buffers.size());
+  for (int i = 0; i < descriptor_buffers.size(); ++i) {
+    deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
+  }
+  VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
+      deferred_initializer, deferred_kernel, deferred_token);
+}
+
+Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
+  std::string data;
+  std::unordered_map<std::string, VulkanShader> smap;
+  std::unordered_map<std::string, FunctionInfo> fmap;
+  std::string fmt = GetFileFormat(file_name, format);
+  std::string meta_file = GetMetaFilePath(file_name);
+  LoadBinaryFromFile(file_name, &data);
+  LoadMetaDataFromFile(meta_file, &fmap);
+  dmlc::MemoryStringStream fs(&data);
+  dmlc::Stream* stream = &fs;
+  uint32_t magic;
+  stream->Read(&magic);
+  CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
+  stream->Read(&smap);
+  return VulkanModuleCreate(smap, fmap, "");
+}
+
+Module VulkanModuleLoadBinary(void* strm) {
+  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+  std::unordered_map<std::string, VulkanShader> smap;
+  std::unordered_map<std::string, FunctionInfo> fmap;
+
+  std::string fmt;
+  stream->Read(&fmt);
+  stream->Read(&fmap);
+  stream->Read(&smap);
+  return VulkanModuleCreate(smap, fmap, "");
+}
+
+TVM_REGISTER_GLOBAL("module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
+
+TVM_REGISTER_GLOBAL("module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
+
+TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
+  DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
+  *rv = static_cast<void*>(ptr);
+});
+
+}  // namespace vulkan
+}  // namespace runtime
+}  // namespace tvm
index ad6100e..9242d3d 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  * specific language governing permissions and limitations
  * under the License.
  */
+#pragma once
 
-/*!
- *  Copyright (c) 2017 by Contributors
- * \file vulkan_common.h
- * \brief Vulkan common header
- */
-#ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
-#define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
-
+#include <dmlc/logging.h>
 #include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/device_api.h>
-#include <dmlc/logging.h>
+#include <tvm/runtime/packed_func.h>
 
 #include <vulkan/vulkan.h>
+#include <memory>
 #include <mutex>
 #include <string>
 #include <vector>
-#include <memory>
-#include "../workspace_pool.h"
 
 namespace tvm {
 namespace runtime {
@@ -43,25 +35,44 @@ namespace vulkan {
 
 inline const char* VKGetErrorString(VkResult error) {
   switch (error) {
-    case VK_SUCCESS: return "VK_SUCCESS";
-    case VK_NOT_READY: return "VK_NOT_READY";
-    case VK_TIMEOUT: return "VK_TIMEOUT";
-    case VK_EVENT_SET: return "VK_EVENT_SET";
-    case VK_EVENT_RESET: return "VK_EVENT_RESET";
-    case VK_INCOMPLETE: return "VK_INCOMPLETE";
-    case VK_ERROR_OUT_OF_HOST_MEMORY: return "VK_ERROR_OUT_OF_HOST_MEMORY";
-    case VK_ERROR_OUT_OF_DEVICE_MEMORY: return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
-    case VK_ERROR_INITIALIZATION_FAILED: return "VK_ERROR_INITIALIZATION_FAILED";
-    case VK_ERROR_DEVICE_LOST: return "VK_ERROR_DEVICE_LOST";
-    case VK_ERROR_MEMORY_MAP_FAILED: return "VK_ERROR_MEMORY_MAP_FAILED";
-    case VK_ERROR_LAYER_NOT_PRESENT: return "VK_ERROR_LAYER_NOT_PRESENT";
-    case VK_ERROR_EXTENSION_NOT_PRESENT: return "VK_ERROR_EXTENSION_NOT_PRESENT";
-    case VK_ERROR_FEATURE_NOT_PRESENT: return "VK_ERROR_FEATURE_NOT_PRESENT";
-    case VK_ERROR_INCOMPATIBLE_DRIVER: return "VK_ERROR_INCOMPATIBLE_DRIVER";
-    case VK_ERROR_TOO_MANY_OBJECTS: return "VK_ERROR_TOO_MANY_OBJECTS";
-    case VK_ERROR_FORMAT_NOT_SUPPORTED: return "VK_ERROR_FORMAT_NOT_SUPPORTED";
-    case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL";
-    default: return "Unknown Vulkan error code";
+    case VK_SUCCESS:
+      return "VK_SUCCESS";
+    case VK_NOT_READY:
+      return "VK_NOT_READY";
+    case VK_TIMEOUT:
+      return "VK_TIMEOUT";
+    case VK_EVENT_SET:
+      return "VK_EVENT_SET";
+    case VK_EVENT_RESET:
+      return "VK_EVENT_RESET";
+    case VK_INCOMPLETE:
+      return "VK_INCOMPLETE";
+    case VK_ERROR_OUT_OF_HOST_MEMORY:
+      return "VK_ERROR_OUT_OF_HOST_MEMORY";
+    case VK_ERROR_OUT_OF_DEVICE_MEMORY:
+      return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
+    case VK_ERROR_INITIALIZATION_FAILED:
+      return "VK_ERROR_INITIALIZATION_FAILED";
+    case VK_ERROR_DEVICE_LOST:
+      return "VK_ERROR_DEVICE_LOST";
+    case VK_ERROR_MEMORY_MAP_FAILED:
+      return "VK_ERROR_MEMORY_MAP_FAILED";
+    case VK_ERROR_LAYER_NOT_PRESENT:
+      return "VK_ERROR_LAYER_NOT_PRESENT";
+    case VK_ERROR_EXTENSION_NOT_PRESENT:
+      return "VK_ERROR_EXTENSION_NOT_PRESENT";
+    case VK_ERROR_FEATURE_NOT_PRESENT:
+      return "VK_ERROR_FEATURE_NOT_PRESENT";
+    case VK_ERROR_INCOMPATIBLE_DRIVER:
+      return "VK_ERROR_INCOMPATIBLE_DRIVER";
+    case VK_ERROR_TOO_MANY_OBJECTS:
+      return "VK_ERROR_TOO_MANY_OBJECTS";
+    case VK_ERROR_FORMAT_NOT_SUPPORTED:
+      return "VK_ERROR_FORMAT_NOT_SUPPORTED";
+    case VK_ERROR_FRAGMENTED_POOL:
+      return "VK_ERROR_FRAGMENTED_POOL";
+    default:
+      return "Unknown Vulkan error code";
   }
 }
 
@@ -69,19 +80,37 @@ inline const char* VKGetErrorString(VkResult error) {
  * \brief Protected Vulkan call
  * \param func Expression to call.
  */
-#define VULKAN_CHECK_ERROR(__e)                                         \
-  {                                                                     \
-    CHECK(__e == VK_SUCCESS)                                            \
-        << "Vulan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \
+#define VULKAN_CHECK_ERROR(__e)                                     \
+  {                                                                 \
+    CHECK(__e == VK_SUCCESS) << "Vulan Error, code=" << __e << ": " \
+                             << vulkan::VKGetErrorString(__e);      \
   }
 
-#define VULKAN_CALL(func)                                             \
-  {                                                                   \
-    VkResult __e = (func);                                            \
-    VULKAN_CHECK_ERROR(__e);                                          \
+#define VULKAN_CALL(func)    \
+  {                          \
+    VkResult __e = (func);   \
+    VULKAN_CHECK_ERROR(__e); \
   }
 
-/*! \brief Auxiliary context structure for vulkan */
+struct VulkanDescriptorTemplateKHRFunctions {
+  PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr};
+  PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr};
+  PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr};
+  PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr};
+};
+
+struct VulkanGetBufferMemoryRequirements2Functions {
+  PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr};
+};
+
+struct VulkanStagingBuffer {
+  VkDevice device{nullptr};
+  VkBuffer buffer{VK_NULL_HANDLE};
+  VkDeviceMemory memory{VK_NULL_HANDLE};
+  void* host_addr{nullptr};
+  size_t size{0};
+};
+
 struct VulkanContext {
   // phyiscal device
   VkPhysicalDevice phy_device{nullptr};
@@ -91,211 +120,27 @@ struct VulkanContext {
   uint32_t staging_mtype_index{0};
   // whether staging is coherent
   bool coherent_staging{false};
+
+  std::unique_ptr<VulkanDescriptorTemplateKHRFunctions> descriptor_template_khr_functions{nullptr};
+  std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>
+      get_buffer_memory_requirements_2_functions{nullptr};
   // Memory type index for compute
   uint32_t compute_mtype_index{0};
   // The logical device
   VkDevice device{nullptr};
   // command queue
+
+  std::unique_ptr<std::mutex> queue_mutex;
   VkQueue queue{nullptr};
   // queue family_index;
   uint32_t queue_family_index{0};
   // Queue family index.
   VkQueueFamilyProperties queue_prop;
-};
 
-/*! \brief The buffer object */
-struct VulkanBuffer {
-  /*! \brief underlying buffer */
-  VkBuffer buffer{VK_NULL_HANDLE};
-  /*! \brief underlying buffer */
-  VkDeviceMemory memory{VK_NULL_HANDLE};
+  bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
 };
 
-/*! \brief Buffer only used for stagging */
-struct VulkanStagingBuffer {
-  /*! \brief the corresponding device */
-  VkDevice device{nullptr};
-  /*! \brief underlying buffer */
-  VkBuffer buffer{VK_NULL_HANDLE};
-  /*! \brief underlying buffer */
-  VkDeviceMemory memory{VK_NULL_HANDLE};
-  /*! \brief host address */
-  void* host_addr{nullptr};
-  /*! \brief size of the memory */
-  size_t size{0};
-};
-
-/*!
- * \brief Process global Vulkan workspace.
- */
-class VulkanWorkspace final : public DeviceAPI {
- public:
-  // global mutex
-  std::mutex mu;
-  // whether the workspace it initialized.
-  bool initialized_{false};
-  // vulkan instance
-  VkInstance instance_{nullptr};
-  // The physical devices, have 1 to 1 mapping to devices
-  std::vector<VulkanContext> context_;
-  // Destructor
-  ~VulkanWorkspace();
-  // Initialize workspace
-  // Return false if already initialized, otherwise return true.
-  void Init();
-  // override device API
-  void SetDevice(TVMContext ctx) final;
-  void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
-  void* AllocDataSpace(TVMContext ctx,
-                       size_t nbytes,
-                       size_t alignment,
-                       TVMType type_hint) final;
-  void FreeDataSpace(TVMContext ctx, void* ptr) final;
-  void CopyDataFromTo(const void* from,
-                      size_t from_size,
-                      void* to,
-                      size_t to_size,
-                      size_t size,
-                      TVMContext ctx_from,
-                      TVMContext ctx_to,
-                      TVMType type_hint,
-                      TVMStreamHandle stream) final;
-  void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
-  void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
-  void FreeWorkspace(TVMContext ctx, void* data) final;
-  // get the global workspace
-  static const std::shared_ptr<VulkanWorkspace>& Global();
-};
-
-/*! \brief Helper command buffer resource */
-struct VulkanCommandBuffer {
-  /*! \brief fence to signal the resource is ready to use */
-  VkFence fence{VK_NULL_HANDLE};
-  /*! \brief The internal command buffer */
-  VkCommandBuffer cmd_buffer{nullptr};
-  /*! \brief Descriptor set used to bind arguments */
-  VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
-  /*! \brief Internal utilities for write command */
-  VkWriteDescriptorSet write_descriptor_set;
-
-  VulkanCommandBuffer() {
-    write_descriptor_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
-    write_descriptor_set.pNext = nullptr;
-    write_descriptor_set.dstSet = VK_NULL_HANDLE;
-    write_descriptor_set.dstBinding = 0;
-    write_descriptor_set.dstArrayElement = 0;
-    write_descriptor_set.descriptorCount = 1;
-    write_descriptor_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-    write_descriptor_set.pImageInfo = nullptr;
-    write_descriptor_set.pBufferInfo = nullptr;
-    write_descriptor_set.pTexelBufferView = nullptr;
-  }
-};
-
-/*!
- * \brief Command pool backed by a fixed size ring buffer.
- *
- *  Vulkan requires us not to reuse command buffer until
- *  All its corresponding jobs have finished.
- *
- *  This class to faciliate automatic management
- *  of the command buffers. A fence is created
- *  for each launch of command buffer jobs
- *  and when we try to reuse the same entry
- *  in the ring, we need to make sure that
- *  the previous pending job already finishes.
- *
- */
-class VulkanCommandPool {
- public:
-  /*! \brief Maximum number of pending jobs in the pool */
-  static constexpr const int kMaxPending = 4;
-  /*! \brief Maximum number of pending jobs in the pool */
-  static constexpr const int kMaxNumArgs = 16;
-  /*!
-   * \brief constructor
-   * \param vctx The corresponding vulkan context.
-   */
-  explicit VulkanCommandPool(const VulkanContext& vctx);
-  /*! \brief destructor */
-  ~VulkanCommandPool();
-  /*!
-   * \brief Allocate a new command buffer entry
-   *
-   *  The caller must only submit the entry once
-   *  with the given fence in the entry,
-   *  before calling next Alloc.
-   *
-   *  This function may block to wait for a
-   *  previously unfinished command when
-   *  there is more than kMaxPending jobs.
-   *
-   * \returns The allocated entry.
-   */
-  VulkanCommandBuffer* Alloc();
-
-  /*!
-   * \brief Allocate a new command buffer entry
-   * \param dlayout the descriptor layout.
-   *
-   * \returns The allocated entry.
-   */
-  VulkanCommandBuffer* Alloc(const VkDescriptorSetLayout* dlayout);
-
- private:
-  /*! \brief Local ring buffer */
-  std::vector<VulkanCommandBuffer> ring_;
-  /*! \brief clock pointer */
-  size_t clock_ptr_{0};
-  /*! \brief the corresponding device*/
-  VkDevice device_{nullptr};
-  /*! \brief internal command buffer pool */
-  VkCommandPool cmd_pool_{VK_NULL_HANDLE};
-  /*! \brief Descriptor pool */
-  VkDescriptorPool descriptor_pool_{VK_NULL_HANDLE};
-};
-
-/*! \brief Thread local workspace */
-class VulkanThreadEntry {
- public:
-  /*! \brief The current context */
-  TVMContext context;
-  /*! \brief workspace pool */
-  WorkspacePool pool;
-  /*! \brief The staging buffers */
-  std::vector<VulkanStagingBuffer> staging_buffer_;
-  /*!
-   * \brief Get the command pool of corresponding device;
-   * \param device_id The device id
-   * \return The corresponding command buffer.
-   */
-  VulkanCommandPool* CommandPool(int device_id);
-  /*!
-   * \brief Get the stagging buffer.
-   * \param device_id The device id
-   * \return The corresponding stagging buffer.
-   */
-  VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
-
-  // constructor
-  VulkanThreadEntry()
-      : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanWorkspace::Global()) {
-    context.device_id = 0;
-    context.device_type = static_cast<DLDeviceType>(kDLVulkan);
-  }
-  ~VulkanThreadEntry();
-  // get the global workspace
-  static VulkanThreadEntry* ThreadLocal();
-
- private:
-  /*! \brief the command pools */
-  std::vector<std::unique_ptr<VulkanCommandPool> > pool_;
-};
-
-// inline implementation
-
 
 }  // namespace vulkan
 }  // namespace runtime
 }  // namespace tvm
-#endif  // TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc
deleted file mode 100644 (file)
index da04acd..0000000
+++ /dev/null
@@ -1,711 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Copyright (c) 2017 by Contributors
- * \file vulkan_device_api.cc
- */
-#include <tvm/runtime/registry.h>
-#include <dmlc/thread_local.h>
-#include <cstring>
-#include "vulkan_common.h"
-
-namespace tvm {
-namespace runtime {
-namespace vulkan {
-
-VulkanWorkspace::~VulkanWorkspace() {
-  for (VulkanContext& ctx : context_) {
-    vkDestroyDevice(ctx.device, nullptr);
-  }
-  if (instance_ != nullptr) {
-    vkDestroyInstance(instance_, nullptr);
-  }
-}
-
-const std::shared_ptr<VulkanWorkspace>& VulkanWorkspace::Global() {
-  static std::shared_ptr<VulkanWorkspace> inst = std::make_shared<VulkanWorkspace>();
-  return inst;
-}
-
-void VulkanWorkspace::SetDevice(TVMContext ctx) {
-  VulkanThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
-}
-
-void VulkanWorkspace::GetAttr(
-    TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
-  this->Init();
-  size_t index = static_cast<size_t>(ctx.device_id);
-  if (kind == kExist) {
-    *rv = static_cast<int>(index< context_.size());
-    return;
-  }
-  CHECK_LT(index, context_.size())
-      << "Invalid device id " << index;
-  switch (kind) {
-    case kMaxThreadsPerBlock: {
-      VkPhysicalDeviceProperties phy_prop;
-      vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
-      int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
-      *rv = value;
-      break;
-    }
-    case kMaxSharedMemoryPerBlock: {
-      VkPhysicalDeviceProperties phy_prop;
-      vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
-      int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
-      *rv = value;
-      break;
-    }
-    case kWarpSize: {
-      *rv = 1;
-      break;
-    }
-    case kComputeVersion: {
-      VkPhysicalDeviceProperties phy_prop;
-      vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
-      int64_t value = phy_prop.apiVersion;
-      std::ostringstream os;
-      os << VK_VERSION_MAJOR(value)
-         << "." << VK_VERSION_MINOR(value)
-         << "." << VK_VERSION_PATCH(value);
-      *rv = os.str();
-      break;
-    }
-    case kDeviceName: return;
-    case kMaxClockRate: return;
-    case kMultiProcessorCount: return;
-    case kExist: break;
-    case kMaxThreadDimensions: break;
-  }
-}
-
-void* VulkanWorkspace::AllocDataSpace(
-    TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
-  this->Init();
-
-  VulkanContext& vctx = context_[ctx.device_id];
-
-  VkBufferCreateInfo info;
-  info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
-  info.pNext = nullptr;
-  info.flags = 0;
-  info.size = size;
-  info.queueFamilyIndexCount = 1;
-  info.pQueueFamilyIndices = &(vctx.queue_family_index);
-  info.usage =
-      VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
-      VK_BUFFER_USAGE_TRANSFER_DST_BIT |
-      VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
-  // create buffer
-  VkBuffer buffer;
-  VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
-  // bind to memory
-  VkMemoryAllocateInfo minfo;
-  minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-  minfo.pNext = nullptr;
-  minfo.allocationSize = size;
-  minfo.memoryTypeIndex = vctx.compute_mtype_index;
-  VkDeviceMemory memory;
-  VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
-  VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
-
-  VulkanBuffer* pbuf = new VulkanBuffer();
-  pbuf->memory = memory;
-  pbuf->buffer = buffer;
-  return pbuf;
-}
-
-void VulkanWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
-  VulkanContext& vctx = context_[ctx.device_id];
-  VulkanBuffer* pbuf = static_cast<VulkanBuffer*>(ptr);
-  vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
-  vkFreeMemory(vctx.device, pbuf->memory, nullptr);
-  delete pbuf;
-}
-
-void VulkanWorkspace::CopyDataFromTo(const void* from,
-                                     size_t from_offset,
-                                     void* to,
-                                     size_t to_offset,
-                                     size_t size,
-                                     TVMContext ctx_from,
-                                     TVMContext ctx_to,
-                                     TVMType type_hint,
-                                     TVMStreamHandle stream) {
-  this->Init();
-  CHECK(stream == nullptr);
-  TVMContext ctx = ctx_from;
-  if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
-  VulkanThreadEntry* tls = VulkanThreadEntry::ThreadLocal();
-  VulkanCommandBuffer* cmd = tls->CommandPool(ctx.device_id)->Alloc();
-
-  VkCommandBufferBeginInfo cb_begin;
-  cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
-  cb_begin.pNext = nullptr;
-  cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
-  cb_begin.pInheritanceInfo = 0;
-
-  VkSubmitInfo cb_submit;
-  cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
-  cb_submit.pNext = nullptr;
-  cb_submit.waitSemaphoreCount = 0;
-  cb_submit.pWaitSemaphores = nullptr;
-  cb_submit.pWaitDstStageMask = 0;
-  cb_submit.commandBufferCount = 1;
-  cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
-  cb_submit.signalSemaphoreCount = 0;
-  cb_submit.pSignalSemaphores = nullptr;
-
-
-  int from_dev_type = static_cast<int>(ctx_from.device_type);
-  int to_dev_type = static_cast<int>(ctx_to.device_type);
-
-  if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
-    CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
-        << "Vulkan disallow cross device copy.";
-    const VulkanContext& vctx = context_[ctx_from.device_id];
-    const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
-    VulkanBuffer* to_buf = static_cast<VulkanBuffer*>(to);
-    // The assumption is that subsequence ops only perform compute/transfer
-    // 0: begin
-    VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
-    // 1: copy
-    VkBufferCopy copy_info;
-    copy_info.srcOffset = from_offset;
-    copy_info.dstOffset = to_offset;
-    copy_info.size = size;
-    vkCmdCopyBuffer(cmd->cmd_buffer, from_buf->buffer, to_buf->buffer, 1, &copy_info);
-    // 2: barrier(transfer-> compute|transfer)
-    VkMemoryBarrier barrier_info;
-    barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
-    barrier_info.pNext = nullptr;
-    barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
-    barrier_info.dstAccessMask =
-        (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
-         VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
-    vkCmdPipelineBarrier(
-        cmd->cmd_buffer,
-        VK_PIPELINE_STAGE_TRANSFER_BIT,
-        VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
-        0, 1, &barrier_info, 0, nullptr, 0, nullptr);
-    // 3: end
-    VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
-    // 4: submit with cmd->fence
-    VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
-  } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
-    const VulkanContext& vctx = context_[ctx_from.device_id];
-    const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
-    VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_from.device_id, size);
-    // 0: begin
-    VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
-    // 1: copy
-    VkBufferCopy copy_info;
-    copy_info.srcOffset = from_offset;
-    copy_info.dstOffset = 0;
-    copy_info.size = size;
-    vkCmdCopyBuffer(cmd->cmd_buffer,
-                    from_buf->buffer,
-                    temp->buffer,
-                    1, &copy_info);
-    // 2: end
-    VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
-    // 4: submit with cmd->fence
-    VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
-    // Block until done, to make sure temp can be reused later.
-    VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
-    // host side invalidation if access is not coherent.
-    // so writes from GPU is visible to CPU
-    if (!vctx.coherent_staging) {
-      VkMappedMemoryRange mrange;
-      mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
-      mrange.pNext = nullptr;
-      mrange.memory = temp->memory;
-      mrange.offset = 0;
-      mrange.size = size;
-      VULKAN_CALL(vkInvalidateMappedMemoryRanges(
-          vctx.device, 1, &mrange));
-    }
-    memcpy(static_cast<char*>(to) + to_offset,
-           static_cast<char*>(temp->host_addr),
-           size);
-  } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
-    const VulkanContext& vctx = context_[ctx_to.device_id];
-    const VulkanBuffer* to_buf = static_cast<const VulkanBuffer*>(to);
-    VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_to.device_id, size);
-    memcpy(temp->host_addr,
-           static_cast<const char*>(from) + from_offset,
-           size);
-    // host side flush if access is not coherent.
-    // so writes from CPU is visible to GPU
-    if (!vctx.coherent_staging) {
-      VkMappedMemoryRange mrange;
-      mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
-      mrange.pNext = nullptr;
-      mrange.memory = temp->memory;
-      mrange.offset = 0;
-      mrange.size = size;
-      VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
-    }
-    VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
-    // 0: barrier(host->transfer)
-    VkMemoryBarrier barrier_info;
-    barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
-    barrier_info.pNext = nullptr;
-    barrier_info.srcAccessMask = 0;
-    barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
-    vkCmdPipelineBarrier(cmd->cmd_buffer,
-                         VK_PIPELINE_STAGE_HOST_BIT,
-                         VK_PIPELINE_STAGE_TRANSFER_BIT,
-                         0, 1, &barrier_info,
-                         0, nullptr, 0, nullptr);
-    // 1: copy
-    VkBufferCopy copy_info;
-    copy_info.srcOffset = 0;
-    copy_info.dstOffset = to_offset;
-    copy_info.size = size;
-    vkCmdCopyBuffer(cmd->cmd_buffer,
-                    temp->buffer,
-                    to_buf->buffer,
-                    1, &copy_info);
-    // 2: end
-    VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
-    // 4: submit with cmd->fence
-    VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
-    // wait until copy finishes, so we can reuse temp next time.
-    VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
-  } else {
-    LOG(FATAL) << "Expect copy from/to Metal or between Metal"
-               << ", from=" << from_dev_type
-               << ", to=" << to_dev_type;
-  }
-}
-
-void VulkanWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
-  CHECK(stream == nullptr);
-  VulkanContext& vctx = context_[ctx.device_id];
-  VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
-}
-
-void* VulkanWorkspace::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) {
-  return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
-}
-
-void VulkanWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
-  VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
-}
-
-// VulkanCommandPool
-VulkanCommandPool::VulkanCommandPool(const VulkanContext& vctx) {
-  ring_.resize(kMaxPending, VulkanCommandBuffer());
-  device_ = vctx.device;
-  {
-    // create command pool
-    VkCommandPoolCreateInfo cmd_pool_cinfo;
-    cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
-    cmd_pool_cinfo.pNext = nullptr;
-    cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
-    cmd_pool_cinfo.queueFamilyIndex = vctx.queue_family_index;
-    VULKAN_CALL(vkCreateCommandPool(device_, &cmd_pool_cinfo, nullptr, &cmd_pool_));
-  }
-  {
-    // create descriptor pool
-    VkDescriptorPoolSize pool_size;
-    pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-    pool_size.descriptorCount = kMaxPending * kMaxNumArgs;
-    VkDescriptorPoolCreateInfo descrip_pool_cinfo;
-    descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
-    descrip_pool_cinfo.pNext = nullptr;
-    descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
-    descrip_pool_cinfo.maxSets = kMaxPending + 2;
-    descrip_pool_cinfo.poolSizeCount = 1;
-    descrip_pool_cinfo.pPoolSizes = &pool_size;
-    VULKAN_CALL(vkCreateDescriptorPool(
-        device_, &descrip_pool_cinfo, nullptr, &descriptor_pool_));
-  }
-  VkCommandBufferAllocateInfo buffer_alloc_info;
-  buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
-  buffer_alloc_info.pNext = nullptr;
-  buffer_alloc_info.commandPool = cmd_pool_;
-  buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
-  buffer_alloc_info.commandBufferCount = 1;
-
-  VkFenceCreateInfo fence_cinfo;
-  fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
-  fence_cinfo.pNext = nullptr;
-  fence_cinfo.flags = VK_FENCE_CREATE_SIGNALED_BIT;
-
-  for (size_t i = 0; i < ring_.size(); ++i) {
-    VULKAN_CALL(vkAllocateCommandBuffers(
-        device_, &buffer_alloc_info, &(ring_[i].cmd_buffer)));
-    VULKAN_CALL(vkCreateFence(
-        device_, &fence_cinfo, nullptr, &(ring_[i].fence)));
-  }
-}
-
-VulkanCommandPool::~VulkanCommandPool() {
-  // wait device to be idle so we know we can recycle buffers
-  VULKAN_CALL(vkDeviceWaitIdle(device_));
-  // start recycling.
-  for (size_t i = 0; i < ring_.size(); ++i) {
-    if (ring_[i].cmd_buffer != nullptr) {
-      vkFreeCommandBuffers(device_, cmd_pool_, 1, &(ring_[i].cmd_buffer));
-      ring_[i].cmd_buffer = nullptr;
-    }
-    if (ring_[i].fence != VK_NULL_HANDLE) {
-      vkDestroyFence(device_, ring_[i].fence, nullptr);
-    }
-  }
-  // delete cmd_pool and descriptor pool
-  vkDestroyCommandPool(device_, cmd_pool_, nullptr);
-  vkDestroyDescriptorPool(device_, descriptor_pool_, nullptr);
-}
-
-VulkanCommandBuffer* VulkanCommandPool::Alloc() {
-  return Alloc(nullptr);
-}
-
-VulkanCommandBuffer* VulkanCommandPool::Alloc(
-    const VkDescriptorSetLayout* dlayout) {
-  // always allocate resource in round robin manner
-  VulkanCommandBuffer* e = &(ring_[clock_ptr_]);
-  clock_ptr_ = (clock_ptr_ + 1) % ring_.size();
-  // Wait until previous usage of commad buffer is finished.
-  uint64_t timeout = 1UL << 30UL;
-  VkResult res;
-  res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
-  while (res == VK_TIMEOUT) {
-    res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
-  }
-  VULKAN_CHECK_ERROR(res);
-  vkResetFences(device_, 1, (&e->fence));
-  if (e->descriptor_set != VK_NULL_HANDLE) {
-    VULKAN_CALL(vkFreeDescriptorSets(
-        device_, descriptor_pool_, 1, &(e->descriptor_set)));
-    e->descriptor_set = VK_NULL_HANDLE;
-  }
-  if (dlayout != nullptr) {
-    VkDescriptorSetAllocateInfo alloc_info;
-    alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
-    alloc_info.pNext = nullptr;
-    alloc_info.descriptorPool = descriptor_pool_;
-    alloc_info.descriptorSetCount = 1;
-    alloc_info.pSetLayouts = dlayout;
-    VULKAN_CALL(vkAllocateDescriptorSets(
-        device_, &alloc_info, &(e->descriptor_set)));
-  }
-  return e;
-}
-
-// VulkanThreadEntry
-typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
-
-VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() {
-  return VulkanThreadStore::Get();
-}
-
-VulkanCommandPool* VulkanThreadEntry::CommandPool(int device_id) {
-  while (pool_.size() <= static_cast<size_t>(device_id)) {
-    pool_.emplace_back(std::unique_ptr<VulkanCommandPool>());
-  }
-  if (pool_[device_id] == nullptr) {
-    const VulkanContext& vctx =
-        VulkanWorkspace::Global()->context_[device_id];
-    pool_[device_id].reset(new VulkanCommandPool(vctx));
-  }
-  return pool_[device_id].get();
-}
-
-VulkanStagingBuffer*
-VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
-  if (staging_buffer_.size() <= static_cast<size_t>(device_id)) {
-    staging_buffer_.resize(device_id + 1, VulkanStagingBuffer());
-  }
-  VulkanStagingBuffer& buf = staging_buffer_[device_id];
-
-  if (buf.device != nullptr && buf.size < size) {
-    // free previous buffer
-    if (buf.host_addr != nullptr) {
-      vkUnmapMemory(buf.device, buf.memory);
-    }
-    if (buf.memory != VK_NULL_HANDLE) {
-      vkFreeMemory(buf.device, buf.memory, nullptr);
-    }
-    if (buf.buffer != VK_NULL_HANDLE) {
-      vkDestroyBuffer(buf.device, buf.buffer, nullptr);
-    }
-    buf.host_addr = nullptr;
-    buf.memory = VK_NULL_HANDLE;
-    buf.buffer = VK_NULL_HANDLE;
-  }
-  const VulkanContext& vctx =
-      VulkanWorkspace::Global()->context_[device_id];
-
-  if (buf.device == nullptr) {
-    buf.device = vctx.device;
-  }
-  if (buf.memory == VK_NULL_HANDLE) {
-    // allocate the stagging buffer memory if necessary
-    VkBufferCreateInfo info;
-    info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
-    info.pNext = nullptr;
-    info.flags = 0;
-    info.size = size;
-    info.queueFamilyIndexCount = 1;
-    info.pQueueFamilyIndices = &(vctx.queue_family_index);
-    info.usage =
-        VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
-        VK_BUFFER_USAGE_TRANSFER_DST_BIT;
-    VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
-    VkMemoryAllocateInfo minfo;
-    minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-    minfo.pNext = nullptr;
-    minfo.allocationSize = size;
-    minfo.memoryTypeIndex = vctx.staging_mtype_index;
-    VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
-    VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
-    VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
-    buf.size = size;
-  }
-  memset(buf.host_addr, 0, size);
-  return &buf;
-}
-
-VulkanThreadEntry::~VulkanThreadEntry() {
-  // Because the thread entry refers to Device API
-  // The command buffer always will be destroyed before
-  // the instance and device get destroyed.
-  // The destruction need to be manually called
-  // to ensure the destruction order.
-  pool_.clear();
-  for (VulkanStagingBuffer buf : staging_buffer_) {
-    if (buf.host_addr != nullptr) {
-      vkUnmapMemory(buf.device, buf.memory);
-    }
-    if (buf.memory != VK_NULL_HANDLE) {
-      vkFreeMemory(buf.device, buf.memory, nullptr);
-    }
-    if (buf.buffer != VK_NULL_HANDLE) {
-      vkDestroyBuffer(buf.device, buf.buffer, nullptr);
-    }
-  }
-}
-
-VkInstance CreateInstance() {
-  VkApplicationInfo app_info;
-  app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
-  app_info.pNext = nullptr;
-  app_info.pApplicationName = "TVM";
-  app_info.applicationVersion = 0;
-  app_info.pEngineName = "";
-  app_info.engineVersion = 0;
-  app_info.apiVersion = VK_MAKE_VERSION(1, 0, 65);
-
-  VkInstanceCreateInfo inst_info;
-  inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
-  inst_info.pNext = nullptr;
-  inst_info.flags = 0;
-  inst_info.pApplicationInfo = &app_info;
-  inst_info.enabledLayerCount = 0;
-  inst_info.ppEnabledLayerNames = nullptr;
-  inst_info.enabledExtensionCount = 0;
-  inst_info.ppEnabledExtensionNames = nullptr;
-
-  VkInstance inst;
-  VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &inst));
-  return inst;
-}
-
-// find suitable mem_type_index for staging and compute
-void FindMemoryTypeIndex(VulkanContext* vctx) {
-  // Find suitable compute index.
-  VkBuffer buffer;
-  VkMemoryRequirements req_staging, req_compute;
-  VkBufferCreateInfo info;
-  info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
-  info.pNext = nullptr;
-  info.flags = 0;
-  info.size = 1024;
-  info.queueFamilyIndexCount = 1;
-  info.pQueueFamilyIndices = &(vctx->queue_family_index);
-
-  // get staging requirement
-  info.usage =
-      VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
-      VK_BUFFER_USAGE_TRANSFER_DST_BIT;
-  VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
-  vkGetBufferMemoryRequirements(vctx->device, buffer, &req_staging);
-  vkDestroyBuffer(vctx->device, buffer, nullptr);
-  // get compute requirement
-  info.usage =
-      VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
-      VK_BUFFER_USAGE_TRANSFER_DST_BIT |
-      VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
-  VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
-  vkGetBufferMemoryRequirements(vctx->device, buffer, &req_compute);
-  vkDestroyBuffer(vctx->device, buffer, nullptr);
-
-  // Query phyiscal device property
-  // find a memory that is host visible, no need to be consistent
-  int win_rank = -1;
-  VkPhysicalDeviceMemoryProperties prop;
-  vkGetPhysicalDeviceMemoryProperties(vctx->phy_device, &prop);
-
-  for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
-    VkMemoryType ty = prop.memoryTypes[k];
-    size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
-    // host visible
-    if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
-    // match copy requirment
-    if (!(req_staging.memoryTypeBits & (1 << k))) continue;
-    if (heap_size < 1024) continue;
-    int rank = 0;
-    rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
-    if (rank > win_rank) {
-      win_rank = rank;
-      vctx->staging_mtype_index = k;
-      vctx->coherent_staging =
-          ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
-    }
-  }
-  CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
-
-  win_rank = -1;
-  for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
-    VkMemoryType ty = prop.memoryTypes[k];
-    size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
-    // host visible
-    if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
-    // match copy requirment
-    if (!(req_staging.memoryTypeBits & (1 << k))) continue;
-    if (heap_size < 1024) continue;
-    int rank = 0;
-    // prefer not host visible
-    rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
-    if (rank > win_rank) {
-      win_rank = rank;
-      vctx->compute_mtype_index = k;
-    }
-  }
-  CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
-}
-
-// Get all logic devices that support compute
-std::vector<VulkanContext> GetContext(VkInstance instance) {
-  std::vector<VulkanContext> result;
-  uint32_t phy_dev_count = 0;
-  VULKAN_CALL(vkEnumeratePhysicalDevices(
-      instance, &phy_dev_count, nullptr));
-  std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
-  VULKAN_CALL(vkEnumeratePhysicalDevices(
-      instance, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
-  for (VkPhysicalDevice phy_dev : all_phy_devs) {
-    uint32_t queue_prop_count = 0;
-    vkGetPhysicalDeviceQueueFamilyProperties(
-        phy_dev, &queue_prop_count, nullptr);
-    std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
-    vkGetPhysicalDeviceQueueFamilyProperties(
-        phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props));
-    uint32_t queue_family_index = 0;
-    std::vector<VkDeviceQueueCreateInfo> queue_create_info;
-
-    for (uint32_t i = 0; i < queue_props.size(); i++) {
-      // find queues that support compute
-      if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
-        float priority = 1.0f;
-
-        VkDeviceQueueCreateInfo info;
-        info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
-        info.pNext = nullptr;
-        info.flags = 0;
-        info.queueFamilyIndex = i;
-        info.queueCount = 1;
-        info.pQueuePriorities = &priority;
-
-        queue_create_info.push_back(info);
-        // only use the first available queue for now
-        if (queue_create_info.size() == 0) {
-          queue_family_index = i;
-        }
-      }
-    }
-    if (queue_create_info.size() == 0) continue;
-
-    VkDeviceCreateInfo device_create_info;
-    device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
-    device_create_info.pNext = nullptr;
-    device_create_info.flags = 0;
-    device_create_info.queueCreateInfoCount
-        = static_cast<uint32_t>(queue_create_info.size());
-    device_create_info.pQueueCreateInfos = queue_create_info.data();
-    device_create_info.enabledLayerCount = 0;
-    device_create_info.ppEnabledLayerNames = nullptr;
-    device_create_info.enabledExtensionCount = 0;
-    device_create_info.ppEnabledExtensionNames = nullptr;
-    device_create_info.pEnabledFeatures = nullptr;
-
-    VulkanContext ctx;
-    // setup context
-    ctx.phy_device = phy_dev;
-    vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
-    VULKAN_CALL(vkCreateDevice(
-        phy_dev, &device_create_info, nullptr, &(ctx.device)));
-    vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
-    ctx.queue_family_index = queue_family_index;
-    FindMemoryTypeIndex(&ctx);
-    // Find suitable memory type for staging and compute
-    result.push_back(ctx);
-  }
-  return result;
-}
-
-void VulkanWorkspace::Init() {
-  if (initialized_) return;
-  std::lock_guard<std::mutex> lock(this->mu);
-  if (initialized_) return;
-  initialized_ = true;
-  try {
-    instance_ = CreateInstance();
-    context_ = GetContext(instance_);
-    LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
-    for (size_t i = 0; i < context_.size(); ++i) {
-      LOG(INFO) << "vulkan(" << i
-                <<  ")=\'" << context_[i].phy_device_prop.deviceName
-                << "\' phy_dev_id=" << context_[i].phy_device;
-    }
-  } catch (const dmlc::Error& err) {
-    LOG(INFO) << "Cannot initialize vulkan: " << err.what() << "\n"
-              << "You can still compile vulkan module but cannot run locally";
-  }
-}
-
-bool InitVulkan(TVMArgs args, TVMRetValue* rv) {
-  vulkan::VulkanWorkspace::Global()->Init();
-  return true;
-}
-
-TVM_REGISTER_GLOBAL("device_api.vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
-    DeviceAPI* ptr = VulkanWorkspace::Global().get();
-    *rv = static_cast<void*>(ptr);
-  });
-
-}  // namespace vulkan
-}  // namespace runtime
-}  // namespace tvm
diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc
deleted file mode 100644 (file)
index c1db14d..0000000
+++ /dev/null
@@ -1,435 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *  Copyright (c) 2018 by Contributors
- * \file vulkan_module.cc
- */
-#include <dmlc/memory_io.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/module.h>
-#include <array>
-#include <string>
-#include <mutex>
-#include "vulkan_common.h"
-#include "vulkan_module.h"
-#include "../pack_args.h"
-#include "../thread_storage_scope.h"
-#include "../meta_data.h"
-#include "../file_util.h"
-
-
-namespace tvm {
-namespace runtime {
-
-void VulkanShader::Save(dmlc::Stream* writer) const {
-  writer->Write(flag);
-  writer->Write(data);
-}
-
-bool VulkanShader::Load(dmlc::Stream* reader) {
-  if (!reader->Read(&flag)) return false;
-  if (!reader->Read(&data)) return false;
-  return true;
-}
-
-// Multi-device enabled module.
-class VulkanModuleNode final :public runtime::ModuleNode {
- public:
-  // Pipeline cache states
-  struct PipelineEntry {
-    VkShaderModule shader{VK_NULL_HANDLE};
-    VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
-    VkDescriptorSetLayout descriptor_layout{VK_NULL_HANDLE};
-    VkPipeline pipeline{VK_NULL_HANDLE};
-  };
-  // constructor
-  explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
-                            std::unordered_map<std::string, FunctionInfo> fmap,
-                            std::string source)
-      : smap_(smap), fmap_(fmap), source_(source) {
-  }
-
-  ~VulkanModuleNode() {
-    // cleanup vulkan related caches.
-    for (DeviceEntry& e : finfo_) {
-      if (e.device == nullptr) continue;
-      for (auto &kv : e.smap) {
-        PipelineEntry& pe = kv.second;
-        vkDestroyShaderModule(e.device, pe.shader, nullptr);
-        vkDestroyDescriptorSetLayout(e.device, pe.descriptor_layout, nullptr);
-        vkDestroyPipelineLayout(e.device, pe.pipeline_layout, nullptr);
-        vkDestroyPipeline(e.device, pe.pipeline, nullptr);
-      }
-    }
-  }
-  const char* type_key() const final {
-    return "vulkan";
-  }
-
-  PackedFunc GetFunction(
-      const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) final;
-
-  void SaveToFile(const std::string& file_name,
-                  const std::string& format) final {
-    std::string fmt = GetFileFormat(file_name, format);
-    CHECK_EQ(fmt, fmt_)
-        << "Can only save to customized format vulkan";
-    std::string meta_file = GetMetaFilePath(file_name);
-    SaveMetaDataToFile(meta_file, fmap_);
-    std::string data_bin;
-    dmlc::MemoryStringStream fs(&data_bin);
-    dmlc::Stream* stream = &fs;
-    uint32_t magic = kVulkanModuleMagic;
-    stream->Write(magic);
-    stream->Write(smap_);
-    SaveBinaryToFile(file_name, data_bin);
-  }
-
-  void SaveToBinary(dmlc::Stream* stream) final {
-    stream->Write(fmt_);
-    stream->Write(fmap_);
-    stream->Write(smap_);
-  }
-  std::string GetSource(const std::string& format) final {
-    // can only return source code.
-    return source_;
-  }
-
-  // get a from primary context in device_id
-  PipelineEntry GetPipeline(size_t device_id,
-                            const std::string& func_name,
-                            size_t num_pack_args) {
-    vulkan::VulkanWorkspace* w = vulkan::VulkanWorkspace::Global().get();
-    CHECK_LT(device_id, w->context_.size());
-    // start lock scope.
-    std::lock_guard<std::mutex> lock(mutex_);
-    if (finfo_.size() <= device_id) {
-      finfo_.resize(device_id + 1, DeviceEntry());
-    }
-    DeviceEntry& e = finfo_[device_id];
-    auto it = e.smap.find(func_name);
-    if (it != e.smap.end()) return it->second;
-    PipelineEntry pe;
-    if (e.device == nullptr) {
-      e.device = w->context_[device_id].device;
-    }
-    {
-      // create shader
-      auto sit = smap_.find(func_name);
-      CHECK(sit != smap_.end());
-      const std::vector<uint32_t>& data = sit->second.data;
-      VkShaderModuleCreateInfo shader_cinfo;
-      shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
-      shader_cinfo.pNext = nullptr;
-      shader_cinfo.flags = 0;
-      shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
-      shader_cinfo.pCode = data.data();
-      VULKAN_CALL(vkCreateShaderModule(
-          e.device, &shader_cinfo, nullptr, &(pe.shader)));
-    }
-    std::vector<VkDescriptorSetLayoutBinding> arg_binding;
-    uint32_t num_pod = 0, num_buffer = 0;
-    {
-      auto fit = fmap_.find(func_name);
-      CHECK(fit != fmap_.end());
-      for (TVMType arg_type : fit->second.arg_types) {
-        if (arg_type.code == kHandle) {
-          VkDescriptorSetLayoutBinding bd;
-          bd.binding = num_buffer;
-          bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-          bd.descriptorCount = 1;
-        bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
-        bd.pImmutableSamplers = nullptr;
-        arg_binding.push_back(bd);
-        ++num_buffer;
-        } else {
-          ++num_pod;
-        }
-      }
-    }
-
-    VkDescriptorSetLayoutCreateInfo descrip_cinfo;
-    descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
-    descrip_cinfo.pNext = nullptr;
-    descrip_cinfo.flags = 0;
-    descrip_cinfo.bindingCount = arg_binding.size();
-    descrip_cinfo.pBindings = arg_binding.data();
-    VULKAN_CALL(vkCreateDescriptorSetLayout(
-        e.device, &descrip_cinfo, nullptr, &(pe.descriptor_layout)));
-
-    VkPushConstantRange crange;
-    crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
-    crange.offset = 0;
-    crange.size = sizeof(ArgUnion) * num_pack_args;
-
-    VkPipelineLayoutCreateInfo playout_cinfo;
-    playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
-    playout_cinfo.pNext = nullptr;
-    playout_cinfo.flags = 0;
-    playout_cinfo.setLayoutCount = 1;
-    playout_cinfo.pSetLayouts = &(pe.descriptor_layout);
-
-    if (num_pack_args != 0) {
-      playout_cinfo.pushConstantRangeCount = 1;
-      playout_cinfo.pPushConstantRanges = &crange;
-      CHECK_LE(crange.size,
-               w->context_[device_id].phy_device_prop.limits.maxPushConstantsSize);
-    } else {
-      playout_cinfo.pushConstantRangeCount = 0;
-      playout_cinfo.pPushConstantRanges = nullptr;
-    }
-
-    VULKAN_CALL(vkCreatePipelineLayout(
-        e.device, &playout_cinfo, nullptr, &(pe.pipeline_layout)));
-    VkComputePipelineCreateInfo pipeline_cinfo;
-    pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
-    pipeline_cinfo.pNext = nullptr;
-    pipeline_cinfo.flags = 0;
-    pipeline_cinfo.stage.sType =
-        VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
-    pipeline_cinfo.stage.pNext = nullptr;
-    pipeline_cinfo.stage.flags = 0;
-    pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
-    pipeline_cinfo.stage.module = pe.shader;
-    pipeline_cinfo.stage.pName = func_name.c_str();
-    pipeline_cinfo.stage.pSpecializationInfo = nullptr;
-    pipeline_cinfo.layout = pe.pipeline_layout;
-    pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
-    pipeline_cinfo.basePipelineIndex = 0;
-    VULKAN_CALL(vkCreateComputePipelines(
-        e.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe.pipeline)));
-    e.smap[func_name] = pe;
-    return pe;
-  }
-
- private:
-  // device specific entry
-  struct DeviceEntry {
-    VkDevice device{nullptr};
-    std::unordered_map<std::string, PipelineEntry> smap;
-  };
-  // the binary data
-  std::vector<uint32_t> data_;
-  // function information table.
-  std::unordered_map<std::string, VulkanShader> smap_;
-  // function information table.
-  std::unordered_map<std::string, FunctionInfo> fmap_;
-  // The format
-  std::string fmt_{"vulkan"};
-  // The source
-  std::string source_;
-  // device local pipeline information.
-  std::vector<DeviceEntry> finfo_;
-  // internal mutex when updating the module
-  std::mutex mutex_;
-};
-
-// a wrapped function class to get packed func.
-class VulkanWrappedFunc {
- public:
-  // initialize the VULKAN function.
-  void Init(VulkanModuleNode* m,
-            std::shared_ptr<ModuleNode> sptr,
-            const std::string& func_name,
-            size_t num_buffer_args,
-            size_t num_pack_args,
-            const std::vector<std::string>& thread_axis_tags) {
-    w_ = vulkan::VulkanWorkspace::Global().get();
-    m_ = m;
-    sptr_ = sptr;
-    func_name_ = func_name;
-    num_buffer_args_ = num_buffer_args;
-    num_pack_args_ = num_pack_args;
-    thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
-  }
-  // invoke the function with void arguments
-  void operator()(TVMArgs args,
-                  TVMRetValue* rv,
-                  const ArgUnion* pack_args) const {
-    vulkan::VulkanThreadEntry* tls = vulkan::VulkanThreadEntry::ThreadLocal();
-    int device_id = tls->context.device_id;
-    CHECK_LT(device_id, kVulkanMaxNumDevice);
-    const vulkan::VulkanContext& vctx = w_->context_[device_id];
-    VulkanModuleNode::PipelineEntry& pe = scache_[device_id];
-    if (pe.pipeline == VK_NULL_HANDLE) {
-      pe = m_->GetPipeline(device_id, func_name_, num_pack_args_);
-    }
-    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
-    vulkan::VulkanCommandBuffer* cmd = tls->CommandPool(device_id)->Alloc(
-        &(pe.descriptor_layout));
-
-    cmd->write_descriptor_set.dstSet = cmd->descriptor_set;
-
-    // setup descriptors
-    for (uint32_t i = 0; i < num_buffer_args_; ++i) {
-      void* buf = args[static_cast<int>(i)];
-      VkDescriptorBufferInfo binfo;
-      binfo.buffer = static_cast<vulkan::VulkanBuffer*>(buf)->buffer;
-      binfo.offset = 0;
-      binfo.range = VK_WHOLE_SIZE;
-      cmd->write_descriptor_set.dstBinding = i;
-      cmd->write_descriptor_set.pBufferInfo = &binfo;
-      vkUpdateDescriptorSets(
-          vctx.device, 1, &(cmd->write_descriptor_set), 0, nullptr);
-    }
-
-    // dispatch
-    VkCommandBufferBeginInfo cb_begin;
-    cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
-    cb_begin.pNext = nullptr;
-    cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
-    cb_begin.pInheritanceInfo = 0;
-
-    VkSubmitInfo cb_submit;
-    cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
-    cb_submit.pNext = nullptr;
-    cb_submit.waitSemaphoreCount = 0;
-    cb_submit.pWaitSemaphores = nullptr;
-    cb_submit.pWaitDstStageMask = 0;
-    cb_submit.commandBufferCount = 1;
-    cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
-    cb_submit.signalSemaphoreCount = 0;
-    cb_submit.pSignalSemaphores = nullptr;
-    // 0: begin
-    VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
-    // 1: dispatch
-    vkCmdBindPipeline(
-        cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline);
-    vkCmdBindDescriptorSets(
-        cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
-        pe.pipeline_layout, 0, 1, &(cmd->descriptor_set), 0, nullptr);
-    // bind push constant if necessary
-    if (num_pack_args_ != 0) {
-      vkCmdPushConstants(
-          cmd->cmd_buffer,
-          pe.pipeline_layout,
-          VK_SHADER_STAGE_COMPUTE_BIT,
-          0, num_pack_args_ * sizeof(ArgUnion),
-          pack_args);
-    }
-    vkCmdDispatch(
-        cmd->cmd_buffer, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
-    // 2: barrier(compute->compute|transfer)
-    VkMemoryBarrier barrier_info;
-    barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
-    barrier_info.pNext = nullptr;
-    barrier_info.srcAccessMask =
-        VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
-    barrier_info.dstAccessMask =
-        (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
-         VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
-    vkCmdPipelineBarrier(
-        cmd->cmd_buffer,
-        VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
-        VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
-        0, 1, &barrier_info, 0, nullptr, 0, nullptr);
-    // 3: end
-    VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
-    // 4: submit with cmd->fence
-    VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
-  }
-
- private:
-  // Reference to global workspace.
-  vulkan::VulkanWorkspace* w_;
-  // internal module
-  VulkanModuleNode* m_;
-  // the resource holder
-  std::shared_ptr<ModuleNode> sptr_;
-  // The name of the function.
-  std::string func_name_;
-  // Number of buffer arguments
-  size_t num_buffer_args_;
-  // number of packed arguments.
-  size_t num_pack_args_;
-  // Device state cache per device.
-  // mark as mutable, to enable lazy initialization
-  mutable std::array<VulkanModuleNode::PipelineEntry, kVulkanMaxNumDevice> scache_;
-  // thread axis configuration
-  ThreadAxisConfig thread_axis_cfg_;
-};
-
-PackedFunc VulkanModuleNode::GetFunction(
-      const std::string& name,
-      const std::shared_ptr<ModuleNode>& sptr_to_self) {
-  CHECK_EQ(sptr_to_self.get(), this);
-  CHECK_NE(name, symbol::tvm_module_main)
-      << "Device function do not have main";
-  auto it = fmap_.find(name);
-  if (it == fmap_.end()) return PackedFunc();
-  const FunctionInfo& info = it->second;
-  VulkanWrappedFunc f;
-  size_t num_buffer_args = NumBufferArgs(info.arg_types);
-  f.Init(this, sptr_to_self, name,
-         num_buffer_args, info.arg_types.size() - num_buffer_args,
-         info.thread_axis_tags);
-  return PackFuncNonBufferArg(f, info.arg_types);
-}
-
-Module VulkanModuleCreate(
-    std::unordered_map<std::string, VulkanShader> smap,
-    std::unordered_map<std::string, FunctionInfo> fmap,
-    std::string source) {
-  vulkan::VulkanWorkspace::Global()->Init();
-  std::shared_ptr<VulkanModuleNode> n =
-      std::make_shared<VulkanModuleNode>(smap, fmap, source);
-  return Module(n);
-}
-
-// Load module from module.
-Module VulkanModuleLoadFile(const std::string& file_name,
-                            const std::string& format) {
-  std::string data;
-  std::unordered_map<std::string, VulkanShader> smap;
-  std::unordered_map<std::string, FunctionInfo> fmap;
-  std::string fmt = GetFileFormat(file_name, format);
-  std::string meta_file = GetMetaFilePath(file_name);
-  LoadBinaryFromFile(file_name, &data);
-  LoadMetaDataFromFile(meta_file, &fmap);
-  dmlc::MemoryStringStream fs(&data);
-  dmlc::Stream* stream = &fs;
-  uint32_t magic;
-  stream->Read(&magic);
-  CHECK_EQ(magic, kVulkanModuleMagic)
-      << "VulkanModule Magic mismatch";
-  stream->Read(&smap);
-  return VulkanModuleCreate(smap, fmap, "");
-}
-
-Module VulkanModuleLoadBinary(void* strm) {
-  dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
-  std::unordered_map<std::string, VulkanShader> smap;
-  std::unordered_map<std::string, FunctionInfo> fmap;
-
-  std::string fmt;
-  stream->Read(&fmt);
-  stream->Read(&fmap);
-  stream->Read(&smap);
-  return VulkanModuleCreate(smap, fmap, "");
-}
-
-TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
-.set_body_typed(VulkanModuleLoadFile);
-
-TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
-.set_body_typed(VulkanModuleLoadBinary);
-}  // namespace runtime
-}  // namespace tvm
index 13e5ce6..15c9ec3 100644 (file)
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  * specific language governing permissions and limitations
  * under the License.
  */
+#pragma once
 
-/*!
- *  Copyright (c) 2017 by Contributors
- * \file metal_module.h
- * \brief Execution handling of Metal kernels
- */
-#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
-#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
-
-#include <tvm/runtime/packed_func.h>
-#include <dmlc/type_traits.h>
-#include <memory>
-#include <vector>
 #include <string>
 #include <unordered_map>
+
 #include "../meta_data.h"
+#include "vulkan_shader.h"
 
 namespace tvm {
 namespace runtime {
-/*! \brief Maximum number of GPU supported in VulkanModule. */
-static constexpr const int kVulkanMaxNumDevice = 8;
-
-/*! \brief TVM Vulkan binary pack magic number */
-static constexpr const int kVulkanModuleMagic = 0x02700027;
-
-/*!
- * \brief A single VK shader program
- *
- *  Due to the global resource declaration.
- *  Current SPIRV only allows one entry program per shader,
- *  making it less useful for a Module like system.
- *
- *  Instead we pass in map of str->VulkanShader until
- *  there is a native solution available.
- */
-struct VulkanShader {
-  /*! \brief header flag */
-  uint32_t flag{0};
-  /*! \brief Data segment */
-  std::vector<uint32_t> data;
+namespace vulkan {
+Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
+                          std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
 
-  void Save(dmlc::Stream *writer) const;
-  bool Load(dmlc::Stream *reader);
-};
+}  // namespace vulkan
 
-/*!
- * \brief create a metal module from data.
- *
- * \param pmap The program map.
- * \param fmap The function information map.
- * \param source Optional, source code.
- */
-Module VulkanModuleCreate(
-    std::unordered_map<std::string, VulkanShader> smap,
-    std::unordered_map<std::string, FunctionInfo> fmap,
-    std::string source);
+using vulkan::VulkanModuleCreate;
 }  // namespace runtime
 }  // namespace tvm
-
-namespace dmlc {
-DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::VulkanShader, true);
-}  // namespace dmlc
-
-#endif  // TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/vulkan/vulkan_shader.h
new file mode 100644 (file)
index 0000000..1b2e454
--- /dev/null
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#pragma once
+
+
+#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace vulkan {
+
+struct VulkanShader {
+  /*! \brief header flag */
+  uint32_t flag{0};
+  /*! \brief Data segment */
+  std::vector<uint32_t> data;
+
+  void Save(dmlc::Stream* writer) const {
+    writer->Write(flag);
+    writer->Write(data);
+  }
+  bool Load(dmlc::Stream* reader) {
+    if (!reader->Read(&flag)) return false;
+    if (!reader->Read(&data)) return false;
+    return true;
+  }
+};
+
+}  // namespace vulkan
+
+using vulkan::VulkanShader;
+}  // namespace runtime
+}  // namespace tvm
+
+namespace dmlc {
+DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::VulkanShader, true);
+}  // namespace dmlc
diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h
new file mode 100644 (file)
index 0000000..512f44f
--- /dev/null
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "vulkan_common.h"
+
+
+namespace tvm {
+namespace runtime {
+namespace vulkan {
+
+class VulkanStreamState {
+ public:
+  VkCommandBuffer cmd_buffer_;
+  VkFence fence_;
+};
+
+// Used to identify state that should only be used once-per-stream.
+struct VulkanStreamToken {
+  VkDescriptorSet descriptor_set_{VK_NULL_HANDLE};
+  std::vector<VkBuffer> buffers_;
+};
+
+class VulkanStream {
+ public:
+  explicit VulkanStream(const VulkanContext* vctx)
+      : vctx_(vctx), state_(new VulkanStreamState()) {
+    // create command pool
+    VkCommandPoolCreateInfo cmd_pool_cinfo;
+    cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
+    cmd_pool_cinfo.pNext = nullptr;
+    cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
+    cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index;
+    VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_));
+
+    VkCommandBufferAllocateInfo buffer_alloc_info;
+    buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+    buffer_alloc_info.pNext = nullptr;
+    buffer_alloc_info.commandPool = cmd_pool_;
+    buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+    buffer_alloc_info.commandBufferCount = 1;
+    VULKAN_CALL(
+        vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_)));
+
+    VkFenceCreateInfo fence_cinfo;
+    fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
+    fence_cinfo.pNext = nullptr;
+    fence_cinfo.flags = 0;  // VK_FENCE_CREATE_SIGNALED_BIT;
+    VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_)));
+
+    VkCommandBufferBeginInfo cb_begin;
+    cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+    cb_begin.pNext = nullptr;
+    cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+    cb_begin.pInheritanceInfo = 0;
+    VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
+  }
+
+  ~VulkanStream() {
+    vkDestroyFence(vctx_->device, state_->fence_, nullptr);
+    vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr);
+  }
+
+  // Launch the kernel on the current stream.
+  void Launch(const std::function<void(VulkanStreamState*)>& kernel) {
+    if (vctx_->UseImmediate()) {
+      kernel(state_.get());
+    } else {
+      deferred_kernels_.push_back(kernel);
+    }
+  }
+
+  // Launch the kernel on the current stream,
+  void LaunchDeferred(const std::function<void()>& deferred_initializer,
+                      const std::function<void(VulkanStreamState*)>& deferred_kernel,
+                      const VulkanStreamToken& deferred_token) {
+    CHECK(!vctx_->UseImmediate());
+
+    // It is invalid to schedule this instance on the current stream if we already
+    // have a matching descriptor set and a non-matching buffer set.
+    if (std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(),
+                    [&](const VulkanStreamToken& token) {
+                      return token.descriptor_set_ == deferred_token.descriptor_set_ &&
+                             token.buffers_ != deferred_token.buffers_;
+                    })) {
+      Synchronize();
+    }
+
+    // It is unnecessary to invoke our initializer if we have a matching token.
+    if (!std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(),
+                     [&](const VulkanStreamToken& token) {
+                       // If we have a matching descriptor set
+                       return token.descriptor_set_ == deferred_token.descriptor_set_ &&
+                              token.buffers_ == deferred_token.buffers_;
+                     })) {
+      deferred_initializer();
+    }
+
+    deferred_kernels_.push_back(deferred_kernel);
+    deferred_tokens_.push_back(deferred_token);
+  }
+
+  // Synchronize the current stream `state_` with respect to the host.
+  void Synchronize() {
+    if (!vctx_->UseImmediate()) {
+      for (const auto& deferred_kernel : deferred_kernels_) {
+        deferred_kernel(state_.get());
+      }
+      deferred_kernels_.clear();
+      deferred_tokens_.clear();
+    } else {
+      DCHECK_EQ(deferred_kernels_.size(), 0);
+      DCHECK_EQ(deferred_tokens_.size(), 0);
+    }
+
+    VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_));
+    VkSubmitInfo cb_submit;
+    cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+    cb_submit.pNext = nullptr;
+    cb_submit.waitSemaphoreCount = 0;
+    cb_submit.pWaitSemaphores = nullptr;
+    cb_submit.pWaitDstStageMask = 0;
+    cb_submit.commandBufferCount = 1;
+    cb_submit.pCommandBuffers = &(state_->cmd_buffer_);
+    cb_submit.signalSemaphoreCount = 0;
+    cb_submit.pSignalSemaphores = nullptr;
+
+    {
+      // Multiple streams (on different threads) use the same VulkanContext
+      // instance, so we need to externally synchronize accesses.
+      std::lock_guard<std::mutex> g(*(vctx_->queue_mutex));
+      VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_));
+    }
+    uint64_t timeout = 1UL << 30UL;
+    VkResult res;
+    do {
+      res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout);
+    } while (res == VK_TIMEOUT);
+    VULKAN_CHECK_ERROR(res);
+    VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0));
+    VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_)));
+
+    // Re-initialize the command buffer
+    VkCommandBufferBeginInfo cb_begin;
+    cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+    cb_begin.pNext = nullptr;
+    cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+    cb_begin.pInheritanceInfo = 0;
+    VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
+  }
+
+ private:
+  const VulkanContext* vctx_;
+  std::unique_ptr<VulkanStreamState> state_;
+  std::vector<VulkanStreamToken> deferred_tokens_;
+  std::vector<std::function<void(VulkanStreamState*)>> deferred_kernels_;
+  VkCommandPool cmd_pool_;
+};
+
+}  // namespace vulkan
+}  // namespace runtime
+}  // namespace tvm
index 2d7edff..7c6de2e 100644 (file)
@@ -16,6 +16,7 @@
 # under the License.
 import tvm
 import re
+import numpy as np
 
 
 def test_vector_comparison():
@@ -54,5 +55,119 @@ def test_vector_comparison():
     check_correct_assembly('float16')
 
 
+tx = tvm.thread_axis("threadIdx.x")
+bx = tvm.thread_axis("blockIdx.x")
+
+
+def test_vulkan_copy():
+
+    def check_vulkan(dtype, n):
+        if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
+            print("skip because vulkan is not enabled..")
+            return
+        A = tvm.placeholder((n,), name='A', dtype=dtype)
+        ctx = tvm.vulkan(0)
+        a_np = np.random.uniform(size=(n,)).astype(A.dtype)
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np)
+        b_np = a.asnumpy()
+        tvm.testing.assert_allclose(a_np, b_np)
+        tvm.testing.assert_allclose(a_np, a.asnumpy())
+
+    for _ in range(100):
+        dtype = np.random.choice(["float32", "float16", "int8", "int32"])
+        logN = np.random.randint(1, 15)
+        peturb = np.random.uniform(low=0.5, high=1.5)
+        check_vulkan(dtype, int(peturb * (2 ** logN)))
+
+
+def test_vulkan_vectorize_add():
+    num_thread = 8
+
+    def check_vulkan(dtype, n, lanes):
+        if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
+            print("skip because vulkan is not enabled..")
+            return
+        A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
+        B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
+        s = tvm.create_schedule(B.op)
+        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
+        s[B].bind(xo, bx)
+        s[B].bind(xi, tx)
+        fun = tvm.build(s, [A, B], "vulkan")
+        ctx = tvm.vulkan(0)
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
+            np.random.uniform(size=(n, lanes)))
+        c = tvm.nd.empty((n,), B.dtype, ctx)
+        fun(a, c)
+        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
+
+    check_vulkan("float32", 64, 2)
+    check_vulkan("float16", 64, 2)
+
+
+def test_vulkan_stress():
+    """
+    Launch a randomized test with multiple kernels per stream, multiple uses of
+    kernels per stream, over multiple threads.
+    """
+    import random
+    import threading
+    n = 1024
+    num_thread = 64
+
+    def run_stress():
+        def worker():
+            if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
+                print("skip because vulkan is not enabled..")
+                return
+            A = tvm.placeholder((n,), name='A', dtype="float32")
+            B = tvm.placeholder((n,), name='B', dtype="float32")
+            functions = [
+                (lambda: tvm.compute((n,), lambda i: 2 * A[i] + 3 * B[i]),
+                 lambda a, b: 2 * a + 3 * b),
+                (lambda: tvm.compute((n,), lambda i: A[i]+B[i]),
+                 lambda a, b: a + b),
+                (lambda: tvm.compute((n,), lambda i: A[i]+2 * B[i]),
+                 lambda a, b: a + 2 * b),
+            ]
+
+            def build_f(f_ref):
+                (C_f, ref) = f_ref
+                C = C_f()
+                s = tvm.create_schedule(C.op)
+                xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
+                s[C].bind(xo, bx)
+                s[C].bind(xi, tx)
+                fun = tvm.build(s, [A, B, C], "vulkan")
+                return (fun, ref)
+
+            fs = [build_f(random.choice(functions))
+                  for _ in range(np.random.randint(low=1, high=10))]
+            ctx = tvm.vulkan(0)
+            a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
+                np.random.uniform(size=(n,)))
+            b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(
+                np.random.uniform(size=(n,)))
+            cs = [tvm.nd.empty((n,), A.dtype, ctx) for _ in fs]
+            for ((f, _), c) in zip(fs, cs):
+                f(a, b, c)
+
+            for ((_, ref), c) in zip(fs, cs):
+                tvm.testing.assert_allclose(
+                    c.asnumpy(), ref(a.asnumpy(), b.asnumpy()))
+
+        ts = [threading.Thread(target=worker)
+              for _ in range(np.random.randint(1, 10))]
+        for t in ts:
+            t.start()
+        for t in ts:
+            t.join()
+
+    run_stress()
+
+
 if __name__ == "__main__":
     test_vector_comparison()
+    test_vulkan_copy()
+    test_vulkan_vectorize_add()
+    test_vulkan_stress()