ac/nir: lower gfx11 vertex parameter exports
authorRhys Perry <pendingchaos02@gmail.com>
Tue, 11 Oct 2022 13:00:14 +0000 (14:00 +0100)
committerMarge Bot <emma+marge@anholt.net>
Mon, 31 Oct 2022 14:33:43 +0000 (14:33 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19228>

src/amd/common/ac_nir.h
src/amd/common/ac_nir_lower_ngg.c
src/amd/vulkan/radv_nir_lower_abi.c
src/amd/vulkan/radv_nir_to_llvm.c
src/amd/vulkan/radv_shader.c

index b7e5e45..ee8d025 100644 (file)
@@ -125,6 +125,7 @@ typedef struct {
 
    unsigned max_workgroup_size;
    unsigned wave_size;
+   const uint8_t *vs_output_param_offset; /* GFX11+ */
    bool can_cull;
    bool disable_streamout;
 
index 424715b..4d10ff6 100644 (file)
@@ -23,6 +23,7 @@
  */
 
 #include "ac_nir.h"
+#include "amdgfxregs.h"
 #include "nir_builder.h"
 #include "nir_xfb_info.h"
 #include "u_math.h"
     BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
     BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
 
+#define POS_EXPORT_MASK \
+   (VARYING_BIT_POS | VARYING_BIT_PSIZ | VARYING_BIT_LAYER | VARYING_BIT_VIEWPORT | \
+    VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_CLIP_DIST0 | VARYING_BIT_CLIP_DIST1 | \
+    VARYING_BIT_EDGE | VARYING_BIT_CLIP_VERTEX)
+
 enum {
    nggc_passflag_used_by_pos = 1,
    nggc_passflag_used_by_other = 2,
@@ -47,12 +53,19 @@ typedef struct
 
 typedef struct
 {
+   gl_varying_slot slot;
+   nir_ssa_def *chan[4];
+} vs_output;
+
+typedef struct
+{
    const ac_nir_lower_ngg_options *options;
 
    nir_variable *position_value_var;
    nir_variable *prim_exp_arg_var;
    nir_variable *es_accepted_var;
    nir_variable *gs_accepted_var;
+   nir_variable *num_es_threads_var;
    nir_variable *gs_vtx_indices_vars[3];
 
    nir_ssa_def *vtx_addr[3];
@@ -1523,6 +1536,18 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       nir_ssa_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
       nir_ssa_def *es_exporter_tid = rep.repacked_invocation_index;
 
+      if (nogs_state->num_es_threads_var) {
+         nir_ssa_def *num_live_vertices_in_wave = num_live_vertices_in_workgroup;
+         if (nogs_state->max_num_waves > 1) {
+            num_live_vertices_in_wave =
+               nir_isub(b, num_live_vertices_in_wave,
+                        nir_imul_imm(b, nir_load_subgroup_id(b), nogs_state->options->wave_size));
+            num_live_vertices_in_wave = nir_umin(b, num_live_vertices_in_wave,
+                                                 nir_imm_int(b, nogs_state->options->wave_size));
+         }
+         nir_store_var(b, nogs_state->num_es_threads_var, num_live_vertices_in_wave, 0x1);
+      }
+
       /* If all vertices are culled, set primitive count to 0 as well. */
       nir_ssa_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
       nir_ssa_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
@@ -1553,6 +1578,9 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       }
       nir_pop_if(b, if_wave_0);
       nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, nogs_state), 0x1u);
+
+      if (nogs_state->num_es_threads_var)
+         nir_store_var(b, nogs_state->num_es_threads_var, nir_load_merged_wave_info_amd(b), 0x1);
    }
    nir_pop_if(b, if_cull_en);
 
@@ -1913,6 +1941,147 @@ ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
    return pervertex_lds_bytes;
 }
 
+static unsigned
+gather_vs_outputs(nir_builder *b, struct exec_list *cf_list, vs_output *outputs,
+                  const uint8_t *vs_output_param_offset)
+{
+   uint64_t output_mask32 = 0;
+   nir_ssa_def *outputs32[64][4] = {0};
+
+   uint64_t output_mask_lo = 0;
+   uint64_t output_mask_hi = 0;
+   nir_ssa_def *outputs_lo[64][4];
+   nir_ssa_def *outputs_hi[64][4];
+
+   /* Assume:
+    * - the shader used nir_lower_io_to_temporaries
+    * - 64-bit outputs are lowered
+    * - no indirect indexing is present
+    */
+   struct nir_cf_node *first_node = exec_node_data(nir_cf_node, exec_list_get_head(cf_list), node);
+   for (nir_block *block = nir_cf_node_cf_tree_first(first_node); block != NULL;
+        block = nir_block_cf_tree_next(block)) {
+      nir_foreach_instr_safe (instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+
+         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+         if (intrin->intrinsic != nir_intrinsic_store_output)
+            continue;
+
+         assert(nir_src_is_const(intrin->src[1]) && !nir_src_as_uint(intrin->src[1]));
+
+         unsigned slot = nir_intrinsic_io_semantics(intrin).location;
+         if (vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
+            continue;
+
+         bool is_hi = nir_intrinsic_io_semantics(intrin).high_16bits;
+         bool is_16bit = slot >= VARYING_SLOT_VAR0_16BIT;
+
+         u_foreach_bit (i, nir_intrinsic_write_mask(intrin)) {
+            unsigned comp = nir_intrinsic_component(intrin) + i;
+            nir_ssa_def *chan = nir_channel(b, intrin->src[0].ssa, i);
+            if (is_16bit && is_hi)
+               outputs_hi[slot - VARYING_SLOT_VAR0_16BIT][comp] = chan;
+            else if (is_16bit)
+               outputs_lo[slot - VARYING_SLOT_VAR0_16BIT][comp] = chan;
+            else
+               outputs32[slot][comp] = chan;
+         }
+
+         if (is_16bit && is_hi)
+            output_mask_hi |= BITFIELD64_BIT(slot - VARYING_SLOT_VAR0_16BIT);
+         else if (is_16bit)
+            output_mask_lo |= BITFIELD64_BIT(slot - VARYING_SLOT_VAR0_16BIT);
+         else
+            output_mask32 |= BITFIELD64_BIT(slot);
+
+         if (slot >= VARYING_SLOT_VAR0 || !(BITFIELD64_BIT(slot) & POS_EXPORT_MASK))
+            nir_instr_remove(&intrin->instr);
+      }
+   }
+
+   unsigned num_outputs = 0;
+   u_foreach_bit64 (i, output_mask32) {
+      outputs[num_outputs].slot = i;
+      for (unsigned j = 0; j < 4; j++) {
+         nir_ssa_def *chan = outputs32[i][j];
+         /* RADV implements 16-bit outputs as 32-bit with VARYING_SLOT_VAR0-31. */
+         outputs[num_outputs].chan[j] = chan && chan->bit_size == 16 ? nir_u2u32(b, chan) : chan;
+      }
+      num_outputs++;
+   }
+
+   if (output_mask_lo | output_mask_hi) {
+      nir_ssa_def *undef = nir_ssa_undef(b, 1, 16);
+      u_foreach_bit64 (i, output_mask_lo | output_mask_hi) {
+         vs_output *output = &outputs[num_outputs++];
+
+         output->slot = i + VARYING_SLOT_VAR0_16BIT;
+         for (unsigned j = 0; j < 4; j++) {
+            nir_ssa_def *lo = output_mask_lo & BITFIELD64_BIT(i) ? outputs_lo[i][j] : NULL;
+            nir_ssa_def *hi = output_mask_hi & BITFIELD64_BIT(i) ? outputs_hi[i][j] : NULL;
+            if (lo || hi)
+               output->chan[j] = nir_pack_32_2x16_split(b, lo ? lo : undef, hi ? hi : undef);
+            else
+               output->chan[j] = NULL;
+         }
+      }
+   }
+
+   return num_outputs;
+}
+
+static void
+create_vertex_param_phis(nir_builder *b, unsigned num_outputs, vs_output *outputs)
+{
+   nir_ssa_def *undef = nir_ssa_undef(b, 1, 32); /* inserted at the start of the shader */
+
+   for (unsigned i = 0; i < num_outputs; i++) {
+      for (unsigned j = 0; j < 4; j++) {
+         if (outputs[i].chan[j])
+            outputs[i].chan[j] = nir_if_phi(b, outputs[i].chan[j], undef);
+      }
+   }
+}
+
+static void
+export_vertex_params_gfx11(nir_builder *b, nir_ssa_def *export_tid, nir_ssa_def *num_export_threads,
+                           unsigned num_outputs, vs_output *outputs,
+                           const uint8_t *vs_output_param_offset)
+{
+   nir_ssa_def *attr_rsrc = nir_load_ring_attr_amd(b);
+
+   /* We should always store full vec4s in groups of 8 lanes for the best performance even if
+    * some of them are garbage or have unused components, so align the number of export threads
+    * to 8.
+    */
+   num_export_threads = nir_iand_imm(b, nir_iadd_imm(b, num_export_threads, 7), ~7);
+   if (!export_tid)
+      nir_push_if(b, nir_is_subgroup_invocation_lt_amd(b, num_export_threads));
+   else
+      nir_push_if(b, nir_ult(b, export_tid, num_export_threads));
+
+   nir_ssa_def *attr_offset = nir_load_ring_attr_offset_amd(b);
+   nir_ssa_def *vindex = nir_load_local_invocation_index(b);
+   nir_ssa_def *voffset = nir_imm_int(b, 0);
+   nir_ssa_def *undef = nir_ssa_undef(b, 1, 32);
+
+   for (unsigned i = 0; i < num_outputs; i++) {
+      gl_varying_slot slot = outputs[i].slot;
+      nir_ssa_def *soffset = nir_iadd_imm(b, attr_offset, vs_output_param_offset[slot] * 16 * 32);
+
+      nir_ssa_def *comp[4];
+      for (unsigned j = 0; j < 4; j++)
+         comp[j] = outputs[i].chan[j] ? outputs[i].chan[j] : undef;
+      nir_store_buffer_amd(b, nir_vec(b, comp, 4), attr_rsrc, voffset, soffset, vindex,
+                           .is_swizzled = true, .memory_modes = nir_var_shader_out,
+                           .access = ACCESS_COHERENT);
+   }
+
+   nir_pop_if(b, NULL);
+}
+
 void
 ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
 {
@@ -1927,6 +2096,10 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *option
       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
    nir_variable *gs_accepted_var =
       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
+   nir_variable *num_es_threads_var =
+      options->can_cull && options->gfx_level >= GFX11
+         ? nir_local_variable_create(impl, glsl_uint_type(), "num_es_threads")
+         : NULL;
 
    bool streamout_enabled = shader->xfb_info && !options->disable_streamout;
    bool has_user_edgeflags =
@@ -1950,6 +2123,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *option
       .prim_exp_arg_var = prim_exp_arg_var,
       .es_accepted_var = es_accepted_var,
       .gs_accepted_var = gs_accepted_var,
+      .num_es_threads_var = num_es_threads_var,
       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
       .has_user_edgeflags = has_user_edgeflags,
    };
@@ -2068,6 +2242,27 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *option
       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
    }
 
+   /* Export varyings for GFX11+ */
+   if (state.options->gfx_level >= GFX11) {
+      vs_output outputs[64];
+
+      b->cursor = nir_after_cf_list(&if_es_thread->then_list);
+      unsigned num_outputs =
+         gather_vs_outputs(b, &if_es_thread->then_list, outputs, options->vs_output_param_offset);
+
+      if (num_outputs) {
+         b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
+         create_vertex_param_phis(b, num_outputs, outputs);
+
+         b->cursor = nir_after_cf_list(&impl->body);
+
+         nir_ssa_def *num_threads = options->can_cull ? nir_load_var(b, num_es_threads_var)
+                                                      : nir_load_merged_wave_info_amd(b);
+         export_vertex_params_gfx11(b, NULL, num_threads, num_outputs, outputs,
+                                    options->vs_output_param_offset);
+      }
+   }
+
    if (options->can_cull) {
       /* Replace uniforms. */
       apply_reusable_variables(b, &state);
@@ -2552,6 +2747,9 @@ ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def
       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
    }
 
+   unsigned num_outputs = 0;
+   vs_output outputs[64];
+
    for (unsigned slot = 0; slot < VARYING_SLOT_MAX; ++slot) {
       if (!(b->shader->info.outputs_written & BITFIELD64_BIT(slot)))
          continue;
@@ -2576,6 +2774,16 @@ ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def
          .no_sysval_output = info->no_sysval_output,
       };
 
+      bool is_pos = BITFIELD64_BIT(slot) & POS_EXPORT_MASK;
+
+      vs_output *output = NULL;
+      if (s->options->gfx_level >= GFX11 &&
+          s->options->vs_output_param_offset[slot] <= AC_EXP_PARAM_OFFSET_31) {
+         output = &outputs[num_outputs++];
+         output->slot = slot;
+         memset(output->chan, 0, sizeof(output->chan));
+      }
+
       while (mask) {
          int start, count;
          u_bit_scan_consecutive_range(&mask, &start, &count);
@@ -2595,15 +2803,25 @@ ngg_gs_export_vertices(nir_builder *b, nir_ssa_def *max_num_out_vtx, nir_ssa_def
             if (bit_size != 32)
                val = nir_u2u(b, val, bit_size);
 
-            nir_store_output(b, val, nir_imm_int(b, 0), .base = info->base,
-                             .io_semantics = io_sem, .component = start + i,
-                             .write_mask = 1);
+            if (s->options->gfx_level < GFX11 || is_pos) {
+               nir_store_output(b, val, nir_imm_int(b, 0), .base = info->base,
+                                .io_semantics = io_sem, .component = start + i, .write_mask = 1);
+            }
+            if (output)
+               output->chan[start + i] = val;
          }
       }
    }
 
    nir_export_vertex_amd(b);
    nir_pop_if(b, if_vtx_export_thread);
+
+   if (num_outputs) {
+      create_vertex_param_phis(b, num_outputs, outputs);
+
+      export_vertex_params_gfx11(b, tid_in_tg, max_num_out_vtx, num_outputs, outputs,
+                                 s->options->vs_output_param_offset);
+   }
 }
 
 static void
index 5ae3e97..215795f 100644 (file)
@@ -143,6 +143,10 @@ lower_abi_instr(nir_builder *b, nir_instr *instr, void *state)
       }
 
       replacement = load_ring(b, RING_PS_ATTR, s);
+
+      nir_ssa_def *dword1 = nir_channel(b, replacement, 1);
+      dword1 = nir_ior_imm(b, dword1, S_008F04_STRIDE(16 * s->info->outinfo.param_exports));
+      replacement = nir_vector_insert_imm(b, replacement, dword1, 1);
       break;
 
    case nir_intrinsic_load_ring_attr_offset_amd: {
index 1433e16..1ca2a31 100644 (file)
@@ -1279,6 +1279,11 @@ ac_setup_rings(struct radv_shader_context *ctx)
         (ctx->stage == MESA_SHADER_GEOMETRY))) {
       ctx->attr_ring = ac_build_load_to_sgpr(&ctx->ac, ring_offsets,
                                              LLVMConstInt(ctx->ac.i32, RING_PS_ATTR, false));
+
+      LLVMValueRef tmp = LLVMBuildExtractElement(ctx->ac.builder, ctx->attr_ring, ctx->ac.i32_1, "");
+      uint32_t stride = S_008F04_STRIDE(16 * ctx->shader_info->outinfo.param_exports);
+      tmp = LLVMBuildOr(ctx->ac.builder, tmp, LLVMConstInt(ctx->ac.i32, stride, false), "");
+      ctx->attr_ring = LLVMBuildInsertElement(ctx->ac.builder, ctx->attr_ring, tmp, ctx->ac.i32_1, "");
    }
 }
 
index 4991703..f83ea05 100644 (file)
@@ -1435,6 +1435,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
    options.gfx_level = device->physical_device->rad_info.gfx_level;
    options.max_workgroup_size = info->workgroup_size;
    options.wave_size = info->wave_size;
+   options.vs_output_param_offset = info->outinfo.vs_output_param_offset;
    options.can_cull = nir->info.stage != MESA_SHADER_GEOMETRY && info->has_ngg_culling;
    options.disable_streamout = !device->physical_device->use_ngg_streamout;