radv: add radv_bind_shader() helper
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Fri, 24 Mar 2023 15:12:16 +0000 (16:12 +0100)
committerMarge Bot <emma+marge@anholt.net>
Wed, 29 Mar 2023 10:18:24 +0000 (10:18 +0000)
Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22128>

src/amd/vulkan/radv_cmd_buffer.c

index 202e0b7..c1c12c5 100644 (file)
@@ -6367,6 +6367,48 @@ radv_bind_task_shader(struct radv_cmd_buffer *cmd_buffer, const struct radv_shad
    cmd_buffer->task_rings_needed = true;
 }
 
+/* This function binds/unbinds a shader to the cmdbuffer state. */
+static void
+radv_bind_shader(struct radv_cmd_buffer *cmd_buffer, struct radv_shader *shader,
+                 gl_shader_stage stage)
+{
+   if (!shader)
+      return;
+
+   switch (stage) {
+   case MESA_SHADER_VERTEX:
+      radv_bind_vertex_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_TESS_CTRL:
+      radv_bind_tess_ctrl_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_TESS_EVAL:
+      radv_bind_tess_eval_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_GEOMETRY:
+      radv_bind_geometry_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_FRAGMENT:
+      radv_bind_fragment_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_MESH:
+      radv_bind_mesh_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_TASK:
+      radv_bind_task_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_COMPUTE:
+   case MESA_SHADER_RAYGEN:
+      /* no-op */
+      break;
+   default:
+      unreachable("invalid shader stage");
+   }
+}
+
+#define RADV_GRAPHICS_STAGES \
+   (VK_SHADER_STAGE_ALL_GRAPHICS | VK_SHADER_STAGE_MESH_BIT_EXT | VK_SHADER_STAGE_TASK_BIT_EXT)
+
 VKAPI_ATTR void VKAPI_CALL
 radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipelineBindPoint,
                      VkPipeline _pipeline)
@@ -6382,6 +6424,9 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
+      radv_bind_shader(cmd_buffer, compute_pipeline->base.shaders[MESA_SHADER_COMPUTE],
+                       MESA_SHADER_COMPUTE);
+
       cmd_buffer->state.compute_pipeline = compute_pipeline;
       cmd_buffer->push_constant_stages |= VK_SHADER_STAGE_COMPUTE_BIT;
       break;
@@ -6393,6 +6438,11 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
+      radv_bind_shader(cmd_buffer, rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE],
+                       MESA_SHADER_COMPUTE);
+      radv_bind_shader(cmd_buffer, rt_pipeline->base.base.shaders[MESA_SHADER_RAYGEN],
+                       MESA_SHADER_RAYGEN);
+
       cmd_buffer->state.rt_pipeline = rt_pipeline;
       cmd_buffer->push_constant_stages |= RADV_RT_STAGE_BITS;
 
@@ -6408,6 +6458,10 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
+      radv_foreach_stage(stage, RADV_GRAPHICS_STAGES) {
+         radv_bind_shader(cmd_buffer, graphics_pipeline->base.shaders[stage], stage);
+      }
+
       bool vtx_emit_count_changed =
          !cmd_buffer->state.graphics_pipeline ||
          cmd_buffer->state.graphics_pipeline->vtx_emit_num != graphics_pipeline->vtx_emit_num ||
@@ -6480,38 +6534,6 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint pipeline
          MAX2(cmd_buffer->scratch_size_per_wave_needed, pipeline->scratch_bytes_per_wave);
       cmd_buffer->scratch_waves_wanted = MAX2(cmd_buffer->scratch_waves_wanted, pipeline->max_waves);
 
-      for (uint32_t s = 0; s < MESA_SHADER_COMPUTE; s++) {
-         const struct radv_shader *shader = graphics_pipeline->base.shaders[s];
-
-         if (!shader)
-            continue;
-
-         switch (s) {
-         case MESA_SHADER_VERTEX:
-            radv_bind_vertex_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_TESS_CTRL:
-            radv_bind_tess_ctrl_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_TESS_EVAL:
-            radv_bind_tess_eval_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_GEOMETRY:
-            radv_bind_geometry_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_FRAGMENT:
-            radv_bind_fragment_shader(cmd_buffer, shader);
-            break;
-         default:
-            unreachable("invalid graphics shader stage");
-         }
-      }
-
-      if (graphics_pipeline->base.shaders[MESA_SHADER_MESH])
-         radv_bind_mesh_shader(cmd_buffer, graphics_pipeline->base.shaders[MESA_SHADER_MESH]);
-      if (graphics_pipeline->base.shaders[MESA_SHADER_TASK])
-         radv_bind_task_shader(cmd_buffer, graphics_pipeline->base.shaders[MESA_SHADER_TASK]);
-
       radv_bind_multisample_state(cmd_buffer, &graphics_pipeline->ms);
       break;
    }