vulkan: Add a more direct way to use a NIR shader
authorFaith Ekstrand <faith.ekstrand@collabora.com>
Tue, 31 Jan 2023 02:11:52 +0000 (20:11 -0600)
committerMarge Bot <emma+marge@anholt.net>
Mon, 31 Jul 2023 17:01:41 +0000 (17:01 +0000)
This follows the pipeline libraries method of including SPIR-V and lets
you provide the NIR shader without wrapping it in a VkShaderModule.

Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24335>

src/vulkan/runtime/vk_pipeline.c
src/vulkan/runtime/vk_pipeline.h

index 812a6bf..9d3d6d4 100644 (file)
@@ -49,6 +49,32 @@ vk_pipeline_shader_stage_is_null(const VkPipelineShaderStageCreateInfo *info)
    return true;
 }
 
+static nir_shader *
+get_builtin_nir(const VkPipelineShaderStageCreateInfo *info)
+{
+   VK_FROM_HANDLE(vk_shader_module, module, info->module);
+
+   nir_shader *nir = NULL;
+   if (module != NULL) {
+      nir = module->nir;
+   } else {
+      const VkPipelineShaderStageNirCreateInfoMESA *nir_info =
+         vk_find_struct_const(info->pNext, PIPELINE_SHADER_STAGE_NIR_CREATE_INFO_MESA);
+      if (nir_info != NULL)
+         nir = nir_info->nir;
+   }
+
+   if (nir == NULL)
+      return NULL;
+
+   assert(nir->info.stage == vk_to_mesa_shader_stage(info->stage));
+   ASSERTED nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
+   assert(strcmp(entrypoint->function->name, info->pName) == 0);
+   assert(info->pSpecializationInfo == NULL);
+
+   return nir;
+}
+
 static uint32_t
 get_required_subgroup_size(const VkPipelineShaderStageCreateInfo *info)
 {
@@ -70,16 +96,11 @@ vk_pipeline_shader_stage_to_nir(struct vk_device *device,
 
    assert(info->sType == VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO);
 
-   if (module != NULL && module->nir != NULL) {
-      assert(module->nir->info.stage == stage);
-      assert(exec_list_length(&module->nir->functions) == 1);
-      ASSERTED const char *nir_name =
-         nir_shader_get_entrypoint(module->nir)->function->name;
-      assert(strcmp(nir_name, info->pName) == 0);
-
-      nir_validate_shader(module->nir, "internal shader");
+   nir_shader *builtin_nir = get_builtin_nir(info);
+   if (builtin_nir != NULL) {
+      nir_validate_shader(builtin_nir, "internal shader");
 
-      nir_shader *clone = nir_shader_clone(mem_ctx, module->nir);
+      nir_shader *clone = nir_shader_clone(mem_ctx, builtin_nir);
       if (clone == NULL)
          return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
 
@@ -142,20 +163,16 @@ vk_pipeline_hash_shader_stage(const VkPipelineShaderStageCreateInfo *info,
 {
    VK_FROM_HANDLE(vk_shader_module, module, info->module);
 
-   if (module && module->nir) {
+   const nir_shader *builtin_nir = get_builtin_nir(info);
+   if (builtin_nir != NULL) {
       /* Internal NIR module: serialize and hash the NIR shader.
        * We don't need to hash other info fields since they should match the
        * NIR data.
        */
-      assert(module->nir->info.stage == vk_to_mesa_shader_stage(info->stage));
-      ASSERTED nir_function_impl *entrypoint = nir_shader_get_entrypoint(module->nir);
-      assert(strcmp(entrypoint->function->name, info->pName) == 0);
-      assert(info->pSpecializationInfo == NULL);
-
       struct blob blob;
 
       blob_init(&blob);
-      nir_serialize(&blob, module->nir, false);
+      nir_serialize(&blob, builtin_nir, false);
       assert(!blob.out_of_memory);
       _mesa_sha1_compute(blob.data, blob.size, stage_sha1);
       blob_finish(&blob);
index a6d8417..ce037eb 100644 (file)
@@ -37,6 +37,18 @@ struct vk_device;
 extern "C" {
 #endif
 
+#define VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_NIR_CREATE_INFO_MESA \
+   (VkStructureType)1000290001
+
+#define VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_NIR_CREATE_INFO_MESA_cast \
+   VkPipelineShaderStageNirCreateInfoMESA
+
+typedef struct VkPipelineShaderStageNirCreateInfoMESA {
+   VkStructureType sType;
+   const void *pNext;
+   struct nir_shader *nir;
+} VkPipelineShaderStageNirCreateInfoMESA;
+
 bool
 vk_pipeline_shader_stage_is_null(const VkPipelineShaderStageCreateInfo *info);