ac/nir/ngg: Refactor mesh shader primitive export.
authorTimur Kristóf <timur.kristof@gmail.com>
Tue, 22 Aug 2023 20:28:43 +0000 (22:28 +0200)
committerMarge Bot <emma+marge@anholt.net>
Sun, 3 Sep 2023 11:04:35 +0000 (11:04 +0000)
Cleanup the code that generates the two channels of the
primitive export instruction, and move storing the built-in
per-primitive outputs out to match how vertex attributes work.

Prepares the mesh shader lowering for a workaround that
affect export instructions.

Cc: mesa-stable
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24574>

src/amd/common/ac_nir_lower_ngg.c

index 3126487..6d59ad5 100644 (file)
     BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
     BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
 
+#define MS_PRIM_ARG_EXP_MASK \
+   (VARYING_BIT_LAYER | \
+    VARYING_BIT_VIEWPORT | \
+    VARYING_BIT_PRIMITIVE_SHADING_RATE)
+
 enum {
    nggc_passflag_used_by_pos = 1,
    nggc_passflag_used_by_other = 2,
@@ -4267,6 +4272,9 @@ static void
 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
                                      lower_ngg_ms_state *s)
 {
+   if (!outputs_mask)
+      return;
+
    nir_def *idx = nir_load_local_invocation_index(b);
    nir_def *ring = nir_load_ring_attr_amd(b);
    nir_def *off = nir_load_ring_attr_offset_amd(b);
@@ -4293,19 +4301,51 @@ ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask
    }
 }
 
-static void
-ms_emit_primitive_export(nir_builder *b,
-                         nir_def *prim_exp_arg_ch1,
-                         uint64_t per_primitive_outputs,
-                         lower_ngg_ms_state *s)
+static nir_def *
+ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
+{
+   /* Primitive connectivity data: describes which vertices the primitive uses. */
+   nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
+   nir_def *indices_loaded = NULL;
+   nir_def *cull_flag = NULL;
+
+   if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
+      nir_def *indices[3] = {0};
+      for (unsigned c = 0; c < s->vertices_per_prim; ++c)
+         indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
+      indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
+   } else {
+      indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
+      indices_loaded = nir_u2u32(b, indices_loaded);
+   }
+
+   if (s->uses_cull_flags) {
+      nir_def *loaded_cull_flag = NULL;
+      if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
+         loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
+      else
+         loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
+
+      cull_flag = nir_i2b(b, loaded_cull_flag);
+   }
+
+   nir_def *indices[3];
+   nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
+
+   for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
+      indices[i] = nir_channel(b, indices_loaded, i);
+      indices[i] = nir_umin(b, indices[i], max_vtx_idx);
+   }
+
+   return emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag);
+}
+
+static nir_def *
+ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
 {
    nir_def *prim_exp_arg_ch2 = NULL;
 
-   uint64_t export_as_prim_arg_slots =
-      VARYING_BIT_LAYER |
-      VARYING_BIT_VIEWPORT |
-      VARYING_BIT_PRIMITIVE_SHADING_RATE;
-   if (per_primitive_outputs & export_as_prim_arg_slots) {
+   if (outputs_mask) {
       /* When layer, viewport etc. are per-primitive, they need to be encoded in
        * the primitive export instruction's second channel. The encoding is:
        *
@@ -4322,29 +4362,37 @@ ms_emit_primitive_export(nir_builder *b,
        */
       prim_exp_arg_ch2 = nir_imm_int(b, 0);
 
-      if (per_primitive_outputs & VARYING_BIT_LAYER) {
+      if (outputs_mask & VARYING_BIT_LAYER) {
          nir_def *layer =
             nir_ishl_imm(b, s->outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17);
          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
       }
 
-      if (per_primitive_outputs & VARYING_BIT_VIEWPORT) {
+      if (outputs_mask & VARYING_BIT_VIEWPORT) {
          nir_def *view = nir_ishl_imm(b, s->outputs[VARYING_SLOT_VIEWPORT][0], 20);
          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
       }
 
-      if (per_primitive_outputs & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
+      if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
          nir_def *rate = s->outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
       }
-
-      /* GFX11: also store these to the attribute ring so PS can load them. */
-      if (s->gfx_level >= GFX11) {
-         ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & export_as_prim_arg_slots,
-                                              s);
-      }
    }
 
+   return prim_exp_arg_ch2;
+}
+
+static void
+ms_emit_primitive_export(nir_builder *b,
+                         nir_def *invocation_index,
+                         nir_def *num_vtx,
+                         uint64_t per_primitive_outputs,
+                         lower_ngg_ms_state *s)
+{
+   const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
+   nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, invocation_index, num_vtx, s);
+   nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
+
    nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
       nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
 
@@ -4400,7 +4448,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
                                           VARYING_BIT_CLIP_DIST0 | VARYING_BIT_CLIP_DIST1 |
                                           VARYING_BIT_PSIZ;
 
-      /* GFX11: also store special outputs to the attribute ring so PS can load them. */
+      /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
       if (s->gfx_level >= GFX11 && (per_vertex_outputs & per_vertex_special)) {
          ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & per_vertex_special, s);
       }
@@ -4423,44 +4471,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
          per_primitive_outputs |= VARYING_BIT_LAYER;
       }
 
-      /* Primitive connectivity data: describes which vertices the primitive uses. */
-      nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
-      nir_def *indices_loaded = NULL;
-      nir_def *cull_flag = NULL;
-
-      if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
-         nir_def *indices[3] = {0};
-         for (unsigned c = 0; c < s->vertices_per_prim; ++c)
-            indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
-         indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
-      }
-      else {
-         indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
-         indices_loaded = nir_u2u32(b, indices_loaded);
-      }
-
-      if (s->uses_cull_flags) {
-         nir_def *loaded_cull_flag = NULL;
-         if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
-            loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
-         else
-            loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
-
-         cull_flag = nir_i2b(b, loaded_cull_flag);
-      }
-
-      nir_def *indices[3];
-      nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
-
-      for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
-         indices[i] = nir_channel(b, indices_loaded, i);
-         indices[i] = nir_umin(b, indices[i], max_vtx_idx);
-      }
-
-      nir_def *prim_exp_arg = emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices,
-                                                             cull_flag);
-
-      ms_emit_primitive_export(b, prim_exp_arg, per_primitive_outputs, s);
+      ms_emit_primitive_export(b, invocation_index, num_vtx, per_primitive_outputs, s);
 
       /* Export generic attributes on GFX10.3
        * (On GFX11 they are already stored in the attribute ring.)
@@ -4469,6 +4480,11 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
          ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0,
                                   s->outputs, NULL, NULL);
       }
+
+      /* GFX11+: also store special per-primitive outputs to the attribute ring so PS can load them. */
+      if (s->gfx_level >= GFX11) {
+         ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, s);
+      }
    }
    nir_pop_if(b, if_has_output_primitive);
 }