From c1346e5c2202902a2359d928634e41dff4d2eb64 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Timur=20Krist=C3=B3f?= Date: Sat, 10 Apr 2021 14:52:55 +0200 Subject: [PATCH] aco: Optimize workgroup exclusive scan to better avoid bank conflicts. MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Previously, every wave had multiple active lanes read the LDS, and the data was processed by VALU DPP instructions. Now, only the first lane reads the LDS in order to avoid bank conflicts, and the results are processed by SALU. Signed-off-by: Timur Kristóf Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_instruction_selection.cpp | 70 +++++++++++++++----------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 3472aa3..75b9fff 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -11400,50 +11400,64 @@ std::pair ngg_gs_workgroup_reduce_and_scan(isel_context *ctx, Temp s Temp wave_id_in_tg_lds_addr = bld.vop2_e64(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand(2u), wave_id_in_tg); store_lds(ctx, 4u, as_vgpr(ctx, sg_reduction), 0x1u, wave_id_in_tg_lds_addr, ctx->ngg_gs_scratch_addr, 4u); - begin_divergent_if_else(ctx, &ic); - end_divergent_if(ctx, &ic); - bld.reset(ctx->block); - /* Wait for all waves to write to LDS. */ create_workgroup_barrier(bld); - /* Activate one lane per wave. */ - Temp wave_count = wave_count_in_threadgroup(ctx); - Temp wave_count_mask = lanecount_to_mask(ctx, wave_count, false); - begin_divergent_if_then(ctx, &ic, wave_count_mask); - bld.reset(ctx->block); - - /* Each lane loads the reduction result from the corresponding wave. */ - Temp thread_id_in_wave = emit_mbcnt(ctx, bld.tmp(v1)); - Temp loaded_wave_id_lds_addr = bld.v_mul24_imm(bld.def(v1), thread_id_in_wave, 4u); - Temp red_per_w = load_lds(ctx, 4u, bld.tmp(v1), loaded_wave_id_lds_addr, ctx->ngg_gs_scratch_addr, 4u); + /* Number of LDS dwords written by all waves (if there is only 1, that is already handled above) */ + unsigned num_lds_dwords = DIV_ROUND_UP(MIN2(ctx->program->workgroup_size, 256), ctx->program->wave_size); + assert(num_lds_dwords >= 2 && num_lds_dwords <= 8); - /* Inclusive scan on the per-wave reduction results, only care about the first 8 lanes. */ - Temp sgincl = bld.vop2_dpp(aco_opcode::v_add_u32, bld.def(v1), red_per_w, red_per_w, dpp_row_sr(1), 0b0001, 0b0111, true); - sgincl = bld.vop2_dpp(aco_opcode::v_add_u32, bld.def(v1), sgincl, sgincl, dpp_row_sr(2), 0x1, 0xf, true); - sgincl = bld.vop2_dpp(aco_opcode::v_add_u32, bld.def(v1), sgincl, sgincl, dpp_row_sr(4), 0x1, 0xf, true); + /* The first lane of each wave loads every wave's results from LDS, to avoid bank conflicts */ + Temp reduction_per_wave_vector = load_lds(ctx, 4u * num_lds_dwords, bld.tmp(RegClass(RegType::vgpr, num_lds_dwords)), + bld.copy(bld.def(v1), Operand(0u)), ctx->ngg_gs_scratch_addr, 4u); begin_divergent_if_else(ctx, &ic); end_divergent_if(ctx, &ic); + bld.reset(ctx->block); - /* Create phi which gets us the above reduction results, or undef. */ + /* Create phis which get us the above reduction results, or undef. */ bld.reset(&ctx->block->instructions, ctx->block->instructions.begin()); - sgincl = bld.pseudo(aco_opcode::p_phi, bld.def(sgincl.regClass()), sgincl, Operand(v1)); + reduction_per_wave_vector = bld.pseudo(aco_opcode::p_phi, bld.def(reduction_per_wave_vector.regClass()), reduction_per_wave_vector, Operand(reduction_per_wave_vector.regClass())); bld.reset(ctx->block); - /* Make it an exclusive scan by shifting the results right by one lane. */ - Temp per_wave_excl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), sgincl, dpp_row_sr(1), 0x1, 0xf, true); + emit_split_vector(ctx, reduction_per_wave_vector, num_lds_dwords); + Temp reduction_per_wave[8]; + + for (unsigned i = 0; i < num_lds_dwords; ++i) { + Temp reduction_current_wave = emit_extract_vector(ctx, reduction_per_wave_vector, i, v1); + reduction_per_wave[i] = bld.readlane(bld.def(s1), reduction_current_wave, Operand(0u)); + } + + Temp wave_count = wave_count_in_threadgroup(ctx); + Temp reduction_result = reduction_per_wave[0]; + Temp excl_base; - /* WG reduction result: the last lane of the above exclusive scan. */ - Temp wg_reduction = bld.readlane(bld.def(s1), per_wave_excl, wave_count); + for (unsigned i = 0; i < num_lds_dwords; ++i) { + /* Workgroup reduction: + * Add the reduction results from all waves (up to and including wave_count). + */ + if (i != 0) { + Temp should_add = bld.sopc(aco_opcode::s_cmp_ge_u32, bld.def(s1, scc), wave_count, Operand(i + 1u)); + Temp addition = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), reduction_per_wave[i], Operand(0u), bld.scc(should_add)); + reduction_result = bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.def(s1, scc), reduction_result, addition); + } + + /* Base of workgroup exclusive scan: + * Add the reduction results from waves up to and excluding wave_id_in_tg. + */ + if (i != (num_lds_dwords - 1)) { + Temp should_add = bld.sopc(aco_opcode::s_cmp_ge_u32, bld.def(s1, scc), wave_id_in_tg, Operand(i + 1u)); + Temp addition = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), reduction_per_wave[i], Operand(0u), bld.scc(should_add)); + excl_base = !excl_base.id() ? addition : bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.def(s1, scc), excl_base, addition); + } + } - /* Base of the exclusive WG scan: the above exclusive result corresponding to the current wave. */ - Temp wg_excl_base = bld.readlane(bld.def(s1), per_wave_excl, wave_id_in_tg); + assert(excl_base.id()); /* WG exclusive scan result: base + subgroup exclusive result. */ - Temp wg_excl = bld.vadd32(bld.def(v1), Operand(wg_excl_base), Operand(sg_excl)); + Temp wg_excl = bld.vadd32(bld.def(v1), Operand(excl_base), Operand(sg_excl)); - return std::make_pair(wg_reduction, wg_excl); + return std::make_pair(reduction_result, wg_excl); } void ngg_gs_clear_primflags(isel_context *ctx, Temp vtx_cnt, unsigned stream) -- 2.7.4