radv/rt: use priorities to select the next shader
authorDaniel Schürmann <daniel@schuermann.dev>
Wed, 22 Mar 2023 01:12:35 +0000 (02:12 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 8 Jun 2023 00:37:03 +0000 (00:37 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>

src/amd/vulkan/radv_cmd_buffer.c
src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_rt_shader.c
src/amd/vulkan/radv_shader.h

index 40fbc9a..d8ac494 100644 (file)
@@ -10132,7 +10132,8 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, const VkTraceRaysIndirectCom
    const struct radv_userdata_info *shader_loc =
       radv_get_user_sgpr(rt_prolog, AC_UD_CS_TRAVERSAL_SHADER_ADDR);
    if (shader_loc->sgpr_idx != -1) {
-      uint64_t traversal_va = cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION]->va;
+      uint64_t traversal_va =
+         cmd_buffer->state.shaders[MESA_SHADER_INTERSECTION]->va | radv_rt_priority_traversal;
       radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs,
                                base_reg + shader_loc->sgpr_idx * 4, traversal_va, true);
    }
index 25a6b18..745ef8a 100644 (file)
@@ -634,7 +634,8 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
          struct radv_shader *shader =
             container_of(pipeline->stages[pipeline->groups[i].recursive_shader].shader,
                          struct radv_shader, base);
-         pipeline->groups[i].handle.recursive_shader_ptr = shader->va;
+         pipeline->groups[i].handle.recursive_shader_ptr =
+            shader->va | radv_get_rt_priority(shader->info.stage);
       }
    }
 
index 3d67f91..8bc0e10 100644 (file)
@@ -311,6 +311,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
                nir_ssa_def *ret_ptr =
                   nir_load_resume_shader_address_amd(&b_shader, nir_intrinsic_call_idx(intr));
+               ret_ptr = nir_ior_imm(&b_shader, ret_ptr, radv_get_rt_priority(shader->info.stage));
 
                nir_store_var(
                   &b_shader, vars->stack_ptr,
@@ -333,6 +334,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
                nir_ssa_def *ret_ptr =
                   nir_load_resume_shader_address_amd(&b_shader, nir_intrinsic_call_idx(intr));
+               ret_ptr = nir_ior_imm(&b_shader, ret_ptr, radv_get_rt_priority(shader->info.stage));
 
                nir_store_var(
                   &b_shader, vars->stack_ptr,
@@ -1712,6 +1714,46 @@ create_rt_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *p
    return b.shader;
 }
 
+/** Select the next shader based on priorities:
+ *
+ * Detect the priority of the shader stage by the lowest bits in the address (low to high):
+ *  - Raygen              - idx 0
+ *  - Traversal           - idx 1
+ *  - Closest Hit / Miss  - idx 2
+ *  - Callable            - idx 3
+ *
+ *
+ * This gives us the following priorities:
+ * Raygen       :  Callable  >               >  Traversal  >  Raygen
+ * Traversal    :            >  Chit / Miss  >             >  Raygen
+ * CHit / Miss  :  Callable  >  Chit / Miss  >  Traversal  >  Raygen
+ * Callable     :  Callable  >  Chit / Miss  >             >  Raygen
+ */
+static nir_ssa_def *
+select_next_shader(nir_builder *b, nir_ssa_def *shader_va, unsigned wave_size)
+{
+   gl_shader_stage stage = b->shader->info.stage;
+   nir_ssa_def *prio = nir_iand_imm(b, shader_va, radv_rt_priority_mask);
+   nir_ssa_def *ballot = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true));
+   nir_ssa_def *ballot_traversal =
+      nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_traversal));
+   nir_ssa_def *ballot_hit_miss =
+      nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_hit_miss));
+   nir_ssa_def *ballot_callable =
+      nir_ballot(b, 1, wave_size, nir_ieq_imm(b, prio, radv_rt_priority_callable));
+
+   if (stage != MESA_SHADER_CALLABLE && stage != MESA_SHADER_INTERSECTION)
+      ballot = nir_bcsel(b, nir_ine_imm(b, ballot_traversal, 0), ballot_traversal, ballot);
+   if (stage != MESA_SHADER_RAYGEN)
+      ballot = nir_bcsel(b, nir_ine_imm(b, ballot_hit_miss, 0), ballot_hit_miss, ballot);
+   if (stage != MESA_SHADER_INTERSECTION)
+      ballot = nir_bcsel(b, nir_ine_imm(b, ballot_callable, 0), ballot_callable, ballot);
+
+   nir_ssa_def *lsb = nir_find_lsb(b, ballot);
+   nir_ssa_def *next = nir_read_invocation(b, shader_va, lsb);
+   return nir_iand_imm(b, next, ~radv_rt_priority_mask);
+}
+
 void
 radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                       const struct radv_shader_args *args, const struct radv_pipeline_key *key,
@@ -1773,6 +1815,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
    /* guard the shader, so that only the correct invocations execute it */
    nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc);
    shader_pc = nir_pack_64_2x32(&b, shader_pc);
+   shader_pc = nir_ior_imm(&b, shader_pc, radv_get_rt_priority(shader->info.stage));
    nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va);
    nir_if *shader_guard = nir_push_if(&b, cond);
    shader_guard->control = nir_selection_control_divergent_always_taken;
@@ -1780,10 +1823,9 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
    nir_pop_if(&b, shader_guard);
 
    /* select next shader */
-   // TODO: use a priority-based selection
    b.cursor = nir_after_cf_list(&impl->body);
    shader_va = nir_load_var(&b, vars.shader_va);
-   nir_ssa_def *next = nir_read_first_invocation(&b, shader_va);
+   nir_ssa_def *next = select_next_shader(&b, shader_va, key->cs.compute_subgroup_size);
    ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next);
 
    /* store back all variables to registers */
index 28ea75c..0fc89fa 100644 (file)
@@ -794,4 +794,32 @@ nir_shader *radv_build_traversal_shader(struct radv_device *device,
                                         struct radv_ray_tracing_pipeline *pipeline,
                                         const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                                         const struct radv_pipeline_key *key);
+
+enum radv_rt_priority {
+   radv_rt_priority_raygen = 0,
+   radv_rt_priority_traversal = 1,
+   radv_rt_priority_hit_miss = 2,
+   radv_rt_priority_callable = 3,
+   radv_rt_priority_mask = 0x3,
+};
+
+static inline enum radv_rt_priority
+radv_get_rt_priority(gl_shader_stage stage)
+{
+   switch (stage) {
+   case MESA_SHADER_RAYGEN:
+      return radv_rt_priority_raygen;
+   case MESA_SHADER_INTERSECTION:
+   case MESA_SHADER_ANY_HIT:
+      return radv_rt_priority_traversal;
+   case MESA_SHADER_CLOSEST_HIT:
+   case MESA_SHADER_MISS:
+      return radv_rt_priority_hit_miss;
+   case MESA_SHADER_CALLABLE:
+      return radv_rt_priority_callable;
+   default:
+      unreachable("Unimplemented RT shader stage.");
+   }
+}
+
 #endif