radv: cleanup helpers that compute NGG info and GS info on GFX9+
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Fri, 26 Aug 2022 09:30:57 +0000 (11:30 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 1 Sep 2022 17:02:17 +0000 (17:02 +0000)
Before moving them to the shader info link step.

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18278>

src/amd/vulkan/radv_pipeline.c

index a459882..51b1d17 100644 (file)
@@ -1894,19 +1894,17 @@ radv_pipeline_init_depth_stencil_state(struct radv_graphics_pipeline *pipeline,
 }
 
 static void
-gfx9_get_gs_info(const struct radv_pipeline *pipeline, struct radv_pipeline_stage *stages,
-                 struct gfx9_gs_info *out)
+gfx9_get_gs_info(const struct radv_device *device, struct radv_pipeline_stage *es_stage,
+                 struct radv_pipeline_stage *gs_stage)
 {
-   const struct radv_physical_device *pdevice = pipeline->device->physical_device;
-   struct radv_shader_info *gs_info = &stages[MESA_SHADER_GEOMETRY].info;
-   struct radv_shader_info *es_info;
-   bool has_tess = !!stages[MESA_SHADER_TESS_CTRL].nir;
-
-   es_info = has_tess ? &stages[MESA_SHADER_TESS_EVAL].info : &stages[MESA_SHADER_VERTEX].info;
+   const enum amd_gfx_level gfx_level = device->physical_device->rad_info.gfx_level;
+   struct radv_shader_info *gs_info = &gs_stage->info;
+   struct radv_shader_info *es_info = &es_stage->info;
+   struct gfx9_gs_info *out = &gs_stage->info.gs_ring_info;
 
-   unsigned gs_num_invocations = MAX2(gs_info->gs.invocations, 1);
-   bool uses_adjacency = gs_info->gs.input_prim == SHADER_PRIM_LINES_ADJACENCY ||
-                         gs_info->gs.input_prim == SHADER_PRIM_TRIANGLES_ADJACENCY;
+   const unsigned gs_num_invocations = MAX2(gs_info->gs.invocations, 1);
+   const bool uses_adjacency = gs_info->gs.input_prim == SHADER_PRIM_LINES_ADJACENCY ||
+                               gs_info->gs.input_prim == SHADER_PRIM_TRIANGLES_ADJACENCY;
 
    /* All these are in dwords: */
    /* We can't allow using the whole LDS, because GS waves compete with
@@ -1984,10 +1982,10 @@ gfx9_get_gs_info(const struct radv_pipeline *pipeline, struct radv_pipeline_stag
     */
    es_verts -= min_es_verts - 1;
 
-   uint32_t es_verts_per_subgroup = es_verts;
-   uint32_t gs_prims_per_subgroup = gs_prims;
-   uint32_t gs_inst_prims_in_subgroup = gs_prims * gs_num_invocations;
-   uint32_t max_prims_per_subgroup = gs_inst_prims_in_subgroup * gs_info->gs.vertices_out;
+   const uint32_t es_verts_per_subgroup = es_verts;
+   const uint32_t gs_prims_per_subgroup = gs_prims;
+   const uint32_t gs_inst_prims_in_subgroup = gs_prims * gs_num_invocations;
+   const uint32_t max_prims_per_subgroup = gs_inst_prims_in_subgroup * gs_info->gs.vertices_out;
    out->lds_size = align(esgs_lds_size, 128) / 128;
    out->vgt_gs_onchip_cntl = S_028A44_ES_VERTS_PER_SUBGRP(es_verts_per_subgroup) |
                              S_028A44_GS_PRIMS_PER_SUBGRP(gs_prims_per_subgroup) |
@@ -1996,12 +1994,10 @@ gfx9_get_gs_info(const struct radv_pipeline *pipeline, struct radv_pipeline_stag
    out->vgt_esgs_ring_itemsize = esgs_itemsize;
    assert(max_prims_per_subgroup <= max_out_prims);
 
-   gl_shader_stage es_stage = has_tess ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;
-   unsigned workgroup_size = ac_compute_esgs_workgroup_size(
-      pdevice->rad_info.gfx_level, stages[es_stage].info.wave_size,
+   unsigned workgroup_size = ac_compute_esgs_workgroup_size(gfx_level, es_info->wave_size,
       es_verts_per_subgroup, gs_inst_prims_in_subgroup);
-   stages[es_stage].info.workgroup_size = workgroup_size;
-   stages[MESA_SHADER_GEOMETRY].info.workgroup_size = workgroup_size;
+   es_info->workgroup_size = workgroup_size;
+   gs_info->workgroup_size = workgroup_size;
 }
 
 static void
@@ -2015,20 +2011,17 @@ clamp_gsprims_to_esverts(unsigned *max_gsprims, unsigned max_esverts, unsigned m
 }
 
 static unsigned
-radv_get_num_input_vertices(const struct radv_pipeline_stage *stages)
+radv_get_num_input_vertices(const struct radv_pipeline_stage *es_stage,
+                            const struct radv_pipeline_stage *gs_stage)
 {
-   if (stages[MESA_SHADER_GEOMETRY].nir) {
-      nir_shader *gs = stages[MESA_SHADER_GEOMETRY].nir;
-
-      return gs->info.gs.vertices_in;
+   if (gs_stage) {
+      return gs_stage->nir->info.gs.vertices_in;
    }
 
-   if (stages[MESA_SHADER_TESS_CTRL].nir) {
-      nir_shader *tes = stages[MESA_SHADER_TESS_EVAL].nir;
-
-      if (tes->info.tess.point_mode)
+   if (es_stage->stage == MESA_SHADER_TESS_EVAL) {
+      if (es_stage->nir->info.tess.point_mode)
          return 1;
-      if (tes->info.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES)
+      if (es_stage->nir->info.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES)
          return 2;
       return 3;
    }
@@ -2046,20 +2039,17 @@ gfx10_emit_ge_pc_alloc(struct radeon_cmdbuf *cs, enum amd_gfx_level gfx_level,
 }
 
 static unsigned
-radv_get_pre_rast_input_topology(struct radv_pipeline_stage *stages)
+radv_get_pre_rast_input_topology(const struct radv_pipeline_stage *es_stage,
+                                 const struct radv_pipeline_stage *gs_stage)
 {
-   if (stages[MESA_SHADER_GEOMETRY].nir) {
-      struct radv_shader_info *gs_info = &stages[MESA_SHADER_GEOMETRY].info;
-
-      return gs_info->gs.input_prim;
+   if (gs_stage) {
+      return gs_stage->nir->info.gs.input_primitive;
    }
 
-   if (stages[MESA_SHADER_TESS_CTRL].nir) {
-      struct radv_shader_info *tes_info = &stages[MESA_SHADER_TESS_EVAL].info;
-
-      if (tes_info->tes.point_mode)
+   if (es_stage->stage == MESA_SHADER_TESS_EVAL) {
+      if (es_stage->nir->info.tess.point_mode)
          return SHADER_PRIM_POINTS;
-      if (tes_info->tes._primitive_mode == TESS_PRIMITIVE_ISOLINES)
+      if (es_stage->nir->info.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES)
          return SHADER_PRIM_LINES;
       return SHADER_PRIM_TRIANGLES;
    }
@@ -2068,20 +2058,19 @@ radv_get_pre_rast_input_topology(struct radv_pipeline_stage *stages)
 }
 
 static void
-gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *stages,
-                   struct gfx10_ngg_info *ngg)
+gfx10_get_ngg_info(const struct radv_device *device, struct radv_pipeline_stage *es_stage,
+                   struct radv_pipeline_stage *gs_stage)
 {
-   const struct radv_physical_device *pdevice = pipeline->device->physical_device;
-   struct radv_shader_info *gs_info = &stages[MESA_SHADER_GEOMETRY].info;
-   struct radv_shader_info *es_info =
-      stages[MESA_SHADER_TESS_CTRL].nir ? &stages[MESA_SHADER_TESS_EVAL].info
-                                        : &stages[MESA_SHADER_VERTEX].info;
-   unsigned gs_type = stages[MESA_SHADER_GEOMETRY].nir ? MESA_SHADER_GEOMETRY : MESA_SHADER_VERTEX;
-   unsigned max_verts_per_prim = radv_get_num_input_vertices(stages);
-   unsigned min_verts_per_prim = gs_type == MESA_SHADER_GEOMETRY ? max_verts_per_prim : 1;
-   unsigned gs_num_invocations = stages[MESA_SHADER_GEOMETRY].nir ? MAX2(gs_info->gs.invocations, 1) : 1;
-
-   const unsigned input_prim = radv_get_pre_rast_input_topology(stages);
+   const enum amd_gfx_level gfx_level = device->physical_device->rad_info.gfx_level;
+   struct radv_shader_info *gs_info = gs_stage ? &gs_stage->info : NULL;
+   struct radv_shader_info *es_info = &es_stage->info;
+   const unsigned max_verts_per_prim = radv_get_num_input_vertices(es_stage, gs_stage);
+   const unsigned min_verts_per_prim = gs_stage ? max_verts_per_prim : 1;
+   struct gfx10_ngg_info *out = gs_stage ? &gs_info->ngg_info : &es_info->ngg_info;
+
+   const unsigned gs_num_invocations = gs_stage ? MAX2(gs_info->gs.invocations, 1) : 1;
+
+   const unsigned input_prim = radv_get_pre_rast_input_topology(es_stage, gs_stage);
    const bool uses_adjacency = input_prim == SHADER_PRIM_LINES_ADJACENCY ||
                                input_prim == SHADER_PRIM_TRIANGLES_ADJACENCY;
 
@@ -2099,7 +2088,7 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
    unsigned gsprim_lds_size = 0;
 
    /* All these are per subgroup: */
-   const unsigned min_esverts = pdevice->rad_info.gfx_level >= GFX10_3 ? 29 : 24;
+   const unsigned min_esverts = gfx_level >= GFX10_3 ? 29 : 24;
    bool max_vert_out_per_gs_instance = false;
    unsigned max_esverts_base = 128;
    unsigned max_gsprims_base = 128; /* default prim group size clamp */
@@ -2114,7 +2103,7 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
     */
    max_esverts_base = MIN2(max_esverts_base, 251 + max_verts_per_prim - 1);
 
-   if (gs_type == MESA_SHADER_GEOMETRY) {
+   if (gs_stage) {
       unsigned max_out_verts_per_gsprim = gs_info->gs.vertices_out * gs_num_invocations;
 
       if (max_out_verts_per_gsprim <= 256) {
@@ -2135,9 +2124,7 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
    } else {
       /* VS and TES. */
       /* LDS size for passing data from GS to ES. */
-      struct radv_streamout_info *so_info = stages[MESA_SHADER_TESS_CTRL].nir
-                                               ? &stages[MESA_SHADER_TESS_EVAL].info.so
-                                               : &stages[MESA_SHADER_VERTEX].info.so;
+      struct radv_streamout_info *so_info = &es_info->so;
 
       if (so_info->num_outputs)
          esvert_lds_size = 4 * so_info->num_outputs + 1;
@@ -2146,7 +2133,7 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
        * corresponding to the ES thread of the provoking vertex. All
        * ES threads load and export PrimitiveID for their thread.
        */
-      if (!stages[MESA_SHADER_TESS_CTRL].nir && stages[MESA_SHADER_VERTEX].info.outinfo.export_prim_id)
+      if (es_stage->stage == MESA_SHADER_VERTEX && es_stage->info.outinfo.export_prim_id)
          esvert_lds_size = MAX2(esvert_lds_size, 1);
    }
 
@@ -2187,11 +2174,10 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
       unsigned orig_max_gsprims;
       unsigned wavesize;
 
-      if (gs_type == MESA_SHADER_GEOMETRY) {
+      if (gs_stage) {
          wavesize = gs_info->wave_size;
       } else {
-         wavesize = stages[MESA_SHADER_TESS_CTRL].nir ? stages[MESA_SHADER_TESS_EVAL].info.wave_size
-                                                      : stages[MESA_SHADER_VERTEX].info.wave_size;
+         wavesize = es_info->wave_size;
       }
 
       do {
@@ -2206,7 +2192,7 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
          max_esverts = MIN2(max_esverts, max_gsprims * max_verts_per_prim);
 
          /* Hardware restriction: minimum value of max_esverts */
-         if (pdevice->rad_info.gfx_level == GFX10)
+         if (gfx_level == GFX10)
             max_esverts = MAX2(max_esverts, min_esverts - 1 + max_verts_per_prim);
          else
             max_esverts = MAX2(max_esverts, min_esverts);
@@ -2229,26 +2215,26 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
       } while (orig_max_esverts != max_esverts || orig_max_gsprims != max_gsprims);
 
       /* Verify the restriction. */
-      if (pdevice->rad_info.gfx_level == GFX10)
+      if (gfx_level == GFX10)
          assert(max_esverts >= min_esverts - 1 + max_verts_per_prim);
       else
          assert(max_esverts >= min_esverts);
    } else {
       /* Hardware restriction: minimum value of max_esverts */
-      if (pdevice->rad_info.gfx_level == GFX10)
+      if (gfx_level == GFX10)
          max_esverts = MAX2(max_esverts, min_esverts - 1 + max_verts_per_prim);
       else
          max_esverts = MAX2(max_esverts, min_esverts);
    }
 
    unsigned max_out_vertices = max_vert_out_per_gs_instance ? gs_info->gs.vertices_out
-                               : gs_type == MESA_SHADER_GEOMETRY
+                               : gs_stage
                                   ? max_gsprims * gs_num_invocations * gs_info->gs.vertices_out
                                   : max_esverts;
    assert(max_out_vertices <= 256);
 
    unsigned prim_amp_factor = 1;
-   if (gs_type == MESA_SHADER_GEOMETRY) {
+   if (gs_stage) {
       /* Number of output primitives per GS input primitive after
        * GS instancing. */
       prim_amp_factor = gs_info->gs.vertices_out;
@@ -2259,35 +2245,36 @@ gfx10_get_ngg_info(struct radv_pipeline *pipeline, struct radv_pipeline_stage *s
     * whenever this check passes, there is enough space for a full
     * primitive without vertex reuse.
     */
-   if (pdevice->rad_info.gfx_level == GFX10)
-      ngg->hw_max_esverts = max_esverts - max_verts_per_prim + 1;
+   if (gfx_level == GFX10)
+      out->hw_max_esverts = max_esverts - max_verts_per_prim + 1;
    else
-      ngg->hw_max_esverts = max_esverts;
+      out->hw_max_esverts = max_esverts;
 
-   ngg->max_gsprims = max_gsprims;
-   ngg->max_out_verts = max_out_vertices;
-   ngg->prim_amp_factor = prim_amp_factor;
-   ngg->max_vert_out_per_gs_instance = max_vert_out_per_gs_instance;
-   ngg->ngg_emit_size = max_gsprims * gsprim_lds_size;
-   ngg->enable_vertex_grouping = true;
+   out->max_gsprims = max_gsprims;
+   out->max_out_verts = max_out_vertices;
+   out->prim_amp_factor = prim_amp_factor;
+   out->max_vert_out_per_gs_instance = max_vert_out_per_gs_instance;
+   out->ngg_emit_size = max_gsprims * gsprim_lds_size;
+   out->enable_vertex_grouping = true;
 
    /* Don't count unusable vertices. */
-   ngg->esgs_ring_size = MIN2(max_esverts, max_gsprims * max_verts_per_prim) * esvert_lds_size * 4;
+   out->esgs_ring_size = MIN2(max_esverts, max_gsprims * max_verts_per_prim) * esvert_lds_size * 4;
 
-   if (gs_type == MESA_SHADER_GEOMETRY) {
-      ngg->vgt_esgs_ring_itemsize = es_info->esgs_itemsize / 4;
+   if (gs_stage) {
+      out->vgt_esgs_ring_itemsize = es_info->esgs_itemsize / 4;
    } else {
-      ngg->vgt_esgs_ring_itemsize = 1;
+      out->vgt_esgs_ring_itemsize = 1;
    }
 
-   assert(ngg->hw_max_esverts >= min_esverts); /* HW limitation */
+   assert(out->hw_max_esverts >= min_esverts); /* HW limitation */
 
-   gl_shader_stage es_stage = stages[MESA_SHADER_TESS_CTRL].nir ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;
    unsigned workgroup_size =
       ac_compute_ngg_workgroup_size(
          max_esverts, max_gsprims * gs_num_invocations, max_out_vertices, prim_amp_factor);
-   stages[MESA_SHADER_GEOMETRY].info.workgroup_size = workgroup_size;
-   stages[es_stage].info.workgroup_size = workgroup_size;
+   if (gs_stage) {
+      gs_info->workgroup_size = workgroup_size;
+   }
+   es_info->workgroup_size = workgroup_size;
 }
 
 static void
@@ -3405,25 +3392,19 @@ radv_fill_shader_info(struct radv_pipeline *pipeline,
    }
 
    if (pipeline_has_ngg) {
-      struct gfx10_ngg_info *ngg_info;
-
-      if (stages[MESA_SHADER_GEOMETRY].nir)
-         ngg_info = &stages[MESA_SHADER_GEOMETRY].info.ngg_info;
-      else if (stages[MESA_SHADER_TESS_CTRL].nir)
-         ngg_info = &stages[MESA_SHADER_TESS_EVAL].info.ngg_info;
-      else if (stages[MESA_SHADER_VERTEX].nir)
-         ngg_info = &stages[MESA_SHADER_VERTEX].info.ngg_info;
-      else if (stages[MESA_SHADER_MESH].nir)
-         ngg_info = &stages[MESA_SHADER_MESH].info.ngg_info;
-      else
-         unreachable("Missing NGG shader stage.");
+      if (last_vgt_api_stage != MESA_SHADER_MESH) {
+         struct radv_pipeline_stage *es_stage =
+            stages[MESA_SHADER_TESS_EVAL].nir ? &stages[MESA_SHADER_TESS_EVAL] : &stages[MESA_SHADER_VERTEX];
+         struct radv_pipeline_stage *gs_stage =
+            stages[MESA_SHADER_GEOMETRY].nir ? &stages[MESA_SHADER_GEOMETRY] : NULL;
 
-      if (last_vgt_api_stage != MESA_SHADER_MESH)
-         gfx10_get_ngg_info(pipeline, stages, ngg_info);
+         gfx10_get_ngg_info(device, es_stage, gs_stage);
+      }
    } else if (stages[MESA_SHADER_GEOMETRY].nir) {
-      struct gfx9_gs_info *gs_info = &stages[MESA_SHADER_GEOMETRY].info.gs_ring_info;
+      struct radv_pipeline_stage *es_stage =
+         stages[MESA_SHADER_TESS_EVAL].nir ? &stages[MESA_SHADER_TESS_EVAL] : &stages[MESA_SHADER_VERTEX];
 
-      gfx9_get_gs_info(pipeline, stages, gs_info);
+      gfx9_get_gs_info(device, es_stage, &stages[MESA_SHADER_GEOMETRY]);
    } else {
       gl_shader_stage hw_vs_api_stage =
          stages[MESA_SHADER_TESS_EVAL].nir ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;