ac/nir: Store mesh shader API and HW workgroup size in lowering state.
authorTimur Kristóf <timur.kristof@gmail.com>
Sun, 27 Feb 2022 17:39:01 +0000 (18:39 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 1 Mar 2022 15:37:12 +0000 (15:37 +0000)
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15199>

src/amd/common/ac_nir_lower_ngg.c

index 7ba1978..1fb68bd 100644 (file)
@@ -107,6 +107,8 @@ typedef struct
    unsigned prim_vtx_indices_addr;
    unsigned numprims_lds_addr;
    unsigned wave_size;
+   unsigned api_workgroup_size;
+   unsigned hw_workgroup_size;
 
    struct {
       /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
@@ -2457,12 +2459,12 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
 }
 
 static void
-handle_smaller_ms_api_workgroup(nir_function_impl *impl,
-                                nir_builder *b,
-                                unsigned api_workgroup_size,
-                                unsigned hw_workgroup_size,
+handle_smaller_ms_api_workgroup(nir_builder *b,
                                 lower_ngg_ms_state *s)
 {
+   if (s->api_workgroup_size >= s->hw_workgroup_size)
+      return;
+
    /* Handle barriers manually when the API workgroup
     * size is less than the HW workgroup size.
     *
@@ -2478,19 +2480,19 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl,
     *    all. In this case, we emit code that consumes every
     *    barrier on the extra waves.
     */
-   assert(hw_workgroup_size % s->wave_size == 0);
-   bool scan_barriers = ALIGN(api_workgroup_size, s->wave_size) < hw_workgroup_size;
-   bool can_shrink_barriers = api_workgroup_size <= s->wave_size;
+   assert(s->hw_workgroup_size % s->wave_size == 0);
+   bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
+   bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
 
    unsigned api_waves_in_flight_addr = s->numprims_lds_addr + 12;
-   unsigned num_api_waves = DIV_ROUND_UP(api_workgroup_size, s->wave_size);
+   unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
 
    /* Scan the shader for workgroup barriers. */
    if (scan_barriers) {
       bool has_any_workgroup_barriers = false;
 
-      nir_foreach_block(block, impl) {
+      nir_foreach_block(block, b->impl) {
          nir_foreach_instr_safe(instr, block) {
             if (instr->type != nir_instr_type_intrinsic)
                continue;
@@ -2521,8 +2523,8 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl,
 
    /* Extract the full control flow of the shader. */
    nir_cf_list extracted;
-   nir_cf_extract(&extracted, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
-   b->cursor = nir_before_cf_list(&impl->body);
+   nir_cf_extract(&extracted, nir_before_cf_list(&b->impl->body), nir_after_cf_list(&b->impl->body));
+   b->cursor = nir_before_cf_list(&b->impl->body);
 
    /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
    nir_ssa_def *invocation_index = nir_load_local_invocation_index(b);
@@ -2542,7 +2544,7 @@ handle_smaller_ms_api_workgroup(nir_function_impl *impl,
                             .memory_modes = nir_var_shader_out | nir_var_mem_shared);
    }
 
-   nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, api_workgroup_size));
+   nir_ssa_def *has_api_ms_invocation = nir_ult(b, invocation_index, nir_imm_int(b, s->api_workgroup_size));
    nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
    {
       nir_cf_reinsert(&extracted, b->cursor);
@@ -2638,19 +2640,6 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
 
    shader->info.shared_size = prim_vtx_indices_addr + prim_vtx_indices_size;
 
-   lower_ngg_ms_state state = {
-      .wave_size = wave_size,
-      .per_vertex_outputs = per_vertex_outputs,
-      .per_primitive_outputs = per_primitive_outputs,
-      .num_per_vertex_outputs = num_per_vertex_outputs,
-      .num_per_primitive_outputs = num_per_primitive_outputs,
-      .vertices_per_prim = vertices_per_prim,
-      .vertex_attr_lds_addr = vertex_attr_lds_addr,
-      .prim_attr_lds_addr = prim_attr_lds_addr,
-      .prim_vtx_indices_addr = prim_vtx_indices_addr,
-      .numprims_lds_addr = numprims_lds_addr,
-   };
-
    /* The workgroup size that is specified by the API shader may be different
     * from the size of the workgroup that actually runs on the HW, due to the
     * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
@@ -2665,6 +2654,21 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
    unsigned hw_workgroup_size =
       ALIGN(MAX3(api_workgroup_size, max_primitives, max_vertices), wave_size);
 
+   lower_ngg_ms_state state = {
+      .wave_size = wave_size,
+      .per_vertex_outputs = per_vertex_outputs,
+      .per_primitive_outputs = per_primitive_outputs,
+      .num_per_vertex_outputs = num_per_vertex_outputs,
+      .num_per_primitive_outputs = num_per_primitive_outputs,
+      .vertices_per_prim = vertices_per_prim,
+      .vertex_attr_lds_addr = vertex_attr_lds_addr,
+      .prim_attr_lds_addr = prim_attr_lds_addr,
+      .prim_vtx_indices_addr = prim_vtx_indices_addr,
+      .numprims_lds_addr = numprims_lds_addr,
+      .api_workgroup_size = api_workgroup_size,
+      .hw_workgroup_size = hw_workgroup_size,
+   };
+
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    assert(impl);
 
@@ -2673,9 +2677,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
    nir_builder_init(b, impl);
    b->cursor = nir_before_cf_list(&impl->body);
 
-   if (api_workgroup_size < hw_workgroup_size) {
-      handle_smaller_ms_api_workgroup(impl, b, api_workgroup_size, hw_workgroup_size, &state);
-   }
+   handle_smaller_ms_api_workgroup(b, &state);
 
    lower_ms_intrinsics(shader, &state);
    emit_ms_finale(b, &state);