radv/rt: move radv_pipeline_key from rt_variables to traversal_data
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 3 Mar 2023 18:49:43 +0000 (19:49 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 7 Mar 2023 17:00:50 +0000 (17:00 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21764>

src/amd/vulkan/radv_rt_shader.c

index 0a82915..bb55116 100644 (file)
@@ -82,7 +82,6 @@ lower_rt_derefs(nir_shader *shader)
  */
 struct rt_variables {
    const VkRayTracingPipelineCreateInfoKHR *create_info;
-   const struct radv_pipeline_key *key;
 
    /* idx of the next shader to run in the next iteration of the main loop.
     * During traversal, idx is used to store the SBT index and will contain
@@ -127,12 +126,10 @@ struct rt_variables {
 };
 
 static struct rt_variables
-create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
-                    const struct radv_pipeline_key *key)
+create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info)
 {
    struct rt_variables vars = {
       .create_info = create_info,
-      .key = key,
    };
    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
@@ -805,7 +802,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
 
    nir_opt_dead_cf(shader);
 
-   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->key);
+   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info);
    map_rt_variables(var_remap, &src_vars, vars);
 
    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
@@ -1138,6 +1135,7 @@ struct traversal_data {
    nir_variable *barycentrics;
 
    struct radv_ray_tracing_module *groups;
+   const struct radv_pipeline_key *key;
 };
 
 static void
@@ -1174,7 +1172,7 @@ visit_any_hit_shaders(struct radv_device *device,
          continue;
 
       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
-      nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key);
+      nir_shader *nir_stage = parse_rt_stage(device, stage, data->key);
 
       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index,
                      shader_id, data->groups);
@@ -1310,12 +1308,12 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
          continue;
 
       const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id];
-      nir_shader *nir_stage = parse_rt_stage(data->device, stage, data->vars->key);
+      nir_shader *nir_stage = parse_rt_stage(data->device, stage, data->key);
 
       nir_shader *any_hit_stage = NULL;
       if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
          stage = &data->createInfo->pStages[any_hit_shader_id];
-         any_hit_stage = parse_rt_stage(data->device, stage, data->vars->key);
+         any_hit_stage = parse_rt_stage(data->device, stage, data->key);
 
          nir_lower_intersection_shader(nir_stage, any_hit_stage);
          ralloc_free(any_hit_stage);
@@ -1378,7 +1376,7 @@ build_traversal_shader(struct radv_device *device,
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
    b.shader->info.shared_size =
       device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
-   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, key);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo);
 
    /* Register storage for hit attributes */
    nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
@@ -1462,6 +1460,7 @@ build_traversal_shader(struct radv_device *device,
       .trav_vars = &trav_vars,
       .barycentrics = barycentrics,
       .groups = groups,
+      .key = key,
    };
 
    struct radv_ray_traversal_args args = {
@@ -1573,7 +1572,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
    b.shader->info.shared_size = device->physical_device->rt_wave_size * RADV_MAX_HIT_ATTRIB_SIZE;
 
-   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, key);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo);
    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, SBT_GENERAL_IDX);
    nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1);