From 1bdeb961bd8c953713646ae0ea28cc5f21f4a232 Mon Sep 17 00:00:00 2001 From: Qiang Yu Date: Thu, 9 Jun 2022 09:11:10 +0800 Subject: [PATCH] ac/nir/ngg: add gs culling MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Port from radeonsi. Cull primitive after GS thread and before final vertex/primitive export. GS culling is like VS/TES culling which read out saved vertex positions of a primitive from LDS then call the primitive culling algorithm to check whether it's visiable or not, only passed primitives will be exported. Unlike the VS/TES culling that read vertex index of a primitive from VGPRs as shader args, GS will set a primitive complete flag for each last vertex of a primitive in LDS, so that vertex thread know the previous 1/2/3 vertex can form a primitive and do primitive culling. Acked-by: Marek Olšák Reviewed-by: Timur Kristóf Signed-off-by: Qiang Yu Part-of: --- src/amd/common/ac_nir.h | 3 +- src/amd/common/ac_nir_lower_ngg.c | 156 ++++++++++++++++++++++++++++++++++++-- src/amd/vulkan/radv_shader.c | 2 +- 3 files changed, 153 insertions(+), 8 deletions(-) diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h index c2fb8c5..6c0d86a 100644 --- a/src/amd/common/ac_nir.h +++ b/src/amd/common/ac_nir.h @@ -140,7 +140,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader, unsigned esgs_ring_lds_bytes, unsigned gs_out_vtx_bytes, unsigned gs_total_out_vtx_bytes, - bool provoking_vtx_last); + bool provoking_vtx_last, + bool can_cull); void ac_nir_lower_ngg_ms(nir_shader *shader, diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index 551a20a..ac87580 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -100,6 +100,7 @@ typedef struct bool found_out_vtxcnt[4]; bool output_compile_time_known; bool provoking_vertex_last; + bool can_cull; gs_output_info output_info[VARYING_SLOT_MAX]; } lower_ngg_gs_state; @@ -1782,12 +1783,17 @@ lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intri /* Calculate and store per-vertex primitive flags based on vertex counts: * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip) * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0) - * - bit 2: always 1 (so that we can use it for determining vertex liveness) + * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1) */ + nir_ssa_def *vertex_live_flag = !stream && s->can_cull ? + nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2) : + nir_imm_int(b, 0b100); + nir_ssa_def *completes_prim = nir_ige(b, current_vtx_per_prim, nir_imm_int(b, s->num_vertices_per_primitive - 1)); - nir_ssa_def *prim_flag = nir_bcsel(b, completes_prim, nir_imm_int(b, 0b101u), nir_imm_int(b, 0b100u)); + nir_ssa_def *complete_flag = nir_b2i32(b, completes_prim); + nir_ssa_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag); if (s->num_vertices_per_primitive == 3) { nir_ssa_def *odd = nir_iand_imm(b, current_vtx_per_prim, 1); prim_flag = nir_iadd_nuw(b, prim_flag, nir_ishl(b, odd, nir_imm_int(b, 1))); @@ -1988,6 +1994,124 @@ ngg_gs_load_out_vtx_primflag_0(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_d } static void +ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_ssa_def *last_vtxidx, nir_ssa_def *last_vtxptr, + nir_ssa_def *last_vtx_primflag, lower_ngg_gs_state *s, + nir_ssa_def *vtxptr[3]) +{ + unsigned last_vtx = s->num_vertices_per_primitive - 1; + vtxptr[last_vtx]= last_vtxptr; + + bool primitive_is_triangle = s->num_vertices_per_primitive == 3; + nir_ssa_def *is_odd = primitive_is_triangle ? + nir_ubfe(b, last_vtx_primflag, nir_imm_int(b, 1), nir_imm_int(b, 1)) : NULL; + + for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) { + nir_ssa_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i)); + + /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep + * CW/CCW order for correct front/back face culling. + */ + if (primitive_is_triangle) + vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd); + + vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s); + } +} + +static nir_ssa_def * +ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_vtxcnt, + nir_ssa_def *out_vtx_lds_addr, nir_ssa_def *out_vtx_primflag_0, + lower_ngg_gs_state *s) +{ + /* we haven't enabled point culling, if enabled this function could be further optimized */ + assert(s->num_vertices_per_primitive > 1); + + /* save the primflag so that we don't need to load it from LDS again */ + nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag"); + nir_store_var(b, primflag_var, out_vtx_primflag_0, 1); + + /* last bit of primflag indicate if this is the final vertex of a primitive */ + nir_ssa_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1)); + nir_ssa_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt); + nir_ssa_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex); + + nir_if *if_prim_enable = nir_push_if(b, prim_enable); + { + /* Calculate the LDS address of every vertex in the current primitive. */ + nir_ssa_def *vtxptr[3]; + ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr); + + /* Load the positions from LDS. */ + nir_ssa_def *pos[3][4]; + for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) { + /* VARYING_SLOT_POS == 0, so base won't count packed location */ + pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */ + nir_ssa_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4); + pos[i][0] = nir_channel(b, xy, 0); + pos[i][1] = nir_channel(b, xy, 1); + + pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]); + pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]); + } + + nir_ssa_def *accepted = ac_nir_cull_primitive( + b, nir_imm_bool(b, true), pos, s->num_vertices_per_primitive, NULL, NULL); + + nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted)); + { + /* clear the primflag if rejected */ + nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr, + .base = s->lds_offs_primflags); + + nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1); + } + nir_pop_if(b, if_rejected); + } + nir_pop_if(b, if_prim_enable); + + /* Wait for LDS primflag access done. */ + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + /* only dead vertex need a chance to relive */ + nir_ssa_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0); + nir_ssa_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex); + nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag); + { + /* get succeeding vertices' primflag to detect this vertex's liveness */ + for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) { + nir_ssa_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i); + nir_ssa_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt); + nir_if *if_not_overflow = nir_push_if(b, not_overflow); + { + nir_ssa_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s); + nir_ssa_def *vtx_primflag = + nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags); + vtx_primflag = nir_u2u32(b, vtx_primflag); + + /* if succeeding vertex is alive end of primitive vertex, need to set current + * thread vertex's liveness flag (bit 2) + */ + nir_ssa_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1)); + nir_ssa_def *vtx_live_flag = + nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0)); + + /* update this vertex's primflag */ + nir_ssa_def *primflag = nir_load_var(b, primflag_var); + primflag = nir_ior(b, primflag, vtx_live_flag); + nir_store_var(b, primflag_var, primflag, 1); + } + nir_pop_if(b, if_not_overflow); + } + } + nir_pop_if(b, if_update_primflag); + + return nir_load_var(b, primflag_var); +} + +static void ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) { nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b); @@ -2016,6 +2140,20 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) return; } + /* cull primitives */ + if (s->can_cull) { + nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b)); + + /* culling code will update the primflag */ + nir_ssa_def *updated_primflag = + ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr, + out_vtx_primflag_0, s); + + nir_pop_if(b, if_cull_en); + + out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0); + } + /* When the output vertex count is not known at compile time: * There may be gaps between invocations that have live vertices, but NGG hardware * requires that the invocations that export vertices are packed (ie. compact). @@ -2054,7 +2192,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader, unsigned esgs_ring_lds_bytes, unsigned gs_out_vtx_bytes, unsigned gs_total_out_vtx_bytes, - bool provoking_vertex_last) + bool provoking_vertex_last, + bool can_cull) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); assert(impl); @@ -2068,15 +2207,20 @@ ac_nir_lower_ngg_gs(nir_shader *shader, .lds_offs_primflags = gs_out_vtx_bytes, .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u, .provoking_vertex_last = provoking_vertex_last, + .can_cull = can_cull, }; unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u; unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes; shader->info.shared_size = total_lds_bytes; - nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, state.const_out_prmcnt, 4u); - state.output_compile_time_known = state.const_out_vtxcnt[0] == shader->info.gs.vertices_out && - state.const_out_prmcnt[0] != -1; + if (!can_cull) { + nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt, + state.const_out_prmcnt, 4u); + state.output_compile_time_known = + state.const_out_vtxcnt[0] == shader->info.gs.vertices_out && + state.const_out_prmcnt[0] != -1; + } if (!state.output_compile_time_known) state.current_clear_primflag_idx_var = nir_local_variable_create(impl, glsl_uint_type(), "current_clear_primflag_idx"); diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 2259865..3044adc 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -1342,7 +1342,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_ assert(info->is_ngg); NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size, info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size, - info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last); + info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false); } else if (nir->info.stage == MESA_SHADER_MESH) { bool scratch_ring = false; NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index); -- 2.7.4