nir,ac/nir,aco,radv: replace has_input_*_amd with more general intrinsics
authorRhys Perry <pendingchaos02@gmail.com>
Tue, 18 Oct 2022 19:52:53 +0000 (20:52 +0100)
committerMarge Bot <emma+marge@anholt.net>
Mon, 31 Oct 2022 14:33:43 +0000 (14:33 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Qiang Yu <yuq825@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19228>

src/amd/common/ac_nir_lower_ngg.c
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/llvm/ac_nir_to_llvm.c
src/amd/vulkan/radv_nir_lower_abi.c
src/compiler/nir/nir_divergence_analysis.c
src/compiler/nir/nir_intrinsics.py

index 7782de9..424715b 100644 (file)
@@ -441,12 +441,24 @@ emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *st)
    }
 }
 
+static nir_ssa_def *
+has_input_vertex(nir_builder *b)
+{
+   return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b));
+}
+
+static nir_ssa_def *
+has_input_primitive(nir_builder *b)
+{
+   return nir_is_subgroup_invocation_lt_amd(b,
+                                            nir_ushr_imm(b, nir_load_merged_wave_info_amd(b), 8));
+}
+
 static void
 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def *arg)
 {
-   nir_ssa_def *gs_thread = st->gs_accepted_var
-                            ? nir_load_var(b, st->gs_accepted_var)
-                            : nir_has_input_primitive_amd(b);
+   nir_ssa_def *gs_thread =
+      st->gs_accepted_var ? nir_load_var(b, st->gs_accepted_var) : has_input_primitive(b);
 
    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
    {
@@ -506,8 +518,8 @@ emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, nir_ssa_def
 static void
 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st)
 {
-   nir_ssa_def *gs_thread = st->gs_accepted_var ?
-      nir_load_var(b, st->gs_accepted_var) : nir_has_input_primitive_amd(b);
+   nir_ssa_def *gs_thread =
+      st->gs_accepted_var ? nir_load_var(b, st->gs_accepted_var) : has_input_primitive(b);
 
    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
    {
@@ -986,8 +998,8 @@ compact_vertices_after_culling(nir_builder *b,
    nir_pop_if(b, if_gs_accepted);
 
    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
-   nir_store_var(b, gs_accepted_var,
-                 nir_iand(b, nir_inot(b, fully_culled), nir_has_input_primitive_amd(b)), 0x1u);
+   nir_store_var(b, gs_accepted_var, nir_iand(b, nir_inot(b, fully_culled), has_input_primitive(b)),
+                 0x1u);
 }
 
 static void
@@ -1359,7 +1371,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
 
    b->cursor = nir_before_cf_list(&impl->body);
 
-   nir_ssa_def *es_thread = nir_has_input_vertex_amd(b);
+   nir_ssa_def *es_thread = has_input_vertex(b);
    nir_if *if_es_thread = nir_push_if(b, es_thread);
    {
       /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
@@ -1392,7 +1404,8 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
    nir_pop_if(b, if_es_thread);
 
    nir_store_var(b, es_accepted_var, es_thread, 0x1u);
-   nir_store_var(b, gs_accepted_var, nir_has_input_primitive_amd(b), 0x1u);
+   nir_ssa_def *gs_thread = has_input_primitive(b);
+   nir_store_var(b, gs_accepted_var, gs_thread, 0x1u);
 
    /* Remove all non-position outputs, and put the position output into the variable. */
    nir_metadata_preserve(impl, nir_metadata_none);
@@ -1414,7 +1427,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
 
       /* ES invocations store their vertex data to LDS for GS threads to read. */
-      if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
+      if_es_thread = nir_push_if(b, es_thread);
       if_es_thread->control = nir_selection_control_divergent_always_taken;
       {
          /* Store position components that are relevant to culling in LDS */
@@ -1440,7 +1453,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
 
       /* GS invocations load the vertex data and perform the culling. */
-      nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
+      nir_if *if_gs_thread = nir_push_if(b, gs_thread);
       {
          /* Load vertex indices from input VGPRs */
          nir_ssa_def *vtx_idx[3] = {0};
@@ -1492,7 +1505,7 @@ add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_c
       nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
 
       /* ES invocations load their accepted flag from LDS. */
-      if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
+      if_es_thread = nir_push_if(b, es_thread);
       if_es_thread->control = nir_selection_control_divergent_always_taken;
       {
          nir_ssa_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
@@ -2021,7 +2034,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *option
 
    nir_intrinsic_instr *export_vertex_instr;
    nir_ssa_def *es_thread =
-      options->can_cull ? nir_load_var(b, es_accepted_var) : nir_has_input_vertex_amd(b);
+      options->can_cull ? nir_load_var(b, es_accepted_var) : has_input_vertex(b);
 
    nir_if *if_es_thread = nir_push_if(b, es_thread);
    {
@@ -2972,7 +2985,7 @@ ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
    state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
 
    /* Wrap the GS control flow. */
-   nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
+   nir_if *if_gs_thread = nir_push_if(b, has_input_primitive(b));
 
    nir_cf_reinsert(&extracted, b->cursor);
    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
index 5948aa0..91d8c5c 100644 (file)
@@ -8250,6 +8250,7 @@ emit_interp_center(isel_context* ctx, Temp dst, Temp bary, Temp pos1, Temp pos2)
 }
 
 Temp merged_wave_info_to_mask(isel_context* ctx, unsigned i);
+Temp lanecount_to_mask(isel_context* ctx, Temp count);
 void ngg_emit_sendmsg_gs_alloc_req(isel_context* ctx, Temp vtx_cnt, Temp prm_cnt);
 static void create_primitive_exports(isel_context *ctx, Temp prim_ch1);
 static void create_vs_exports(isel_context* ctx);
@@ -9140,11 +9141,9 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
       /* unused in the legacy pipeline, the HW keeps track of this for us */
       break;
    }
-   case nir_intrinsic_has_input_vertex_amd:
-   case nir_intrinsic_has_input_primitive_amd: {
-      assert(ctx->stage.hw == HWStage::NGG);
-      unsigned i = instr->intrinsic == nir_intrinsic_has_input_vertex_amd ? 0 : 1;
-      bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)), merged_wave_info_to_mask(ctx, i));
+   case nir_intrinsic_is_subgroup_invocation_lt_amd: {
+      Temp src = bld.as_uniform(get_ssa_temp(ctx, instr->src[0].ssa));
+      bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)), lanecount_to_mask(ctx, src));
       break;
    }
    case nir_intrinsic_export_vertex_amd: {
@@ -11777,7 +11776,7 @@ cleanup_cfg(Program* program)
 }
 
 Temp
-lanecount_to_mask(isel_context* ctx, Temp count, bool allow64 = true)
+lanecount_to_mask(isel_context* ctx, Temp count)
 {
    assert(count.regClass() == s1);
 
@@ -11786,10 +11785,6 @@ lanecount_to_mask(isel_context* ctx, Temp count, bool allow64 = true)
    Temp cond;
 
    if (ctx->program->wave_size == 64) {
-      /* If we know that all 64 threads can't be active at a time, we just use the mask as-is */
-      if (!allow64)
-         return mask;
-
       /* Special case for 64 active invocations, because 64 doesn't work with s_bfm */
       Temp active_64 = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), count,
                                 Operand::c32(6u /* log2(64) */));
index eda0d7f..6106462 100644 (file)
@@ -597,8 +597,6 @@ init_context(isel_context* ctx, nir_shader* shader)
                case nir_intrinsic_first_invocation:
                case nir_intrinsic_ballot:
                case nir_intrinsic_bindless_image_samples:
-               case nir_intrinsic_has_input_vertex_amd:
-               case nir_intrinsic_has_input_primitive_amd:
                case nir_intrinsic_load_force_vrs_rates_amd:
                case nir_intrinsic_load_scalar_arg_amd:
                case nir_intrinsic_load_smem_amd: type = RegType::sgpr; break;
index 05a1aee..6ad9388 100644 (file)
@@ -4283,16 +4283,10 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
       else
          result = ctx->ac.i32_0;
       break;
-   case nir_intrinsic_has_input_vertex_amd: {
-      LLVMValueRef num =
-         ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 0, 8);
-      result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), num, "");
-      break;
-   }
-   case nir_intrinsic_has_input_primitive_amd: {
-      LLVMValueRef num =
-         ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, ctx->args->merged_wave_info), 8, 8);
-      result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), num, "");
+   case nir_intrinsic_is_subgroup_invocation_lt_amd: {
+      LLVMValueRef count = LLVMBuildAnd(ctx->ac.builder, get_src(ctx, instr->src[0]),
+                                        LLVMConstInt(ctx->ac.i32, 0xff, 0), "");
+      result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
       break;
    }
    case nir_intrinsic_load_workgroup_num_input_vertices_amd:
index 4787967..5ae3e97 100644 (file)
@@ -208,6 +208,9 @@ lower_abi_instr(nir_builder *b, nir_instr *instr, void *state)
    case nir_intrinsic_load_prim_xfb_query_enabled_amd:
       replacement = ngg_query_bool_setting(b, radv_ngg_query_prim_xfb, s);
       break;
+   case nir_intrinsic_load_merged_wave_info_amd:
+      replacement = ac_nir_load_arg(b, &s->args->ac, s->args->ac.merged_wave_info);
+      break;
    case nir_intrinsic_load_cull_any_enabled_amd:
       replacement = nggc_bool_setting(
          b, radv_nggc_front_face | radv_nggc_back_face | radv_nggc_small_primitives, s);
index b916e63..df4f4e7 100644 (file)
@@ -171,6 +171,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr)
    case nir_intrinsic_load_pipeline_stat_query_enabled_amd:
    case nir_intrinsic_load_prim_gen_query_enabled_amd:
    case nir_intrinsic_load_prim_xfb_query_enabled_amd:
+   case nir_intrinsic_load_merged_wave_info_amd:
    case nir_intrinsic_load_cull_front_face_enabled_amd:
    case nir_intrinsic_load_cull_back_face_enabled_amd:
    case nir_intrinsic_load_cull_ccw_amd:
@@ -642,8 +643,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr *instr)
    case nir_intrinsic_load_tlb_color_v3d:
    case nir_intrinsic_load_tess_rel_patch_id_amd:
    case nir_intrinsic_load_gs_vertex_offset_amd:
-   case nir_intrinsic_has_input_vertex_amd:
-   case nir_intrinsic_has_input_primitive_amd:
+   case nir_intrinsic_is_subgroup_invocation_lt_amd:
    case nir_intrinsic_load_packed_passthrough_primitive_amd:
    case nir_intrinsic_load_initial_edgeflags_amd:
    case nir_intrinsic_gds_atomic_add_amd:
index 8cea094..4bc16b5 100644 (file)
@@ -1375,11 +1375,9 @@ system_value("streamout_offset_amd", 1, indices=[BASE])
 
 # AMD merged shader intrinsics
 
-# Whether the current invocation has an input vertex / primitive to process (also known as "ES thread" or "GS thread").
-# Not safe to reorder because it changes after overwrite_subgroup_num_vertices_and_primitives_amd.
-# Also, the generated code is more optimal if they are not CSE'd.
-intrinsic("has_input_vertex_amd", src_comp=[], dest_comp=1, bit_sizes=[1], indices=[])
-intrinsic("has_input_primitive_amd", src_comp=[], dest_comp=1, bit_sizes=[1], indices=[])
+# Whether the current invocation index in the subgroup is less than the source. The source must be
+# subgroup uniform and bits 0-7 must be less than or equal to the wave size.
+intrinsic("is_subgroup_invocation_lt_amd", src_comp=[1], dest_comp=1, bit_sizes=[1], flags=[CAN_ELIMINATE])
 
 # AMD NGG intrinsics
 
@@ -1395,6 +1393,9 @@ system_value("pipeline_stat_query_enabled_amd", dest_comp=1, bit_sizes=[1])
 system_value("prim_gen_query_enabled_amd", dest_comp=1, bit_sizes=[1])
 # Whether NGG should execute shader query for primitive streamouted.
 system_value("prim_xfb_query_enabled_amd", dest_comp=1, bit_sizes=[1])
+# Merged wave info. Bits 0-7 are the ES thread count, 8-15 are the GS thread count, 16-24 is the
+# GS Wave ID, 24-27 is the wave index in the workgroup, and 28-31 is the workgroup size in waves.
+system_value("merged_wave_info_amd", dest_comp=1)
 # Whether the shader should cull front facing triangles.
 intrinsic("load_cull_front_face_enabled_amd", dest_comp=1, bit_sizes=[1], flags=[CAN_ELIMINATE])
 # Whether the shader should cull back facing triangles.