radv: Fix stack size calculation with stage ids
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Mon, 8 Aug 2022 10:37:42 +0000 (12:37 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 11 Aug 2022 17:59:47 +0000 (17:59 +0000)
In create_rt_shader, we were setting group_idx to the stage index before.

Fixes the following tests:

dEQP-VK.ray_query.builtin.instancecustomindex.miss.aabbs
dEQP-VK.ray_query.builtin.objectrayorigin.miss.triangles

Fixes: c39ccce ("radv/rt: use stage ID as handle for general and closestHit shaders")
Signed-off-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17936>

src/amd/vulkan/radv_pipeline_rt.c

index 792b2ee..2f83367 100644 (file)
@@ -185,6 +185,8 @@ fail:
  * Global variables for an RT pipeline
  */
 struct rt_variables {
+   const VkRayTracingPipelineCreateInfoKHR *create_info;
+
    /* idx of the next shader to run in the next iteration of the main loop.
     * During traversal, idx is used to store the SBT index and will contain
     * the correct resume index upon returning.
@@ -232,14 +234,31 @@ struct rt_variables {
 
    /* Array of stack size struct for recording the max stack size for each group. */
    struct radv_pipeline_shader_stack_size *stack_sizes;
-   unsigned group_idx;
+   unsigned stage_idx;
 };
 
+static void
+reserve_stack_size(struct rt_variables *vars, uint32_t size)
+{
+   for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
+      const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx;
+
+      if (vars->stage_idx == group->generalShader || vars->stage_idx == group->closestHitShader)
+         vars->stack_sizes[group_idx].recursive_size =
+            MAX2(vars->stack_sizes[group_idx].recursive_size, size);
+
+      if (vars->stage_idx == group->anyHitShader || vars->stage_idx == group->intersectionShader)
+         vars->stack_sizes[group_idx].non_recursive_size =
+            MAX2(vars->stack_sizes[group_idx].non_recursive_size, size);
+   }
+}
+
 static struct rt_variables
-create_rt_variables(nir_shader *shader, struct radv_pipeline_shader_stack_size *stack_sizes)
+create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
+                    struct radv_pipeline_shader_stack_size *stack_sizes)
 {
    struct rt_variables vars = {
-      NULL,
+      .create_info = create_info,
    };
    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
@@ -294,6 +313,8 @@ static void
 map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
                  const struct rt_variables *dst)
 {
+   src->create_info = dst->create_info;
+
    _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
@@ -321,7 +342,7 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
    _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
 
    src->stack_sizes = dst->stack_sizes;
-   src->group_idx = dst->group_idx;
+   src->stage_idx = dst->stage_idx;
 }
 
 /*
@@ -440,8 +461,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                nir_store_var(&b_shader, vars->arg,
                              nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1);
 
-               vars->stack_sizes[vars->group_idx].recursive_size =
-                  MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16);
+               reserve_stack_size(vars, size + 16);
                break;
             }
             case nir_intrinsic_rt_trace_ray: {
@@ -462,8 +482,7 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
                nir_store_var(&b_shader, vars->arg,
                              nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1);
 
-               vars->stack_sizes[vars->group_idx].recursive_size =
-                  MAX2(vars->stack_sizes[vars->group_idx].recursive_size, size + 16);
+               reserve_stack_size(vars, size + 16);
 
                /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
                nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
@@ -708,14 +727,14 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
 }
 
 static void
-insert_rt_case(nir_builder *b, nir_shader *shader, const struct rt_variables *vars,
-               nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx)
+insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
+               uint32_t call_idx_base, uint32_t call_idx)
 {
    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
 
    nir_opt_dead_cf(shader);
 
-   struct rt_variables src_vars = create_rt_variables(shader, vars->stack_sizes);
+   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes);
    map_rt_variables(var_remap, &src_vars, vars);
 
    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
@@ -724,14 +743,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, const struct rt_variables *va
    NIR_PASS(_, shader, nir_lower_returns);
    NIR_PASS(_, shader, nir_opt_dce);
 
-   if (b->shader->info.stage == MESA_SHADER_ANY_HIT ||
-       b->shader->info.stage == MESA_SHADER_INTERSECTION) {
-      src_vars.stack_sizes[src_vars.group_idx].non_recursive_size =
-         MAX2(src_vars.stack_sizes[src_vars.group_idx].non_recursive_size, shader->scratch_size);
-   } else {
-      src_vars.stack_sizes[src_vars.group_idx].recursive_size =
-         MAX2(src_vars.stack_sizes[src_vars.group_idx].recursive_size, shader->scratch_size);
-   }
+   reserve_stack_size(vars, shader->scratch_size);
 
    nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
    nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
@@ -1095,7 +1107,7 @@ visit_any_hit_shaders(struct radv_device *device,
       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
       nir_shader *nir_stage = parse_rt_stage(device, stage);
 
-      vars->group_idx = i;
+      vars->stage_idx = shader_id;
       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
    }
    nir_pop_if(b, NULL);
@@ -1304,7 +1316,7 @@ insert_traversal_aabb_case(struct radv_device *device,
             ralloc_free(any_hit_stage);
          }
 
-         inner_vars.group_idx = i;
+         inner_vars.stage_idx = shader_id;
          insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
       }
       nir_push_else(b, NULL);
@@ -1379,7 +1391,7 @@ build_traversal_shader(struct radv_device *device,
    b.shader->info.internal = false;
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
-   struct rt_variables vars = create_rt_variables(b.shader, dst_vars->stack_sizes);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, dst_vars->stack_sizes);
    map_rt_variables(var_remap, &vars, dst_vars);
 
    unsigned lanes = device->physical_device->rt_wave_size;
@@ -1735,7 +1747,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
 
-   struct rt_variables vars = create_rt_variables(b.shader, stack_sizes);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0);
    nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
 
@@ -1777,7 +1789,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
       nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders,
                              &num_resume_shaders, nir_stage);
 
-      vars.group_idx = i;
+      vars.stage_idx = i;
       insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
       for (unsigned j = 0; j < num_resume_shaders; ++j) {
          insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);