static unsigned
compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
- const struct radv_ray_tracing_group *groups)
+ const struct radv_ray_tracing_group *groups,
+ const struct radv_ray_tracing_stage *stages)
{
if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
return -1u;
if (shader_id == VK_SHADER_UNUSED_KHR)
continue;
- const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
- switch (stage->stage) {
- case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
+ switch (stages[shader_id].stage) {
+ case MESA_SHADER_RAYGEN:
raygen_size = MAX2(raygen_size, size);
break;
- case VK_SHADER_STAGE_MISS_BIT_KHR:
+ case MESA_SHADER_MISS:
miss_size = MAX2(miss_size, size);
break;
- case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
+ case MESA_SHADER_CLOSEST_HIT:
chit_size = MAX2(chit_size, size);
break;
- case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
+ case MESA_SHADER_CALLABLE:
callable_size = MAX2(callable_size, size);
break;
default:
goto shader_fail;
}
- rt_pipeline->stack_size = compute_rt_stack_size(&local_create_info, rt_pipeline->groups);
+ rt_pipeline->stack_size = compute_rt_stack_size(&local_create_info, rt_pipeline->groups, stages);
rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE] = radv_create_rt_prolog(device);
combine_config(&rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config,