radv/rt: introduce struct radv_ray_tracing_module
authorDaniel Schürmann <daniel@schuermann.dev>
Thu, 2 Mar 2023 17:03:17 +0000 (18:03 +0100)
committerMarge Bot <emma+marge@anholt.net>
Mon, 6 Mar 2023 13:58:54 +0000 (13:58 +0000)
This is preliminary work for separate shader functions.
The ray_tracing_module is eventually intended as self-contained
pipeline struct per RT group.

For now, these modules only contain the group handles.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21667>

src/amd/vulkan/radv_pipeline.c
src/amd/vulkan/radv_pipeline_cache.c
src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_private.h
src/amd/vulkan/radv_rt_shader.c
src/amd/vulkan/radv_shader.h

index 98215b4..13db7b4 100644 (file)
@@ -136,13 +136,11 @@ radv_pipeline_destroy(struct radv_device *device, struct radv_pipeline *pipeline
    } else if (pipeline->type == RADV_PIPELINE_RAY_TRACING) {
       struct radv_ray_tracing_pipeline *rt_pipeline = radv_pipeline_to_ray_tracing(pipeline);
 
-      free(rt_pipeline->group_handles);
       free(rt_pipeline->stack_sizes);
    } else if (pipeline->type == RADV_PIPELINE_LIBRARY) {
       struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline);
 
       ralloc_free(library_pipeline->ctx);
-      free(library_pipeline->group_handles);
    } else if (pipeline->type == RADV_PIPELINE_GRAPHICS_LIB) {
       struct radv_graphics_lib_pipeline *gfx_pipeline_lib =
          radv_pipeline_to_graphics_lib(pipeline);
index e3e179d..03067c2 100644 (file)
@@ -178,7 +178,7 @@ radv_hash_rt_stages(struct mesa_sha1 *ctx, const VkPipelineShaderStageCreateInfo
 void
 radv_hash_rt_shaders(unsigned char *hash, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                      const struct radv_pipeline_key *key,
-                     const struct radv_pipeline_group_handle *group_handles, uint32_t flags)
+                     const struct radv_ray_tracing_module *groups, uint32_t flags)
 {
    RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
    struct mesa_sha1 ctx;
@@ -191,9 +191,6 @@ radv_hash_rt_shaders(unsigned char *hash, const VkRayTracingPipelineCreateInfoKH
 
    radv_hash_rt_stages(&ctx, pCreateInfo->pStages, pCreateInfo->stageCount);
 
-   _mesa_sha1_update(&ctx, group_handles,
-                     sizeof(struct radv_pipeline_group_handle) * pCreateInfo->groupCount);
-
    for (uint32_t i = 0; i < pCreateInfo->groupCount; i++) {
       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].type,
                         sizeof(pCreateInfo->pGroups[i].type));
@@ -205,6 +202,7 @@ radv_hash_rt_shaders(unsigned char *hash, const VkRayTracingPipelineCreateInfoKH
                         sizeof(pCreateInfo->pGroups[i].closestHitShader));
       _mesa_sha1_update(&ctx, &pCreateInfo->pGroups[i].intersectionShader,
                         sizeof(pCreateInfo->pGroups[i].intersectionShader));
+      _mesa_sha1_update(&ctx, &groups[i].handle, sizeof(struct radv_pipeline_group_handle));
    }
 
    const uint32_t pipeline_flags =
index db840c1..3bfd136 100644 (file)
@@ -82,13 +82,8 @@ handle_from_stages(struct radv_device *device, const VkPipelineShaderStageCreate
 static VkResult
 radv_create_group_handles(struct radv_device *device,
                           const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                          struct radv_pipeline_group_handle **out_handles)
+                          struct radv_ray_tracing_module *groups)
 {
-   struct radv_pipeline_group_handle *handles = calloc(sizeof(*handles), pCreateInfo->groupCount);
-   if (!handles) {
-      return VK_ERROR_OUT_OF_HOST_MEMORY;
-   }
-
    bool capture_replay = pCreateInfo->flags &
                          VK_PIPELINE_CREATE_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR;
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
@@ -96,12 +91,12 @@ radv_create_group_handles(struct radv_device *device,
       switch (group_info->type) {
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
          if (group_info->generalShader != VK_SHADER_UNUSED_KHR)
-            handles[i].general_index = handle_from_stages(
+            groups[i].handle.general_index = handle_from_stages(
                device, &pCreateInfo->pStages[group_info->generalShader], 1, capture_replay);
          break;
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
-            handles[i].closest_hit_index = handle_from_stages(
+            groups[i].handle.closest_hit_index = handle_from_stages(
                device, &pCreateInfo->pStages[group_info->closestHitShader], 1, capture_replay);
          if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
             VkPipelineShaderStageCreateInfo stages[2];
@@ -109,15 +104,16 @@ radv_create_group_handles(struct radv_device *device,
             stages[cnt++] = pCreateInfo->pStages[group_info->intersectionShader];
             if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
                stages[cnt++] = pCreateInfo->pStages[group_info->anyHitShader];
-            handles[i].intersection_index = handle_from_stages(device, stages, cnt, capture_replay);
+            groups[i].handle.intersection_index =
+               handle_from_stages(device, stages, cnt, capture_replay);
          }
          break;
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
-            handles[i].closest_hit_index = handle_from_stages(
+            groups[i].handle.closest_hit_index = handle_from_stages(
                device, &pCreateInfo->pStages[group_info->closestHitShader], 1, capture_replay);
          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
-            handles[i].any_hit_index = handle_from_stages(
+            groups[i].handle.any_hit_index = handle_from_stages(
                device, &pCreateInfo->pStages[group_info->anyHitShader], 1, capture_replay);
          break;
       case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR:
@@ -126,15 +122,13 @@ radv_create_group_handles(struct radv_device *device,
 
       if (capture_replay) {
          if (group_info->pShaderGroupCaptureReplayHandle &&
-             memcmp(group_info->pShaderGroupCaptureReplayHandle, &handles[i], sizeof(handles[i])) !=
-                0) {
-            free(handles);
+             memcmp(group_info->pShaderGroupCaptureReplayHandle, &groups[i].handle,
+                    sizeof(groups[i].handle)) != 0) {
             return VK_ERROR_INVALID_OPAQUE_CAPTURE_ADDRESS;
          }
       }
    }
 
-   *out_handles = handles;
    return VK_SUCCESS;
 }
 
@@ -222,7 +216,9 @@ radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache,
    if (!local_create_info.pStages || !local_create_info.pGroups)
       return VK_ERROR_OUT_OF_HOST_MEMORY;
 
-   pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*pipeline), 8,
+   size_t pipeline_size =
+      sizeof(*pipeline) + local_create_info.groupCount * sizeof(struct radv_ray_tracing_module);
+   pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, pipeline_size, 8,
                          VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
    if (pipeline == NULL)
       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
@@ -231,8 +227,7 @@ radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache,
 
    pipeline->ctx = ralloc_context(NULL);
 
-   VkResult result =
-      radv_create_group_handles(device, &local_create_info, &pipeline->group_handles);
+   VkResult result = radv_create_group_handles(device, &local_create_info, pipeline->groups);
    if (result != VK_SUCCESS)
       goto fail;
 
@@ -461,7 +456,9 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
    VkPipelineCreateFlags flags =
       pCreateInfo->flags | VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT;
 
-   rt_pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*rt_pipeline), 8,
+   size_t pipeline_size =
+      sizeof(*rt_pipeline) + local_create_info.groupCount * sizeof(struct radv_ray_tracing_module);
+   rt_pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, pipeline_size, 8,
                             VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
    if (rt_pipeline == NULL) {
       result = VK_ERROR_OUT_OF_HOST_MEMORY;
@@ -471,7 +468,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
    radv_pipeline_init(device, &rt_pipeline->base.base, RADV_PIPELINE_RAY_TRACING);
    rt_pipeline->group_count = local_create_info.groupCount;
 
-   result = radv_create_group_handles(device, &local_create_info, &rt_pipeline->group_handles);
+   result = radv_create_group_handles(device, &local_create_info, rt_pipeline->groups);
    if (result != VK_SUCCESS)
       goto pipeline_fail;
 
@@ -480,7 +477,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
 
    struct radv_pipeline_key key = radv_generate_rt_pipeline_key(rt_pipeline, pCreateInfo->flags);
 
-   radv_hash_rt_shaders(hash, &local_create_info, &key, rt_pipeline->group_handles,
+   radv_hash_rt_shaders(hash, &local_create_info, &key, rt_pipeline->groups,
                         radv_get_hash_flags(device, keep_statistic_info));
 
    /* First check if we can get things from the cache before we take the expensive step of
@@ -504,7 +501,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
       }
 
       shader = create_rt_shader(device, &local_create_info, rt_pipeline->stack_sizes,
-                                rt_pipeline->group_handles, &key);
+                                rt_pipeline->groups, &key);
       module.nir = shader;
       result = radv_compute_pipeline_compile(
          &rt_pipeline->base, pipeline_layout, device, cache, &key, &stage, pCreateInfo->flags,
@@ -575,20 +572,21 @@ radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, u
                                         uint32_t groupCount, size_t dataSize, void *pData)
 {
    RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
-   struct radv_pipeline_group_handle *handles;
+   struct radv_ray_tracing_module *groups;
    if (pipeline->type == RADV_PIPELINE_LIBRARY) {
-      handles = radv_pipeline_to_library(pipeline)->group_handles;
+      groups = radv_pipeline_to_library(pipeline)->groups;
    } else {
-      handles = radv_pipeline_to_ray_tracing(pipeline)->group_handles;
+      groups = radv_pipeline_to_ray_tracing(pipeline)->groups;
    }
    char *data = pData;
 
-   STATIC_ASSERT(sizeof(*handles) <= RADV_RT_HANDLE_SIZE);
+   STATIC_ASSERT(sizeof(struct radv_pipeline_group_handle) <= RADV_RT_HANDLE_SIZE);
 
    memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE);
 
    for (uint32_t i = 0; i < groupCount; ++i) {
-      memcpy(data + i * RADV_RT_HANDLE_SIZE, &handles[firstGroup + i], sizeof(*handles));
+      memcpy(data + i * RADV_RT_HANDLE_SIZE, &groups[firstGroup + i].handle,
+             sizeof(struct radv_pipeline_group_handle));
    }
 
    return VK_SUCCESS;
index aa4abe6..77199f0 100644 (file)
@@ -1989,7 +1989,7 @@ struct radv_event {
 #define RADV_HASH_SHADER_NO_FMASK              (1 << 19)
 #define RADV_HASH_SHADER_NGG_STREAMOUT         (1 << 20)
 
-struct radv_pipeline_group_handle;
+struct radv_ray_tracing_module;
 struct radv_pipeline_key;
 
 void radv_pipeline_stage_init(const VkPipelineShaderStageCreateInfo *sinfo,
@@ -2004,7 +2004,7 @@ void radv_hash_rt_stages(struct mesa_sha1 *ctx, const VkPipelineShaderStageCreat
 
 void radv_hash_rt_shaders(unsigned char *hash, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                           const struct radv_pipeline_key *key,
-                          const struct radv_pipeline_group_handle *group_handles, uint32_t flags);
+                          const struct radv_ray_tracing_module *groups, uint32_t flags);
 
 uint32_t radv_get_hash_flags(const struct radv_device *device, bool stats);
 
@@ -2204,6 +2204,10 @@ struct radv_compute_pipeline {
    bool cs_regalloc_hang_bug;
 };
 
+struct radv_ray_tracing_module {
+   struct radv_pipeline_group_handle handle;
+};
+
 struct radv_library_pipeline {
    struct radv_pipeline base;
 
@@ -2219,7 +2223,7 @@ struct radv_library_pipeline {
       uint8_t sha1[SHA1_DIGEST_LENGTH];
    } *hashes;
 
-   struct radv_pipeline_group_handle *group_handles;
+   struct radv_ray_tracing_module groups[];
 };
 
 struct radv_graphics_lib_pipeline {
@@ -2235,10 +2239,10 @@ struct radv_graphics_lib_pipeline {
 struct radv_ray_tracing_pipeline {
    struct radv_compute_pipeline base;
 
-   struct radv_pipeline_group_handle *group_handles;
    struct radv_pipeline_shader_stack_size *stack_sizes;
    uint32_t group_count;
    uint32_t stack_size;
+   struct radv_ray_tracing_module groups[];
 };
 
 #define RADV_DECL_PIPELINE_DOWNCAST(pipe_type, pipe_enum)            \
index cada611..5192c8b 100644 (file)
@@ -1147,7 +1147,7 @@ struct traversal_data {
    struct rt_traversal_vars *trav_vars;
    nir_variable *barycentrics;
 
-   const struct radv_pipeline_group_handle *handles;
+   const struct radv_ray_tracing_module *groups;
 };
 
 static void
@@ -1177,7 +1177,7 @@ visit_any_hit_shaders(struct radv_device *device,
       /* Avoid emitting stages with the same shaders/handles multiple times. */
       bool is_dup = false;
       for (unsigned j = 0; j < i; ++j)
-         if (data->handles[j].any_hit_index == data->handles[i].any_hit_index)
+         if (data->groups[j].handle.any_hit_index == data->groups[i].handle.any_hit_index)
             is_dup = true;
 
       if (is_dup)
@@ -1187,7 +1187,7 @@ visit_any_hit_shaders(struct radv_device *device,
       nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key);
 
       vars->stage_idx = shader_id;
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->handles[i].any_hit_index);
+      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, data->groups[i].handle.any_hit_index);
    }
 
    if (!(vars->create_info->flags & VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
@@ -1313,7 +1313,7 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
       /* Avoid emitting stages with the same shaders/handles multiple times. */
       bool is_dup = false;
       for (unsigned j = 0; j < i; ++j)
-         if (data->handles[j].intersection_index == data->handles[i].intersection_index)
+         if (data->groups[j].handle.intersection_index == data->groups[i].handle.intersection_index)
             is_dup = true;
 
       if (is_dup)
@@ -1333,7 +1333,7 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
 
       inner_vars.stage_idx = shader_id;
       insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0,
-                     data->handles[i].intersection_index);
+                     data->groups[i].handle.intersection_index);
    }
 
    if (!(data->vars->create_info->flags &
@@ -1380,7 +1380,7 @@ static nir_shader *
 build_traversal_shader(struct radv_device *device,
                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                        struct radv_pipeline_shader_stack_size *stack_sizes,
-                       const struct radv_pipeline_group_handle *handles,
+                       const struct radv_ray_tracing_module *groups,
                        const struct radv_pipeline_key *key)
 {
    /* Create the traversal shader as an intersection shader to prevent validation failures due to
@@ -1474,7 +1474,7 @@ build_traversal_shader(struct radv_device *device,
       .vars = &vars,
       .trav_vars = &trav_vars,
       .barycentrics = barycentrics,
-      .handles = handles,
+      .groups = groups,
    };
 
    struct radv_ray_traversal_args args = {
@@ -1579,8 +1579,7 @@ move_rt_instructions(nir_shader *shader)
 nir_shader *
 create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                  struct radv_pipeline_shader_stack_size *stack_sizes,
-                 const struct radv_pipeline_group_handle *handles,
-                 const struct radv_pipeline_key *key)
+                 const struct radv_ray_tracing_module *groups, const struct radv_pipeline_key *key)
 {
    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
    b.shader->info.internal = false;
@@ -1612,7 +1611,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
 
    /* Insert traversal shader */
-   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, handles, key);
+   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes, groups, key);
    b.shader->info.shared_size = MAX2(b.shader->info.shared_size, traversal->info.shared_size);
    assert(b.shader->info.shared_size <= 32768);
    insert_rt_case(&b, traversal, &vars, idx, 0, 1);
@@ -1631,7 +1630,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
       /* Avoid emitting stages with the same shaders/handles multiple times. */
       bool is_dup = false;
       for (unsigned j = 0; j < i; ++j)
-         if (handles[j].general_index == handles[i].general_index)
+         if (groups[j].handle.general_index == groups[i].handle.general_index)
             is_dup = true;
 
       if (is_dup)
@@ -1660,7 +1659,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
       nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, &num_resume_shaders, nir_stage);
 
       vars.stage_idx = stage_idx;
-      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, handles[i].general_index);
+      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, groups[i].handle.general_index);
       for (unsigned j = 0; j < num_resume_shaders; ++j) {
          insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
       }
index dee8a42..9676539 100644 (file)
@@ -47,7 +47,7 @@ struct radv_physical_device;
 struct radv_device;
 struct radv_pipeline;
 struct radv_pipeline_cache;
-struct radv_pipeline_group_handle;
+struct radv_ray_tracing_module;
 struct radv_pipeline_key;
 struct radv_shader_args;
 struct radv_vs_input_state;
@@ -751,7 +751,7 @@ bool radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage
 nir_shader *create_rt_shader(struct radv_device *device,
                              const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                              struct radv_pipeline_shader_stack_size *stack_sizes,
-                             const struct radv_pipeline_group_handle *handles,
+                             const struct radv_ray_tracing_module *groups,
                              const struct radv_pipeline_key *key);
 
 #endif