ac/nir/ngg: add gs culling
authorQiang Yu <yuq825@gmail.com>
Thu, 9 Jun 2022 01:11:10 +0000 (09:11 +0800)
committerMarge Bot <emma+marge@anholt.net>
Fri, 26 Aug 2022 05:50:30 +0000 (05:50 +0000)
Port from radeonsi.

Cull primitive after GS thread and before final vertex/primitive
export. GS culling is like VS/TES culling which read out saved
vertex positions of a primitive from LDS then call the primitive
culling algorithm to check whether it's visiable or not, only
passed primitives will be exported.

Unlike the VS/TES culling that read vertex index of a primitive
from VGPRs as shader args, GS will set a primitive complete flag
for each last vertex of a primitive in LDS, so that vertex thread
know the previous 1/2/3 vertex can form a primitive and do primitive
culling.

Acked-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17651>

src/amd/common/ac_nir.h
src/amd/common/ac_nir_lower_ngg.c
src/amd/vulkan/radv_shader.c

index c2fb8c5..6c0d86a 100644 (file)
@@ -140,7 +140,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned esgs_ring_lds_bytes,
                     unsigned gs_out_vtx_bytes,
                     unsigned gs_total_out_vtx_bytes,
-                    bool provoking_vtx_last);
+                    bool provoking_vtx_last,
+                    bool can_cull);
 
 void
 ac_nir_lower_ngg_ms(nir_shader *shader,
index 551a20a..ac87580 100644 (file)
@@ -100,6 +100,7 @@ typedef struct
    bool found_out_vtxcnt[4];
    bool output_compile_time_known;
    bool provoking_vertex_last;
+   bool can_cull;
    gs_output_info output_info[VARYING_SLOT_MAX];
 } lower_ngg_gs_state;
 
@@ -1782,12 +1783,17 @@ lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intri
    /* Calculate and store per-vertex primitive flags based on vertex counts:
     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
-    * - bit 2: always 1 (so that we can use it for determining vertex liveness)
+    * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
     */
 
+   nir_ssa_def *vertex_live_flag = !stream && s->can_cull ?
+      nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2) :
+      nir_imm_int(b, 0b100);
+
    nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1));
-   nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u));
+   nir_ssa_def *complete_flag = nir_b2i32(b, completes_prim);
 
+   nir_ssa_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
    if (s->num_vertices_per_primitive == 3) {
       nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1);
       prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1)));
@@ -1988,6 +1994,124 @@ ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_d
 }
 
 static void
+ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_ssa_def *last_vtxidx, nir_ssa_def *last_vtxptr,
+                           nir_ssa_def *last_vtx_primflag, lower_ngg_gs_state *s,
+                           nir_ssa_def *vtxptr[3])
+{
+   unsigned last_vtx = s->num_vertices_per_primitive - 1;
+   vtxptr[last_vtx]= last_vtxptr;
+
+   bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
+   nir_ssa_def *is_odd = primitive_is_triangle ?
+      nir_ubfe(b, last_vtx_primflag, nir_imm_int(b, 1), nir_imm_int(b, 1)) : NULL;
+
+   for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
+      nir_ssa_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
+
+      /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
+       * CW/CCW order for correct front/back face culling.
+       */
+      if (primitive_is_triangle)
+         vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
+
+      vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
+   }
+}
+
+static nir_ssa_def *
+ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_vtxcnt,
+                      nir_ssa_def *out_vtx_lds_addr, nir_ssa_def *out_vtx_primflag_0,
+                      lower_ngg_gs_state *s)
+{
+   /* we haven't enabled point culling, if enabled this function could be further optimized */
+   assert(s->num_vertices_per_primitive > 1);
+
+   /* save the primflag so that we don't need to load it from LDS again */
+   nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
+   nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
+
+   /* last bit of primflag indicate if this is the final vertex of a primitive */
+   nir_ssa_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
+   nir_ssa_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
+   nir_ssa_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
+
+   nir_if *if_prim_enable = nir_push_if(b, prim_enable);
+   {
+      /* Calculate the LDS address of every vertex in the current primitive. */
+      nir_ssa_def *vtxptr[3];
+      ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
+
+      /* Load the positions from LDS. */
+      nir_ssa_def *pos[3][4];
+      for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
+         /* VARYING_SLOT_POS == 0, so base won't count packed location */
+         pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
+         nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
+         pos[i][0] = nir_channel(b, xy, 0);
+         pos[i][1] = nir_channel(b, xy, 1);
+
+         pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
+         pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
+      }
+
+      nir_ssa_def *accepted = ac_nir_cull_primitive(
+         b, nir_imm_bool(b, true), pos, s->num_vertices_per_primitive, NULL, NULL);
+
+      nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
+      {
+         /* clear the primflag if rejected */
+         nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
+                          .base = s->lds_offs_primflags);
+
+         nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
+      }
+      nir_pop_if(b, if_rejected);
+   }
+   nir_pop_if(b, if_prim_enable);
+
+   /* Wait for LDS primflag access done. */
+   nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
+                         .memory_scope = NIR_SCOPE_WORKGROUP,
+                         .memory_semantics = NIR_MEMORY_ACQ_REL,
+                         .memory_modes = nir_var_mem_shared);
+
+   /* only dead vertex need a chance to relive */
+   nir_ssa_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
+   nir_ssa_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
+   nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
+   {
+      /* get succeeding vertices' primflag to detect this vertex's liveness */
+      for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
+         nir_ssa_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
+         nir_ssa_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
+         nir_if *if_not_overflow = nir_push_if(b, not_overflow);
+         {
+            nir_ssa_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
+            nir_ssa_def *vtx_primflag =
+               nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
+            vtx_primflag = nir_u2u32(b, vtx_primflag);
+
+            /* if succeeding vertex is alive end of primitive vertex, need to set current
+             * thread vertex's liveness flag (bit 2)
+             */
+            nir_ssa_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
+            nir_ssa_def *vtx_live_flag =
+               nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
+
+            /* update this vertex's primflag */
+            nir_ssa_def *primflag = nir_load_var(b, primflag_var);
+            primflag = nir_ior(b, primflag, vtx_live_flag);
+            nir_store_var(b, primflag_var, primflag, 1);
+         }
+         nir_pop_if(b, if_not_overflow);
+      }
+   }
+   nir_pop_if(b, if_update_primflag);
+
+   return nir_load_var(b, primflag_var);
+}
+
+static void
 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
 {
    nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
@@ -2016,6 +2140,20 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
       return;
    }
 
+   /* cull primitives */
+   if (s->can_cull) {
+      nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
+
+      /* culling code will update the primflag */
+      nir_ssa_def *updated_primflag =
+         ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
+                               out_vtx_primflag_0, s);
+
+      nir_pop_if(b, if_cull_en);
+
+      out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
+   }
+
    /* When the output vertex count is not known at compile time:
     * There may be gaps between invocations that have live vertices, but NGG hardware
     * requires that the invocations that export vertices are packed (ie. compact).
@@ -2054,7 +2192,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned esgs_ring_lds_bytes,
                     unsigned gs_out_vtx_bytes,
                     unsigned gs_total_out_vtx_bytes,
-                    bool provoking_vertex_last)
+                    bool provoking_vertex_last,
+                    bool can_cull)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    assert(impl);
@@ -2068,15 +2207,20 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
       .lds_offs_primflags = gs_out_vtx_bytes,
       .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
       .provoking_vertex_last = provoking_vertex_last,
+      .can_cull = can_cull,
    };
 
    unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
    unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
    shader->info.shared_size = total_lds_bytes;
 
-   nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u);
-   state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
-                                     state.const_out_prmcnt[0] != -1;
+   if (!can_cull) {
+      nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
+                                           state.const_out_prmcnt, 4u);
+      state.output_compile_time_known =
+         state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
+         state.const_out_prmcnt[0] != -1;
+   }
 
    if (!state.output_compile_time_known)
       state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx");
index 2259865..3044adc 100644 (file)
@@ -1342,7 +1342,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
       assert(info->is_ngg);
       NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size,
                  info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
-                 info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last);
+                 info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false);
    } else if (nir->info.stage == MESA_SHADER_MESH) {
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);