radv: fix binding raytracing/compute pipelines
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Fri, 31 Mar 2023 11:32:59 +0000 (13:32 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 31 Mar 2023 18:29:05 +0000 (18:29 +0000)
If a compute pipeline is bound after a raytracing pipeline, the
computes shader slot (aka RT prolog) will be overwritten.

To fix this, move the RT prolog outside of the compute shader slot.

Fixes: d109362a3da ("radv: copy bound shaders to the cmdbuf state")
Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22235>

src/amd/vulkan/radv_cmd_buffer.c
src/amd/vulkan/radv_private.h

index 31b68c1..ad71cd6 100644 (file)
@@ -4559,8 +4559,11 @@ radv_flush_indirect_descriptor_sets(struct radv_cmd_buffer *cmd_buffer,
                                     cmd_buffer->state.shaders[MESA_SHADER_TASK]->info.user_data_0,
                                     AC_UD_INDIRECT_DESCRIPTOR_SETS, va);
    } else {
-      radv_emit_userdata_address(device, cs, cmd_buffer->state.shaders[MESA_SHADER_COMPUTE],
-                                 cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]->info.user_data_0,
+      struct radv_shader *compute_shader = bind_point == VK_PIPELINE_BIND_POINT_COMPUTE
+                                              ? cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]
+                                              : cmd_buffer->state.rt_prolog;
+
+      radv_emit_userdata_address(device, cs, compute_shader, compute_shader->info.user_data_0,
                                  AC_UD_INDIRECT_DESCRIPTOR_SETS, va);
    }
 }
@@ -4587,8 +4590,11 @@ radv_flush_descriptors(struct radv_cmd_buffer *cmd_buffer, VkShaderStageFlags st
       radeon_check_space(device->ws, cs, MAX_SETS * MESA_VULKAN_SHADER_STAGES * 4);
 
    if (stages & VK_SHADER_STAGE_COMPUTE_BIT) {
-      radv_emit_descriptor_pointers(device, cs, cmd_buffer->state.shaders[MESA_SHADER_COMPUTE],
-                                    cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]->info.user_data_0,
+      struct radv_shader *compute_shader = bind_point == VK_PIPELINE_BIND_POINT_COMPUTE
+                                              ? cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]
+                                              : cmd_buffer->state.rt_prolog;
+
+      radv_emit_descriptor_pointers(device, cs, compute_shader, compute_shader->info.user_data_0,
                                     descriptors_state);
    } else {
       radv_foreach_stage(stage, stages & ~VK_SHADER_STAGE_TASK_BIT_EXT)
@@ -4691,10 +4697,12 @@ radv_flush_constants(struct radv_cmd_buffer *cmd_buffer, VkShaderStageFlags stag
    }
 
    if (internal_stages & VK_SHADER_STAGE_COMPUTE_BIT) {
-      radv_emit_all_inline_push_consts(device, cs, cmd_buffer->state.shaders[MESA_SHADER_COMPUTE],
-                                       cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]->info.user_data_0,
-                                       (uint32_t *)cmd_buffer->push_constants, &need_push_constants);
+      struct radv_shader *compute_shader = bind_point == VK_PIPELINE_BIND_POINT_COMPUTE
+                                              ? cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]
+                                              : cmd_buffer->state.rt_prolog;
 
+      radv_emit_all_inline_push_consts(device, cs, compute_shader, compute_shader->info.user_data_0,
+                                       (uint32_t *)cmd_buffer->push_constants, &need_push_constants);
    } else {
       radv_foreach_stage(stage, internal_stages & ~VK_SHADER_STAGE_TASK_BIT_EXT) {
          shader = radv_get_shader(cmd_buffer->state.shaders, stage);
@@ -4733,8 +4741,11 @@ radv_flush_constants(struct radv_cmd_buffer *cmd_buffer, VkShaderStageFlags stag
          radeon_check_space(cmd_buffer->device->ws, cmd_buffer->cs, MESA_VULKAN_SHADER_STAGES * 4);
 
       if (internal_stages & VK_SHADER_STAGE_COMPUTE_BIT) {
-         radv_emit_userdata_address(device, cs, cmd_buffer->state.shaders[MESA_SHADER_COMPUTE],
-                                    cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]->info.user_data_0,
+         struct radv_shader *compute_shader = bind_point == VK_PIPELINE_BIND_POINT_COMPUTE
+                                                 ? cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]
+                                                 : cmd_buffer->state.rt_prolog;
+
+         radv_emit_userdata_address(device, cs, compute_shader, compute_shader->info.user_data_0,
                                     AC_UD_PUSH_CONSTANTS, va);
       } else {
          prev_shader = NULL;
@@ -6267,12 +6278,13 @@ radv_emit_compute_pipeline(struct radv_cmd_buffer *cmd_buffer,
    cmd_buffer->compute_scratch_waves_wanted =
       MAX2(cmd_buffer->compute_scratch_waves_wanted, pipeline->base.max_waves);
 
-   radv_cs_add_buffer(cmd_buffer->device->ws, cmd_buffer->cs,
-                      cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]->bo);
-
-   if (pipeline->base.type == RADV_PIPELINE_RAY_TRACING) {
+   if (pipeline->base.type == RADV_PIPELINE_COMPUTE) {
+      radv_cs_add_buffer(cmd_buffer->device->ws, cmd_buffer->cs,
+                         cmd_buffer->state.shaders[MESA_SHADER_COMPUTE]->bo);
+   } else {
+      radv_cs_add_buffer(cmd_buffer->device->ws, cmd_buffer->cs, cmd_buffer->state.rt_prolog->bo);
       radv_cs_add_buffer(cmd_buffer->device->ws, cmd_buffer->cs,
-                      cmd_buffer->state.shaders[MESA_SHADER_RAYGEN]->bo);
+                         cmd_buffer->state.shaders[MESA_SHADER_RAYGEN]->bo);
    }
 
    if (unlikely(cmd_buffer->device->trace_bo))
@@ -6558,10 +6570,9 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
-      radv_bind_shader(cmd_buffer, rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE],
-                       MESA_SHADER_COMPUTE);
       radv_bind_shader(cmd_buffer, rt_pipeline->base.base.shaders[MESA_SHADER_RAYGEN],
                        MESA_SHADER_RAYGEN);
+      cmd_buffer->state.rt_prolog = rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE];
 
       cmd_buffer->state.rt_pipeline = rt_pipeline;
       cmd_buffer->push_constant_stages |= RADV_RT_STAGE_BITS;
@@ -9684,9 +9695,9 @@ radv_upload_compute_shader_descriptors(struct radv_cmd_buffer *cmd_buffer,
 
 static void
 radv_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info,
-              struct radv_compute_pipeline *pipeline, VkPipelineBindPoint bind_point)
+              struct radv_compute_pipeline *pipeline, struct radv_shader *compute_shader,
+              VkPipelineBindPoint bind_point)
 {
-   struct radv_shader *compute_shader = cmd_buffer->state.shaders[MESA_SHADER_COMPUTE];
    bool has_prefetch = cmd_buffer->device->physical_device->rad_info.gfx_level >= GFX7;
    bool pipeline_is_dirty = pipeline != cmd_buffer->state.emitted_compute_pipeline;
 
@@ -9759,6 +9770,7 @@ void
 radv_compute_dispatch(struct radv_cmd_buffer *cmd_buffer, const struct radv_dispatch_info *info)
 {
    radv_dispatch(cmd_buffer, info, cmd_buffer->state.compute_pipeline,
+                 cmd_buffer->state.shaders[MESA_SHADER_COMPUTE],
                  VK_PIPELINE_BIND_POINT_COMPUTE);
 }
 
@@ -9827,12 +9839,12 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, const VkTraceRaysIndirectCom
                 uint64_t indirect_va, enum radv_rt_mode mode)
 {
    struct radv_compute_pipeline *pipeline = &cmd_buffer->state.rt_pipeline->base;
-   const struct radv_shader *compute_shader = cmd_buffer->state.shaders[MESA_SHADER_COMPUTE];
-   uint32_t base_reg = compute_shader->info.user_data_0;
+   struct radv_shader *rt_prolog = cmd_buffer->state.rt_prolog;
+   uint32_t base_reg = rt_prolog->info.user_data_0;
 
    /* Reserve scratch for stacks manually since it is not handled by the compute path. */
    uint32_t scratch_bytes_per_wave = pipeline->base.scratch_bytes_per_wave;
-   uint32_t wave_size = compute_shader->info.wave_size;
+   uint32_t wave_size = rt_prolog->info.wave_size;
 
    /* The hardware register is specified as a multiple of 256 DWORDS. */
    scratch_bytes_per_wave += align(cmd_buffer->state.rt_stack_size * wave_size, 1024);
@@ -9876,29 +9888,29 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, const VkTraceRaysIndirectCom
    ASSERTED unsigned cdw_max = radeon_check_space(cmd_buffer->device->ws, cmd_buffer->cs, 15);
 
    const struct radv_userdata_info *desc_loc =
-      radv_get_user_sgpr(compute_shader, AC_UD_CS_SBT_DESCRIPTORS);
+      radv_get_user_sgpr(rt_prolog, AC_UD_CS_SBT_DESCRIPTORS);
    if (desc_loc->sgpr_idx != -1) {
       radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs,
                                base_reg + desc_loc->sgpr_idx * 4, sbt_va, true);
    }
 
    const struct radv_userdata_info *size_loc =
-      radv_get_user_sgpr(compute_shader, AC_UD_CS_RAY_LAUNCH_SIZE_ADDR);
+      radv_get_user_sgpr(rt_prolog, AC_UD_CS_RAY_LAUNCH_SIZE_ADDR);
    if (size_loc->sgpr_idx != -1) {
       radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs,
                                base_reg + size_loc->sgpr_idx * 4, launch_size_va, true);
    }
 
    const struct radv_userdata_info *base_loc =
-      radv_get_user_sgpr(compute_shader, AC_UD_CS_RAY_DYNAMIC_CALLABLE_STACK_BASE);
+      radv_get_user_sgpr(rt_prolog, AC_UD_CS_RAY_DYNAMIC_CALLABLE_STACK_BASE);
    if (base_loc->sgpr_idx != -1) {
-      const struct radv_shader_info *cs_info = &compute_shader->info;
+      const struct radv_shader_info *cs_info = &rt_prolog->info;
       radeon_set_sh_reg(cmd_buffer->cs, R_00B900_COMPUTE_USER_DATA_0 + base_loc->sgpr_idx * 4,
                         pipeline->base.scratch_bytes_per_wave / cs_info->wave_size);
    }
 
    const struct radv_userdata_info *shader_loc =
-      radv_get_user_sgpr(compute_shader, AC_UD_CS_TRAVERSAL_SHADER_ADDR);
+      radv_get_user_sgpr(rt_prolog, AC_UD_CS_TRAVERSAL_SHADER_ADDR);
    if (shader_loc->sgpr_idx != -1) {
       uint64_t raygen_va = cmd_buffer->state.shaders[MESA_SHADER_RAYGEN]->va;
       radv_emit_shader_pointer(cmd_buffer->device, cmd_buffer->cs,
@@ -9907,7 +9919,7 @@ radv_trace_rays(struct radv_cmd_buffer *cmd_buffer, const VkTraceRaysIndirectCom
 
    assert(cmd_buffer->cs->cdw <= cdw_max);
 
-   radv_dispatch(cmd_buffer, &info, pipeline, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR);
+   radv_dispatch(cmd_buffer, &info, pipeline, rt_prolog, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR);
 }
 
 VKAPI_ATTR void VKAPI_CALL
index ad03861..5a822a4 100644 (file)
@@ -1606,6 +1606,7 @@ struct radv_cmd_state {
    struct radv_shader *shaders[MESA_VULKAN_SHADER_STAGES];
    struct radv_shader *gs_copy_shader;
    struct radv_shader *last_vgt_shader;
+   struct radv_shader *rt_prolog;
 
    uint32_t prefetch_L2_mask;