ac/nir: add ac_nir_lower_ngg_options
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 14 Oct 2022 16:15:39 +0000 (17:15 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 27 Oct 2022 13:31:40 +0000 (13:31 +0000)
These signatures were getting ridiculous.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19340>

src/amd/common/ac_nir.h
src/amd/common/ac_nir_lower_ngg.c
src/amd/vulkan/radv_shader.c

index 0f9808a..b7e5e45 100644 (file)
@@ -119,32 +119,36 @@ bool
 ac_nir_lower_indirect_derefs(nir_shader *shader,
                              enum amd_gfx_level gfx_level);
 
+typedef struct {
+   enum radeon_family family;
+   enum amd_gfx_level gfx_level;
+
+   unsigned max_workgroup_size;
+   unsigned wave_size;
+   bool can_cull;
+   bool disable_streamout;
+
+   /* VS */
+   unsigned num_vertices_per_primitive;
+   bool early_prim_export;
+   bool passthrough;
+   bool use_edgeflags;
+   bool has_prim_query;
+   int primitive_id_location;
+   uint32_t instance_rate_inputs;
+   uint32_t clipdist_enable_mask;
+   uint32_t user_clip_plane_enable_mask;
+
+   /* GS */
+   unsigned gs_out_vtx_bytes;
+   bool has_xfb_query;
+} ac_nir_lower_ngg_options;
+
 void
-ac_nir_lower_ngg_nogs(nir_shader *shader,
-                      enum radeon_family family,
-                      unsigned num_vertices_per_primitive,
-                      unsigned max_workgroup_size,
-                      unsigned wave_size,
-                      bool can_cull,
-                      bool early_prim_export,
-                      bool passthrough,
-                      bool use_edgeflags,
-                      bool has_prim_query,
-                      bool disable_streamout,
-                      int primitive_id_location,
-                      uint32_t instance_rate_inputs,
-                      uint32_t clipdist_enable_mask,
-                      uint32_t user_clip_plane_enable_mask);
+ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options);
 
 void
-ac_nir_lower_ngg_gs(nir_shader *shader,
-                    enum amd_gfx_level gfx_level,
-                    unsigned wave_size,
-                    unsigned max_workgroup_size,
-                    unsigned gs_out_vtx_bytes,
-                    bool has_xfb_query,
-                    bool can_cull,
-                    bool disable_streamout);
+ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options);
 
 void
 ac_nir_lower_ngg_ms(nir_shader *shader,
index 523a951..94f6a7a 100644 (file)
@@ -47,6 +47,8 @@ typedef struct
 
 typedef struct
 {
+   const ac_nir_lower_ngg_options *options;
+
    nir_variable *position_value_var;
    nir_variable *prim_exp_arg_var;
    nir_variable *es_accepted_var;
@@ -57,24 +59,17 @@ typedef struct
 
    struct u_vector saved_uniforms;
 
-   bool passthrough;
    bool early_prim_export;
-   bool use_edgeflags;
-   bool has_prim_query;
    bool streamout_enabled;
    bool has_user_edgeflags;
-   unsigned wave_size;
    unsigned max_num_waves;
-   unsigned num_vertices_per_primitives;
    unsigned position_store_base;
-   int primitive_id_location;
 
    /* LDS params */
    unsigned pervertex_lds_bytes;
 
    uint64_t inputs_needed_by_pos;
    uint64_t inputs_needed_by_others;
-   uint32_t instance_rate_inputs;
 
    nir_instr *compact_arg_stores[4];
    nir_intrinsic_instr *overwrite_args;
@@ -82,8 +77,6 @@ typedef struct
    /* clip distance */
    nir_variable *clip_vertex_var;
    nir_variable *clipdist_neg_mask_var;
-   unsigned clipdist_enable_mask;
-   unsigned user_clip_plane_enable_mask;
    bool has_clipdist;
 } lower_ngg_nogs_state;
 
@@ -102,23 +95,21 @@ typedef struct
 
 typedef struct
 {
+   const ac_nir_lower_ngg_options *options;
+
    nir_function_impl *impl;
-   enum amd_gfx_level gfx_level;
    nir_variable *output_vars[VARYING_SLOT_MAX][4];
    nir_variable *current_clear_primflag_idx_var;
    int const_out_vtxcnt[4];
    int const_out_prmcnt[4];
-   unsigned wave_size;
    unsigned max_num_waves;
    unsigned num_vertices_per_primitive;
    nir_ssa_def *lds_addr_gs_out_vtx;
    nir_ssa_def *lds_addr_gs_scratch;
    unsigned lds_bytes_per_gs_out_vertex;
    unsigned lds_offs_primflags;
-   bool has_xfb_query;
    bool found_out_vtxcnt[4];
    bool output_compile_time_known;
-   bool can_cull;
    bool streamout_enabled;
    gs_output_info output_info[VARYING_SLOT_MAX];
 } lower_ngg_gs_state;
@@ -421,10 +412,10 @@ emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
 static void
 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *st)
 {
-   for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v) {
+   for (unsigned v = 0; v < st->options->num_vertices_per_primitive; ++v) {
       st->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
 
-      nir_ssa_def *vtx = st->passthrough ?
+      nir_ssa_def *vtx = st->options->passthrough ?
          nir_ubfe(b, nir_load_packed_passthrough_primitive_amd(b),
                   nir_imm_int(b, 10 * v), nir_imm_int(b, 9)) :
          nir_ubfe(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
@@ -437,15 +428,16 @@ ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower
 static nir_ssa_def *
 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
 {
-   if (st->passthrough) {
+   if (st->options->passthrough) {
       return nir_load_packed_passthrough_primitive_amd(b);
    } else {
       nir_ssa_def *vtx_idx[3] = {0};
 
-      for (unsigned v = 0; v < st->num_vertices_per_primitives; ++v)
+      for (unsigned v = 0; v < st->options->num_vertices_per_primitive; ++v)
          vtx_idx[v] = nir_load_var(b, st->gs_vtx_indices_vars[v]);
 
-      return emit_pack_ngg_prim_exp_arg(b, st->num_vertices_per_primitives, vtx_idx, NULL, st->use_edgeflags);
+      return emit_pack_ngg_prim_exp_arg(b, st->options->num_vertices_per_primitive, vtx_idx, NULL,
+                                        st->options->use_edgeflags);
    }
 }
 
@@ -480,7 +472,7 @@ emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def
             edge_flag_offset = packed_location * 16;
          }
 
-         for (int i = 0; i < st->num_vertices_per_primitives; i++) {
+         for (int i = 0; i < st->options->num_vertices_per_primitive; i++) {
             nir_ssa_def *vtx_idx = nir_load_var(b, st->gs_vtx_indices_vars[i]);
             nir_ssa_def *addr = pervertex_lds_addr(b, vtx_idx, st->pervertex_lds_bytes);
             nir_ssa_def *edge = nir_load_shared(b, 1, 32, addr, .base = edge_flag_offset);
@@ -489,11 +481,12 @@ emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def
          arg = nir_iand(b, arg, mask);
       }
 
-      if (st->has_prim_query) {
+      if (st->options->has_prim_query) {
          nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
          {
             /* Number of active GS threads. Each has 1 output primitive. */
-            nir_ssa_def *num_gs_threads = nir_bit_count(b, nir_ballot(b, 1, st->wave_size, nir_imm_bool(b, true)));
+            nir_ssa_def *num_gs_threads =
+               nir_bit_count(b, nir_ballot(b, 1, st->options->wave_size, nir_imm_bool(b, true)));
             /* Activate only 1 lane and add the number of primitives to query result. */
             nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
             {
@@ -523,12 +516,12 @@ emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st)
        * It will be exported as a per-vertex attribute.
        */
       nir_ssa_def *gs_vtx_indices[3];
-      for (unsigned i = 0; i < st->num_vertices_per_primitives; i++)
+      for (unsigned i = 0; i < st->options->num_vertices_per_primitive; i++)
          gs_vtx_indices[i] = nir_load_var(b, st->gs_vtx_indices_vars[i]);
 
       nir_ssa_def *provoking_vertex = nir_load_provoking_vtx_in_prim_amd(b);
       nir_ssa_def *provoking_vtx_idx = nir_select_from_ssa_def_array(
-         b, gs_vtx_indices, st->num_vertices_per_primitives, provoking_vertex);
+         b, gs_vtx_indices, st->options->num_vertices_per_primitive, provoking_vertex);
 
       nir_ssa_def *prim_id = nir_load_primitive_id(b);
       nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, st->pervertex_lds_bytes);
@@ -563,7 +556,7 @@ emit_store_ngg_nogs_es_primitive_id(nir_builder *b, lower_ngg_nogs_state *st)
    };
 
    nir_store_output(b, prim_id, nir_imm_zero(b, 1, 32),
-                    .base = st->primitive_id_location,
+                    .base = st->options->primitive_id_location,
                     .src_type = nir_type_uint32, .io_semantics = io_sem);
 }
 
@@ -640,7 +633,7 @@ remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
       base += component;
 
       /* valid clipdist component mask */
-      unsigned mask = (s->clipdist_enable_mask >> base) & writemask;
+      unsigned mask = (s->options->clipdist_enable_mask >> base) & writemask;
       u_foreach_bit(i, mask) {
          add_clipdist_bit(b, nir_channel(b, store_val, i), base + i,
                           s->clipdist_neg_mask_var);
@@ -838,7 +831,7 @@ cleanup_culling_shader_after_dce(nir_shader *shader,
             uses_vs_instance_id = true;
             break;
          case nir_intrinsic_load_input:
-            if (state->instance_rate_inputs &
+            if (state->options->instance_rate_inputs &
                 (1u << (nir_intrinsic_base(intrin) - VERT_ATTRIB_GENERIC0)))
                uses_vs_instance_id = true;
             else
@@ -978,7 +971,7 @@ compact_vertices_after_culling(nir_builder *b,
       nir_ssa_def *exporter_vtx_indices[3] = {0};
 
       /* Load the index of the ES threads that will export the current GS thread's vertices */
-      for (unsigned v = 0; v < nogs_state->num_vertices_per_primitives; ++v) {
+      for (unsigned v = 0; v < nogs_state->options->num_vertices_per_primitive; ++v) {
          nir_ssa_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
          nir_ssa_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
          exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
@@ -986,8 +979,8 @@ compact_vertices_after_culling(nir_builder *b,
       }
 
       nir_ssa_def *prim_exp_arg =
-         emit_pack_ngg_prim_exp_arg(b, nogs_state->num_vertices_per_primitives,
-                                    exporter_vtx_indices, NULL, nogs_state->use_edgeflags);
+         emit_pack_ngg_prim_exp_arg(b, nogs_state->options->num_vertices_per_primitive,
+                                    exporter_vtx_indices, NULL, nogs_state->options->use_edgeflags);
       nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
    }
    nir_pop_if(b, if_gs_accepted);
@@ -1260,7 +1253,7 @@ cull_primitive_accepted(nir_builder *b, void *state)
    nir_store_var(b, s->gs_accepted_var, nir_imm_true(b), 0x1u);
 
    /* Store the accepted state to LDS for ES threads */
-   for (unsigned vtx = 0; vtx < s->num_vertices_per_primitives; ++vtx)
+   for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx)
       nir_store_shared(b, nir_imm_intN_t(b, 1, 8), s->vtx_addr[vtx], .base = lds_es_vertex_accepted);
 }
 
@@ -1269,7 +1262,7 @@ clipdist_culling_es_part(nir_builder *b, lower_ngg_nogs_state *nogs_state,
                          nir_ssa_def *es_vertex_lds_addr)
 {
    /* no gl_ClipDistance used but we have user defined clip plane */
-   if (nogs_state->user_clip_plane_enable_mask && !nogs_state->has_clipdist) {
+   if (nogs_state->options->user_clip_plane_enable_mask && !nogs_state->has_clipdist) {
       /* use gl_ClipVertex if defined */
       nir_variable *clip_vertex_var =
          b->shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CLIP_VERTEX) ?
@@ -1278,7 +1271,7 @@ clipdist_culling_es_part(nir_builder *b, lower_ngg_nogs_state *nogs_state,
 
       /* clip against user defined clip planes */
       for (unsigned i = 0; i < 8; i++) {
-         if (!(nogs_state->user_clip_plane_enable_mask & BITFIELD_BIT(i)))
+         if (!(nogs_state->options->user_clip_plane_enable_mask & BITFIELD_BIT(i)))
             continue;
 
          nir_ssa_def *plane = nir_load_user_clip_plane(b, .ucp_id = i);
@@ -1349,7 +1342,8 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       nir_local_variable_create(impl, glsl_uint_type(), "repacked_arg_3"),
    };
 
-   if (nogs_state->clipdist_enable_mask || nogs_state->user_clip_plane_enable_mask) {
+   if (nogs_state->options->clipdist_enable_mask ||
+       nogs_state->options->user_clip_plane_enable_mask) {
       nogs_state->clip_vertex_var =
          nir_local_variable_create(impl, glsl_vec4_type(), "clip_vertex");
       nogs_state->clipdist_neg_mask_var =
@@ -1450,20 +1444,21 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       {
          /* Load vertex indices from input VGPRs */
          nir_ssa_def *vtx_idx[3] = {0};
-         for (unsigned vertex = 0; vertex < nogs_state->num_vertices_per_primitives; ++vertex)
+         for (unsigned vertex = 0; vertex < nogs_state->options->num_vertices_per_primitive;
+              ++vertex)
             vtx_idx[vertex] = nir_load_var(b, nogs_state->gs_vtx_indices_vars[vertex]);
 
          nir_ssa_def *pos[3][4] = {0};
 
          /* Load W positions of vertices first because the culling code will use these first */
-         for (unsigned vtx = 0; vtx < nogs_state->num_vertices_per_primitives; ++vtx) {
+         for (unsigned vtx = 0; vtx < nogs_state->options->num_vertices_per_primitive; ++vtx) {
             nogs_state->vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
             pos[vtx][3] = nir_load_shared(b, 1, 32, nogs_state->vtx_addr[vtx], .base = lds_es_pos_w);
             nir_store_var(b, gs_vtxaddr_vars[vtx], nogs_state->vtx_addr[vtx], 0x1u);
          }
 
          /* Load the X/W, Y/W positions of vertices */
-         for (unsigned vtx = 0; vtx < nogs_state->num_vertices_per_primitives; ++vtx) {
+         for (unsigned vtx = 0; vtx < nogs_state->options->num_vertices_per_primitive; ++vtx) {
             nir_ssa_def *xy = nir_load_shared(b, 2, 32, nogs_state->vtx_addr[vtx], .base = lds_es_pos_x);
             pos[vtx][0] = nir_channel(b, xy, 0);
             pos[vtx][1] = nir_channel(b, xy, 1);
@@ -1472,7 +1467,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
          nir_ssa_def *accepted_by_clipdist;
          if (nogs_state->has_clipdist) {
             nir_ssa_def *clipdist_neg_mask = nir_imm_intN_t(b, 0xff, 8);
-            for (unsigned vtx = 0; vtx < nogs_state->num_vertices_per_primitives; ++vtx) {
+            for (unsigned vtx = 0; vtx < nogs_state->options->num_vertices_per_primitive; ++vtx) {
                nir_ssa_def *mask =
                   nir_load_shared(b, 1, 8, nogs_state->vtx_addr[vtx],
                                   .base = lds_es_clipdist_neg_mask);
@@ -1486,7 +1481,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
 
          /* See if the current primitive is accepted */
          ac_nir_cull_primitive(b, accepted_by_clipdist, pos,
-                               nogs_state->num_vertices_per_primitives,
+                               nogs_state->options->num_vertices_per_primitive,
                                cull_primitive_accepted, nogs_state);
       }
       nir_pop_if(b, if_gs_thread);
@@ -1511,7 +1506,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       /* Repack the vertices that survived the culling. */
       wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, lds_scratch_base,
                                                              nogs_state->max_num_waves,
-                                                             nogs_state->wave_size);
+                                                             nogs_state->options->wave_size);
       nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
       nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
 
@@ -1828,7 +1823,7 @@ ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s)
    nir_ssa_def *so_buffer[4] = {0};
    nir_ssa_def *prim_stride[4] = {0};
    nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
-   ngg_build_streamout_buffer_info(b, info, s->has_prim_query,
+   ngg_build_streamout_buffer_info(b, info, s->options->has_prim_query,
                                    lds_scratch_base, tid_in_tg,
                                    gen_prim_per_stream, prim_stride,
                                    so_buffer, buffer_offsets,
@@ -1841,7 +1836,7 @@ ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s)
       nir_ssa_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
       nir_ssa_def *vtx_buffer_idx = nir_imul(b, tid_in_tg, num_vert_per_prim);
 
-      for (unsigned i = 0; i < s->num_vertices_per_primitives; i++) {
+      for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++) {
          nir_if *if_valid_vertex =
             nir_push_if(b, nir_ilt(b, nir_imm_int(b, i), num_vert_per_prim));
          {
@@ -1908,34 +1903,23 @@ ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
 }
 
 void
-ac_nir_lower_ngg_nogs(nir_shader *shader,
-                      enum radeon_family family,
-                      unsigned num_vertices_per_primitives,
-                      unsigned max_workgroup_size,
-                      unsigned wave_size,
-                      bool can_cull,
-                      bool early_prim_export,
-                      bool passthrough,
-                      bool use_edgeflags,
-                      bool has_prim_query,
-                      bool disable_streamout,
-                      int primitive_id_location,
-                      uint32_t instance_rate_inputs,
-                      uint32_t clipdist_enable_mask,
-                      uint32_t user_clip_plane_enable_mask)
+ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    assert(impl);
-   assert(max_workgroup_size && wave_size);
-   assert(!(can_cull && passthrough));
+   assert(options->max_workgroup_size && options->wave_size);
+   assert(!(options->can_cull && options->passthrough));
 
    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
-   nir_variable *es_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
-   nir_variable *gs_accepted_var = can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
-
-   bool streamout_enabled = shader->xfb_info && !disable_streamout;
-   bool has_user_edgeflags = use_edgeflags && (shader->info.outputs_written & VARYING_BIT_EDGE);
+   nir_variable *es_accepted_var =
+      options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
+   nir_variable *gs_accepted_var =
+      options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
+
+   bool streamout_enabled = shader->xfb_info && !options->disable_streamout;
+   bool has_user_edgeflags =
+      options->use_edgeflags && (shader->info.outputs_written & VARYING_BIT_EDGE);
    /* streamout need to be done before either prim or vertex export. Because when no
     * param export, rasterization can start right after prim and vertex export,
     * which left streamout buffer writes un-finished.
@@ -1944,33 +1928,25 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
     * This is because edge flags are written by ES threads but they
     * are exported by GS threads as part of th primitive export.
     */
-   if (streamout_enabled || has_user_edgeflags)
-      early_prim_export = false;
+   bool early_prim_export =
+      options->early_prim_export && !(streamout_enabled || has_user_edgeflags);
 
    lower_ngg_nogs_state state = {
-      .passthrough = passthrough,
-      .primitive_id_location = primitive_id_location,
+      .options = options,
       .early_prim_export = early_prim_export,
-      .use_edgeflags = use_edgeflags,
-      .has_prim_query = has_prim_query,
       .streamout_enabled = streamout_enabled,
-      .num_vertices_per_primitives = num_vertices_per_primitives,
       .position_value_var = position_value_var,
       .prim_exp_arg_var = prim_exp_arg_var,
       .es_accepted_var = es_accepted_var,
       .gs_accepted_var = gs_accepted_var,
-      .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
-      .wave_size = wave_size,
-      .instance_rate_inputs = instance_rate_inputs,
-      .clipdist_enable_mask = clipdist_enable_mask,
-      .user_clip_plane_enable_mask = user_clip_plane_enable_mask,
+      .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
       .has_user_edgeflags = has_user_edgeflags,
    };
 
    const bool need_prim_id_store_shared =
-      primitive_id_location >= 0 && shader->info.stage == MESA_SHADER_VERTEX;
+      options->primitive_id_location >= 0 && shader->info.stage == MESA_SHADER_VERTEX;
 
-   if (primitive_id_location >= 0) {
+   if (options->primitive_id_location >= 0) {
       nir_variable *prim_id_var = nir_variable_create(shader, nir_var_shader_out, glsl_uint_type(), "ngg_prim_id");
       prim_id_var->data.location = VARYING_SLOT_PRIMITIVE_ID;
       prim_id_var->data.driver_location = VARYING_SLOT_PRIMITIVE_ID;
@@ -1982,7 +1958,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
    nir_builder *b = &builder; /* This is to avoid the & */
    nir_builder_init(b, impl);
 
-   if (can_cull) {
+   if (options->can_cull) {
       /* We need divergence info for culling shaders. */
       nir_divergence_analysis(shader);
       analyze_shader_before_culling(shader, &state);
@@ -1995,9 +1971,9 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
 
    ngg_nogs_init_vertex_indices_vars(b, impl, &state);
 
-   if (!can_cull) {
+   if (!options->can_cull) {
       /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
-      if (!(passthrough && family >= CHIP_NAVI23)) {
+      if (!(options->passthrough && options->family >= CHIP_NAVI23)) {
          /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
          nir_if *if_wave_0 = nir_push_if(b, nir_ieq(b, nir_load_subgroup_id(b), nir_imm_int(b, 0)));
          {
@@ -2034,7 +2010,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       ngg_nogs_get_pervertex_lds_size(shader->info.stage,
                                       shader->num_outputs,
                                       state.streamout_enabled,
-                                      primitive_id_location >= 0,
+                                      options->primitive_id_location >= 0,
                                       state.has_user_edgeflags);
 
    if (need_prim_id_store_shared) {
@@ -2046,7 +2022,8 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
    }
 
    nir_intrinsic_instr *export_vertex_instr;
-   nir_ssa_def *es_thread = can_cull ? nir_load_var(b, es_accepted_var) : nir_has_input_vertex_amd(b);
+   nir_ssa_def *es_thread =
+      options->can_cull ? nir_load_var(b, es_accepted_var) : nir_has_input_vertex_amd(b);
 
    nir_if *if_es_thread = nir_push_if(b, es_thread);
    {
@@ -2054,7 +2031,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       nir_cf_reinsert(&extracted, b->cursor);
       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
 
-      if (state.primitive_id_location >= 0)
+      if (options->primitive_id_location >= 0)
          emit_store_ngg_nogs_es_primitive_id(b, &state);
 
       /* Export all vertex attributes (including the primitive ID) */
@@ -2064,7 +2041,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
 
    if (state.streamout_enabled) {
       /* TODO: support culling after streamout. */
-      assert(!can_cull);
+      assert(!options->can_cull);
 
       ngg_nogs_build_streamout(b, &state);
    }
@@ -2080,7 +2057,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
    }
 
-   if (can_cull) {
+   if (options->can_cull) {
       /* Replace uniforms. */
       apply_reusable_variables(b, &state);
 
@@ -2109,7 +2086,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
    nir_lower_alu_to_scalar(shader, NULL, NULL);
    nir_lower_phis_to_scalar(shader, true);
 
-   if (can_cull) {
+   if (options->can_cull) {
       /* It's beneficial to redo these opts after splitting the shader. */
       nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
       nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
@@ -2122,7 +2099,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       NIR_PASS(progress, shader, nir_opt_dce);
       NIR_PASS(progress, shader, nir_opt_dead_cf);
 
-      if (can_cull)
+      if (options->can_cull)
          progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
    } while (progress);
 }
@@ -2211,8 +2188,8 @@ ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned strea
 static void
 ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
 {
-   bool has_xfb_query = s->has_xfb_query;
-   bool has_pipeline_stats_query = s->gfx_level < GFX11;
+   bool has_xfb_query = s->options->has_xfb_query;
+   bool has_pipeline_stats_query = s->options->gfx_level < GFX11;
 
    nir_ssa_def *pipeline_query_enabled = NULL;
    nir_ssa_def *prim_gen_query_enabled = NULL;
@@ -2244,7 +2221,8 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_st
       unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
       unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
       unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
-      nir_ssa_def *num_threads = nir_bit_count(b, nir_ballot(b, 1, s->wave_size, nir_imm_bool(b, true)));
+      nir_ssa_def *num_threads =
+         nir_bit_count(b, nir_ballot(b, 1, s->options->wave_size, nir_imm_bool(b, true)));
       num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
    } else {
       nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
@@ -2427,9 +2405,10 @@ lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intri
     * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
     */
 
-   nir_ssa_def *vertex_live_flag = !stream && s->can_cull ?
-      nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2) :
-      nir_imm_int(b, 0b100);
+   nir_ssa_def *vertex_live_flag =
+      !stream && s->options->can_cull
+         ? nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2)
+         : nir_imm_int(b, 0b100);
 
    nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
    nir_ssa_def *complete_flag = nir_b2i32(b, completes_prim);
@@ -2816,7 +2795,7 @@ ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *st)
        */
       wg_repack_result rep =
          repack_invocations_in_workgroup(b, prim_live[stream], scratch_base,
-                                         st->max_num_waves, st->wave_size);
+                                         st->max_num_waves, st->options->wave_size);
 
       /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
        * current wave, but still need LDS to sum all wave's count to get workgroup count.
@@ -2837,7 +2816,7 @@ ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *st)
    nir_ssa_def *buffer_offsets[4] = {0};
    nir_ssa_def *so_buffer[4] = {0};
    nir_ssa_def *prim_stride[4] = {0};
-   ngg_build_streamout_buffer_info(b, info, st->has_xfb_query,
+   ngg_build_streamout_buffer_info(b, info, st->options->has_xfb_query,
                                    st->lds_addr_gs_scratch, tid_in_tg, gen_prim,
                                    prim_stride, so_buffer, buffer_offsets, emit_prim);
 
@@ -2906,7 +2885,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
    }
 
    /* cull primitives */
-   if (s->can_cull) {
+   if (s->options->can_cull) {
       nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
 
       /* culling code will update the primflag */
@@ -2925,7 +2904,8 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
     * To ensure this, we need to repack invocations that have a live vertex.
     */
    nir_ssa_def *vertex_live = nir_ine(b, out_vtx_primflag_0, nir_imm_zero(b, 1, out_vtx_primflag_0->bit_size));
-   wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch, s->max_num_waves, s->wave_size);
+   wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch,
+                                                          s->max_num_waves, s->options->wave_size);
 
    nir_ssa_def *workgroup_num_vertices = rep.num_repacked_invocations;
    nir_ssa_def *exporter_tid_in_tg = rep.repacked_invocation_index;
@@ -2951,31 +2931,21 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
 }
 
 void
-ac_nir_lower_ngg_gs(nir_shader *shader,
-                    enum amd_gfx_level gfx_level,
-                    unsigned wave_size,
-                    unsigned max_workgroup_size,
-                    unsigned gs_out_vtx_bytes,
-                    bool has_xfb_query,
-                    bool can_cull,
-                    bool disable_streamout)
+ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    assert(impl);
 
    lower_ngg_gs_state state = {
+      .options = options,
       .impl = impl,
-      .gfx_level = gfx_level,
-      .max_num_waves = DIV_ROUND_UP(max_workgroup_size, wave_size),
-      .wave_size = wave_size,
-      .lds_offs_primflags = gs_out_vtx_bytes,
-      .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
-      .can_cull = can_cull,
-      .streamout_enabled = shader->xfb_info && !disable_streamout,
-      .has_xfb_query = has_xfb_query,
+      .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
+      .lds_offs_primflags = options->gs_out_vtx_bytes,
+      .lds_bytes_per_gs_out_vertex = options->gs_out_vtx_bytes + 4u,
+      .streamout_enabled = shader->xfb_info && !options->disable_streamout,
    };
 
-   if (!can_cull) {
+   if (!options->can_cull) {
       nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
                                            state.const_out_prmcnt, 4u);
       state.output_compile_time_known =
index fe852fd..e31c128 100644 (file)
@@ -1426,31 +1426,39 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
 
    setup_ngg_lds_layout(nir, &ngg_stage->info, max_vtx_in);
 
+   ac_nir_lower_ngg_options options = {0};
+   options.family = device->physical_device->rad_info.family;
+   options.gfx_level = device->physical_device->rad_info.gfx_level;
+   options.max_workgroup_size = info->workgroup_size;
+   options.wave_size = info->wave_size;
+   options.can_cull = nir->info.stage != MESA_SHADER_GEOMETRY && info->has_ngg_culling;
+   options.disable_streamout = true;
+
    if (nir->info.stage == MESA_SHADER_VERTEX ||
        nir->info.stage == MESA_SHADER_TESS_EVAL) {
-      bool export_prim_id = info->outinfo.export_prim_id;
-
       assert(info->is_ngg);
 
       if (info->has_ngg_culling)
          radv_optimize_nir_algebraic(nir, false);
 
-      NIR_PASS_V(nir, ac_nir_lower_ngg_nogs,
-                 device->physical_device->rad_info.family, num_vertices_per_prim,
-                 info->workgroup_size, info->wave_size, info->has_ngg_culling,
-                 info->has_ngg_early_prim_export, info->is_ngg_passthrough,
-                 false, pl_key->primitives_generated_query,
-                 true, export_prim_id ? VARYING_SLOT_PRIMITIVE_ID : -1,
-                 pl_key->vs.instance_rate_inputs, 0, 0);
+      options.num_vertices_per_primitive = num_vertices_per_prim;
+      options.early_prim_export = info->has_ngg_early_prim_export;
+      options.passthrough = info->is_ngg_passthrough;
+      options.has_prim_query = pl_key->primitives_generated_query;
+      options.primitive_id_location = info->outinfo.export_prim_id ? VARYING_SLOT_PRIMITIVE_ID : -1;
+      options.instance_rate_inputs = pl_key->vs.instance_rate_inputs;
+
+      NIR_PASS_V(nir, ac_nir_lower_ngg_nogs, &options);
 
       /* Increase ESGS ring size so the LLVM binary contains the correct LDS size. */
       ngg_stage->info.ngg_info.esgs_ring_size = nir->info.shared_size;
    } else if (nir->info.stage == MESA_SHADER_GEOMETRY) {
       assert(info->is_ngg);
-      NIR_PASS_V(nir, ac_nir_lower_ngg_gs,
-                 device->physical_device->rad_info.gfx_level,
-                 info->wave_size, info->workgroup_size,
-                 info->gs.gsvs_vertex_size, true, false, true);
+
+      options.gs_out_vtx_bytes = info->gs.gsvs_vertex_size;
+      options.has_xfb_query = true;
+
+      NIR_PASS_V(nir, ac_nir_lower_ngg_gs, &options);
    } else if (nir->info.stage == MESA_SHADER_MESH) {
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);