ac/nir: move ngg_gs_shader_query to a common function
authorQiang Yu <yuq825@gmail.com>
Wed, 30 Nov 2022 08:49:11 +0000 (16:49 +0800)
committerQiang Yu <yuq825@gmail.com>
Tue, 13 Dec 2022 03:43:49 +0000 (11:43 +0800)
To be shared by NGG GS and legacy GS. Legacy GS need this when
GFX10 which mix use NGG and legacy GS. For example when streamout
is enabled, it uses legacy GS, otherwise uses NGG GS. So legacy
GS also need to update query emulation which is a sum of NGG and
legacy GS results.

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20158>

src/amd/common/ac_nir.c
src/amd/common/ac_nir.h
src/amd/common/ac_nir_lower_ngg.c

index 569bb9e..8e738d5 100644 (file)
@@ -333,3 +333,108 @@ ac_nir_lower_legacy_vs(nir_shader *nir, int primitive_id_location, bool disable_
    nir_export_vertex_amd(&b);
    nir_metadata_preserve(impl, preserved);
 }
+
+bool
+ac_nir_gs_shader_query(nir_builder *b,
+                       bool has_gen_prim_query,
+                       bool has_pipeline_stats_query,
+                       unsigned num_vertices_per_primitive,
+                       unsigned wave_size,
+                       nir_ssa_def *vertex_count[4],
+                       nir_ssa_def *primitive_count[4])
+{
+   nir_ssa_def *pipeline_query_enabled = NULL;
+   nir_ssa_def *prim_gen_query_enabled = NULL;
+   nir_ssa_def *shader_query_enabled = NULL;
+   if (has_gen_prim_query) {
+      prim_gen_query_enabled = nir_load_prim_gen_query_enabled_amd(b);
+      if (has_pipeline_stats_query) {
+         pipeline_query_enabled = nir_load_pipeline_stat_query_enabled_amd(b);
+         shader_query_enabled = nir_ior(b, pipeline_query_enabled, prim_gen_query_enabled);
+      } else {
+         shader_query_enabled = prim_gen_query_enabled;
+      }
+   } else if (has_pipeline_stats_query) {
+      pipeline_query_enabled = nir_load_pipeline_stat_query_enabled_amd(b);
+      shader_query_enabled = pipeline_query_enabled;
+   } else {
+      /* has no query */
+      return false;
+   }
+
+   nir_if *if_shader_query = nir_push_if(b, shader_query_enabled);
+
+   nir_ssa_def *active_threads_mask = nir_ballot(b, 1, wave_size, nir_imm_bool(b, true));
+   nir_ssa_def *num_active_threads = nir_bit_count(b, active_threads_mask);
+
+   /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
+    * GS emits points, line strips or triangle strips.
+    * Real primitives are points, lines or triangles.
+    */
+   nir_ssa_def *num_prims_in_wave[4] = {0};
+   u_foreach_bit (i, b->shader->info.gs.active_stream_mask) {
+      assert(vertex_count[i] && primitive_count[i]);
+
+      nir_ssa_scalar vtx_cnt = nir_get_ssa_scalar(vertex_count[i], 0);
+      nir_ssa_scalar prm_cnt = nir_get_ssa_scalar(primitive_count[i], 0);
+
+      if (nir_ssa_scalar_is_const(vtx_cnt) && nir_ssa_scalar_is_const(prm_cnt)) {
+         unsigned gs_vtx_cnt = nir_ssa_scalar_as_uint(vtx_cnt);
+         unsigned gs_prm_cnt = nir_ssa_scalar_as_uint(prm_cnt);
+         unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (num_vertices_per_primitive - 1u);
+         if (total_prm_cnt == 0)
+            continue;
+
+         num_prims_in_wave[i] = nir_imul_imm(b, num_active_threads, total_prm_cnt);
+      } else {
+         nir_ssa_def *gs_vtx_cnt = vtx_cnt.def;
+         nir_ssa_def *gs_prm_cnt = prm_cnt.def;
+         if (num_vertices_per_primitive > 1)
+            gs_prm_cnt = nir_iadd(b, nir_imul_imm(b, gs_prm_cnt, -1u * (num_vertices_per_primitive - 1)), gs_vtx_cnt);
+         num_prims_in_wave[i] = nir_reduce(b, gs_prm_cnt, .reduction_op = nir_op_iadd);
+      }
+   }
+
+   /* Store the query result to query result using an atomic add. */
+   nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
+   {
+      if (has_pipeline_stats_query) {
+         nir_if *if_pipeline_query = nir_push_if(b, pipeline_query_enabled);
+         {
+            nir_ssa_def *count = NULL;
+
+            /* Add all streams' number to the same counter. */
+            for (int i = 0; i < 4; i++) {
+               if (num_prims_in_wave[i]) {
+                  if (count)
+                     count = nir_iadd(b, count, num_prims_in_wave[i]);
+                  else
+                     count = num_prims_in_wave[i];
+               }
+            }
+
+            if (count)
+               nir_atomic_add_gs_emit_prim_count_amd(b, count);
+
+            nir_atomic_add_gs_invocation_count_amd(b, num_active_threads);
+         }
+         nir_pop_if(b, if_pipeline_query);
+      }
+
+      if (has_gen_prim_query) {
+         nir_if *if_prim_gen_query = nir_push_if(b, prim_gen_query_enabled);
+         {
+            /* Add to the counter for this stream. */
+            for (int i = 0; i < 4; i++) {
+               if (num_prims_in_wave[i])
+                  nir_atomic_add_gen_prim_count_amd(b, num_prims_in_wave[i], .stream_id = i);
+            }
+         }
+         nir_pop_if(b, if_prim_gen_query);
+      }
+   }
+   nir_pop_if(b, if_first_lane);
+
+   nir_pop_if(b, if_shader_query);
+   return true;
+}
index 39794a4..e32b694 100644 (file)
@@ -198,6 +198,15 @@ ac_nir_create_gs_copy_shader(const nir_shader *gs_nir,
 void
 ac_nir_lower_legacy_vs(nir_shader *nir, int primitive_id_location, bool disable_streamout);
 
+bool
+ac_nir_gs_shader_query(nir_builder *b,
+                       bool has_gen_prim_query,
+                       bool has_pipeline_stats_query,
+                       unsigned num_vertices_per_primitive,
+                       unsigned wave_size,
+                       nir_ssa_def *vertex_count[4],
+                       nir_ssa_def *primitive_count[4]);
+
 #ifdef __cplusplus
 }
 #endif
index bf075a3..7fbc763 100644 (file)
@@ -2391,107 +2391,6 @@ ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def *num_vertices, unsigned strea
    nir_pop_loop(b, loop);
 }
 
-static void
-ngg_gs_shader_query(nir_builder *b, lower_ngg_gs_state *s)
-{
-   bool has_gen_prim_query = s->options->has_gen_prim_query;
-   bool has_pipeline_stats_query = s->options->gfx_level < GFX11;
-
-   nir_ssa_def *pipeline_query_enabled = NULL;
-   nir_ssa_def *prim_gen_query_enabled = NULL;
-   nir_ssa_def *shader_query_enabled = NULL;
-   if (has_gen_prim_query) {
-      prim_gen_query_enabled = nir_load_prim_gen_query_enabled_amd(b);
-      if (has_pipeline_stats_query) {
-         pipeline_query_enabled = nir_load_pipeline_stat_query_enabled_amd(b);
-         shader_query_enabled = nir_ior(b, pipeline_query_enabled, prim_gen_query_enabled);
-      } else {
-         shader_query_enabled = prim_gen_query_enabled;
-      }
-   } else if (has_pipeline_stats_query) {
-      pipeline_query_enabled = nir_load_pipeline_stat_query_enabled_amd(b);
-      shader_query_enabled = pipeline_query_enabled;
-   } else {
-      /* has no query */
-      return;
-   }
-
-   nir_if *if_shader_query = nir_push_if(b, shader_query_enabled);
-
-   nir_ssa_def *active_threads_mask = nir_ballot(b, 1, s->options->wave_size, nir_imm_bool(b, true));
-   nir_ssa_def *num_active_threads = nir_bit_count(b, active_threads_mask);
-
-   /* Calculate the "real" number of emitted primitives from the emitted GS vertices and primitives.
-    * GS emits points, line strips or triangle strips.
-    * Real primitives are points, lines or triangles.
-    */
-   nir_ssa_def *num_prims_in_wave[4] = {0};
-   u_foreach_bit (i, b->shader->info.gs.active_stream_mask) {
-      assert(s->vertex_count[i] && s->primitive_count[i]);
-
-      nir_ssa_scalar vtx_cnt = nir_get_ssa_scalar(s->vertex_count[i], 0);
-      nir_ssa_scalar prm_cnt = nir_get_ssa_scalar(s->primitive_count[i], 0);
-
-      if (nir_ssa_scalar_is_const(vtx_cnt) && nir_ssa_scalar_is_const(prm_cnt)) {
-         unsigned gs_vtx_cnt = nir_ssa_scalar_as_uint(vtx_cnt);
-         unsigned gs_prm_cnt = nir_ssa_scalar_as_uint(prm_cnt);
-         unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * (s->num_vertices_per_primitive - 1u);
-         if (total_prm_cnt == 0)
-            continue;
-
-         num_prims_in_wave[i] = nir_imul_imm(b, num_active_threads, total_prm_cnt);
-      } else {
-         nir_ssa_def *gs_vtx_cnt = vtx_cnt.def;
-         nir_ssa_def *gs_prm_cnt = prm_cnt.def;
-         if (s->num_vertices_per_primitive > 1)
-            gs_prm_cnt = nir_iadd(b, nir_imul_imm(b, gs_prm_cnt, -1u * (s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
-         num_prims_in_wave[i] = nir_reduce(b, gs_prm_cnt, .reduction_op = nir_op_iadd);
-      }
-   }
-
-   /* Store the query result to query result using an atomic add. */
-   nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
-   {
-      if (has_pipeline_stats_query) {
-         nir_if *if_pipeline_query = nir_push_if(b, pipeline_query_enabled);
-         {
-            nir_ssa_def *count = NULL;
-
-            /* Add all streams' number to the same counter. */
-            for (int i = 0; i < 4; i++) {
-               if (num_prims_in_wave[i]) {
-                  if (count)
-                     count = nir_iadd(b, count, num_prims_in_wave[i]);
-                  else
-                     count = num_prims_in_wave[i];
-               }
-            }
-
-            if (count)
-               nir_atomic_add_gs_emit_prim_count_amd(b, count);
-
-            nir_atomic_add_gs_invocation_count_amd(b, num_active_threads);
-         }
-         nir_pop_if(b, if_pipeline_query);
-      }
-
-      if (has_gen_prim_query) {
-         nir_if *if_prim_gen_query = nir_push_if(b, prim_gen_query_enabled);
-         {
-            /* Add to the counter for this stream. */
-            for (int i = 0; i < 4; i++) {
-               if (num_prims_in_wave[i])
-                  nir_atomic_add_gen_prim_count_amd(b, num_prims_in_wave[i], .stream_id = i);
-            }
-         }
-         nir_pop_if(b, if_prim_gen_query);
-      }
-   }
-   nir_pop_if(b, if_first_lane);
-
-   nir_pop_if(b, if_shader_query);
-}
-
 static bool
 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
 {
@@ -3381,7 +3280,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
 
    /* Emit shader queries */
    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
-   ngg_gs_shader_query(b, &state);
+   ac_nir_gs_shader_query(b,
+                          state.options->has_gen_prim_query,
+                          state.options->gfx_level < GFX11,
+                          state.num_vertices_per_primitive,
+                          state.options->wave_size,
+                          state.vertex_count,
+                          state.primitive_count);
 
    b->cursor = nir_after_cf_list(&impl->body);