#endif
#ifdef TVM_VULKAN_RUNTIME
-#include "../src/runtime/vulkan/vulkan_device_api.cc"
-#include "../src/runtime/vulkan/vulkan_module.cc"
+#include "../src/runtime/vulkan/vulkan.cc"
#endif
#ifdef USE_SORT
# Be compatible with older version of CMake
find_vulkan(${USE_VULKAN})
+# Extra Vulkan runtime options, exposed for advanced users.
+tvm_option(USE_VULKAN_IMMEDIATE_MODE "Use Vulkan Immediate mode
+(KHR_push_descriptor extension)" ON IF USE_VULKAN)
+tvm_option(USE_VULKAN_DEDICATED_ALLOCATION "Use Vulkan dedicated allocations" ON
+IF USE_VULKAN)
+tvm_option(USE_VULKAN_VALIDATION "Enable Vulkan API validation layers" OFF
+ IF USE_VULKAN)
+
if(Vulkan_FOUND)
# always set the includedir
# avoid global retrigger of cmake
if(NOT Vulkan_FOUND)
message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN})
endif()
- message(STATUS "Build with VULKAN support")
- file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
+ message(STATUS "Build with Vulkan support")
+ file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/vulkan.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
-
list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
+
+ if(USE_VULKAN_IMMEDIATE_MODE)
+ message(STATUS "Build with Vulkan immediate mode")
+ add_definitions(-DUSE_VULKAN_IMMEDIATE_MODE=1)
+ endif()
+ if(USE_VULKAN_DEDICATED_ALLOCATION)
+ message(STATUS "Build with Vulkan dedicated allocation")
+ add_definitions(-DUSE_VULKAN_DEDICATED_ALLOCATION=1)
+ endif()
+ if(USE_VULKAN_VALIDATION)
+ message(STATUS "Build with Vulkan API validation")
+ add_definitions(-DUSE_VULKAN_VALIDATION=1)
+ endif()
endif(USE_VULKAN)
#include "codegen_spirv.h"
#include "../build_common.h"
+
+#include "../../runtime/vulkan/vulkan_shader.h"
#include "../../runtime/vulkan/vulkan_module.h"
namespace tvm {
void IRBuilder::InitHeader() {
CHECK_EQ(header_.size(), 0U);
header_.push_back(spv::MagicNumber);
- header_.push_back(spv::Version);
+ // Use SPIR-V v1.0. This needs to be kept in sync (or at least behind)
+ // `VkApplicationInfo.apiVersion` in `vulkan.cc` to ensure Vulkan API
+ // validation passes.
+ header_.push_back(0x10000);
// generator: set to 0, unknown
header_.push_back(0U);
// Bound: set during Finalize
--- /dev/null
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+
+## Components
+
+### VulkanDeviceAPI
+
+Implements the TVM DeviceAPI interface. Owns the core Vulkan datastructures. Is
+responsible for initializing the Vulkan instance and devices, querying for
+possible extensions.
+
+### VulkanThreadEntry
+
+Thread-local state for the Vulkan runtime. Maintains a staging buffer (for
+copies), and a VulkanStream per device.
+
+### VulkanWrappedFunc
+
+Responsible for launching computation kernels. Responsible for obtaining a
+VulkanPipeline instance (from the VulkanModuleNode), and launches the kernel
+(via immediate or deferred mode) on the active VulkanStream instance.
+
+## Stream execution in the Vulkan programming model.
+
+The natural model for TVM DeviceAPI implementation and runtime follows the CUDA
+API model. That is, we launch "kernels" onto a (implicit or explicit) "stream"
+(which execute asynchronously with respect to the host, but ordered with respect
+to the stream), and explicitly synchronize the stream with respect to the host.
+We simulate this behaviour in the Vulkan model by maintaining a thread-local
+`vkCommandBuffer` instance, and queueing up (or eagerly executing, depending on
+the availability of the `VK_KHR_push_descriptor` extension). When we synchronize
+the stream, we end the command buffer recording, submit it to the device queue,
+and wait on the corresponding fence.
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <vulkan/vulkan.h>
+#include <dmlc/memory_io.h>
+#include <dmlc/thread_local.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include <array>
+#include <cstring>
+
+
+#include "../file_util.h"
+#include "../pack_args.h"
+#include "../thread_storage_scope.h"
+#include "../workspace_pool.h"
+
+#include "vulkan_common.h"
+#include "vulkan_module.h"
+#include "vulkan_shader.h"
+#include "vulkan_stream.h"
+
+namespace tvm {
+namespace runtime {
+namespace vulkan {
+
+/*! \brief Maximum number of GPU supported in VulkanModule. */
+static constexpr const int kVulkanMaxNumDevice = 8;
+
+/*! \brief TVM Vulkan binary pack magic number */
+static constexpr const int kVulkanModuleMagic = 0x02700027;
+
+class VulkanThreadEntry {
+ public:
+ VulkanThreadEntry();
+ static VulkanThreadEntry* ThreadLocal();
+
+ ~VulkanThreadEntry() {
+ // Because the thread entry refers to Device API
+ // The command buffer always will be destroyed before
+ // the instance and device get destroyed.
+ // The destruction need to be manually called
+ // to ensure the destruction order.
+ streams_.clear();
+ for (const auto& kv : staging_buffers_) {
+ if (!kv.second) {
+ continue;
+ }
+ auto& buf = *(kv.second);
+ if (buf.host_addr != nullptr) {
+ vkUnmapMemory(buf.device, buf.memory);
+ }
+ if (buf.memory != VK_NULL_HANDLE) {
+ vkFreeMemory(buf.device, buf.memory, nullptr);
+ }
+ if (buf.buffer != VK_NULL_HANDLE) {
+ vkDestroyBuffer(buf.device, buf.buffer, nullptr);
+ }
+ }
+ }
+
+ TVMContext ctx;
+ WorkspacePool pool;
+ VulkanStream* Stream(size_t device_id);
+ VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
+
+ private:
+ std::unordered_map<size_t, std::unique_ptr<VulkanStream>> streams_;
+ std::unordered_map<size_t, std::unique_ptr<VulkanStagingBuffer>> staging_buffers_;
+};
+
+struct VulkanBuffer {
+ VkBuffer buffer{VK_NULL_HANDLE};
+ VkDeviceMemory memory{VK_NULL_HANDLE};
+};
+
+struct VulkanPipeline {
+ VulkanContext* vctx_{nullptr};
+ VkShaderModule shader{VK_NULL_HANDLE};
+ VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE};
+ VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};
+ VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
+ VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
+ VkPipeline pipeline{VK_NULL_HANDLE};
+ VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
+};
+
+typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
+
+class VulkanDeviceAPI final : public DeviceAPI {
+ public:
+ VulkanDeviceAPI();
+ ~VulkanDeviceAPI() {
+ for (auto& vctx : context_) {
+ vkDestroyDevice(vctx.device, nullptr);
+ }
+ if (instance_) {
+ vkDestroyInstance(instance_, nullptr);
+ }
+ }
+ void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
+ void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
+ void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
+ const auto& vctx = context(ctx.device_id);
+ VkBufferCreateInfo info;
+ info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+ info.pNext = nullptr;
+ info.flags = 0;
+ info.size = nbytes;
+ info.queueFamilyIndexCount = 1;
+ info.pQueueFamilyIndices = &(vctx.queue_family_index);
+ info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+ info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+ VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
+ // create buffer
+ VkBuffer buffer;
+ VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
+ // bind to memory
+ VkBufferMemoryRequirementsInfo2KHR req_info2;
+ req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
+ req_info2.pNext = 0;
+ req_info2.buffer = buffer;
+
+ VkMemoryRequirements2KHR req2;
+ req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
+ req2.pNext = 0;
+
+ VkMemoryDedicatedRequirementsKHR dedicated_req;
+ dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
+ dedicated_req.pNext = 0;
+ req2.pNext = &dedicated_req;
+
+ bool dedicated_allocation = false;
+ if (vctx.get_buffer_memory_requirements_2_functions) {
+ vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
+ vctx.device, &req_info2, &req2);
+ dedicated_allocation =
+ dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
+ }
+
+ VkDeviceMemory memory;
+ if (!dedicated_allocation) {
+ VkMemoryAllocateInfo minfo;
+ minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+ minfo.pNext = nullptr;
+ minfo.allocationSize = nbytes;
+ minfo.memoryTypeIndex = vctx.compute_mtype_index;
+ VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+ } else {
+ VkMemoryAllocateInfo minfo;
+ minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+ minfo.pNext = nullptr;
+ minfo.allocationSize = req2.memoryRequirements.size;
+ minfo.memoryTypeIndex = vctx.compute_mtype_index;
+
+ VkMemoryDedicatedAllocateInfoKHR mdinfo;
+ mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
+ mdinfo.pNext = 0;
+ mdinfo.image = 0;
+ mdinfo.buffer = buffer;
+ minfo.pNext = &mdinfo;
+ VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+ }
+ VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
+ VulkanBuffer* pbuf = new VulkanBuffer();
+ pbuf->memory = memory;
+ pbuf->buffer = buffer;
+ return pbuf;
+ }
+
+ void FreeDataSpace(TVMContext ctx, void* ptr) final {
+ const auto& vctx = context(ctx.device_id);
+ auto* pbuf = static_cast<VulkanBuffer*>(ptr);
+ vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
+ vkFreeMemory(vctx.device, pbuf->memory, nullptr);
+ delete pbuf;
+ }
+
+ void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
+ TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
+ TVMStreamHandle stream) final {
+ CHECK(stream == nullptr);
+ TVMContext ctx = ctx_from;
+ if (ctx_from.device_type == kDLCPU) {
+ ctx = ctx_to;
+ }
+
+ int from_dev_type = static_cast<int>(ctx_from.device_type);
+ int to_dev_type = static_cast<int>(ctx_to.device_type);
+ if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
+ VulkanThreadEntry::ThreadLocal()
+ ->Stream(ctx_from.device_id)
+ ->Launch([=](VulkanStreamState* state) {
+ // 1: copy
+ const auto* from_buf = static_cast<const VulkanBuffer*>(from);
+ auto* to_buf = static_cast<VulkanBuffer*>(to);
+ VkBufferCopy copy_info;
+ copy_info.srcOffset = from_offset;
+ copy_info.dstOffset = to_offset;
+ copy_info.size = size;
+ vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info);
+ // 2: barrier(transfer-> compute|transfer)
+ CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Vulkan disallow cross device copy.";
+ VkMemoryBarrier barrier_info;
+ barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+ barrier_info.pNext = nullptr;
+ barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
+ barrier_info.dstAccessMask =
+ (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
+ VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
+ vkCmdPipelineBarrier(
+ state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT,
+ VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1,
+ &barrier_info, 0, nullptr, 0, nullptr);
+ });
+
+ } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
+ const auto* from_buf = static_cast<const VulkanBuffer*>(from);
+ const auto& vctx = context(ctx_from.device_id);
+ auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_from.device_id, size);
+ VulkanThreadEntry::ThreadLocal()
+ ->Stream(ctx_from.device_id)
+ ->Launch([&](VulkanStreamState* state) {
+ VkBufferCopy copy_info;
+ copy_info.srcOffset = from_offset;
+ copy_info.dstOffset = 0;
+ copy_info.size = size;
+ vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->buffer, 1, ©_info);
+ });
+ VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
+ if (!vctx.coherent_staging) {
+ VkMappedMemoryRange mrange;
+ mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
+ mrange.pNext = nullptr;
+ mrange.memory = temp->memory;
+ mrange.offset = 0;
+ mrange.size = VK_WHOLE_SIZE; // size;
+ VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange));
+ }
+ memcpy(static_cast<char*>(to) + to_offset, static_cast<char*>(temp->host_addr), size);
+ } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
+ const auto& vctx = context(ctx_to.device_id);
+ const auto* to_buf = static_cast<const VulkanBuffer*>(to);
+ VulkanStagingBuffer* temp =
+ VulkanThreadEntry::ThreadLocal()->StagingBuffer(ctx_to.device_id, size);
+ memcpy(temp->host_addr, static_cast<const char*>(from) + from_offset, size);
+ // host side flush if access is not coherent.
+ // so writes from CPU is visible to GPU
+ if (!vctx.coherent_staging) {
+ VkMappedMemoryRange mrange;
+ mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
+ mrange.pNext = nullptr;
+ mrange.memory = temp->memory;
+ mrange.offset = 0;
+ mrange.size = VK_WHOLE_SIZE; // size;
+ VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
+ }
+
+ VulkanThreadEntry::ThreadLocal()
+ ->Stream(ctx_from.device_id)
+ ->Launch([&](VulkanStreamState* state) {
+ // 0: barrier(host->transfer)
+ VkMemoryBarrier barrier_info;
+ barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+ barrier_info.pNext = nullptr;
+ barrier_info.srcAccessMask = 0;
+ barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
+ vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT,
+ VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0,
+ nullptr);
+ // 1: copy
+ VkBufferCopy copy_info;
+ copy_info.srcOffset = 0;
+ copy_info.dstOffset = to_offset;
+ copy_info.size = size;
+ vkCmdCopyBuffer(state->cmd_buffer_, temp->buffer, to_buf->buffer, 1, ©_info);
+ });
+ // TODO(tulloch): should we instead make the staging buffer a property of the
+ // Stream? This would allow us to elide synchronizations here.
+ VulkanThreadEntry::ThreadLocal()->Stream(ctx_from.device_id)->Synchronize();
+ } else {
+ LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan"
+ << ", from=" << from_dev_type << ", to=" << to_dev_type;
+ }
+ }
+
+ // Always use the default stream
+ TVMStreamHandle CreateStream(TVMContext ctx) {
+ LOG(FATAL) << "Not implemented";
+ return nullptr;
+ }
+
+ void FreeStream(TVMContext ctx, TVMStreamHandle stream) {
+ LOG(FATAL) << "Not implemented";
+ return;
+ }
+
+ void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
+ LOG(FATAL) << "Not implemented";
+ return;
+ }
+
+ void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
+ CHECK(stream == nullptr);
+ VulkanThreadEntry::ThreadLocal()->Stream(ctx.device_id)->Synchronize();
+ }
+
+ void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
+ LOG(FATAL) << "Not implemented";
+ return;
+ }
+
+ void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
+ return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
+ }
+
+ void FreeWorkspace(TVMContext ctx, void* data) final {
+ VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
+ }
+
+ static const std::shared_ptr<VulkanDeviceAPI>& Global() {
+ static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
+ return inst;
+ }
+
+ const VulkanContext& context(size_t device_id) const {
+ CHECK_LT(device_id, context_.size());
+ return context_[device_id];
+ }
+
+ private:
+ VkInstance instance_{nullptr};
+ // The physical devices, have 1 to 1 mapping to devices
+ std::vector<VulkanContext> context_;
+};
+
+void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
+ size_t index = static_cast<size_t>(ctx.device_id);
+ if (kind == kExist) {
+ *rv = static_cast<int>(index < context_.size());
+ return;
+ }
+ CHECK_LT(index, context_.size()) << "Invalid device id " << index;
+ const auto& vctx = context(index);
+ switch (kind) {
+ case kMaxThreadsPerBlock: {
+ VkPhysicalDeviceProperties phy_prop;
+ vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
+ int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
+ *rv = value;
+ break;
+ }
+ case kMaxSharedMemoryPerBlock: {
+ VkPhysicalDeviceProperties phy_prop;
+ vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
+ int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
+ *rv = value;
+ break;
+ }
+ case kWarpSize: {
+ *rv = 1;
+ break;
+ }
+ case kComputeVersion: {
+ VkPhysicalDeviceProperties phy_prop;
+ vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop);
+ int64_t value = phy_prop.apiVersion;
+ std::ostringstream os;
+ os << VK_VERSION_MAJOR(value) << "." << VK_VERSION_MINOR(value) << "."
+ << VK_VERSION_PATCH(value);
+ *rv = os.str();
+ break;
+ }
+ case kDeviceName:
+ return;
+ case kMaxClockRate:
+ return;
+ case kMultiProcessorCount:
+ return;
+ case kExist:
+ break;
+ case kMaxThreadDimensions:
+ break;
+ }
+}
+
+VulkanDeviceAPI::VulkanDeviceAPI() {
+ VkApplicationInfo app_info;
+ app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+ app_info.pNext = nullptr;
+ app_info.pApplicationName = "TVM";
+ app_info.applicationVersion = 0;
+ app_info.pEngineName = "";
+ app_info.engineVersion = 0;
+ app_info.apiVersion = VK_MAKE_VERSION(1, 0, 0);
+
+ VkInstanceCreateInfo inst_info;
+ inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+ inst_info.pNext = nullptr;
+ inst_info.flags = 0;
+
+ const auto layers = []() -> std::vector<const char*> {
+ uint32_t inst_layer_prop_count;
+ VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr));
+ std::vector<VkLayerProperties> inst_layer_prop(inst_layer_prop_count);
+ VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data()));
+ std::vector<const char*> l;
+ for (const auto& lp : inst_layer_prop) {
+ // TODO(tulloch): add CMAKE options.
+ (void)lp; // suppress unused variable warning.
+#ifdef USE_VULKAN_VALIDATION
+ if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) {
+ l.push_back("VK_LAYER_LUNARG_standard_validation");
+ }
+ if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) {
+ l.push_back("VK_LAYER_LUNARG_parameter_validation");
+ }
+ if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) {
+ l.push_back("VK_LAYER_KHRONOS_validation");
+ }
+#endif
+ }
+ return l;
+ }();
+
+ const auto instance_extensions = []() -> std::vector<const char*> {
+ uint32_t inst_extension_prop_count;
+ VULKAN_CALL(
+ vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
+ std::vector<VkExtensionProperties> inst_extension_prop(inst_extension_prop_count);
+ VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count,
+ inst_extension_prop.data()));
+ std::vector<const char*> extensions;
+ for (const auto& ip : inst_extension_prop) {
+ if (std::strcmp(ip.extensionName, "VK_KHR_get_physical_device_properties2") == 0) {
+ extensions.push_back("VK_KHR_get_physical_device_properties2");
+ }
+ }
+ return extensions;
+ }();
+
+ inst_info.pApplicationInfo = &app_info;
+ inst_info.enabledLayerCount = layers.size();
+ inst_info.ppEnabledLayerNames = layers.data();
+ inst_info.enabledExtensionCount = instance_extensions.size();
+ inst_info.ppEnabledExtensionNames = instance_extensions.data();
+
+ VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_));
+
+ uint32_t phy_dev_count = 0;
+ VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr));
+ std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
+ VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
+ for (VkPhysicalDevice phy_dev : all_phy_devs) {
+ uint32_t queue_prop_count = 0;
+ vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr);
+ std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
+ vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count,
+ dmlc::BeginPtr(queue_props));
+ uint32_t queue_family_index = 0;
+ std::vector<VkDeviceQueueCreateInfo> queue_create_info;
+ float priority = 1.0f;
+ for (uint32_t i = 0; i < queue_props.size(); i++) {
+ // find queues that support compute
+ if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
+ VkDeviceQueueCreateInfo info;
+ info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+ info.pNext = nullptr;
+ info.flags = 0;
+ info.queueFamilyIndex = i;
+ info.queueCount = 1;
+ info.pQueuePriorities = &priority;
+
+ queue_create_info.push_back(info);
+ // only use the first available queue for now
+ if (queue_create_info.size() == 0) {
+ queue_family_index = i;
+ }
+ }
+ }
+ if (queue_create_info.size() == 0) continue;
+
+ VulkanContext ctx;
+ // setup context
+ ctx.phy_device = phy_dev;
+ vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
+
+ const auto extensions = [&]() {
+ uint32_t device_extension_prop_count;
+ VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr,
+ &device_extension_prop_count, nullptr));
+ std::vector<VkExtensionProperties> device_extension_prop(device_extension_prop_count);
+ VULKAN_CALL(vkEnumerateDeviceExtensionProperties(
+ ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data()));
+ std::vector<const char*> extensions;
+ for (const auto& dp : device_extension_prop) {
+ if ((std::strcmp(dp.extensionName, "VK_KHR_push_descriptor") == 0) && dp.specVersion > 0) {
+ extensions.push_back("VK_KHR_push_descriptor");
+ }
+ if ((std::strcmp(dp.extensionName, "VK_KHR_descriptor_update_template") == 0) &&
+ dp.specVersion > 0) {
+ extensions.push_back("VK_KHR_descriptor_update_template");
+ }
+ if ((std::strcmp(dp.extensionName, "VK_KHR_get_memory_requirements2") == 0) &&
+ dp.specVersion > 0) {
+ extensions.push_back("VK_KHR_get_memory_requirements2");
+ }
+ if ((std::strcmp(dp.extensionName, "VK_KHR_dedicated_allocation") == 0) &&
+ dp.specVersion > 0) {
+ extensions.push_back("VK_KHR_dedicated_allocation");
+ }
+ }
+ return extensions;
+ }();
+ VkDeviceCreateInfo device_create_info;
+ device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
+ device_create_info.pNext = nullptr;
+ device_create_info.flags = 0;
+ device_create_info.queueCreateInfoCount = static_cast<uint32_t>(queue_create_info.size());
+ device_create_info.pQueueCreateInfos = queue_create_info.data();
+ device_create_info.enabledLayerCount = 0;
+ device_create_info.ppEnabledLayerNames = nullptr;
+ device_create_info.enabledExtensionCount = extensions.size();
+ device_create_info.ppEnabledExtensionNames = extensions.data();
+ device_create_info.pEnabledFeatures = nullptr;
+ VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
+ ctx.queue_mutex.reset(new std::mutex());
+ vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
+ ctx.queue_family_index = queue_family_index;
+ // Find suitable memory type for staging and compute
+ // Find suitable compute index.
+ VkBuffer buffer;
+ VkMemoryRequirements req_staging, req_compute;
+ VkBufferCreateInfo info;
+ info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+ info.pNext = nullptr;
+ info.flags = 0;
+ info.size = 1024;
+ info.queueFamilyIndexCount = 1;
+ info.pQueueFamilyIndices = &(ctx.queue_family_index);
+ info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+
+ // get staging requirement
+ info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
+ VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
+ vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging);
+ vkDestroyBuffer(ctx.device, buffer, nullptr);
+ // get compute requirement
+ info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+ VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
+ VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer));
+ vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute);
+ vkDestroyBuffer(ctx.device, buffer, nullptr);
+
+ // Query phyiscal device property
+ // find a memory that is host visible, no need to be consistent
+ int win_rank = -1;
+ VkPhysicalDeviceMemoryProperties prop;
+ vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop);
+
+ for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
+ VkMemoryType ty = prop.memoryTypes[k];
+ size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+ // host visible
+ if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
+ // match copy requirment
+ if (!(req_staging.memoryTypeBits & (1 << k))) continue;
+ if (heap_size < 1024) continue;
+ int rank = 0;
+ rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
+ if (rank > win_rank) {
+ win_rank = rank;
+ ctx.staging_mtype_index = k;
+ ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
+ }
+ }
+ CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
+
+ win_rank = -1;
+ for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
+ VkMemoryType ty = prop.memoryTypes[k];
+ size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
+ // host visible
+ if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
+ // match copy requirment
+ if (!(req_staging.memoryTypeBits & (1 << k))) continue;
+ if (heap_size < 1024) continue;
+ int rank = 0;
+ // prefer not host visible
+ rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
+ if (rank > win_rank) {
+ win_rank = rank;
+ ctx.compute_mtype_index = k;
+ }
+ }
+ CHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device.";
+ auto has_extension = [&extensions](const char* query) {
+ return std::any_of(extensions.begin(), extensions.end(),
+ [&](const char* extension) { return std::strcmp(query, extension) == 0; });
+ };
+
+#ifdef USE_VULKAN_IMMEDIATE_MODE
+ if (has_extension("VK_KHR_push_descriptor") &&
+ has_extension("VK_KHR_descriptor_update_template")) {
+ ctx.descriptor_template_khr_functions =
+ std::unique_ptr<VulkanDescriptorTemplateKHRFunctions>(
+ new VulkanDescriptorTemplateKHRFunctions());
+ ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR =
+ CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
+ ctx.device, "vkCreateDescriptorUpdateTemplateKHR"));
+ ctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR =
+ CHECK_NOTNULL((PFN_vkDestroyDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr(
+ ctx.device, "vkDestroyDescriptorUpdateTemplateKHR"));
+ ctx.descriptor_template_khr_functions->vkUpdateDescriptorSetWithTemplateKHR =
+ CHECK_NOTNULL((PFN_vkUpdateDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
+ ctx.device, "vkUpdateDescriptorSetWithTemplateKHR"));
+ ctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR =
+ CHECK_NOTNULL((PFN_vkCmdPushDescriptorSetWithTemplateKHR)vkGetDeviceProcAddr(
+ ctx.device, "vkCmdPushDescriptorSetWithTemplateKHR"));
+ }
+#endif
+
+#ifdef USE_VULKAN_DEDICATED_ALLOCATION
+ if (has_extension("VK_KHR_get_memory_requirements2") &&
+ has_extension("VK_KHR_dedicated_allocation")) {
+ ctx.get_buffer_memory_requirements_2_functions =
+ std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>(
+ new VulkanGetBufferMemoryRequirements2Functions());
+ ctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR =
+ CHECK_NOTNULL((PFN_vkGetBufferMemoryRequirements2KHR)vkGetDeviceProcAddr(
+ ctx.device, "vkGetBufferMemoryRequirements2KHR"));
+ }
+#endif
+ context_.push_back(std::move(ctx));
+ }
+
+ LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
+ for (size_t i = 0; i < context_.size(); ++i) {
+ LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName
+ << "\' phy_dev_id=" << context_[i].phy_device
+ << " use_immediate=" << context_[i].UseImmediate();
+ }
+} // namespace vulkan
+class VulkanModuleNode;
+
+// a wrapped function class to get packed func.
+class VulkanWrappedFunc {
+ public:
+ void Init(VulkanModuleNode* m, std::shared_ptr<ModuleNode> sptr, const std::string& func_name,
+ size_t num_buffer_args, size_t num_pack_args,
+ const std::vector<std::string>& thread_axis_tags) {
+ m_ = m;
+ sptr_ = sptr;
+ func_name_ = func_name;
+ num_buffer_args_ = num_buffer_args;
+ num_pack_args_ = num_pack_args;
+ thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
+ }
+
+ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
+
+ private:
+ // internal module
+ VulkanModuleNode* m_;
+ // the resource holder
+ std::shared_ptr<ModuleNode> sptr_;
+ // v The name of the function.
+ std::string func_name_;
+ // Number of buffer arguments
+ size_t num_buffer_args_;
+ // number of packed arguments.
+ size_t num_pack_args_;
+ // Device state cache per device.
+ // mark as mutable, to enable lazy initialization
+ // thread axis configuration
+ ThreadAxisConfig thread_axis_cfg_;
+
+ mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> scache_;
+};
+
+// Multi-device enabled module.
+class VulkanModuleNode final : public runtime::ModuleNode {
+ public:
+ explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
+ : smap_(smap), fmap_(fmap), source_(source) {}
+
+ const char* type_key() const final { return "vulkan"; }
+
+ PackedFunc GetFunction(const std::string& name,
+ const std::shared_ptr<ModuleNode>& sptr_to_self) final {
+ CHECK_EQ(sptr_to_self.get(), this);
+ CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
+ auto it = fmap_.find(name);
+ if (it == fmap_.end()) return PackedFunc();
+ const FunctionInfo& info = it->second;
+ VulkanWrappedFunc f;
+ size_t num_buffer_args = NumBufferArgs(info.arg_types);
+ f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
+ info.thread_axis_tags);
+ return PackFuncNonBufferArg(std::move(f), info.arg_types);
+ }
+
+ ~VulkanModuleNode() {
+ // cleanup vulkan related caches.
+ for (int device_id = 0; device_id < ecache_.size(); ++device_id) {
+ for (auto& kv : ecache_[device_id]) {
+ auto& pe = kv.second;
+ CHECK(pe);
+ const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+
+ if (pe->descriptor_update_template != VK_NULL_HANDLE) {
+ vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR(
+ vctx.device, pe->descriptor_update_template, nullptr);
+ }
+ vkDestroyPipeline(vctx.device, pe->pipeline, nullptr);
+ vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr);
+ vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
+ vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
+ vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
+ }
+ }
+ }
+
+ std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
+ size_t num_pack_args) {
+ const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+ std::lock_guard<std::mutex> lock(mutex_);
+ const auto& cp = ecache_[device_id][func_name];
+ if (cp) {
+ return cp;
+ }
+ // Create new pipeline
+ auto pe = std::shared_ptr<VulkanPipeline>(new VulkanPipeline());
+ {
+ // create shader
+ auto sit = smap_.find(func_name);
+ CHECK(sit != smap_.end());
+ const std::vector<uint32_t>& data = sit->second.data;
+ VkShaderModuleCreateInfo shader_cinfo;
+ shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
+ shader_cinfo.pNext = nullptr;
+ shader_cinfo.flags = 0;
+ shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
+ shader_cinfo.pCode = data.data();
+ VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader)));
+ }
+ std::vector<VkDescriptorSetLayoutBinding> arg_binding;
+ std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
+ uint32_t num_pod = 0, num_buffer = 0;
+ {
+ auto fit = fmap_.find(func_name);
+ CHECK(fit != fmap_.end());
+ for (TVMType arg_type : fit->second.arg_types) {
+ if (arg_type.code == kHandle) {
+ {
+ VkDescriptorSetLayoutBinding bd;
+ bd.binding = num_buffer;
+ bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ bd.descriptorCount = 1;
+ bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+ bd.pImmutableSamplers = nullptr;
+ arg_binding.push_back(bd);
+ }
+ {
+ VkDescriptorUpdateTemplateEntryKHR tpl;
+ tpl.dstBinding = num_buffer;
+ tpl.dstArrayElement = 0;
+ tpl.descriptorCount = 1;
+ tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
+ tpl.stride = sizeof(VkDescriptorBufferInfo);
+ arg_template.push_back(tpl);
+ }
+ ++num_buffer;
+ } else {
+ ++num_pod;
+ }
+ }
+ }
+
+ {
+ VkDescriptorSetLayoutCreateInfo descrip_cinfo;
+ descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
+ descrip_cinfo.pNext = nullptr;
+ descrip_cinfo.flags = 0;
+ if (vctx.UseImmediate()) {
+ descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
+ }
+ descrip_cinfo.bindingCount = arg_binding.size();
+ descrip_cinfo.pBindings = arg_binding.data();
+ VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr,
+ &(pe->descriptor_set_layout)));
+ }
+
+ {
+ VkDescriptorPoolSize pool_size;
+ pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ pool_size.descriptorCount = arg_binding.size();
+ VkDescriptorPoolCreateInfo descrip_pool_cinfo;
+ descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
+ descrip_pool_cinfo.pNext = nullptr;
+ descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
+ descrip_pool_cinfo.maxSets = 1;
+ descrip_pool_cinfo.poolSizeCount = 1;
+ descrip_pool_cinfo.pPoolSizes = &pool_size;
+ VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
+ &(pe->descriptor_pool)));
+ }
+
+ if (!vctx.UseImmediate()) {
+ VkDescriptorSetAllocateInfo alloc_info;
+ alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
+ alloc_info.pNext = nullptr;
+ alloc_info.descriptorPool = pe->descriptor_pool;
+ alloc_info.descriptorSetCount = 1;
+ alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
+ VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set)));
+ }
+
+ VkPushConstantRange crange;
+ crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+ crange.offset = 0;
+ crange.size = sizeof(ArgUnion) * num_pack_args;
+
+ VkPipelineLayoutCreateInfo playout_cinfo;
+ playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
+ playout_cinfo.pNext = nullptr;
+ playout_cinfo.flags = 0;
+ playout_cinfo.setLayoutCount = 1;
+ playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
+
+ if (num_pack_args != 0) {
+ playout_cinfo.pushConstantRangeCount = 1;
+ playout_cinfo.pPushConstantRanges = &crange;
+ CHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
+ } else {
+ playout_cinfo.pushConstantRangeCount = 0;
+ playout_cinfo.pPushConstantRanges = nullptr;
+ }
+
+ VULKAN_CALL(
+ vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
+
+ VkComputePipelineCreateInfo pipeline_cinfo;
+ pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
+ pipeline_cinfo.pNext = nullptr;
+ pipeline_cinfo.flags = 0;
+ pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
+ pipeline_cinfo.stage.pNext = nullptr;
+ pipeline_cinfo.stage.flags = 0;
+ pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
+ pipeline_cinfo.stage.module = pe->shader;
+ pipeline_cinfo.stage.pName = func_name.c_str();
+ pipeline_cinfo.stage.pSpecializationInfo = nullptr;
+ pipeline_cinfo.layout = pe->pipeline_layout;
+ pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
+ pipeline_cinfo.basePipelineIndex = 0;
+ VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
+ &(pe->pipeline)));
+
+ if (vctx.UseImmediate()) {
+ VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
+ descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
+ descrip_template_cinfo.pNext = 0;
+ descrip_template_cinfo.flags = 0;
+ descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
+ descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
+ descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
+ descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
+ descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
+ descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
+ descrip_template_cinfo.set = 0;
+ VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
+ vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template)));
+ }
+ ecache_[device_id][func_name] = pe;
+ return pe;
+ }
+
+ void SaveToFile(const std::string& file_name, const std::string& format) final {
+ std::string fmt = GetFileFormat(file_name, format);
+ CHECK_EQ(fmt, fmt_) << "Can only save to customized format vulkan";
+ std::string meta_file = GetMetaFilePath(file_name);
+ SaveMetaDataToFile(meta_file, fmap_);
+ std::string data_bin;
+ dmlc::MemoryStringStream fs(&data_bin);
+ dmlc::Stream* stream = &fs;
+ uint32_t magic = kVulkanModuleMagic;
+ stream->Write(magic);
+ stream->Write(smap_);
+ SaveBinaryToFile(file_name, data_bin);
+ }
+
+ void SaveToBinary(dmlc::Stream* stream) final {
+ stream->Write(fmt_);
+ stream->Write(fmap_);
+ stream->Write(smap_);
+ }
+ std::string GetSource(const std::string& format) final {
+ // can only return source code.
+ return source_;
+ }
+
+ private:
+ // the binary data
+ std::vector<uint32_t> data_;
+ // function information table.
+ std::unordered_map<std::string, VulkanShader> smap_;
+ // function information table.
+ std::unordered_map<std::string, FunctionInfo> fmap_;
+ // The format
+ std::string fmt_{"vulkan"};
+ // The source
+ std::string source_;
+
+ // Guards accesses to `ecache_`
+ std::mutex mutex_;
+ std::array<std::unordered_map<std::string, std::shared_ptr<VulkanPipeline>>, kVulkanMaxNumDevice>
+ ecache_;
+};
+
+Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
+ std::shared_ptr<VulkanModuleNode> n = std::make_shared<VulkanModuleNode>(smap, fmap, source);
+ return Module(n);
+}
+
+VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); }
+
+VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
+ if (!staging_buffers_[device_id]) {
+ staging_buffers_[device_id] = std::unique_ptr<VulkanStagingBuffer>(new VulkanStagingBuffer());
+ }
+ auto& buf = *(staging_buffers_[device_id]);
+ if (buf.device != nullptr && buf.size < size) {
+ // free previous buffer
+ if (buf.host_addr != nullptr) {
+ vkUnmapMemory(buf.device, buf.memory);
+ }
+ if (buf.memory != VK_NULL_HANDLE) {
+ vkFreeMemory(buf.device, buf.memory, nullptr);
+ }
+ if (buf.buffer != VK_NULL_HANDLE) {
+ vkDestroyBuffer(buf.device, buf.buffer, nullptr);
+ }
+ buf.host_addr = nullptr;
+ buf.memory = VK_NULL_HANDLE;
+ buf.buffer = VK_NULL_HANDLE;
+ }
+ const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+
+ if (buf.device == nullptr) {
+ buf.device = vctx.device;
+ }
+ if (buf.memory == VK_NULL_HANDLE) {
+ // allocate the stagging buffer memory if necessary
+ VkBufferCreateInfo info;
+ info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+ info.pNext = nullptr;
+ info.flags = 0;
+ info.size = size;
+ info.queueFamilyIndexCount = 1;
+ info.pQueueFamilyIndices = &(vctx.queue_family_index);
+ info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;
+ info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+ VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
+ VkMemoryAllocateInfo minfo;
+ minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+ minfo.pNext = nullptr;
+ minfo.allocationSize = size;
+ minfo.memoryTypeIndex = vctx.staging_mtype_index;
+ VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
+ VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
+ VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
+ buf.size = size;
+ }
+ memset(buf.host_addr, 0, size);
+ return &buf;
+}
+
+VulkanThreadEntry::VulkanThreadEntry()
+ : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
+ ctx.device_id = 0;
+ ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
+}
+
+VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
+ if (!streams_[device_id]) {
+ streams_[device_id] = std::unique_ptr<VulkanStream>(
+ new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id)));
+ }
+ return streams_[device_id].get();
+}
+
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
+ const ArgUnion* pack_args) const {
+ int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
+ CHECK_LT(device_id, kVulkanMaxNumDevice);
+ const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+ if (!scache_[device_id]) {
+ scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
+ }
+ const auto& pipeline = scache_[device_id];
+ ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+ std::vector<VkDescriptorBufferInfo> descriptor_buffers;
+ descriptor_buffers.resize(num_buffer_args_);
+ for (int i = 0; i < num_buffer_args_; ++i) {
+ void* buf = args[static_cast<int>(i)];
+ VkDescriptorBufferInfo binfo;
+ binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
+ binfo.offset = 0;
+ binfo.range = VK_WHOLE_SIZE;
+ descriptor_buffers[i] = binfo;
+ }
+ if (vctx.UseImmediate()) {
+ // Can safely capture by reference as this lambda is immediately executed on the calling thread.
+ VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
+ vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
+ CHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
+ vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
+ state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
+ descriptor_buffers.data());
+ if (num_pack_args_ != 0) {
+ vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
+ VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
+ pack_args);
+ }
+ vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
+ VkMemoryBarrier barrier_info;
+ barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+ barrier_info.pNext = nullptr;
+ barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
+ barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
+ VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
+ vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
+ VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
+ 1, &barrier_info, 0, nullptr, 0, nullptr);
+ });
+ return;
+ }
+
+ // Otherwise, the more expensive deferred path.
+ std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
+ const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
+ std::vector<VkWriteDescriptorSet> write_descriptor_sets;
+ write_descriptor_sets.resize(descriptor_buffers.size());
+ for (int i = 0; i < write_descriptor_sets.size(); i++) {
+ write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
+ write_descriptor_sets[i].pNext = 0;
+ write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
+ write_descriptor_sets[i].dstBinding = i;
+ write_descriptor_sets[i].dstArrayElement = 0;
+ write_descriptor_sets[i].descriptorCount = 1;
+ write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ write_descriptor_sets[i].pImageInfo = 0;
+ write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
+ write_descriptor_sets[i].pTexelBufferView = 0;
+ }
+ vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
+ 0, 0);
+ };
+ const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
+ vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
+ vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
+ pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
+ nullptr);
+ if (pack_args_storage.size() != 0) {
+ vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
+ 0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
+ }
+ vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
+ VkMemoryBarrier barrier_info;
+ barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
+ barrier_info.pNext = nullptr;
+ barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
+ barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
+ VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
+ vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
+ VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
+ 1, &barrier_info, 0, nullptr, 0, nullptr);
+ };
+ VulkanStreamToken deferred_token;
+ deferred_token.descriptor_set_ = pipeline->descriptor_set;
+ deferred_token.buffers_.resize(descriptor_buffers.size());
+ for (int i = 0; i < descriptor_buffers.size(); ++i) {
+ deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
+ }
+ VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred(
+ deferred_initializer, deferred_kernel, deferred_token);
+}
+
+Module VulkanModuleLoadFile(const std::string& file_name, const std::string& format) {
+ std::string data;
+ std::unordered_map<std::string, VulkanShader> smap;
+ std::unordered_map<std::string, FunctionInfo> fmap;
+ std::string fmt = GetFileFormat(file_name, format);
+ std::string meta_file = GetMetaFilePath(file_name);
+ LoadBinaryFromFile(file_name, &data);
+ LoadMetaDataFromFile(meta_file, &fmap);
+ dmlc::MemoryStringStream fs(&data);
+ dmlc::Stream* stream = &fs;
+ uint32_t magic;
+ stream->Read(&magic);
+ CHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch";
+ stream->Read(&smap);
+ return VulkanModuleCreate(smap, fmap, "");
+}
+
+Module VulkanModuleLoadBinary(void* strm) {
+ dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+ std::unordered_map<std::string, VulkanShader> smap;
+ std::unordered_map<std::string, FunctionInfo> fmap;
+
+ std::string fmt;
+ stream->Read(&fmt);
+ stream->Read(&fmap);
+ stream->Read(&smap);
+ return VulkanModuleCreate(smap, fmap, "");
+}
+
+TVM_REGISTER_GLOBAL("module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
+
+TVM_REGISTER_GLOBAL("module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
+
+TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
+ DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
+ *rv = static_cast<void*>(ptr);
+});
+
+} // namespace vulkan
+} // namespace runtime
+} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* specific language governing permissions and limitations
* under the License.
*/
+#pragma once
-/*!
- * Copyright (c) 2017 by Contributors
- * \file vulkan_common.h
- * \brief Vulkan common header
- */
-#ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
-#define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
-
+#include <dmlc/logging.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/device_api.h>
-#include <dmlc/logging.h>
+#include <tvm/runtime/packed_func.h>
#include <vulkan/vulkan.h>
+#include <memory>
#include <mutex>
#include <string>
#include <vector>
-#include <memory>
-#include "../workspace_pool.h"
namespace tvm {
namespace runtime {
inline const char* VKGetErrorString(VkResult error) {
switch (error) {
- case VK_SUCCESS: return "VK_SUCCESS";
- case VK_NOT_READY: return "VK_NOT_READY";
- case VK_TIMEOUT: return "VK_TIMEOUT";
- case VK_EVENT_SET: return "VK_EVENT_SET";
- case VK_EVENT_RESET: return "VK_EVENT_RESET";
- case VK_INCOMPLETE: return "VK_INCOMPLETE";
- case VK_ERROR_OUT_OF_HOST_MEMORY: return "VK_ERROR_OUT_OF_HOST_MEMORY";
- case VK_ERROR_OUT_OF_DEVICE_MEMORY: return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
- case VK_ERROR_INITIALIZATION_FAILED: return "VK_ERROR_INITIALIZATION_FAILED";
- case VK_ERROR_DEVICE_LOST: return "VK_ERROR_DEVICE_LOST";
- case VK_ERROR_MEMORY_MAP_FAILED: return "VK_ERROR_MEMORY_MAP_FAILED";
- case VK_ERROR_LAYER_NOT_PRESENT: return "VK_ERROR_LAYER_NOT_PRESENT";
- case VK_ERROR_EXTENSION_NOT_PRESENT: return "VK_ERROR_EXTENSION_NOT_PRESENT";
- case VK_ERROR_FEATURE_NOT_PRESENT: return "VK_ERROR_FEATURE_NOT_PRESENT";
- case VK_ERROR_INCOMPATIBLE_DRIVER: return "VK_ERROR_INCOMPATIBLE_DRIVER";
- case VK_ERROR_TOO_MANY_OBJECTS: return "VK_ERROR_TOO_MANY_OBJECTS";
- case VK_ERROR_FORMAT_NOT_SUPPORTED: return "VK_ERROR_FORMAT_NOT_SUPPORTED";
- case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL";
- default: return "Unknown Vulkan error code";
+ case VK_SUCCESS:
+ return "VK_SUCCESS";
+ case VK_NOT_READY:
+ return "VK_NOT_READY";
+ case VK_TIMEOUT:
+ return "VK_TIMEOUT";
+ case VK_EVENT_SET:
+ return "VK_EVENT_SET";
+ case VK_EVENT_RESET:
+ return "VK_EVENT_RESET";
+ case VK_INCOMPLETE:
+ return "VK_INCOMPLETE";
+ case VK_ERROR_OUT_OF_HOST_MEMORY:
+ return "VK_ERROR_OUT_OF_HOST_MEMORY";
+ case VK_ERROR_OUT_OF_DEVICE_MEMORY:
+ return "VK_ERROR_OUT_OF_DEVICE_MEMORY";
+ case VK_ERROR_INITIALIZATION_FAILED:
+ return "VK_ERROR_INITIALIZATION_FAILED";
+ case VK_ERROR_DEVICE_LOST:
+ return "VK_ERROR_DEVICE_LOST";
+ case VK_ERROR_MEMORY_MAP_FAILED:
+ return "VK_ERROR_MEMORY_MAP_FAILED";
+ case VK_ERROR_LAYER_NOT_PRESENT:
+ return "VK_ERROR_LAYER_NOT_PRESENT";
+ case VK_ERROR_EXTENSION_NOT_PRESENT:
+ return "VK_ERROR_EXTENSION_NOT_PRESENT";
+ case VK_ERROR_FEATURE_NOT_PRESENT:
+ return "VK_ERROR_FEATURE_NOT_PRESENT";
+ case VK_ERROR_INCOMPATIBLE_DRIVER:
+ return "VK_ERROR_INCOMPATIBLE_DRIVER";
+ case VK_ERROR_TOO_MANY_OBJECTS:
+ return "VK_ERROR_TOO_MANY_OBJECTS";
+ case VK_ERROR_FORMAT_NOT_SUPPORTED:
+ return "VK_ERROR_FORMAT_NOT_SUPPORTED";
+ case VK_ERROR_FRAGMENTED_POOL:
+ return "VK_ERROR_FRAGMENTED_POOL";
+ default:
+ return "Unknown Vulkan error code";
}
}
* \brief Protected Vulkan call
* \param func Expression to call.
*/
-#define VULKAN_CHECK_ERROR(__e) \
- { \
- CHECK(__e == VK_SUCCESS) \
- << "Vulan Error, code=" << __e << ": " << vulkan::VKGetErrorString(__e); \
+#define VULKAN_CHECK_ERROR(__e) \
+ { \
+ CHECK(__e == VK_SUCCESS) << "Vulan Error, code=" << __e << ": " \
+ << vulkan::VKGetErrorString(__e); \
}
-#define VULKAN_CALL(func) \
- { \
- VkResult __e = (func); \
- VULKAN_CHECK_ERROR(__e); \
+#define VULKAN_CALL(func) \
+ { \
+ VkResult __e = (func); \
+ VULKAN_CHECK_ERROR(__e); \
}
-/*! \brief Auxiliary context structure for vulkan */
+struct VulkanDescriptorTemplateKHRFunctions {
+ PFN_vkCreateDescriptorUpdateTemplateKHR vkCreateDescriptorUpdateTemplateKHR{nullptr};
+ PFN_vkDestroyDescriptorUpdateTemplateKHR vkDestroyDescriptorUpdateTemplateKHR{nullptr};
+ PFN_vkUpdateDescriptorSetWithTemplateKHR vkUpdateDescriptorSetWithTemplateKHR{nullptr};
+ PFN_vkCmdPushDescriptorSetWithTemplateKHR vkCmdPushDescriptorSetWithTemplateKHR{nullptr};
+};
+
+struct VulkanGetBufferMemoryRequirements2Functions {
+ PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr};
+};
+
+struct VulkanStagingBuffer {
+ VkDevice device{nullptr};
+ VkBuffer buffer{VK_NULL_HANDLE};
+ VkDeviceMemory memory{VK_NULL_HANDLE};
+ void* host_addr{nullptr};
+ size_t size{0};
+};
+
struct VulkanContext {
// phyiscal device
VkPhysicalDevice phy_device{nullptr};
uint32_t staging_mtype_index{0};
// whether staging is coherent
bool coherent_staging{false};
+
+ std::unique_ptr<VulkanDescriptorTemplateKHRFunctions> descriptor_template_khr_functions{nullptr};
+ std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>
+ get_buffer_memory_requirements_2_functions{nullptr};
// Memory type index for compute
uint32_t compute_mtype_index{0};
// The logical device
VkDevice device{nullptr};
// command queue
+
+ std::unique_ptr<std::mutex> queue_mutex;
VkQueue queue{nullptr};
// queue family_index;
uint32_t queue_family_index{0};
// Queue family index.
VkQueueFamilyProperties queue_prop;
-};
-/*! \brief The buffer object */
-struct VulkanBuffer {
- /*! \brief underlying buffer */
- VkBuffer buffer{VK_NULL_HANDLE};
- /*! \brief underlying buffer */
- VkDeviceMemory memory{VK_NULL_HANDLE};
+ bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
};
-/*! \brief Buffer only used for stagging */
-struct VulkanStagingBuffer {
- /*! \brief the corresponding device */
- VkDevice device{nullptr};
- /*! \brief underlying buffer */
- VkBuffer buffer{VK_NULL_HANDLE};
- /*! \brief underlying buffer */
- VkDeviceMemory memory{VK_NULL_HANDLE};
- /*! \brief host address */
- void* host_addr{nullptr};
- /*! \brief size of the memory */
- size_t size{0};
-};
-
-/*!
- * \brief Process global Vulkan workspace.
- */
-class VulkanWorkspace final : public DeviceAPI {
- public:
- // global mutex
- std::mutex mu;
- // whether the workspace it initialized.
- bool initialized_{false};
- // vulkan instance
- VkInstance instance_{nullptr};
- // The physical devices, have 1 to 1 mapping to devices
- std::vector<VulkanContext> context_;
- // Destructor
- ~VulkanWorkspace();
- // Initialize workspace
- // Return false if already initialized, otherwise return true.
- void Init();
- // override device API
- void SetDevice(TVMContext ctx) final;
- void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
- void* AllocDataSpace(TVMContext ctx,
- size_t nbytes,
- size_t alignment,
- TVMType type_hint) final;
- void FreeDataSpace(TVMContext ctx, void* ptr) final;
- void CopyDataFromTo(const void* from,
- size_t from_size,
- void* to,
- size_t to_size,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- TVMType type_hint,
- TVMStreamHandle stream) final;
- void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
- void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final;
- void FreeWorkspace(TVMContext ctx, void* data) final;
- // get the global workspace
- static const std::shared_ptr<VulkanWorkspace>& Global();
-};
-
-/*! \brief Helper command buffer resource */
-struct VulkanCommandBuffer {
- /*! \brief fence to signal the resource is ready to use */
- VkFence fence{VK_NULL_HANDLE};
- /*! \brief The internal command buffer */
- VkCommandBuffer cmd_buffer{nullptr};
- /*! \brief Descriptor set used to bind arguments */
- VkDescriptorSet descriptor_set{VK_NULL_HANDLE};
- /*! \brief Internal utilities for write command */
- VkWriteDescriptorSet write_descriptor_set;
-
- VulkanCommandBuffer() {
- write_descriptor_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
- write_descriptor_set.pNext = nullptr;
- write_descriptor_set.dstSet = VK_NULL_HANDLE;
- write_descriptor_set.dstBinding = 0;
- write_descriptor_set.dstArrayElement = 0;
- write_descriptor_set.descriptorCount = 1;
- write_descriptor_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
- write_descriptor_set.pImageInfo = nullptr;
- write_descriptor_set.pBufferInfo = nullptr;
- write_descriptor_set.pTexelBufferView = nullptr;
- }
-};
-
-/*!
- * \brief Command pool backed by a fixed size ring buffer.
- *
- * Vulkan requires us not to reuse command buffer until
- * All its corresponding jobs have finished.
- *
- * This class to faciliate automatic management
- * of the command buffers. A fence is created
- * for each launch of command buffer jobs
- * and when we try to reuse the same entry
- * in the ring, we need to make sure that
- * the previous pending job already finishes.
- *
- */
-class VulkanCommandPool {
- public:
- /*! \brief Maximum number of pending jobs in the pool */
- static constexpr const int kMaxPending = 4;
- /*! \brief Maximum number of pending jobs in the pool */
- static constexpr const int kMaxNumArgs = 16;
- /*!
- * \brief constructor
- * \param vctx The corresponding vulkan context.
- */
- explicit VulkanCommandPool(const VulkanContext& vctx);
- /*! \brief destructor */
- ~VulkanCommandPool();
- /*!
- * \brief Allocate a new command buffer entry
- *
- * The caller must only submit the entry once
- * with the given fence in the entry,
- * before calling next Alloc.
- *
- * This function may block to wait for a
- * previously unfinished command when
- * there is more than kMaxPending jobs.
- *
- * \returns The allocated entry.
- */
- VulkanCommandBuffer* Alloc();
-
- /*!
- * \brief Allocate a new command buffer entry
- * \param dlayout the descriptor layout.
- *
- * \returns The allocated entry.
- */
- VulkanCommandBuffer* Alloc(const VkDescriptorSetLayout* dlayout);
-
- private:
- /*! \brief Local ring buffer */
- std::vector<VulkanCommandBuffer> ring_;
- /*! \brief clock pointer */
- size_t clock_ptr_{0};
- /*! \brief the corresponding device*/
- VkDevice device_{nullptr};
- /*! \brief internal command buffer pool */
- VkCommandPool cmd_pool_{VK_NULL_HANDLE};
- /*! \brief Descriptor pool */
- VkDescriptorPool descriptor_pool_{VK_NULL_HANDLE};
-};
-
-/*! \brief Thread local workspace */
-class VulkanThreadEntry {
- public:
- /*! \brief The current context */
- TVMContext context;
- /*! \brief workspace pool */
- WorkspacePool pool;
- /*! \brief The staging buffers */
- std::vector<VulkanStagingBuffer> staging_buffer_;
- /*!
- * \brief Get the command pool of corresponding device;
- * \param device_id The device id
- * \return The corresponding command buffer.
- */
- VulkanCommandPool* CommandPool(int device_id);
- /*!
- * \brief Get the stagging buffer.
- * \param device_id The device id
- * \return The corresponding stagging buffer.
- */
- VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);
-
- // constructor
- VulkanThreadEntry()
- : pool(static_cast<DLDeviceType>(kDLVulkan), VulkanWorkspace::Global()) {
- context.device_id = 0;
- context.device_type = static_cast<DLDeviceType>(kDLVulkan);
- }
- ~VulkanThreadEntry();
- // get the global workspace
- static VulkanThreadEntry* ThreadLocal();
-
- private:
- /*! \brief the command pools */
- std::vector<std::unique_ptr<VulkanCommandPool> > pool_;
-};
-
-// inline implementation
-
} // namespace vulkan
} // namespace runtime
} // namespace tvm
-#endif // TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2017 by Contributors
- * \file vulkan_device_api.cc
- */
-#include <tvm/runtime/registry.h>
-#include <dmlc/thread_local.h>
-#include <cstring>
-#include "vulkan_common.h"
-
-namespace tvm {
-namespace runtime {
-namespace vulkan {
-
-VulkanWorkspace::~VulkanWorkspace() {
- for (VulkanContext& ctx : context_) {
- vkDestroyDevice(ctx.device, nullptr);
- }
- if (instance_ != nullptr) {
- vkDestroyInstance(instance_, nullptr);
- }
-}
-
-const std::shared_ptr<VulkanWorkspace>& VulkanWorkspace::Global() {
- static std::shared_ptr<VulkanWorkspace> inst = std::make_shared<VulkanWorkspace>();
- return inst;
-}
-
-void VulkanWorkspace::SetDevice(TVMContext ctx) {
- VulkanThreadEntry::ThreadLocal()->context.device_id = ctx.device_id;
-}
-
-void VulkanWorkspace::GetAttr(
- TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) {
- this->Init();
- size_t index = static_cast<size_t>(ctx.device_id);
- if (kind == kExist) {
- *rv = static_cast<int>(index< context_.size());
- return;
- }
- CHECK_LT(index, context_.size())
- << "Invalid device id " << index;
- switch (kind) {
- case kMaxThreadsPerBlock: {
- VkPhysicalDeviceProperties phy_prop;
- vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
- int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0];
- *rv = value;
- break;
- }
- case kMaxSharedMemoryPerBlock: {
- VkPhysicalDeviceProperties phy_prop;
- vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
- int64_t value = phy_prop.limits.maxComputeSharedMemorySize;
- *rv = value;
- break;
- }
- case kWarpSize: {
- *rv = 1;
- break;
- }
- case kComputeVersion: {
- VkPhysicalDeviceProperties phy_prop;
- vkGetPhysicalDeviceProperties(context_[ctx.device_id].phy_device, &phy_prop);
- int64_t value = phy_prop.apiVersion;
- std::ostringstream os;
- os << VK_VERSION_MAJOR(value)
- << "." << VK_VERSION_MINOR(value)
- << "." << VK_VERSION_PATCH(value);
- *rv = os.str();
- break;
- }
- case kDeviceName: return;
- case kMaxClockRate: return;
- case kMultiProcessorCount: return;
- case kExist: break;
- case kMaxThreadDimensions: break;
- }
-}
-
-void* VulkanWorkspace::AllocDataSpace(
- TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) {
- this->Init();
-
- VulkanContext& vctx = context_[ctx.device_id];
-
- VkBufferCreateInfo info;
- info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
- info.pNext = nullptr;
- info.flags = 0;
- info.size = size;
- info.queueFamilyIndexCount = 1;
- info.pQueueFamilyIndices = &(vctx.queue_family_index);
- info.usage =
- VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
- VK_BUFFER_USAGE_TRANSFER_DST_BIT |
- VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
- // create buffer
- VkBuffer buffer;
- VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
- // bind to memory
- VkMemoryAllocateInfo minfo;
- minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
- minfo.pNext = nullptr;
- minfo.allocationSize = size;
- minfo.memoryTypeIndex = vctx.compute_mtype_index;
- VkDeviceMemory memory;
- VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
- VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
-
- VulkanBuffer* pbuf = new VulkanBuffer();
- pbuf->memory = memory;
- pbuf->buffer = buffer;
- return pbuf;
-}
-
-void VulkanWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
- VulkanContext& vctx = context_[ctx.device_id];
- VulkanBuffer* pbuf = static_cast<VulkanBuffer*>(ptr);
- vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr);
- vkFreeMemory(vctx.device, pbuf->memory, nullptr);
- delete pbuf;
-}
-
-void VulkanWorkspace::CopyDataFromTo(const void* from,
- size_t from_offset,
- void* to,
- size_t to_offset,
- size_t size,
- TVMContext ctx_from,
- TVMContext ctx_to,
- TVMType type_hint,
- TVMStreamHandle stream) {
- this->Init();
- CHECK(stream == nullptr);
- TVMContext ctx = ctx_from;
- if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
- VulkanThreadEntry* tls = VulkanThreadEntry::ThreadLocal();
- VulkanCommandBuffer* cmd = tls->CommandPool(ctx.device_id)->Alloc();
-
- VkCommandBufferBeginInfo cb_begin;
- cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
- cb_begin.pNext = nullptr;
- cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
- cb_begin.pInheritanceInfo = 0;
-
- VkSubmitInfo cb_submit;
- cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
- cb_submit.pNext = nullptr;
- cb_submit.waitSemaphoreCount = 0;
- cb_submit.pWaitSemaphores = nullptr;
- cb_submit.pWaitDstStageMask = 0;
- cb_submit.commandBufferCount = 1;
- cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
- cb_submit.signalSemaphoreCount = 0;
- cb_submit.pSignalSemaphores = nullptr;
-
-
- int from_dev_type = static_cast<int>(ctx_from.device_type);
- int to_dev_type = static_cast<int>(ctx_to.device_type);
-
- if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) {
- CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
- << "Vulkan disallow cross device copy.";
- const VulkanContext& vctx = context_[ctx_from.device_id];
- const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
- VulkanBuffer* to_buf = static_cast<VulkanBuffer*>(to);
- // The assumption is that subsequence ops only perform compute/transfer
- // 0: begin
- VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
- // 1: copy
- VkBufferCopy copy_info;
- copy_info.srcOffset = from_offset;
- copy_info.dstOffset = to_offset;
- copy_info.size = size;
- vkCmdCopyBuffer(cmd->cmd_buffer, from_buf->buffer, to_buf->buffer, 1, ©_info);
- // 2: barrier(transfer-> compute|transfer)
- VkMemoryBarrier barrier_info;
- barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
- barrier_info.pNext = nullptr;
- barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
- barrier_info.dstAccessMask =
- (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
- VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
- vkCmdPipelineBarrier(
- cmd->cmd_buffer,
- VK_PIPELINE_STAGE_TRANSFER_BIT,
- VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
- 0, 1, &barrier_info, 0, nullptr, 0, nullptr);
- // 3: end
- VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
- // 4: submit with cmd->fence
- VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
- } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) {
- const VulkanContext& vctx = context_[ctx_from.device_id];
- const VulkanBuffer* from_buf = static_cast<const VulkanBuffer*>(from);
- VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_from.device_id, size);
- // 0: begin
- VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
- // 1: copy
- VkBufferCopy copy_info;
- copy_info.srcOffset = from_offset;
- copy_info.dstOffset = 0;
- copy_info.size = size;
- vkCmdCopyBuffer(cmd->cmd_buffer,
- from_buf->buffer,
- temp->buffer,
- 1, ©_info);
- // 2: end
- VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
- // 4: submit with cmd->fence
- VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
- // Block until done, to make sure temp can be reused later.
- VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
- // host side invalidation if access is not coherent.
- // so writes from GPU is visible to CPU
- if (!vctx.coherent_staging) {
- VkMappedMemoryRange mrange;
- mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
- mrange.pNext = nullptr;
- mrange.memory = temp->memory;
- mrange.offset = 0;
- mrange.size = size;
- VULKAN_CALL(vkInvalidateMappedMemoryRanges(
- vctx.device, 1, &mrange));
- }
- memcpy(static_cast<char*>(to) + to_offset,
- static_cast<char*>(temp->host_addr),
- size);
- } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) {
- const VulkanContext& vctx = context_[ctx_to.device_id];
- const VulkanBuffer* to_buf = static_cast<const VulkanBuffer*>(to);
- VulkanStagingBuffer* temp = tls->StagingBuffer(ctx_to.device_id, size);
- memcpy(temp->host_addr,
- static_cast<const char*>(from) + from_offset,
- size);
- // host side flush if access is not coherent.
- // so writes from CPU is visible to GPU
- if (!vctx.coherent_staging) {
- VkMappedMemoryRange mrange;
- mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
- mrange.pNext = nullptr;
- mrange.memory = temp->memory;
- mrange.offset = 0;
- mrange.size = size;
- VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange));
- }
- VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
- // 0: barrier(host->transfer)
- VkMemoryBarrier barrier_info;
- barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
- barrier_info.pNext = nullptr;
- barrier_info.srcAccessMask = 0;
- barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
- vkCmdPipelineBarrier(cmd->cmd_buffer,
- VK_PIPELINE_STAGE_HOST_BIT,
- VK_PIPELINE_STAGE_TRANSFER_BIT,
- 0, 1, &barrier_info,
- 0, nullptr, 0, nullptr);
- // 1: copy
- VkBufferCopy copy_info;
- copy_info.srcOffset = 0;
- copy_info.dstOffset = to_offset;
- copy_info.size = size;
- vkCmdCopyBuffer(cmd->cmd_buffer,
- temp->buffer,
- to_buf->buffer,
- 1, ©_info);
- // 2: end
- VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
- // 4: submit with cmd->fence
- VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
- // wait until copy finishes, so we can reuse temp next time.
- VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
- } else {
- LOG(FATAL) << "Expect copy from/to Metal or between Metal"
- << ", from=" << from_dev_type
- << ", to=" << to_dev_type;
- }
-}
-
-void VulkanWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
- CHECK(stream == nullptr);
- VulkanContext& vctx = context_[ctx.device_id];
- VULKAN_CALL(vkQueueWaitIdle(vctx.queue));
-}
-
-void* VulkanWorkspace::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) {
- return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
-}
-
-void VulkanWorkspace::FreeWorkspace(TVMContext ctx, void* data) {
- VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
-}
-
-// VulkanCommandPool
-VulkanCommandPool::VulkanCommandPool(const VulkanContext& vctx) {
- ring_.resize(kMaxPending, VulkanCommandBuffer());
- device_ = vctx.device;
- {
- // create command pool
- VkCommandPoolCreateInfo cmd_pool_cinfo;
- cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
- cmd_pool_cinfo.pNext = nullptr;
- cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
- cmd_pool_cinfo.queueFamilyIndex = vctx.queue_family_index;
- VULKAN_CALL(vkCreateCommandPool(device_, &cmd_pool_cinfo, nullptr, &cmd_pool_));
- }
- {
- // create descriptor pool
- VkDescriptorPoolSize pool_size;
- pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
- pool_size.descriptorCount = kMaxPending * kMaxNumArgs;
- VkDescriptorPoolCreateInfo descrip_pool_cinfo;
- descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
- descrip_pool_cinfo.pNext = nullptr;
- descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
- descrip_pool_cinfo.maxSets = kMaxPending + 2;
- descrip_pool_cinfo.poolSizeCount = 1;
- descrip_pool_cinfo.pPoolSizes = &pool_size;
- VULKAN_CALL(vkCreateDescriptorPool(
- device_, &descrip_pool_cinfo, nullptr, &descriptor_pool_));
- }
- VkCommandBufferAllocateInfo buffer_alloc_info;
- buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
- buffer_alloc_info.pNext = nullptr;
- buffer_alloc_info.commandPool = cmd_pool_;
- buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
- buffer_alloc_info.commandBufferCount = 1;
-
- VkFenceCreateInfo fence_cinfo;
- fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
- fence_cinfo.pNext = nullptr;
- fence_cinfo.flags = VK_FENCE_CREATE_SIGNALED_BIT;
-
- for (size_t i = 0; i < ring_.size(); ++i) {
- VULKAN_CALL(vkAllocateCommandBuffers(
- device_, &buffer_alloc_info, &(ring_[i].cmd_buffer)));
- VULKAN_CALL(vkCreateFence(
- device_, &fence_cinfo, nullptr, &(ring_[i].fence)));
- }
-}
-
-VulkanCommandPool::~VulkanCommandPool() {
- // wait device to be idle so we know we can recycle buffers
- VULKAN_CALL(vkDeviceWaitIdle(device_));
- // start recycling.
- for (size_t i = 0; i < ring_.size(); ++i) {
- if (ring_[i].cmd_buffer != nullptr) {
- vkFreeCommandBuffers(device_, cmd_pool_, 1, &(ring_[i].cmd_buffer));
- ring_[i].cmd_buffer = nullptr;
- }
- if (ring_[i].fence != VK_NULL_HANDLE) {
- vkDestroyFence(device_, ring_[i].fence, nullptr);
- }
- }
- // delete cmd_pool and descriptor pool
- vkDestroyCommandPool(device_, cmd_pool_, nullptr);
- vkDestroyDescriptorPool(device_, descriptor_pool_, nullptr);
-}
-
-VulkanCommandBuffer* VulkanCommandPool::Alloc() {
- return Alloc(nullptr);
-}
-
-VulkanCommandBuffer* VulkanCommandPool::Alloc(
- const VkDescriptorSetLayout* dlayout) {
- // always allocate resource in round robin manner
- VulkanCommandBuffer* e = &(ring_[clock_ptr_]);
- clock_ptr_ = (clock_ptr_ + 1) % ring_.size();
- // Wait until previous usage of commad buffer is finished.
- uint64_t timeout = 1UL << 30UL;
- VkResult res;
- res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
- while (res == VK_TIMEOUT) {
- res = vkWaitForFences(device_, 1, &(e->fence), 0, timeout);
- }
- VULKAN_CHECK_ERROR(res);
- vkResetFences(device_, 1, (&e->fence));
- if (e->descriptor_set != VK_NULL_HANDLE) {
- VULKAN_CALL(vkFreeDescriptorSets(
- device_, descriptor_pool_, 1, &(e->descriptor_set)));
- e->descriptor_set = VK_NULL_HANDLE;
- }
- if (dlayout != nullptr) {
- VkDescriptorSetAllocateInfo alloc_info;
- alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
- alloc_info.pNext = nullptr;
- alloc_info.descriptorPool = descriptor_pool_;
- alloc_info.descriptorSetCount = 1;
- alloc_info.pSetLayouts = dlayout;
- VULKAN_CALL(vkAllocateDescriptorSets(
- device_, &alloc_info, &(e->descriptor_set)));
- }
- return e;
-}
-
-// VulkanThreadEntry
-typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
-
-VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() {
- return VulkanThreadStore::Get();
-}
-
-VulkanCommandPool* VulkanThreadEntry::CommandPool(int device_id) {
- while (pool_.size() <= static_cast<size_t>(device_id)) {
- pool_.emplace_back(std::unique_ptr<VulkanCommandPool>());
- }
- if (pool_[device_id] == nullptr) {
- const VulkanContext& vctx =
- VulkanWorkspace::Global()->context_[device_id];
- pool_[device_id].reset(new VulkanCommandPool(vctx));
- }
- return pool_[device_id].get();
-}
-
-VulkanStagingBuffer*
-VulkanThreadEntry::StagingBuffer(int device_id, size_t size) {
- if (staging_buffer_.size() <= static_cast<size_t>(device_id)) {
- staging_buffer_.resize(device_id + 1, VulkanStagingBuffer());
- }
- VulkanStagingBuffer& buf = staging_buffer_[device_id];
-
- if (buf.device != nullptr && buf.size < size) {
- // free previous buffer
- if (buf.host_addr != nullptr) {
- vkUnmapMemory(buf.device, buf.memory);
- }
- if (buf.memory != VK_NULL_HANDLE) {
- vkFreeMemory(buf.device, buf.memory, nullptr);
- }
- if (buf.buffer != VK_NULL_HANDLE) {
- vkDestroyBuffer(buf.device, buf.buffer, nullptr);
- }
- buf.host_addr = nullptr;
- buf.memory = VK_NULL_HANDLE;
- buf.buffer = VK_NULL_HANDLE;
- }
- const VulkanContext& vctx =
- VulkanWorkspace::Global()->context_[device_id];
-
- if (buf.device == nullptr) {
- buf.device = vctx.device;
- }
- if (buf.memory == VK_NULL_HANDLE) {
- // allocate the stagging buffer memory if necessary
- VkBufferCreateInfo info;
- info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
- info.pNext = nullptr;
- info.flags = 0;
- info.size = size;
- info.queueFamilyIndexCount = 1;
- info.pQueueFamilyIndices = &(vctx.queue_family_index);
- info.usage =
- VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
- VK_BUFFER_USAGE_TRANSFER_DST_BIT;
- VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &(buf.buffer)));
- VkMemoryAllocateInfo minfo;
- minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
- minfo.pNext = nullptr;
- minfo.allocationSize = size;
- minfo.memoryTypeIndex = vctx.staging_mtype_index;
- VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &(buf.memory)));
- VULKAN_CALL(vkBindBufferMemory(vctx.device, (buf.buffer), buf.memory, 0));
- VULKAN_CALL(vkMapMemory(vctx.device, buf.memory, 0, size, 0, &(buf.host_addr)));
- buf.size = size;
- }
- memset(buf.host_addr, 0, size);
- return &buf;
-}
-
-VulkanThreadEntry::~VulkanThreadEntry() {
- // Because the thread entry refers to Device API
- // The command buffer always will be destroyed before
- // the instance and device get destroyed.
- // The destruction need to be manually called
- // to ensure the destruction order.
- pool_.clear();
- for (VulkanStagingBuffer buf : staging_buffer_) {
- if (buf.host_addr != nullptr) {
- vkUnmapMemory(buf.device, buf.memory);
- }
- if (buf.memory != VK_NULL_HANDLE) {
- vkFreeMemory(buf.device, buf.memory, nullptr);
- }
- if (buf.buffer != VK_NULL_HANDLE) {
- vkDestroyBuffer(buf.device, buf.buffer, nullptr);
- }
- }
-}
-
-VkInstance CreateInstance() {
- VkApplicationInfo app_info;
- app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
- app_info.pNext = nullptr;
- app_info.pApplicationName = "TVM";
- app_info.applicationVersion = 0;
- app_info.pEngineName = "";
- app_info.engineVersion = 0;
- app_info.apiVersion = VK_MAKE_VERSION(1, 0, 65);
-
- VkInstanceCreateInfo inst_info;
- inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
- inst_info.pNext = nullptr;
- inst_info.flags = 0;
- inst_info.pApplicationInfo = &app_info;
- inst_info.enabledLayerCount = 0;
- inst_info.ppEnabledLayerNames = nullptr;
- inst_info.enabledExtensionCount = 0;
- inst_info.ppEnabledExtensionNames = nullptr;
-
- VkInstance inst;
- VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &inst));
- return inst;
-}
-
-// find suitable mem_type_index for staging and compute
-void FindMemoryTypeIndex(VulkanContext* vctx) {
- // Find suitable compute index.
- VkBuffer buffer;
- VkMemoryRequirements req_staging, req_compute;
- VkBufferCreateInfo info;
- info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
- info.pNext = nullptr;
- info.flags = 0;
- info.size = 1024;
- info.queueFamilyIndexCount = 1;
- info.pQueueFamilyIndices = &(vctx->queue_family_index);
-
- // get staging requirement
- info.usage =
- VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
- VK_BUFFER_USAGE_TRANSFER_DST_BIT;
- VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
- vkGetBufferMemoryRequirements(vctx->device, buffer, &req_staging);
- vkDestroyBuffer(vctx->device, buffer, nullptr);
- // get compute requirement
- info.usage =
- VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
- VK_BUFFER_USAGE_TRANSFER_DST_BIT |
- VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
- VULKAN_CALL(vkCreateBuffer(vctx->device, &info, nullptr, &buffer));
- vkGetBufferMemoryRequirements(vctx->device, buffer, &req_compute);
- vkDestroyBuffer(vctx->device, buffer, nullptr);
-
- // Query phyiscal device property
- // find a memory that is host visible, no need to be consistent
- int win_rank = -1;
- VkPhysicalDeviceMemoryProperties prop;
- vkGetPhysicalDeviceMemoryProperties(vctx->phy_device, &prop);
-
- for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
- VkMemoryType ty = prop.memoryTypes[k];
- size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
- // host visible
- if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue;
- // match copy requirment
- if (!(req_staging.memoryTypeBits & (1 << k))) continue;
- if (heap_size < 1024) continue;
- int rank = 0;
- rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT;
- if (rank > win_rank) {
- win_rank = rank;
- vctx->staging_mtype_index = k;
- vctx->coherent_staging =
- ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
- }
- }
- CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
-
- win_rank = -1;
- for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) {
- VkMemoryType ty = prop.memoryTypes[k];
- size_t heap_size = prop.memoryHeaps[ty.heapIndex].size;
- // host visible
- if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue;
- // match copy requirment
- if (!(req_staging.memoryTypeBits & (1 << k))) continue;
- if (heap_size < 1024) continue;
- int rank = 0;
- // prefer not host visible
- rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT);
- if (rank > win_rank) {
- win_rank = rank;
- vctx->compute_mtype_index = k;
- }
- }
- CHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device.";
-}
-
-// Get all logic devices that support compute
-std::vector<VulkanContext> GetContext(VkInstance instance) {
- std::vector<VulkanContext> result;
- uint32_t phy_dev_count = 0;
- VULKAN_CALL(vkEnumeratePhysicalDevices(
- instance, &phy_dev_count, nullptr));
- std::vector<VkPhysicalDevice> all_phy_devs(phy_dev_count);
- VULKAN_CALL(vkEnumeratePhysicalDevices(
- instance, &phy_dev_count, dmlc::BeginPtr(all_phy_devs)));
- for (VkPhysicalDevice phy_dev : all_phy_devs) {
- uint32_t queue_prop_count = 0;
- vkGetPhysicalDeviceQueueFamilyProperties(
- phy_dev, &queue_prop_count, nullptr);
- std::vector<VkQueueFamilyProperties> queue_props(queue_prop_count);
- vkGetPhysicalDeviceQueueFamilyProperties(
- phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props));
- uint32_t queue_family_index = 0;
- std::vector<VkDeviceQueueCreateInfo> queue_create_info;
-
- for (uint32_t i = 0; i < queue_props.size(); i++) {
- // find queues that support compute
- if (VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) {
- float priority = 1.0f;
-
- VkDeviceQueueCreateInfo info;
- info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
- info.pNext = nullptr;
- info.flags = 0;
- info.queueFamilyIndex = i;
- info.queueCount = 1;
- info.pQueuePriorities = &priority;
-
- queue_create_info.push_back(info);
- // only use the first available queue for now
- if (queue_create_info.size() == 0) {
- queue_family_index = i;
- }
- }
- }
- if (queue_create_info.size() == 0) continue;
-
- VkDeviceCreateInfo device_create_info;
- device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
- device_create_info.pNext = nullptr;
- device_create_info.flags = 0;
- device_create_info.queueCreateInfoCount
- = static_cast<uint32_t>(queue_create_info.size());
- device_create_info.pQueueCreateInfos = queue_create_info.data();
- device_create_info.enabledLayerCount = 0;
- device_create_info.ppEnabledLayerNames = nullptr;
- device_create_info.enabledExtensionCount = 0;
- device_create_info.ppEnabledExtensionNames = nullptr;
- device_create_info.pEnabledFeatures = nullptr;
-
- VulkanContext ctx;
- // setup context
- ctx.phy_device = phy_dev;
- vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop));
- VULKAN_CALL(vkCreateDevice(
- phy_dev, &device_create_info, nullptr, &(ctx.device)));
- vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
- ctx.queue_family_index = queue_family_index;
- FindMemoryTypeIndex(&ctx);
- // Find suitable memory type for staging and compute
- result.push_back(ctx);
- }
- return result;
-}
-
-void VulkanWorkspace::Init() {
- if (initialized_) return;
- std::lock_guard<std::mutex> lock(this->mu);
- if (initialized_) return;
- initialized_ = true;
- try {
- instance_ = CreateInstance();
- context_ = GetContext(instance_);
- LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices..";
- for (size_t i = 0; i < context_.size(); ++i) {
- LOG(INFO) << "vulkan(" << i
- << ")=\'" << context_[i].phy_device_prop.deviceName
- << "\' phy_dev_id=" << context_[i].phy_device;
- }
- } catch (const dmlc::Error& err) {
- LOG(INFO) << "Cannot initialize vulkan: " << err.what() << "\n"
- << "You can still compile vulkan module but cannot run locally";
- }
-}
-
-bool InitVulkan(TVMArgs args, TVMRetValue* rv) {
- vulkan::VulkanWorkspace::Global()->Init();
- return true;
-}
-
-TVM_REGISTER_GLOBAL("device_api.vulkan")
-.set_body([](TVMArgs args, TVMRetValue* rv) {
- DeviceAPI* ptr = VulkanWorkspace::Global().get();
- *rv = static_cast<void*>(ptr);
- });
-
-} // namespace vulkan
-} // namespace runtime
-} // namespace tvm
+++ /dev/null
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2018 by Contributors
- * \file vulkan_module.cc
- */
-#include <dmlc/memory_io.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/module.h>
-#include <array>
-#include <string>
-#include <mutex>
-#include "vulkan_common.h"
-#include "vulkan_module.h"
-#include "../pack_args.h"
-#include "../thread_storage_scope.h"
-#include "../meta_data.h"
-#include "../file_util.h"
-
-
-namespace tvm {
-namespace runtime {
-
-void VulkanShader::Save(dmlc::Stream* writer) const {
- writer->Write(flag);
- writer->Write(data);
-}
-
-bool VulkanShader::Load(dmlc::Stream* reader) {
- if (!reader->Read(&flag)) return false;
- if (!reader->Read(&data)) return false;
- return true;
-}
-
-// Multi-device enabled module.
-class VulkanModuleNode final :public runtime::ModuleNode {
- public:
- // Pipeline cache states
- struct PipelineEntry {
- VkShaderModule shader{VK_NULL_HANDLE};
- VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
- VkDescriptorSetLayout descriptor_layout{VK_NULL_HANDLE};
- VkPipeline pipeline{VK_NULL_HANDLE};
- };
- // constructor
- explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source)
- : smap_(smap), fmap_(fmap), source_(source) {
- }
-
- ~VulkanModuleNode() {
- // cleanup vulkan related caches.
- for (DeviceEntry& e : finfo_) {
- if (e.device == nullptr) continue;
- for (auto &kv : e.smap) {
- PipelineEntry& pe = kv.second;
- vkDestroyShaderModule(e.device, pe.shader, nullptr);
- vkDestroyDescriptorSetLayout(e.device, pe.descriptor_layout, nullptr);
- vkDestroyPipelineLayout(e.device, pe.pipeline_layout, nullptr);
- vkDestroyPipeline(e.device, pe.pipeline, nullptr);
- }
- }
- }
- const char* type_key() const final {
- return "vulkan";
- }
-
- PackedFunc GetFunction(
- const std::string& name,
- const std::shared_ptr<ModuleNode>& sptr_to_self) final;
-
- void SaveToFile(const std::string& file_name,
- const std::string& format) final {
- std::string fmt = GetFileFormat(file_name, format);
- CHECK_EQ(fmt, fmt_)
- << "Can only save to customized format vulkan";
- std::string meta_file = GetMetaFilePath(file_name);
- SaveMetaDataToFile(meta_file, fmap_);
- std::string data_bin;
- dmlc::MemoryStringStream fs(&data_bin);
- dmlc::Stream* stream = &fs;
- uint32_t magic = kVulkanModuleMagic;
- stream->Write(magic);
- stream->Write(smap_);
- SaveBinaryToFile(file_name, data_bin);
- }
-
- void SaveToBinary(dmlc::Stream* stream) final {
- stream->Write(fmt_);
- stream->Write(fmap_);
- stream->Write(smap_);
- }
- std::string GetSource(const std::string& format) final {
- // can only return source code.
- return source_;
- }
-
- // get a from primary context in device_id
- PipelineEntry GetPipeline(size_t device_id,
- const std::string& func_name,
- size_t num_pack_args) {
- vulkan::VulkanWorkspace* w = vulkan::VulkanWorkspace::Global().get();
- CHECK_LT(device_id, w->context_.size());
- // start lock scope.
- std::lock_guard<std::mutex> lock(mutex_);
- if (finfo_.size() <= device_id) {
- finfo_.resize(device_id + 1, DeviceEntry());
- }
- DeviceEntry& e = finfo_[device_id];
- auto it = e.smap.find(func_name);
- if (it != e.smap.end()) return it->second;
- PipelineEntry pe;
- if (e.device == nullptr) {
- e.device = w->context_[device_id].device;
- }
- {
- // create shader
- auto sit = smap_.find(func_name);
- CHECK(sit != smap_.end());
- const std::vector<uint32_t>& data = sit->second.data;
- VkShaderModuleCreateInfo shader_cinfo;
- shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
- shader_cinfo.pNext = nullptr;
- shader_cinfo.flags = 0;
- shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
- shader_cinfo.pCode = data.data();
- VULKAN_CALL(vkCreateShaderModule(
- e.device, &shader_cinfo, nullptr, &(pe.shader)));
- }
- std::vector<VkDescriptorSetLayoutBinding> arg_binding;
- uint32_t num_pod = 0, num_buffer = 0;
- {
- auto fit = fmap_.find(func_name);
- CHECK(fit != fmap_.end());
- for (TVMType arg_type : fit->second.arg_types) {
- if (arg_type.code == kHandle) {
- VkDescriptorSetLayoutBinding bd;
- bd.binding = num_buffer;
- bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
- bd.descriptorCount = 1;
- bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
- bd.pImmutableSamplers = nullptr;
- arg_binding.push_back(bd);
- ++num_buffer;
- } else {
- ++num_pod;
- }
- }
- }
-
- VkDescriptorSetLayoutCreateInfo descrip_cinfo;
- descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
- descrip_cinfo.pNext = nullptr;
- descrip_cinfo.flags = 0;
- descrip_cinfo.bindingCount = arg_binding.size();
- descrip_cinfo.pBindings = arg_binding.data();
- VULKAN_CALL(vkCreateDescriptorSetLayout(
- e.device, &descrip_cinfo, nullptr, &(pe.descriptor_layout)));
-
- VkPushConstantRange crange;
- crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
- crange.offset = 0;
- crange.size = sizeof(ArgUnion) * num_pack_args;
-
- VkPipelineLayoutCreateInfo playout_cinfo;
- playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
- playout_cinfo.pNext = nullptr;
- playout_cinfo.flags = 0;
- playout_cinfo.setLayoutCount = 1;
- playout_cinfo.pSetLayouts = &(pe.descriptor_layout);
-
- if (num_pack_args != 0) {
- playout_cinfo.pushConstantRangeCount = 1;
- playout_cinfo.pPushConstantRanges = &crange;
- CHECK_LE(crange.size,
- w->context_[device_id].phy_device_prop.limits.maxPushConstantsSize);
- } else {
- playout_cinfo.pushConstantRangeCount = 0;
- playout_cinfo.pPushConstantRanges = nullptr;
- }
-
- VULKAN_CALL(vkCreatePipelineLayout(
- e.device, &playout_cinfo, nullptr, &(pe.pipeline_layout)));
- VkComputePipelineCreateInfo pipeline_cinfo;
- pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
- pipeline_cinfo.pNext = nullptr;
- pipeline_cinfo.flags = 0;
- pipeline_cinfo.stage.sType =
- VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
- pipeline_cinfo.stage.pNext = nullptr;
- pipeline_cinfo.stage.flags = 0;
- pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
- pipeline_cinfo.stage.module = pe.shader;
- pipeline_cinfo.stage.pName = func_name.c_str();
- pipeline_cinfo.stage.pSpecializationInfo = nullptr;
- pipeline_cinfo.layout = pe.pipeline_layout;
- pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
- pipeline_cinfo.basePipelineIndex = 0;
- VULKAN_CALL(vkCreateComputePipelines(
- e.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe.pipeline)));
- e.smap[func_name] = pe;
- return pe;
- }
-
- private:
- // device specific entry
- struct DeviceEntry {
- VkDevice device{nullptr};
- std::unordered_map<std::string, PipelineEntry> smap;
- };
- // the binary data
- std::vector<uint32_t> data_;
- // function information table.
- std::unordered_map<std::string, VulkanShader> smap_;
- // function information table.
- std::unordered_map<std::string, FunctionInfo> fmap_;
- // The format
- std::string fmt_{"vulkan"};
- // The source
- std::string source_;
- // device local pipeline information.
- std::vector<DeviceEntry> finfo_;
- // internal mutex when updating the module
- std::mutex mutex_;
-};
-
-// a wrapped function class to get packed func.
-class VulkanWrappedFunc {
- public:
- // initialize the VULKAN function.
- void Init(VulkanModuleNode* m,
- std::shared_ptr<ModuleNode> sptr,
- const std::string& func_name,
- size_t num_buffer_args,
- size_t num_pack_args,
- const std::vector<std::string>& thread_axis_tags) {
- w_ = vulkan::VulkanWorkspace::Global().get();
- m_ = m;
- sptr_ = sptr;
- func_name_ = func_name;
- num_buffer_args_ = num_buffer_args;
- num_pack_args_ = num_pack_args;
- thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
- }
- // invoke the function with void arguments
- void operator()(TVMArgs args,
- TVMRetValue* rv,
- const ArgUnion* pack_args) const {
- vulkan::VulkanThreadEntry* tls = vulkan::VulkanThreadEntry::ThreadLocal();
- int device_id = tls->context.device_id;
- CHECK_LT(device_id, kVulkanMaxNumDevice);
- const vulkan::VulkanContext& vctx = w_->context_[device_id];
- VulkanModuleNode::PipelineEntry& pe = scache_[device_id];
- if (pe.pipeline == VK_NULL_HANDLE) {
- pe = m_->GetPipeline(device_id, func_name_, num_pack_args_);
- }
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
- vulkan::VulkanCommandBuffer* cmd = tls->CommandPool(device_id)->Alloc(
- &(pe.descriptor_layout));
-
- cmd->write_descriptor_set.dstSet = cmd->descriptor_set;
-
- // setup descriptors
- for (uint32_t i = 0; i < num_buffer_args_; ++i) {
- void* buf = args[static_cast<int>(i)];
- VkDescriptorBufferInfo binfo;
- binfo.buffer = static_cast<vulkan::VulkanBuffer*>(buf)->buffer;
- binfo.offset = 0;
- binfo.range = VK_WHOLE_SIZE;
- cmd->write_descriptor_set.dstBinding = i;
- cmd->write_descriptor_set.pBufferInfo = &binfo;
- vkUpdateDescriptorSets(
- vctx.device, 1, &(cmd->write_descriptor_set), 0, nullptr);
- }
-
- // dispatch
- VkCommandBufferBeginInfo cb_begin;
- cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
- cb_begin.pNext = nullptr;
- cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
- cb_begin.pInheritanceInfo = 0;
-
- VkSubmitInfo cb_submit;
- cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
- cb_submit.pNext = nullptr;
- cb_submit.waitSemaphoreCount = 0;
- cb_submit.pWaitSemaphores = nullptr;
- cb_submit.pWaitDstStageMask = 0;
- cb_submit.commandBufferCount = 1;
- cb_submit.pCommandBuffers = &(cmd->cmd_buffer);
- cb_submit.signalSemaphoreCount = 0;
- cb_submit.pSignalSemaphores = nullptr;
- // 0: begin
- VULKAN_CALL(vkBeginCommandBuffer(cmd->cmd_buffer, &cb_begin));
- // 1: dispatch
- vkCmdBindPipeline(
- cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, pe.pipeline);
- vkCmdBindDescriptorSets(
- cmd->cmd_buffer, VK_PIPELINE_BIND_POINT_COMPUTE,
- pe.pipeline_layout, 0, 1, &(cmd->descriptor_set), 0, nullptr);
- // bind push constant if necessary
- if (num_pack_args_ != 0) {
- vkCmdPushConstants(
- cmd->cmd_buffer,
- pe.pipeline_layout,
- VK_SHADER_STAGE_COMPUTE_BIT,
- 0, num_pack_args_ * sizeof(ArgUnion),
- pack_args);
- }
- vkCmdDispatch(
- cmd->cmd_buffer, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
- // 2: barrier(compute->compute|transfer)
- VkMemoryBarrier barrier_info;
- barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
- barrier_info.pNext = nullptr;
- barrier_info.srcAccessMask =
- VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
- barrier_info.dstAccessMask =
- (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
- VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
- vkCmdPipelineBarrier(
- cmd->cmd_buffer,
- VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
- VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
- 0, 1, &barrier_info, 0, nullptr, 0, nullptr);
- // 3: end
- VULKAN_CALL(vkEndCommandBuffer(cmd->cmd_buffer));
- // 4: submit with cmd->fence
- VULKAN_CALL(vkQueueSubmit(vctx.queue, 1, &cb_submit, cmd->fence));
- }
-
- private:
- // Reference to global workspace.
- vulkan::VulkanWorkspace* w_;
- // internal module
- VulkanModuleNode* m_;
- // the resource holder
- std::shared_ptr<ModuleNode> sptr_;
- // The name of the function.
- std::string func_name_;
- // Number of buffer arguments
- size_t num_buffer_args_;
- // number of packed arguments.
- size_t num_pack_args_;
- // Device state cache per device.
- // mark as mutable, to enable lazy initialization
- mutable std::array<VulkanModuleNode::PipelineEntry, kVulkanMaxNumDevice> scache_;
- // thread axis configuration
- ThreadAxisConfig thread_axis_cfg_;
-};
-
-PackedFunc VulkanModuleNode::GetFunction(
- const std::string& name,
- const std::shared_ptr<ModuleNode>& sptr_to_self) {
- CHECK_EQ(sptr_to_self.get(), this);
- CHECK_NE(name, symbol::tvm_module_main)
- << "Device function do not have main";
- auto it = fmap_.find(name);
- if (it == fmap_.end()) return PackedFunc();
- const FunctionInfo& info = it->second;
- VulkanWrappedFunc f;
- size_t num_buffer_args = NumBufferArgs(info.arg_types);
- f.Init(this, sptr_to_self, name,
- num_buffer_args, info.arg_types.size() - num_buffer_args,
- info.thread_axis_tags);
- return PackFuncNonBufferArg(f, info.arg_types);
-}
-
-Module VulkanModuleCreate(
- std::unordered_map<std::string, VulkanShader> smap,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source) {
- vulkan::VulkanWorkspace::Global()->Init();
- std::shared_ptr<VulkanModuleNode> n =
- std::make_shared<VulkanModuleNode>(smap, fmap, source);
- return Module(n);
-}
-
-// Load module from module.
-Module VulkanModuleLoadFile(const std::string& file_name,
- const std::string& format) {
- std::string data;
- std::unordered_map<std::string, VulkanShader> smap;
- std::unordered_map<std::string, FunctionInfo> fmap;
- std::string fmt = GetFileFormat(file_name, format);
- std::string meta_file = GetMetaFilePath(file_name);
- LoadBinaryFromFile(file_name, &data);
- LoadMetaDataFromFile(meta_file, &fmap);
- dmlc::MemoryStringStream fs(&data);
- dmlc::Stream* stream = &fs;
- uint32_t magic;
- stream->Read(&magic);
- CHECK_EQ(magic, kVulkanModuleMagic)
- << "VulkanModule Magic mismatch";
- stream->Read(&smap);
- return VulkanModuleCreate(smap, fmap, "");
-}
-
-Module VulkanModuleLoadBinary(void* strm) {
- dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
- std::unordered_map<std::string, VulkanShader> smap;
- std::unordered_map<std::string, FunctionInfo> fmap;
-
- std::string fmt;
- stream->Read(&fmt);
- stream->Read(&fmap);
- stream->Read(&smap);
- return VulkanModuleCreate(smap, fmap, "");
-}
-
-TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
-.set_body_typed(VulkanModuleLoadFile);
-
-TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
-.set_body_typed(VulkanModuleLoadBinary);
-} // namespace runtime
-} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* specific language governing permissions and limitations
* under the License.
*/
+#pragma once
-/*!
- * Copyright (c) 2017 by Contributors
- * \file metal_module.h
- * \brief Execution handling of Metal kernels
- */
-#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
-#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
-
-#include <tvm/runtime/packed_func.h>
-#include <dmlc/type_traits.h>
-#include <memory>
-#include <vector>
#include <string>
#include <unordered_map>
+
#include "../meta_data.h"
+#include "vulkan_shader.h"
namespace tvm {
namespace runtime {
-/*! \brief Maximum number of GPU supported in VulkanModule. */
-static constexpr const int kVulkanMaxNumDevice = 8;
-
-/*! \brief TVM Vulkan binary pack magic number */
-static constexpr const int kVulkanModuleMagic = 0x02700027;
-
-/*!
- * \brief A single VK shader program
- *
- * Due to the global resource declaration.
- * Current SPIRV only allows one entry program per shader,
- * making it less useful for a Module like system.
- *
- * Instead we pass in map of str->VulkanShader until
- * there is a native solution available.
- */
-struct VulkanShader {
- /*! \brief header flag */
- uint32_t flag{0};
- /*! \brief Data segment */
- std::vector<uint32_t> data;
+namespace vulkan {
+Module VulkanModuleCreate(std::unordered_map<std::string, VulkanShader> smap,
+ std::unordered_map<std::string, FunctionInfo> fmap, std::string source);
- void Save(dmlc::Stream *writer) const;
- bool Load(dmlc::Stream *reader);
-};
+} // namespace vulkan
-/*!
- * \brief create a metal module from data.
- *
- * \param pmap The program map.
- * \param fmap The function information map.
- * \param source Optional, source code.
- */
-Module VulkanModuleCreate(
- std::unordered_map<std::string, VulkanShader> smap,
- std::unordered_map<std::string, FunctionInfo> fmap,
- std::string source);
+using vulkan::VulkanModuleCreate;
} // namespace runtime
} // namespace tvm
-
-namespace dmlc {
-DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::VulkanShader, true);
-} // namespace dmlc
-
-#endif // TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#pragma once
+
+
+#include <dmlc/logging.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace vulkan {
+
+struct VulkanShader {
+ /*! \brief header flag */
+ uint32_t flag{0};
+ /*! \brief Data segment */
+ std::vector<uint32_t> data;
+
+ void Save(dmlc::Stream* writer) const {
+ writer->Write(flag);
+ writer->Write(data);
+ }
+ bool Load(dmlc::Stream* reader) {
+ if (!reader->Read(&flag)) return false;
+ if (!reader->Read(&data)) return false;
+ return true;
+ }
+};
+
+} // namespace vulkan
+
+using vulkan::VulkanShader;
+} // namespace runtime
+} // namespace tvm
+
+namespace dmlc {
+DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::vulkan::VulkanShader, true);
+} // namespace dmlc
--- /dev/null
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "vulkan_common.h"
+
+
+namespace tvm {
+namespace runtime {
+namespace vulkan {
+
+class VulkanStreamState {
+ public:
+ VkCommandBuffer cmd_buffer_;
+ VkFence fence_;
+};
+
+// Used to identify state that should only be used once-per-stream.
+struct VulkanStreamToken {
+ VkDescriptorSet descriptor_set_{VK_NULL_HANDLE};
+ std::vector<VkBuffer> buffers_;
+};
+
+class VulkanStream {
+ public:
+ explicit VulkanStream(const VulkanContext* vctx)
+ : vctx_(vctx), state_(new VulkanStreamState()) {
+ // create command pool
+ VkCommandPoolCreateInfo cmd_pool_cinfo;
+ cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
+ cmd_pool_cinfo.pNext = nullptr;
+ cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
+ cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index;
+ VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_));
+
+ VkCommandBufferAllocateInfo buffer_alloc_info;
+ buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+ buffer_alloc_info.pNext = nullptr;
+ buffer_alloc_info.commandPool = cmd_pool_;
+ buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+ buffer_alloc_info.commandBufferCount = 1;
+ VULKAN_CALL(
+ vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_)));
+
+ VkFenceCreateInfo fence_cinfo;
+ fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
+ fence_cinfo.pNext = nullptr;
+ fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT;
+ VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_)));
+
+ VkCommandBufferBeginInfo cb_begin;
+ cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+ cb_begin.pNext = nullptr;
+ cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+ cb_begin.pInheritanceInfo = 0;
+ VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
+ }
+
+ ~VulkanStream() {
+ vkDestroyFence(vctx_->device, state_->fence_, nullptr);
+ vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr);
+ }
+
+ // Launch the kernel on the current stream.
+ void Launch(const std::function<void(VulkanStreamState*)>& kernel) {
+ if (vctx_->UseImmediate()) {
+ kernel(state_.get());
+ } else {
+ deferred_kernels_.push_back(kernel);
+ }
+ }
+
+ // Launch the kernel on the current stream,
+ void LaunchDeferred(const std::function<void()>& deferred_initializer,
+ const std::function<void(VulkanStreamState*)>& deferred_kernel,
+ const VulkanStreamToken& deferred_token) {
+ CHECK(!vctx_->UseImmediate());
+
+ // It is invalid to schedule this instance on the current stream if we already
+ // have a matching descriptor set and a non-matching buffer set.
+ if (std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(),
+ [&](const VulkanStreamToken& token) {
+ return token.descriptor_set_ == deferred_token.descriptor_set_ &&
+ token.buffers_ != deferred_token.buffers_;
+ })) {
+ Synchronize();
+ }
+
+ // It is unnecessary to invoke our initializer if we have a matching token.
+ if (!std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(),
+ [&](const VulkanStreamToken& token) {
+ // If we have a matching descriptor set
+ return token.descriptor_set_ == deferred_token.descriptor_set_ &&
+ token.buffers_ == deferred_token.buffers_;
+ })) {
+ deferred_initializer();
+ }
+
+ deferred_kernels_.push_back(deferred_kernel);
+ deferred_tokens_.push_back(deferred_token);
+ }
+
+ // Synchronize the current stream `state_` with respect to the host.
+ void Synchronize() {
+ if (!vctx_->UseImmediate()) {
+ for (const auto& deferred_kernel : deferred_kernels_) {
+ deferred_kernel(state_.get());
+ }
+ deferred_kernels_.clear();
+ deferred_tokens_.clear();
+ } else {
+ DCHECK_EQ(deferred_kernels_.size(), 0);
+ DCHECK_EQ(deferred_tokens_.size(), 0);
+ }
+
+ VULKAN_CALL(vkEndCommandBuffer(state_->cmd_buffer_));
+ VkSubmitInfo cb_submit;
+ cb_submit.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+ cb_submit.pNext = nullptr;
+ cb_submit.waitSemaphoreCount = 0;
+ cb_submit.pWaitSemaphores = nullptr;
+ cb_submit.pWaitDstStageMask = 0;
+ cb_submit.commandBufferCount = 1;
+ cb_submit.pCommandBuffers = &(state_->cmd_buffer_);
+ cb_submit.signalSemaphoreCount = 0;
+ cb_submit.pSignalSemaphores = nullptr;
+
+ {
+ // Multiple streams (on different threads) use the same VulkanContext
+ // instance, so we need to externally synchronize accesses.
+ std::lock_guard<std::mutex> g(*(vctx_->queue_mutex));
+ VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_));
+ }
+ uint64_t timeout = 1UL << 30UL;
+ VkResult res;
+ do {
+ res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout);
+ } while (res == VK_TIMEOUT);
+ VULKAN_CHECK_ERROR(res);
+ VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0));
+ VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_)));
+
+ // Re-initialize the command buffer
+ VkCommandBufferBeginInfo cb_begin;
+ cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+ cb_begin.pNext = nullptr;
+ cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+ cb_begin.pInheritanceInfo = 0;
+ VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));
+ }
+
+ private:
+ const VulkanContext* vctx_;
+ std::unique_ptr<VulkanStreamState> state_;
+ std::vector<VulkanStreamToken> deferred_tokens_;
+ std::vector<std::function<void(VulkanStreamState*)>> deferred_kernels_;
+ VkCommandPool cmd_pool_;
+};
+
+} // namespace vulkan
+} // namespace runtime
+} // namespace tvm
# under the License.
import tvm
import re
+import numpy as np
def test_vector_comparison():
check_correct_assembly('float16')
+tx = tvm.thread_axis("threadIdx.x")
+bx = tvm.thread_axis("blockIdx.x")
+
+
+def test_vulkan_copy():
+
+ def check_vulkan(dtype, n):
+ if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
+ print("skip because vulkan is not enabled..")
+ return
+ A = tvm.placeholder((n,), name='A', dtype=dtype)
+ ctx = tvm.vulkan(0)
+ a_np = np.random.uniform(size=(n,)).astype(A.dtype)
+ a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(a_np)
+ b_np = a.asnumpy()
+ tvm.testing.assert_allclose(a_np, b_np)
+ tvm.testing.assert_allclose(a_np, a.asnumpy())
+
+ for _ in range(100):
+ dtype = np.random.choice(["float32", "float16", "int8", "int32"])
+ logN = np.random.randint(1, 15)
+ peturb = np.random.uniform(low=0.5, high=1.5)
+ check_vulkan(dtype, int(peturb * (2 ** logN)))
+
+
+def test_vulkan_vectorize_add():
+ num_thread = 8
+
+ def check_vulkan(dtype, n, lanes):
+ if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
+ print("skip because vulkan is not enabled..")
+ return
+ A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
+ B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
+ s = tvm.create_schedule(B.op)
+ xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
+ s[B].bind(xo, bx)
+ s[B].bind(xi, tx)
+ fun = tvm.build(s, [A, B], "vulkan")
+ ctx = tvm.vulkan(0)
+ a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
+ np.random.uniform(size=(n, lanes)))
+ c = tvm.nd.empty((n,), B.dtype, ctx)
+ fun(a, c)
+ tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 1)
+
+ check_vulkan("float32", 64, 2)
+ check_vulkan("float16", 64, 2)
+
+
+def test_vulkan_stress():
+ """
+ Launch a randomized test with multiple kernels per stream, multiple uses of
+ kernels per stream, over multiple threads.
+ """
+ import random
+ import threading
+ n = 1024
+ num_thread = 64
+
+ def run_stress():
+ def worker():
+ if not tvm.vulkan(0).exist or not tvm.module.enabled("vulkan"):
+ print("skip because vulkan is not enabled..")
+ return
+ A = tvm.placeholder((n,), name='A', dtype="float32")
+ B = tvm.placeholder((n,), name='B', dtype="float32")
+ functions = [
+ (lambda: tvm.compute((n,), lambda i: 2 * A[i] + 3 * B[i]),
+ lambda a, b: 2 * a + 3 * b),
+ (lambda: tvm.compute((n,), lambda i: A[i]+B[i]),
+ lambda a, b: a + b),
+ (lambda: tvm.compute((n,), lambda i: A[i]+2 * B[i]),
+ lambda a, b: a + 2 * b),
+ ]
+
+ def build_f(f_ref):
+ (C_f, ref) = f_ref
+ C = C_f()
+ s = tvm.create_schedule(C.op)
+ xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
+ s[C].bind(xo, bx)
+ s[C].bind(xi, tx)
+ fun = tvm.build(s, [A, B, C], "vulkan")
+ return (fun, ref)
+
+ fs = [build_f(random.choice(functions))
+ for _ in range(np.random.randint(low=1, high=10))]
+ ctx = tvm.vulkan(0)
+ a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(
+ np.random.uniform(size=(n,)))
+ b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(
+ np.random.uniform(size=(n,)))
+ cs = [tvm.nd.empty((n,), A.dtype, ctx) for _ in fs]
+ for ((f, _), c) in zip(fs, cs):
+ f(a, b, c)
+
+ for ((_, ref), c) in zip(fs, cs):
+ tvm.testing.assert_allclose(
+ c.asnumpy(), ref(a.asnumpy(), b.asnumpy()))
+
+ ts = [threading.Thread(target=worker)
+ for _ in range(np.random.randint(1, 10))]
+ for t in ts:
+ t.start()
+ for t in ts:
+ t.join()
+
+ run_stress()
+
+
if __name__ == "__main__":
test_vector_comparison()
+ test_vulkan_copy()
+ test_vulkan_vectorize_add()
+ test_vulkan_stress()