radv/rt: store stack_sizes per stage instead of per group
authorDaniel Schürmann <daniel@schuermann.dev>
Tue, 25 Apr 2023 17:44:17 +0000 (19:44 +0200)
committerMarge Bot <emma+marge@anholt.net>
Wed, 10 May 2023 07:02:13 +0000 (07:02 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22503>

src/amd/vulkan/radv_pipeline_cache.c
src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_private.h
src/amd/vulkan/radv_rt_shader.c

index 07486a5..a68015e 100644 (file)
@@ -251,7 +251,7 @@ radv_pipeline_cache_object_create(struct vk_device *device, unsigned num_shaders
    assert(num_stack_sizes == 0 || ps_epilog_binary_size == 0);
    const size_t size = sizeof(struct radv_pipeline_cache_object) +
                        (num_shaders * sizeof(struct radv_shader *)) + ps_epilog_binary_size +
-                       (num_stack_sizes * sizeof(struct radv_pipeline_shader_stack_size));
+                       (num_stack_sizes * sizeof(uint32_t));
 
    struct radv_pipeline_cache_object *object =
       vk_alloc(&device->alloc, size, 8, VK_SYSTEM_ALLOCATION_SCOPE_CACHE);
@@ -325,8 +325,7 @@ radv_pipeline_cache_object_deserialize(struct vk_pipeline_cache *cache, const vo
       object->shaders[i] = container_of(shader, struct radv_shader, base);
    }
 
-   const size_t data_size =
-      ps_epilog_binary_size + (num_stack_sizes * sizeof(struct radv_pipeline_shader_stack_size));
+   const size_t data_size = ps_epilog_binary_size + (num_stack_sizes * sizeof(uint32_t));
    blob_copy_bytes(blob, object->data, data_size);
 
    if (ps_epilog_binary_size) {
@@ -358,8 +357,7 @@ radv_pipeline_cache_object_serialize(struct vk_pipeline_cache_object *object, st
       blob_write_bytes(blob, pipeline_obj->shaders[i]->sha1, SHA1_DIGEST_LENGTH);
 
    const size_t data_size =
-      pipeline_obj->ps_epilog_binary_size +
-      (pipeline_obj->num_stack_sizes * sizeof(struct radv_pipeline_shader_stack_size));
+      pipeline_obj->ps_epilog_binary_size + (pipeline_obj->num_stack_sizes * sizeof(uint32_t));
    blob_write_bytes(blob, pipeline_obj->data, data_size);
 
    return true;
@@ -417,12 +415,12 @@ radv_pipeline_cache_search(struct radv_device *device, struct vk_pipeline_cache
    }
 
    if (pipeline->type == RADV_PIPELINE_RAY_TRACING) {
-      unsigned num_rt_groups = radv_pipeline_to_ray_tracing(pipeline)->group_count;
-      assert(num_rt_groups == pipeline_obj->num_stack_sizes);
-      struct radv_pipeline_shader_stack_size *stack_sizes = pipeline_obj->data;
-      struct radv_ray_tracing_group *rt_groups = radv_pipeline_to_ray_tracing(pipeline)->groups;
-      for (unsigned i = 0; i < num_rt_groups; i++)
-         rt_groups[i].stack_size = stack_sizes[i];
+      unsigned num_rt_stages = radv_pipeline_to_ray_tracing(pipeline)->stage_count;
+      assert(num_rt_stages == pipeline_obj->num_stack_sizes);
+      uint32_t *stack_sizes = pipeline_obj->data;
+      struct radv_ray_tracing_stage *rt_stages = radv_pipeline_to_ray_tracing(pipeline)->stages;
+      for (unsigned i = 0; i < num_rt_stages; i++)
+         rt_stages[i].stack_size = stack_sizes[i];
    }
 
    vk_pipeline_cache_object_unref(&device->vk, object);
@@ -448,12 +446,12 @@ radv_pipeline_cache_insert(struct radv_device *device, struct vk_pipeline_cache
    num_shaders += pipeline->gs_copy_shader ? 1 : 0;
 
    unsigned ps_epilog_binary_size = ps_epilog_binary ? ps_epilog_binary->total_size : 0;
-   unsigned num_rt_groups = 0;
+   unsigned num_rt_stages = 0;
    if (pipeline->type == RADV_PIPELINE_RAY_TRACING)
-      num_rt_groups = radv_pipeline_to_ray_tracing(pipeline)->group_count;
+      num_rt_stages = radv_pipeline_to_ray_tracing(pipeline)->stage_count;
 
    struct radv_pipeline_cache_object *pipeline_obj;
-   pipeline_obj = radv_pipeline_cache_object_create(&device->vk, num_shaders, sha1, num_rt_groups,
+   pipeline_obj = radv_pipeline_cache_object_create(&device->vk, num_shaders, sha1, num_rt_stages,
                                                     ps_epilog_binary_size);
 
    if (!pipeline_obj)
@@ -482,10 +480,10 @@ radv_pipeline_cache_insert(struct radv_device *device, struct vk_pipeline_cache
    }
 
    if (pipeline->type == RADV_PIPELINE_RAY_TRACING) {
-      struct radv_pipeline_shader_stack_size *stack_sizes = pipeline_obj->data;
-      struct radv_ray_tracing_group *rt_groups = radv_pipeline_to_ray_tracing(pipeline)->groups;
-      for (unsigned i = 0; i < num_rt_groups; i++)
-         stack_sizes[i] = rt_groups[i].stack_size;
+      uint32_t *stack_sizes = pipeline_obj->data;
+      struct radv_ray_tracing_stage *rt_stages = radv_pipeline_to_ray_tracing(pipeline)->stages;
+      for (unsigned i = 0; i < num_rt_stages; i++)
+         stack_sizes[i] = rt_stages[i].stack_size;
    }
 
    /* Add the object to the cache */
index 0519877..36c01d1 100644 (file)
@@ -393,52 +393,50 @@ radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR
    return false;
 }
 
-static unsigned
+static void
 compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                      const struct radv_ray_tracing_group *groups,
-                      const struct radv_ray_tracing_stage *stages)
+                      struct radv_ray_tracing_pipeline *pipeline)
 {
-   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
-      return -1u;
+   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) {
+      pipeline->stack_size = -1u;
+      return;
+   }
 
    unsigned raygen_size = 0;
    unsigned callable_size = 0;
-   unsigned chit_size = 0;
-   unsigned miss_size = 0;
-   unsigned non_recursive_size = 0;
-
-   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
-      non_recursive_size = MAX2(groups[i].stack_size.non_recursive_size, non_recursive_size);
+   unsigned chit_miss_size = 0;
+   unsigned intersection_size = 0;
+   unsigned any_hit_size = 0;
 
-      uint32_t shader_id = groups[i].recursive_shader;
-      unsigned size = groups[i].stack_size.recursive_size;
-
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      switch (stages[shader_id].stage) {
+   for (unsigned i = 0; i < pipeline->stage_count; ++i) {
+      uint32_t size = pipeline->stages[i].stack_size;
+      switch (pipeline->stages[i].stage) {
       case MESA_SHADER_RAYGEN:
          raygen_size = MAX2(raygen_size, size);
          break;
-      case MESA_SHADER_MISS:
-         miss_size = MAX2(miss_size, size);
-         break;
       case MESA_SHADER_CLOSEST_HIT:
-         chit_size = MAX2(chit_size, size);
+      case MESA_SHADER_MISS:
+         chit_miss_size = MAX2(chit_miss_size, size);
          break;
       case MESA_SHADER_CALLABLE:
          callable_size = MAX2(callable_size, size);
          break;
+      case MESA_SHADER_INTERSECTION:
+         intersection_size = MAX2(intersection_size, size);
+         break;
+      case MESA_SHADER_ANY_HIT:
+         any_hit_size = MAX2(any_hit_size, size);
+         break;
       default:
          unreachable("Invalid stage type in RT shader");
       }
    }
-   return raygen_size +
-          MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
-             MAX2(MAX2(chit_size, miss_size), non_recursive_size) +
-          MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) *
-             MAX2(chit_size, miss_size) +
-          2 * callable_size;
+   pipeline->stack_size =
+      raygen_size +
+      MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
+         MAX2(chit_miss_size, intersection_size + any_hit_size) +
+      MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * chit_miss_size +
+      2 * callable_size;
 }
 
 static struct radv_pipeline_key
@@ -540,7 +538,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
    if (result != VK_SUCCESS)
       goto done;
 
-   pipeline->stack_size = compute_rt_stack_size(&local_create_info, pipeline->groups, pipeline->stages);
+   compute_rt_stack_size(pCreateInfo, pipeline);
    pipeline->base.base.shaders[MESA_SHADER_COMPUTE] = radv_create_rt_prolog(device);
 
    combine_config(&pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config,
@@ -641,14 +639,18 @@ radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline,
 {
    RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
    struct radv_ray_tracing_pipeline *rt_pipeline = radv_pipeline_to_ray_tracing(pipeline);
-   const struct radv_pipeline_shader_stack_size *stack_size =
-      &rt_pipeline->groups[group].stack_size;
-
-   if (groupShader == VK_SHADER_GROUP_SHADER_ANY_HIT_KHR ||
-       groupShader == VK_SHADER_GROUP_SHADER_INTERSECTION_KHR)
-      return stack_size->non_recursive_size;
-   else
-      return stack_size->recursive_size;
+   struct radv_ray_tracing_group *rt_group = &rt_pipeline->groups[group];
+   switch (groupShader) {
+   case VK_SHADER_GROUP_SHADER_GENERAL_KHR:
+   case VK_SHADER_GROUP_SHADER_CLOSEST_HIT_KHR:
+      return rt_pipeline->stages[rt_group->recursive_shader].stack_size;
+   case VK_SHADER_GROUP_SHADER_ANY_HIT_KHR:
+      return rt_pipeline->stages[rt_group->any_hit_shader].stack_size;
+   case VK_SHADER_GROUP_SHADER_INTERSECTION_KHR:
+      return rt_pipeline->stages[rt_group->intersection_shader].stack_size;
+   default:
+      return 0;
+   }
 }
 
 VKAPI_ATTR VkResult VKAPI_CALL
index 272386e..4c407d1 100644 (file)
@@ -2190,12 +2190,6 @@ struct radv_pipeline_group_handle {
    };
 };
 
-struct radv_pipeline_shader_stack_size {
-   uint32_t recursive_size;
-   /* anyhit + intersection */
-   uint32_t non_recursive_size;
-};
-
 enum radv_depth_clamp_mode {
    RADV_DEPTH_CLAMP_MODE_VIEWPORT = 0,       /* Clamp to the viewport min/max depth bounds */
    RADV_DEPTH_CLAMP_MODE_ZERO_TO_ONE = 1,    /* Clamp between 0.0f and 1.0f */
@@ -2315,12 +2309,12 @@ struct radv_ray_tracing_group {
    uint32_t any_hit_shader;
    uint32_t intersection_shader;
    struct radv_pipeline_group_handle handle;
-   struct radv_pipeline_shader_stack_size stack_size;
 };
 
 struct radv_ray_tracing_stage {
    struct vk_pipeline_cache_object *shader;
    gl_shader_stage stage;
+   uint32_t stack_size;
 };
 
 struct radv_ray_tracing_pipeline {
index 19a5997..36f06b4 100644 (file)
@@ -818,7 +818,7 @@ inline_constants(nir_shader *dst, nir_shader *src)
 static void
 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
                uint32_t call_idx_base, uint32_t call_idx, unsigned stage_idx,
-               struct radv_ray_tracing_group *groups)
+               struct radv_ray_tracing_stage *stages)
 {
    uint32_t workgroup_size = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] *
                              b->shader->info.workgroup_size[2];
@@ -850,18 +850,9 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
 
    ralloc_free(var_remap);
 
-   /* reserve stack sizes */
-   for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
-      struct radv_ray_tracing_group *group = groups + group_idx;
-
-      if (stage_idx == group->recursive_shader)
-         group->stack_size.recursive_size =
-            MAX2(group->stack_size.recursive_size, src_vars.stack_size);
-
-      if (stage_idx == group->any_hit_shader || stage_idx == group->intersection_shader)
-         group->stack_size.non_recursive_size =
-            MAX2(group->stack_size.non_recursive_size, src_vars.stack_size);
-   }
+   /* reserve stack size */
+   if (stages)
+      stages[stage_idx].stack_size = MAX2(stages[stage_idx].stack_size, src_vars.stack_size);
 }
 
 nir_shader *
@@ -987,6 +978,9 @@ lower_any_hit_for_intersection(nir_shader *any_hit)
             /* We place all any_hit scratch variables after intersection scratch variables.
              * For that reason, we increment the scratch offset by the intersection scratch
              * size. For call_data, we have to subtract the offset again.
+             *
+             * Note that we don't increase the scratch size as it is already reflected via
+             * the any_hit stack_size.
              */
             case nir_intrinsic_load_scratch:
                b->cursor = nir_before_instr(instr);
@@ -1114,8 +1108,6 @@ nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
          nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted);
       }
    }
-   /* Any-hit scratch variables are placed after intersection scratch variables. */
-   intersection->scratch_size += any_hit->scratch_size;
    nir_metadata_preserve(impl, nir_metadata_none);
 
    /* We did some inlining; have to re-index SSA defs */
@@ -1232,7 +1224,7 @@ visit_any_hit_shaders(struct radv_device *device,
       assert(nir_stage);
 
       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index,
-                     shader_id, data->groups);
+                     shader_id, data->stages);
       ralloc_free(nir_stage);
    }
 
@@ -1375,12 +1367,15 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
             radv_pipeline_cache_handle_to_nir(data->device, data->stages[any_hit_shader_id].shader);
          assert(any_hit_stage);
 
+         /* reserve stack size for any_hit before it is inlined */
+         data->stages[any_hit_shader_id].stack_size = any_hit_stage->scratch_size;
+
          nir_lower_intersection_shader(nir_stage, any_hit_stage);
          ralloc_free(any_hit_stage);
       }
 
       insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0,
-                     data->groups[i].handle.intersection_index, shader_id, data->groups);
+                     data->groups[i].handle.intersection_index, shader_id, data->stages);
       ralloc_free(nir_stage);
    }
 
@@ -1644,7 +1639,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    nir_shader *traversal = build_traversal_shader(device, stages, pCreateInfo, groups, key);
    b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size);
    assert(b.shader->info.shared_size <= 32768);
-   insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, groups);
+   insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, NULL);
    ralloc_free(traversal);
 
    unsigned call_idx_base = 1;
@@ -1685,10 +1680,10 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
       nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, &num_resume_shaders, nir_stage);
 
       insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, groups[i].handle.general_index,
-                     stage_idx, groups);
+                     stage_idx, stages);
       for (unsigned j = 0; j < num_resume_shaders; ++j) {
          insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j,
-                        stage_idx, groups);
+                        stage_idx, stages);
       }
 
       ralloc_free(nir_stage);