ac/nir,radv: pass workgroup size to ac_nir_lower_ngg_ms
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 19 Oct 2023 18:27:07 +0000 (19:27 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 24 Oct 2023 21:36:06 +0000 (21:36 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25040>

src/amd/common/ac_nir.h
src/amd/common/ac_nir_lower_ngg.c
src/amd/vulkan/radv_shader.c

index 0d91a7e..8c7eac5 100644 (file)
@@ -196,6 +196,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                     bool has_param_exports,
                     bool *out_needs_scratch_ring,
                     unsigned wave_size,
+                    unsigned workgroup_size,
                     bool multiview,
                     bool has_query,
                     bool fast_launch_2);
index e481c6d..deedb4c 100644 (file)
@@ -4868,6 +4868,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                     bool has_param_exports,
                     bool *out_needs_scratch_ring,
                     unsigned wave_size,
+                    unsigned hw_workgroup_size,
                     bool multiview,
                     bool has_query,
                     bool fast_launch_2)
@@ -4907,9 +4908,6 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                                  shader->info.workgroup_size[1] *
                                  shader->info.workgroup_size[2];
 
-   unsigned hw_workgroup_size =
-      ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
-
    lower_ngg_ms_state state = {
       .layout = layout,
       .wave_size = wave_size,
index 11bfd5f..6262527 100644 (file)
@@ -913,10 +913,13 @@ radv_lower_ngg(struct radv_device *device, struct radv_shader_stage *ngg_stage,
 
       NIR_PASS_V(nir, ac_nir_lower_ngg_gs, &options);
    } else if (nir->info.stage == MESA_SHADER_MESH) {
+      /* ACO aligns the workgroup size to the wave size. */
+      unsigned hw_workgroup_size = ALIGN(info->workgroup_size, info->wave_size);
+
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, options.gfx_level, options.clipdist_enable_mask,
                  options.vs_output_param_offset, options.has_param_exports, &scratch_ring, info->wave_size,
-                 pl_key->has_multiview_view_index, info->ms.has_query, device->mesh_fast_launch_2);
+                 hw_workgroup_size, pl_key->has_multiview_view_index, info->ms.has_query, device->mesh_fast_launch_2);
       ngg_stage->info.ms.needs_ms_scratch_ring = scratch_ring;
    } else {
       unreachable("invalid SW stage passed to radv_lower_ngg");