radv/rt: pass radv_ray_tracing_pipeline to RT shader creation
authorDaniel Schürmann <daniel@schuermann.dev>
Thu, 1 Jun 2023 10:20:31 +0000 (12:20 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 8 Jun 2023 00:37:03 +0000 (00:37 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>

src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_private.h
src/amd/vulkan/radv_rt_shader.c
src/amd/vulkan/radv_shader.h

index a1288f7..97a2a40 100644 (file)
@@ -467,11 +467,7 @@ radv_rt_pipeline_compile(struct radv_ray_tracing_pipeline *pipeline,
    if (result != VK_SUCCESS)
       return result;
 
-   VkRayTracingPipelineCreateInfoKHR local_create_info =
-      radv_create_merged_rt_create_info(pCreateInfo);
-
-   rt_stage.internal_nir = create_rt_shader(device, &local_create_info, pipeline->stages,
-                                            pipeline->groups, pipeline_key);
+   rt_stage.internal_nir = create_rt_shader(device, pipeline, pCreateInfo, pipeline_key);
 
    /* Compile SPIR-V shader to NIR. */
    rt_stage.nir =
index d8db5cc..54e087d 100644 (file)
@@ -2144,6 +2144,7 @@ struct radv_event {
 #define RADV_HASH_SHADER_NGG_STREAMOUT         (1 << 20)
 
 struct radv_pipeline_key;
+struct radv_ray_tracing_group;
 
 void radv_pipeline_stage_init(const VkPipelineShaderStageCreateInfo *sinfo,
                               struct radv_pipeline_stage *out_stage, gl_shader_stage stage);
index 7041dc6..e54869c 100644 (file)
@@ -1186,28 +1186,25 @@ init_traversal_vars(nir_builder *b)
 
 struct traversal_data {
    struct radv_device *device;
-   const VkRayTracingPipelineCreateInfoKHR *createInfo;
    struct rt_variables *vars;
    struct rt_traversal_vars *trav_vars;
    nir_variable *barycentrics;
 
-   struct radv_ray_tracing_group *groups;
-   struct radv_ray_tracing_stage *stages;
+   struct radv_ray_tracing_pipeline *pipeline;
    const struct radv_pipeline_key *key;
 };
 
 static void
-visit_any_hit_shaders(struct radv_device *device,
-                      const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
-                      struct traversal_data *data, struct rt_variables *vars)
+visit_any_hit_shaders(struct radv_device *device, nir_builder *b, struct traversal_data *data,
+                      struct rt_variables *vars)
 {
    nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
 
    if (!(vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
       nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
 
-   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
-      struct radv_ray_tracing_group *group = &data->groups[i];
+   for (unsigned i = 0; i < data->pipeline->group_count; ++i) {
+      struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
 
       switch (group->type) {
@@ -1223,18 +1220,19 @@ visit_any_hit_shaders(struct radv_device *device,
       /* Avoid emitting stages with the same shaders/handles multiple times. */
       bool is_dup = false;
       for (unsigned j = 0; j < i; ++j)
-         if (data->groups[j].handle.any_hit_index == data->groups[i].handle.any_hit_index)
+         if (data->pipeline->groups[j].handle.any_hit_index ==
+             data->pipeline->groups[i].handle.any_hit_index)
             is_dup = true;
 
       if (is_dup)
          continue;
 
       nir_shader *nir_stage =
-         radv_pipeline_cache_handle_to_nir(device, data->stages[shader_id].shader);
+         radv_pipeline_cache_handle_to_nir(device, data->pipeline->stages[shader_id].shader);
       assert(nir_stage);
 
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index,
-                     shader_id, data->stages);
+      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->pipeline->groups[i].handle.any_hit_index,
+                     shader_id, data->pipeline->stages);
       ralloc_free(nir_stage);
    }
 
@@ -1279,7 +1277,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
 
       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
 
-      visit_any_hit_shaders(data->device, data->createInfo, b, args->data, &inner_vars);
+      visit_any_hit_shaders(data->device, b, args->data, &inner_vars);
 
       nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
       {
@@ -1341,8 +1339,8 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
    if (!(data->vars->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR))
       nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
 
-   for (unsigned i = 0; i < data->createInfo->groupCount; ++i) {
-      struct radv_ray_tracing_group *group = &data->groups[i];
+   for (unsigned i = 0; i < data->pipeline->group_count; ++i) {
+      struct radv_ray_tracing_group *group = &data->pipeline->groups[i];
       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
       uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
 
@@ -1360,31 +1358,33 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
       /* Avoid emitting stages with the same shaders/handles multiple times. */
       bool is_dup = false;
       for (unsigned j = 0; j < i; ++j)
-         if (data->groups[j].handle.intersection_index == data->groups[i].handle.intersection_index)
+         if (data->pipeline->groups[j].handle.intersection_index ==
+             data->pipeline->groups[i].handle.intersection_index)
             is_dup = true;
 
       if (is_dup)
          continue;
 
       nir_shader *nir_stage =
-         radv_pipeline_cache_handle_to_nir(data->device, data->stages[shader_id].shader);
+         radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[shader_id].shader);
       assert(nir_stage);
 
       nir_shader *any_hit_stage = NULL;
       if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
-         any_hit_stage =
-            radv_pipeline_cache_handle_to_nir(data->device, data->stages[any_hit_shader_id].shader);
+         any_hit_stage = radv_pipeline_cache_handle_to_nir(
+            data->device, data->pipeline->stages[any_hit_shader_id].shader);
          assert(any_hit_stage);
 
          /* reserve stack size for any_hit before it is inlined */
-         data->stages[any_hit_shader_id].stack_size = any_hit_stage->scratch_size;
+         data->pipeline->stages[any_hit_shader_id].stack_size = any_hit_stage->scratch_size;
 
          nir_lower_intersection_shader(nir_stage, any_hit_stage);
          ralloc_free(any_hit_stage);
       }
 
       insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0,
-                     data->groups[i].handle.intersection_index, shader_id, data->stages);
+                     data->pipeline->groups[i].handle.intersection_index, shader_id,
+                     data->pipeline->stages);
       ralloc_free(nir_stage);
    }
 
@@ -1428,9 +1428,9 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave
 }
 
 static nir_shader *
-build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_stage *stages,
-                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                       struct radv_ray_tracing_group *groups, const struct radv_pipeline_key *key)
+radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
+                            const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                            const struct radv_pipeline_key *key)
 {
    /* Create the traversal shader as an intersection shader to prevent validation failures due to
     * invalid variable modes.*/
@@ -1517,12 +1517,10 @@ build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_stage
 
    struct traversal_data data = {
       .device = device,
-      .createInfo = pCreateInfo,
       .vars = &vars,
       .trav_vars = &trav_vars,
       .barycentrics = barycentrics,
-      .groups = groups,
-      .stages = stages,
+      .pipeline = pipeline,
       .key = key,
    };
 
@@ -1626,8 +1624,8 @@ move_rt_instructions(nir_shader *shader)
 }
 
 nir_shader *
-create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                 struct radv_ray_tracing_stage *stages, struct radv_ray_tracing_group *groups,
+create_rt_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
+                 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                  const struct radv_pipeline_key *key)
 {
    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, "rt_combined");
@@ -1644,12 +1642,14 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
 
    /* Insert traversal shader */
-   nir_shader *traversal = build_traversal_shader(device, stages, pCreateInfo, groups, key);
+   nir_shader *traversal = radv_build_traversal_shader(device, pipeline, pCreateInfo, key);
    b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size);
    assert(b.shader->info.shared_size <= 32768);
    insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, NULL);
    ralloc_free(traversal);
 
+   struct radv_ray_tracing_group *groups = pipeline->groups;
+   struct radv_ray_tracing_stage *stages = pipeline->stages;
    unsigned call_idx_base = 1;
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
       unsigned stage_idx = groups[i].recursive_shader;
index 2381036..f3ced77 100644 (file)
@@ -43,8 +43,7 @@
 struct radv_physical_device;
 struct radv_device;
 struct radv_pipeline;
-struct radv_ray_tracing_stage;
-struct radv_ray_tracing_group;
+struct radv_ray_tracing_pipeline;
 struct radv_pipeline_key;
 struct radv_shader_args;
 struct radv_vs_input_state;
@@ -787,10 +786,8 @@ bool radv_consider_culling(const struct radv_physical_device *pdevice, struct ni
 
 void radv_get_nir_options(struct radv_physical_device *device);
 
-nir_shader *create_rt_shader(struct radv_device *device,
+nir_shader *create_rt_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
                              const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                             struct radv_ray_tracing_stage *stages,
-                             struct radv_ray_tracing_group *groups,
                              const struct radv_pipeline_key *key);
 
 #endif