From 913de78731aa58c27861e62d16f50dc9249be58f Mon Sep 17 00:00:00 2001 From: Bas Nieuwenhuizen Date: Wed, 11 Jan 2023 01:30:24 +0100 Subject: [PATCH] radv: Use provided handles for switch cases in RT shaders. Part-of: --- src/amd/vulkan/radv_pipeline_rt.c | 3 +- src/amd/vulkan/radv_rt_shader.c | 86 ++++++++++++++++++++++++++++----------- src/amd/vulkan/radv_shader.h | 2 + 3 files changed, 66 insertions(+), 25 deletions(-) diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 02c7832..197dcd9 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -435,7 +435,8 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, goto pipeline_fail; } - shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes, &key); + shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes, + rt_pipeline->group_handles, &key); module.nir = shader; result = radv_compute_pipeline_compile( &rt_pipeline->base, pipeline_layout, device, cache, &key, &stage, pCreateInfo->flags, diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 212ac0e..a301dd5 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -1078,10 +1078,20 @@ init_traversal_vars(nir_builder *b) return ret; } +struct traversal_data { + struct radv_device *device; + const VkRayTracingPipelineCreateInfoKHR *createInfo; + struct rt_variables *vars; + struct rt_traversal_vars *trav_vars; + nir_variable *barycentrics; + + const struct radv_pipeline_group_handle *handles; +}; + static void visit_any_hit_shaders(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, - struct rt_variables *vars) + struct traversal_data *data, struct rt_variables *vars) { nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx); @@ -1102,25 +1112,26 @@ visit_any_hit_shaders(struct radv_device *device, if (shader_id == VK_SHADER_UNUSED_KHR) continue; + /* Avoid emitting stages with the same shaders/handles multiple times. */ + bool is_dup = false; + for (unsigned j = 0; j < i; ++j) + if (data->handles[j].any_hit_index == data->handles[i].any_hit_index) + is_dup = true; + + if (is_dup) + continue; + const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key); vars->stage_idx = shader_id; - insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2); + insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->handles[i].any_hit_index); } if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) nir_pop_if(b, NULL); } -struct traversal_data { - struct radv_device *device; - const VkRayTracingPipelineCreateInfoKHR *createInfo; - struct rt_variables *vars; - struct rt_traversal_vars *trav_vars; - nir_variable *barycentrics; -}; - static void handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection, const struct radv_ray_traversal_args *args, @@ -1158,7 +1169,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX); - visit_any_hit_shaders(data->device, data->createInfo, b, &inner_vars); + visit_any_hit_shaders(data->device, data->createInfo, b, args->data, &inner_vars); nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept))); { @@ -1237,6 +1248,15 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio if (shader_id == VK_SHADER_UNUSED_KHR) continue; + /* Avoid emitting stages with the same shaders/handles multiple times. */ + bool is_dup = false; + for (unsigned j = 0; j < i; ++j) + if (data->handles[j].intersection_index == data->handles[i].intersection_index) + is_dup = true; + + if (is_dup) + continue; + const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id]; nir_shader *nir_stage = parse_rt_stage(data->device, stage, data->vars->key); @@ -1250,7 +1270,8 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio } inner_vars.stage_idx = shader_id; - insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2); + insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, + data->handles[i].intersection_index); } if (!(data->vars->create_info->flags & @@ -1297,6 +1318,7 @@ static nir_shader * build_traversal_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_group_handle *handles, const struct radv_pipeline_key *key) { /* Create the traversal shader as an intersection shader to prevent validation failures due to @@ -1383,6 +1405,7 @@ build_traversal_shader(struct radv_device *device, .vars = &vars, .trav_vars = &trav_vars, .barycentrics = barycentrics, + .handles = handles, }; struct radv_ray_traversal_args args = { @@ -1518,6 +1541,7 @@ lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs) nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_group_handle *handles, const struct radv_pipeline_key *key) { nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined"); @@ -1554,23 +1578,37 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_ssa_def *idx = nir_load_var(&b, vars.idx); /* Insert traversal shader */ - nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, key); + nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, handles, key); assert(b.shader->info.shared_size == 0); b.shader->info.shared_size = traversal->info.shared_size; assert(b.shader->info.shared_size <= 32768); insert_rt_case(&b, traversal, &vars, idx, 0, 1); - /* We do a trick with the indexing of the resume shaders so that the first - * shader of stage x always gets id x and the resume shader ids then come after - * stageCount. This makes the shadergroup handles independent of compilation. */ - unsigned call_idx_base = pCreateInfo->stageCount + 1; - for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) { - const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i]; - gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage); - if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE && - type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS) + unsigned call_idx_base = 1; + for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { + unsigned stage_idx = VK_SHADER_UNUSED_KHR; + if (pCreateInfo->pGroups[i].type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR) + stage_idx = pCreateInfo->pGroups[i].generalShader; + else + stage_idx = pCreateInfo->pGroups[i].closestHitShader; + + if (stage_idx == VK_SHADER_UNUSED_KHR) continue; + /* Avoid emitting stages with the same shaders/handles multiple times. */ + bool is_dup = false; + for (unsigned j = 0; j < i; ++j) + if (handles[j].general_index == handles[i].general_index) + is_dup = true; + + if (is_dup) + continue; + + const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[stage_idx]; + ASSERTED gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage); + assert(type == MESA_SHADER_RAYGEN || type == MESA_SHADER_CALLABLE || + type == MESA_SHADER_CLOSEST_HIT || type == MESA_SHADER_MISS); + nir_shader *nir_stage = parse_rt_stage(device, stage, key); /* Move ray tracing system values to the top that are set by rt_trace_ray @@ -1588,8 +1626,8 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_shader **resume_shaders = NULL; nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, &num_resume_shaders, nir_stage); - vars.stage_idx = i; - insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2); + vars.stage_idx = stage_idx; + insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, handles[i].general_index); for (unsigned j = 0; j < num_resume_shaders; ++j) { insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j); } diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index f928e3d..d9d6970 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -47,6 +47,7 @@ struct radv_physical_device; struct radv_device; struct radv_pipeline; struct radv_pipeline_cache; +struct radv_pipeline_group_handle; struct radv_pipeline_key; struct radv_shader_args; struct radv_vs_input_state; @@ -755,6 +756,7 @@ bool radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage nir_shader *create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes, + const struct radv_pipeline_group_handle *handles, const struct radv_pipeline_key *key); #endif -- 2.7.4