radv/rt: Propagate radv_pipeline_key
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Sat, 10 Dec 2022 11:36:13 +0000 (12:36 +0100)
committerMarge Bot <emma+marge@anholt.net>
Mon, 12 Dec 2022 18:18:32 +0000 (18:18 +0000)
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20243>

src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_rt_shader.c
src/amd/vulkan/radv_shader.c
src/amd/vulkan/radv_shader.h

index 68a8212..f9c8cbf 100644 (file)
@@ -321,7 +321,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
          goto pipeline_fail;
       }
 
-      shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes);
+      shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes, &key);
       module.nir = shader;
       result = radv_create_shaders(&rt_pipeline->base.base, pipeline_layout, device, cache, &key,
                                    &stage, 1, pCreateInfo->flags, hash, creation_feedback,
index 52c5833..62b9c70 100644 (file)
@@ -82,6 +82,7 @@ lower_rt_derefs(nir_shader *shader)
  */
 struct rt_variables {
    const VkRayTracingPipelineCreateInfoKHR *create_info;
+   const struct radv_pipeline_key *key;
 
    /* idx of the next shader to run in the next iteration of the main loop.
     * During traversal, idx is used to store the SBT index and will contain
@@ -143,10 +144,12 @@ reserve_stack_size(struct rt_variables *vars, uint32_t size)
 
 static struct rt_variables
 create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
-                    struct radv_pipeline_shader_stack_size *stack_sizes)
+                    struct radv_pipeline_shader_stack_size *stack_sizes,
+                    const struct radv_pipeline_key *key)
 {
    struct rt_variables vars = {
       .create_info = create_info,
+      .key = key,
    };
    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
@@ -754,7 +757,8 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
 
    nir_opt_dead_cf(shader);
 
-   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes);
+   struct rt_variables src_vars =
+      create_rt_variables(shader, vars->create_info, vars->stack_sizes, vars->key);
    map_rt_variables(var_remap, &src_vars, vars);
 
    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
@@ -776,16 +780,14 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
 }
 
 static nir_shader *
-parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo)
+parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo,
+               const struct radv_pipeline_key *key)
 {
-   struct radv_pipeline_key key;
-   memset(&key, 0, sizeof(key));
-
    struct radv_pipeline_stage rt_stage;
 
    radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage));
 
-   nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key);
+   nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key);
 
    if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
        shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) {
@@ -1087,7 +1089,7 @@ visit_any_hit_shaders(struct radv_device *device,
          continue;
 
       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
-      nir_shader *nir_stage = parse_rt_stage(device, stage);
+      nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key);
 
       vars->stage_idx = shader_id;
       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
@@ -1217,12 +1219,12 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
          continue;
 
       const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id];
-      nir_shader *nir_stage = parse_rt_stage(data->device, stage);
+      nir_shader *nir_stage = parse_rt_stage(data->device, stage, data->vars->key);
 
       nir_shader *any_hit_stage = NULL;
       if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
          stage = &data->createInfo->pStages[any_hit_shader_id];
-         any_hit_stage = parse_rt_stage(data->device, stage);
+         any_hit_stage = parse_rt_stage(data->device, stage, data->vars->key);
 
          nir_lower_intersection_shader(nir_stage, any_hit_stage);
          ralloc_free(any_hit_stage);
@@ -1272,7 +1274,8 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave
 static nir_shader *
 build_traversal_shader(struct radv_device *device,
                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                       struct radv_pipeline_shader_stack_size *stack_sizes)
+                       struct radv_pipeline_shader_stack_size *stack_sizes,
+                       const struct radv_pipeline_key *key)
 {
    /* Create the traversal shader as an intersection shader to prevent validation failures due to
     * invalid variable modes.*/
@@ -1282,7 +1285,7 @@ build_traversal_shader(struct radv_device *device,
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
    b.shader->info.shared_size =
       device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
-   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes, key);
 
    nir_variable *barycentrics = nir_variable_create(
       b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
@@ -1552,17 +1555,15 @@ lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs)
 
 nir_shader *
 create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                 struct radv_pipeline_shader_stack_size *stack_sizes)
+                 struct radv_pipeline_shader_stack_size *stack_sizes,
+                 const struct radv_pipeline_key *key)
 {
-   struct radv_pipeline_key key;
-   memset(&key, 0, sizeof(key));
-
    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
    b.shader->info.internal = false;
    b.shader->info.workgroup_size[0] = 8;
    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
 
-   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes, key);
    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, SBT_GENERAL_IDX);
    if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
       nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1);
@@ -1583,7 +1584,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
 
    /* Insert traversal shader */
-   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes);
+   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, key);
    assert(b.shader->info.shared_size == 0);
    b.shader->info.shared_size = traversal->info.shared_size;
    assert(b.shader->info.shared_size <= 32768);
@@ -1600,7 +1601,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
           type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
          continue;
 
-      nir_shader *nir_stage = parse_rt_stage(device, stage);
+      nir_shader *nir_stage = parse_rt_stage(device, stage, key);
 
       /* Move ray tracing system values to the top that are set by rt_trace_ray
        * to prevent them from being overwritten by other rt_trace_ray calls.
index 7a1e0e6..6ea3f36 100644 (file)
@@ -723,7 +723,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_pipeline_
       /* Only compute shaders currently support requiring a
        * specific subgroup size.
        */
-      assert(stage->stage == MESA_SHADER_COMPUTE);
+      assert(stage->stage >= MESA_SHADER_COMPUTE);
       subgroup_size = key->cs.compute_subgroup_size;
       ballot_bit_size = key->cs.compute_subgroup_size;
    }
index 1723ba4..062fe17 100644 (file)
@@ -754,6 +754,7 @@ bool radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage
 
 nir_shader *create_rt_shader(struct radv_device *device,
                              const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                             struct radv_pipeline_shader_stack_size *stack_sizes);
+                             struct radv_pipeline_shader_stack_size *stack_sizes,
+                             const struct radv_pipeline_key *key);
 
 #endif