radv/rt: add shader stage indices to radv_ray_tracing_group
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 14 Apr 2023 10:00:03 +0000 (12:00 +0200)
committerMarge Bot <emma+marge@anholt.net>
Wed, 26 Apr 2023 02:48:29 +0000 (02:48 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22686>

src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_private.h

index 223e0b2..2bc06c9 100644 (file)
@@ -132,6 +132,50 @@ radv_create_group_handles(struct radv_device *device,
    return VK_SUCCESS;
 }
 
+static VkResult
+radv_rt_fill_group_info(struct radv_device *device,
+                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                        struct radv_ray_tracing_group *groups)
+{
+   VkResult result = radv_create_group_handles(device, pCreateInfo, groups);
+
+   uint32_t idx;
+   for (idx = 0; idx < pCreateInfo->groupCount; idx++) {
+      groups[idx].type = pCreateInfo->pGroups[idx].type;
+      if (groups[idx].type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
+         groups[idx].recursive_shader = pCreateInfo->pGroups[idx].generalShader;
+      else
+         groups[idx].recursive_shader = pCreateInfo->pGroups[idx].closestHitShader;
+      groups[idx].any_hit_shader = pCreateInfo->pGroups[idx].anyHitShader;
+      groups[idx].intersection_shader = pCreateInfo->pGroups[idx].intersectionShader;
+   }
+
+   /* copy and adjust library groups (incl. handles) */
+   if (pCreateInfo->pLibraryInfo) {
+      unsigned stage_count = pCreateInfo->stageCount;
+      for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
+         RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
+         struct radv_ray_tracing_lib_pipeline *library_pipeline =
+            radv_pipeline_to_ray_tracing_lib(pipeline);
+
+         for (unsigned j = 0; j < library_pipeline->group_count; ++j) {
+            struct radv_ray_tracing_group *dst = &groups[idx + j];
+            *dst = library_pipeline->groups[j];
+            if (dst->recursive_shader != VK_SHADER_UNUSED_KHR)
+               dst->recursive_shader += stage_count;
+            if (dst->any_hit_shader != VK_SHADER_UNUSED_KHR)
+               dst->any_hit_shader += stage_count;
+            if (dst->intersection_shader != VK_SHADER_UNUSED_KHR)
+               dst->intersection_shader += stage_count;
+         }
+         idx += library_pipeline->group_count;
+         stage_count += library_pipeline->stage_count;
+      }
+   }
+
+   return result;
+}
+
 static VkRayTracingPipelineCreateInfoKHR
 radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
 {
@@ -349,7 +393,7 @@ radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache,
 
    pipeline->ctx = ralloc_context(NULL);
 
-   result = radv_create_group_handles(device, &local_create_info, pipeline->groups);
+   result = radv_rt_fill_group_info(device, pCreateInfo, pipeline->groups);
    if (result != VK_SUCCESS)
       goto pipeline_fail;
 
@@ -555,7 +599,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
    radv_pipeline_init(device, &rt_pipeline->base.base, RADV_PIPELINE_RAY_TRACING);
    rt_pipeline->group_count = local_create_info.groupCount;
 
-   result = radv_create_group_handles(device, &local_create_info, rt_pipeline->groups);
+   result = radv_rt_fill_group_info(device, pCreateInfo, rt_pipeline->groups);
    if (result != VK_SUCCESS)
       goto pipeline_fail;
 
index 40ae521..35c2296 100644 (file)
@@ -2285,6 +2285,10 @@ struct radv_compute_pipeline {
 };
 
 struct radv_ray_tracing_group {
+   VkRayTracingShaderGroupTypeKHR type;
+   uint32_t recursive_shader; /* generalShader or closestHitShader */
+   uint32_t any_hit_shader;
+   uint32_t intersection_shader;
    struct radv_pipeline_group_handle handle;
    struct radv_pipeline_shader_stack_size stack_size;
 };