radv/rt: refactor compute_rt_stack_size() to use radv_ray_tracing_stage information
authorDaniel Schürmann <daniel@schuermann.dev>
Tue, 25 Apr 2023 13:56:57 +0000 (15:56 +0200)
committerMarge Bot <emma+marge@anholt.net>
Tue, 2 May 2023 19:15:10 +0000 (19:15 +0000)
instead of pStages.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22100>

src/amd/vulkan/radv_pipeline_rt.c

index b4bea12..8a527cf 100644 (file)
@@ -506,7 +506,8 @@ radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR
 
 static unsigned
 compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                      const struct radv_ray_tracing_group *groups)
+                      const struct radv_ray_tracing_group *groups,
+                      const struct radv_ray_tracing_stage *stages)
 {
    if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
       return -1u;
@@ -526,18 +527,17 @@ compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
       if (shader_id == VK_SHADER_UNUSED_KHR)
          continue;
 
-      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
-      switch (stage->stage) {
-      case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
+      switch (stages[shader_id].stage) {
+      case MESA_SHADER_RAYGEN:
          raygen_size = MAX2(raygen_size, size);
          break;
-      case VK_SHADER_STAGE_MISS_BIT_KHR:
+      case MESA_SHADER_MISS:
          miss_size = MAX2(miss_size, size);
          break;
-      case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
+      case MESA_SHADER_CLOSEST_HIT:
          chit_size = MAX2(chit_size, size);
          break;
-      case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
+      case MESA_SHADER_CALLABLE:
          callable_size = MAX2(callable_size, size);
          break;
       default:
@@ -684,7 +684,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
          goto shader_fail;
    }
 
-   rt_pipeline->stack_size = compute_rt_stack_size(&local_create_info, rt_pipeline->groups);
+   rt_pipeline->stack_size = compute_rt_stack_size(&local_create_info, rt_pipeline->groups, stages);
    rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE] = radv_create_rt_prolog(device);
 
    combine_config(&rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config,