radv/rt: replace uses of pGroups with radv_ray_tracing_group
authorDaniel Schürmann <daniel@schuermann.dev>
Tue, 25 Apr 2023 11:37:29 +0000 (13:37 +0200)
committerMarge Bot <emma+marge@anholt.net>
Wed, 26 Apr 2023 02:48:29 +0000 (02:48 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22686>

src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_rt_shader.c

index f9d1327..fa8c138 100644 (file)
@@ -475,21 +475,9 @@ compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
       non_recursive_size = MAX2(groups[i].stack_size.non_recursive_size, non_recursive_size);
 
-      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
-      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
+      uint32_t shader_id = groups[i].recursive_shader;
       unsigned size = groups[i].stack_size.recursive_size;
 
-      switch (group_info->type) {
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
-         shader_id = group_info->generalShader;
-         break;
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
-         shader_id = group_info->closestHitShader;
-         break;
-      default:
-         break;
-      }
       if (shader_id == VK_SHADER_UNUSED_KHR)
          continue;
 
index 1e553bd..40b3339 100644 (file)
@@ -852,15 +852,15 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
 
    /* reserve stack sizes */
    for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
-      const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx;
+      struct radv_ray_tracing_group *group = groups + group_idx;
 
-      if (stage_idx == group->generalShader || stage_idx == group->closestHitShader)
-         groups[group_idx].stack_size.recursive_size =
-            MAX2(groups[group_idx].stack_size.recursive_size, src_vars.stack_size);
+      if (stage_idx == group->recursive_shader)
+         group->stack_size.recursive_size =
+            MAX2(group->stack_size.recursive_size, src_vars.stack_size);
 
-      if (stage_idx == group->anyHitShader || stage_idx == group->intersectionShader)
-         groups[group_idx].stack_size.non_recursive_size =
-            MAX2(groups[group_idx].stack_size.non_recursive_size, src_vars.stack_size);
+      if (stage_idx == group->any_hit_shader || stage_idx == group->intersection_shader)
+         group->stack_size.non_recursive_size =
+            MAX2(group->stack_size.non_recursive_size, src_vars.stack_size);
    }
 }
 
@@ -1204,12 +1204,12 @@ visit_any_hit_shaders(struct radv_device *device,
       nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
 
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
-      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
+      struct radv_ray_tracing_group *group = &data->groups[i];
       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
 
-      switch (group_info->type) {
+      switch (group->type) {
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
-         shader_id = group_info->anyHitShader;
+         shader_id = group->any_hit_shader;
          break;
       default:
          break;
@@ -1339,14 +1339,14 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
       nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
 
    for (unsigned i = 0; i < data->createInfo->groupCount; ++i) {
-      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &data->createInfo->pGroups[i];
+      struct radv_ray_tracing_group *group = &data->groups[i];
       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
       uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
 
-      switch (group_info->type) {
+      switch (group->type) {
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
-         shader_id = group_info->intersectionShader;
-         any_hit_shader_id = group_info->anyHitShader;
+         shader_id = group->intersection_shader;
+         any_hit_shader_id = group->any_hit_shader;
          break;
       default:
          break;
@@ -1643,11 +1643,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
 
    unsigned call_idx_base = 1;
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
-      unsigned stage_idx = VK_SHADER_UNUSED_KHR;
-      if (pCreateInfo->pGroups[i].type == VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
-         stage_idx = pCreateInfo->pGroups[i].generalShader;
-      else
-         stage_idx = pCreateInfo->pGroups[i].closestHitShader;
+      unsigned stage_idx = groups[i].recursive_shader;
 
       if (stage_idx == VK_SHADER_UNUSED_KHR)
          continue;