ac/nir/ngg: use nir_load_provoking_vtx_in_prim_amd in ngg lower
authorQiang Yu <yuq825@gmail.com>
Sun, 12 Jun 2022 10:27:21 +0000 (18:27 +0800)
committerMarge Bot <emma+marge@anholt.net>
Thu, 20 Oct 2022 06:53:56 +0000 (06:53 +0000)
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Signed-off-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19166>

src/amd/common/ac_nir.h
src/amd/common/ac_nir_lower_ngg.c
src/amd/vulkan/radv_shader.c

index ac4351f..63c9ed4 100644 (file)
@@ -128,7 +128,6 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
                       bool early_prim_export,
                       bool passthrough,
                       bool export_prim_id,
-                      bool provoking_vtx_last,
                       bool use_edgeflags,
                       bool has_prim_query,
                       bool disable_streamout,
@@ -143,7 +142,6 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned esgs_ring_lds_bytes,
                     unsigned gs_out_vtx_bytes,
                     unsigned gs_total_out_vtx_bytes,
-                    bool provoking_vtx_last,
                     bool can_cull,
                     bool disable_streamout);
 
index fbb4a86..4f9e160 100644 (file)
@@ -66,7 +66,6 @@ typedef struct
    unsigned wave_size;
    unsigned max_num_waves;
    unsigned num_vertices_per_primitives;
-   unsigned provoking_vtx_idx;
    unsigned max_es_num_vertices;
    unsigned position_store_base;
 
@@ -115,7 +114,6 @@ typedef struct
    unsigned lds_offs_primflags;
    bool found_out_vtxcnt[4];
    bool output_compile_time_known;
-   bool provoking_vertex_last;
    bool can_cull;
    bool streamout_enabled;
    gs_output_info output_info[VARYING_SLOT_MAX];
@@ -488,8 +486,15 @@ emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st)
        * corresponding to the ES thread of the provoking vertex.
        * It will be exported as a per-vertex attribute.
        */
+      nir_ssa_def *gs_vtx_indices[3];
+      for (unsigned i = 0; i < st->num_vertices_per_primitives; i++)
+         gs_vtx_indices[i] = nir_load_var(b, st->gs_vtx_indices_vars[i]);
+
+      nir_ssa_def *provoking_vertex = nir_load_provoking_vtx_in_prim_amd(b);
+      nir_ssa_def *provoking_vtx_idx = nir_select_from_ssa_def_array(
+         b, gs_vtx_indices, st->num_vertices_per_primitives, provoking_vertex);
+
       nir_ssa_def *prim_id = nir_load_primitive_id(b);
-      nir_ssa_def *provoking_vtx_idx = nir_load_var(b, st->gs_vtx_indices_vars[st->provoking_vtx_idx]);
       nir_ssa_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, st->pervertex_lds_bytes);
 
       /* primitive id is always at last of a vertex */
@@ -1783,7 +1788,6 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
                       bool early_prim_export,
                       bool passthrough,
                       bool export_prim_id,
-                      bool provoking_vtx_last,
                       bool use_edgeflags,
                       bool has_prim_query,
                       bool disable_streamout,
@@ -1817,7 +1821,6 @@ ac_nir_lower_ngg_nogs(nir_shader *shader,
       .has_prim_query = has_prim_query,
       .streamout_enabled = streamout_enabled,
       .num_vertices_per_primitives = num_vertices_per_primitives,
-      .provoking_vtx_idx = provoking_vtx_last ? (num_vertices_per_primitives - 1) : 0,
       .position_value_var = position_value_var,
       .prim_exp_arg_var = prim_exp_arg_var,
       .es_accepted_var = es_accepted_var,
@@ -2352,13 +2355,16 @@ ngg_gs_export_primitives(nir_builder *b, nir_ssa_def *max_num_out_prims, nir_ssa
        */
 
       nir_ssa_def *is_odd = nir_ubfe(b, primflag_0, nir_imm_int(b, 1), nir_imm_int(b, 1));
-      if (!s->provoking_vertex_last) {
-         vtx_indices[1] = nir_iadd(b, vtx_indices[1], is_odd);
-         vtx_indices[2] = nir_isub(b, vtx_indices[2], is_odd);
-      } else {
-         vtx_indices[0] = nir_iadd(b, vtx_indices[0], is_odd);
-         vtx_indices[1] = nir_isub(b, vtx_indices[1], is_odd);
-      }
+      nir_ssa_def *provoking_vertex_index = nir_load_provoking_vtx_in_prim_amd(b);
+      nir_ssa_def *provoking_vertex_first = nir_ieq_imm(b, provoking_vertex_index, 0);
+
+      vtx_indices[0] = nir_bcsel(b, provoking_vertex_first, vtx_indices[0],
+                                 nir_iadd(b, vtx_indices[0], is_odd));
+      vtx_indices[1] = nir_bcsel(b, provoking_vertex_first,
+                                 nir_iadd(b, vtx_indices[1], is_odd),
+                                 nir_isub(b, vtx_indices[1], is_odd));
+      vtx_indices[2] = nir_bcsel(b, provoking_vertex_first,
+                                 nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]);
    }
 
    nir_ssa_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices, is_null_prim, false);
@@ -2763,7 +2769,6 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned esgs_ring_lds_bytes,
                     unsigned gs_out_vtx_bytes,
                     unsigned gs_total_out_vtx_bytes,
-                    bool provoking_vertex_last,
                     bool can_cull,
                     bool disable_streamout)
 {
@@ -2778,7 +2783,6 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
       .lds_addr_gs_scratch = ALIGN(esgs_ring_lds_bytes + gs_total_out_vtx_bytes, 8u /* for the repacking code */),
       .lds_offs_primflags = gs_out_vtx_bytes,
       .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
-      .provoking_vertex_last = provoking_vertex_last,
       .can_cull = can_cull,
       .streamout_enabled = shader->xfb_info && !disable_streamout,
    };
index 58f5d1e..21f465d 100644 (file)
@@ -1383,7 +1383,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
                  max_vtx_in, num_vertices_per_prim,
                  info->workgroup_size, info->wave_size, info->has_ngg_culling,
                  info->has_ngg_early_prim_export, info->is_ngg_passthrough, export_prim_id,
-                 pl_key->vs.provoking_vtx_last, false, pl_key->primitives_generated_query,
+                 false, pl_key->primitives_generated_query,
                  true, pl_key->vs.instance_rate_inputs, 0, 0);
 
       /* Increase ESGS ring size so the LLVM binary contains the correct LDS size. */
@@ -1392,8 +1392,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
       assert(info->is_ngg);
       NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size,
                  info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
-                 info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last,
-                 false, true);
+                 info->ngg_info.ngg_emit_size * 4u, false, true);
    } else if (nir->info.stage == MESA_SHADER_MESH) {
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);