nir: Add ability to count emitted GS vertices per primitive.
authorTimur Kristóf <timur.kristof@gmail.com>
Tue, 16 Jun 2020 16:58:39 +0000 (18:58 +0200)
committerTimur Kristóf <timur.kristof@gmail.com>
Fri, 9 Oct 2020 13:26:14 +0000 (15:26 +0200)
Add an option to nir_lower_gs_intrinsics so that it can also track
the number of emitted vertices per primitive, not just the total
vertex count.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6964>

src/compiler/nir/nir.h
src/compiler/nir/nir_intrinsics.py
src/compiler/nir/nir_lower_gs_intrinsics.c

index ebbed9d..e712c6d 100644 (file)
@@ -4729,6 +4729,7 @@ bool nir_lower_to_source_mods(nir_shader *shader, nir_lower_to_source_mods_flags
 typedef enum {
    nir_lower_gs_intrinsics_per_stream = 1 << 0,
    nir_lower_gs_intrinsics_count_primitives = 1 << 1,
+   nir_lower_gs_intrinsics_count_vertices_per_primitive = 1 << 2,
 } nir_lower_gs_intrinsics_flags;
 
 bool nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags options);
index 049af6f..deafdb0 100644 (file)
@@ -348,11 +348,13 @@ intrinsic("end_primitive", indices=[STREAM_ID])
 # Alternatively, drivers may implement these intrinsics, and use
 # nir_lower_gs_intrinsics() to convert from the basic intrinsics.
 #
-# These maintain a count of the number of vertices emitted, as an additional
-# unsigned integer source.
-intrinsic("emit_vertex_with_counter", src_comp=[1], indices=[STREAM_ID])
-intrinsic("end_primitive_with_counter", src_comp=[1], indices=[STREAM_ID])
-# Contains the final total vertex and primitive counts
+# These contain two additional unsigned integer sources:
+# 1. The total number of vertices emitted so far.
+# 2. The number of vertices emitted for the current primitive
+#    so far if we're counting, otherwise undef.
+intrinsic("emit_vertex_with_counter", src_comp=[1, 1], indices=[STREAM_ID])
+intrinsic("end_primitive_with_counter", src_comp=[1, 1], indices=[STREAM_ID])
+# Contains the final total vertex and primitive counts in the current GS thread.
 intrinsic("set_vertex_and_primitive_count", src_comp=[1, 1], indices=[STREAM_ID])
 
 # Atomic counters
index 07a17de..a514551 100644 (file)
 struct state {
    nir_builder *builder;
    nir_variable *vertex_count_vars[NIR_MAX_XFB_STREAMS];
+   nir_variable *vtxcnt_per_prim_vars[NIR_MAX_XFB_STREAMS];
    nir_variable *primitive_count_vars[NIR_MAX_XFB_STREAMS];
    bool per_stream;
    bool count_prims;
+   bool count_vtx_per_prim;
    bool progress;
 };
 
@@ -67,8 +69,9 @@ struct state {
  * Replace emit_vertex intrinsics with:
  *
  * if (vertex_count < max_vertices) {
- *    emit_vertex_with_counter vertex_count ...
+ *    emit_vertex_with_counter vertex_count, vertex_count_per_primitive (optional) ...
  *    vertex_count += 1
+ *    vertex_count_per_primitive += 1
  * }
  */
 static void
@@ -81,6 +84,12 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
    b->cursor = nir_before_instr(&intrin->instr);
    assert(state->vertex_count_vars[stream] != NULL);
    nir_ssa_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
+   nir_ssa_def *count_per_primitive;
+
+   if (state->count_vtx_per_prim)
+      count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
+   else
+      count_per_primitive = nir_ssa_undef(b, 1, 32);
 
    nir_ssa_def *max_vertices =
       nir_imm_int(b, b->shader->info.gs.vertices_out);
@@ -97,6 +106,7 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
                                  nir_intrinsic_emit_vertex_with_counter);
    nir_intrinsic_set_stream_id(lowered, stream);
    lowered->src[0] = nir_src_for_ssa(count);
+   lowered->src[1] = nir_src_for_ssa(count_per_primitive);
    nir_builder_instr_insert(b, &lowered->instr);
 
    /* Increment the vertex count by 1 */
@@ -104,6 +114,15 @@ rewrite_emit_vertex(nir_intrinsic_instr *intrin, struct state *state)
                  nir_iadd_imm(b, count, 1),
                  0x1); /* .x */
 
+   if (state->count_vtx_per_prim) {
+      /* Increment the per-primitive vertex count by 1 */
+      nir_variable *var = state->vtxcnt_per_prim_vars[stream];
+      nir_ssa_def *vtx_per_prim_cnt = nir_load_var(b, var);
+      nir_store_var(b, var,
+                    nir_iadd_imm(b, vtx_per_prim_cnt, 1),
+                    0x1); /* .x */
+   }
+
    nir_pop_if(b, NULL);
 
    nir_instr_remove(&intrin->instr);
@@ -123,12 +142,19 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
    b->cursor = nir_before_instr(&intrin->instr);
    assert(state->vertex_count_vars[stream] != NULL);
    nir_ssa_def *count = nir_load_var(b, state->vertex_count_vars[stream]);
+   nir_ssa_def *count_per_primitive;
+
+   if (state->count_vtx_per_prim)
+      count_per_primitive = nir_load_var(b, state->vtxcnt_per_prim_vars[stream]);
+   else
+      count_per_primitive = nir_ssa_undef(b, count->num_components, count->bit_size);
 
    nir_intrinsic_instr *lowered =
       nir_intrinsic_instr_create(b->shader,
                                  nir_intrinsic_end_primitive_with_counter);
    nir_intrinsic_set_stream_id(lowered, stream);
    lowered->src[0] = nir_src_for_ssa(count);
+   lowered->src[1] = nir_src_for_ssa(count_per_primitive);
    nir_builder_instr_insert(b, &lowered->instr);
 
    if (state->count_prims) {
@@ -139,6 +165,13 @@ rewrite_end_primitive(nir_intrinsic_instr *intrin, struct state *state)
                     0x1); /* .x */
    }
 
+   if (state->count_vtx_per_prim) {
+      /* Store 0 to per-primitive vertex count */
+      nir_store_var(b, state->vtxcnt_per_prim_vars[stream],
+                    nir_imm_int(b, 0),
+                    0x1); /* .x */
+   }
+
    nir_instr_remove(&intrin->instr);
 
    state->progress = true;
@@ -218,10 +251,13 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
 {
    bool per_stream = options & nir_lower_gs_intrinsics_per_stream;
    bool count_primitives = options & nir_lower_gs_intrinsics_count_primitives;
+   bool count_vtx_per_prim =
+      options & nir_lower_gs_intrinsics_count_vertices_per_primitive;
 
    struct state state;
    state.progress = false;
    state.count_prims = count_primitives;
+   state.count_vtx_per_prim = count_vtx_per_prim;
    state.per_stream = per_stream;
 
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
@@ -249,6 +285,12 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
             /* initialize to 1 */
             nir_store_var(&b, state.primitive_count_vars[i], nir_imm_int(&b, 1), 0x1);
          }
+         if (count_vtx_per_prim) {
+            state.vtxcnt_per_prim_vars[i] =
+               nir_local_variable_create(impl, glsl_uint_type(), "vertices_per_primitive");
+            /* initialize to 0 */
+            nir_store_var(&b, state.vtxcnt_per_prim_vars[i], nir_imm_int(&b, 0), 0x1);
+         }
       } else {
          /* If per_stream is false, we only have one counter of each kind which we
           * want to use for all streams. Duplicate the counter pointers so all
@@ -258,6 +300,8 @@ nir_lower_gs_intrinsics(nir_shader *shader, nir_lower_gs_intrinsics_flags option
 
          if (count_primitives)
             state.primitive_count_vars[i] = state.primitive_count_vars[0];
+         if (count_vtx_per_prim)
+            state.vtxcnt_per_prim_vars[i] = state.vtxcnt_per_prim_vars[0];
       }
    }