From 60f9dbeb2ba086baccb74321b7a9b0547e6b9263 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Daniel=20Sch=C3=BCrmann?= Date: Wed, 22 Mar 2023 02:12:35 +0100 Subject: [PATCH] radv/rt: use priorities to select the next shader Part-of: --- src/amd/vulkan/radv_cmd_buffer.c | 3 ++- src/amd/vulkan/radv_pipeline_rt.c | 3 ++- src/amd/vulkan/radv_rt_shader.c | 46 +++++++++++++++++++++++++++++++++++++-- src/amd/vulkan/radv_shader.h | 28 ++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c index 40fbc9a..d8ac494 100644 --- a/src/amd/vulkan/radv_cmd_buffer.c +++ b/src/amd/vulkan/radv_cmd_buffer.c @@ -10132,7 +10132,8 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, const VkTraceRaysIndirectCom const struct radv_userdata_info *shader_loc = radv_get_user_sgpr(rt_prolog, AC_UD_CS_TRAVERSAL_SHADER_ADDR); if (shader_loc->sgpr_idx != -1) { - uint64_t traversal_va = cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION]->va; + uint64_t traversal_va = + cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION]->va | radv_rt_priority_traversal; radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs, base_reg + shader_loc->sgpr_idx * 4, traversal_va, true); } diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 25a6b18..745ef8a 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -634,7 +634,8 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, struct radv_shader *shader = container_of(pipeline->stages[pipeline->groups[i].recursive_shader].shader, struct radv_shader, base); - pipeline->groups[i].handle.recursive_shader_ptr = shader->va; + pipeline->groups[i].handle.recursive_shader_ptr = + shader->va | radv_get_rt_priority(shader->info.stage); } } diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 3d67f91..8bc0e10 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -311,6 +311,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca uint32_t size = align(nir_intrinsic_stack_size(intr), 16); nir_ssa_def *ret_ptr = nir_load_resume_shader_address_amd(&b_shader, nir_intrinsic_call_idx(intr)); + ret_ptr = nir_ior_imm(&b_shader, ret_ptr, radv_get_rt_priority(shader->info.stage)); nir_store_var( &b_shader, vars->stack_ptr, @@ -333,6 +334,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca uint32_t size = align(nir_intrinsic_stack_size(intr), 16); nir_ssa_def *ret_ptr = nir_load_resume_shader_address_amd(&b_shader, nir_intrinsic_call_idx(intr)); + ret_ptr = nir_ior_imm(&b_shader, ret_ptr, radv_get_rt_priority(shader->info.stage)); nir_store_var( &b_shader, vars->stack_ptr, @@ -1712,6 +1714,46 @@ create_rt_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *p return b.shader; } +/** Select the next shader based on priorities: + * + * Detect the priority of the shader stage by the lowest bits in the address (low to high): + * - Raygen - idx 0 + * - Traversal - idx 1 + * - Closest Hit / Miss - idx 2 + * - Callable - idx 3 + * + * + * This gives us the following priorities: + * Raygen : Callable > > Traversal > Raygen + * Traversal : > Chit / Miss > > Raygen + * CHit / Miss : Callable > Chit / Miss > Traversal > Raygen + * Callable : Callable > Chit / Miss > > Raygen + */ +static nir_ssa_def * +select_next_shader(nir_builder *b, nir_ssa_def *shader_va, unsigned wave_size) +{ + gl_shader_stage stage = b->shader->info.stage; + nir_ssa_def *prio = nir_iand_imm(b, shader_va, radv_rt_priority_mask); + nir_ssa_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true)); + nir_ssa_def *ballot_traversal = + nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal)); + nir_ssa_def *ballot_hit_miss = + nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss)); + nir_ssa_def *ballot_callable = + nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable)); + + if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION) + ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot); + if (stage != MESA_SHADER_RAYGEN) + ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot); + if (stage != MESA_SHADER_INTERSECTION) + ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot); + + nir_ssa_def *lsb = nir_find_lsb(b, ballot); + nir_ssa_def *next = nir_read_invocation(b, shader_va, lsb); + return nir_iand_imm(b, next, ~radv_rt_priority_mask); +} + void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_shader_args *args, const struct radv_pipeline_key *key, @@ -1773,6 +1815,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH /* guard the shader, so that only the correct invocations execute it */ nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc); shader_pc = nir_pack_64_2x32(&b, shader_pc); + shader_pc = nir_ior_imm(&b, shader_pc, radv_get_rt_priority(shader->info.stage)); nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va); nir_if *shader_guard = nir_push_if(&b, cond); shader_guard->control = nir_selection_control_divergent_always_taken; @@ -1780,10 +1823,9 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH nir_pop_if(&b, shader_guard); /* select next shader */ - // TODO: use a priority-based selection b.cursor = nir_after_cf_list(&impl->body); shader_va = nir_load_var(&b, vars.shader_va); - nir_ssa_def *next = nir_read_first_invocation(&b, shader_va); + nir_ssa_def *next = select_next_shader(&b, shader_va, key->cs.compute_subgroup_size); ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next); /* store back all variables to registers */ diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index 28ea75c..0fc89fa 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -794,4 +794,32 @@ nir_shader *radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_key *key); + +enum radv_rt_priority { + radv_rt_priority_raygen = 0, + radv_rt_priority_traversal = 1, + radv_rt_priority_hit_miss = 2, + radv_rt_priority_callable = 3, + radv_rt_priority_mask = 0x3, +}; + +static inline enum radv_rt_priority +radv_get_rt_priority(gl_shader_stage stage) +{ + switch (stage) { + case MESA_SHADER_RAYGEN: + return radv_rt_priority_raygen; + case MESA_SHADER_INTERSECTION: + case MESA_SHADER_ANY_HIT: + return radv_rt_priority_traversal; + case MESA_SHADER_CLOSEST_HIT: + case MESA_SHADER_MISS: + return radv_rt_priority_hit_miss; + case MESA_SHADER_CALLABLE: + return radv_rt_priority_callable; + default: + unreachable("Unimplemented RT shader stage."); + } +} + #endif -- 2.7.4