From dae1629778de81ecb24f3790f8404dd2c24dd338 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Fri, 13 May 2022 12:01:03 +0100 Subject: [PATCH] aco: disable sdwa on gfx11 MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Instead of SDWA v_mov_b32/v_xor_b32, we can use a combination of v_add_u16/v_sub_u16 (add/sub swap, similar to xor swap) and v_perm_b32 with a literal. I don't know yet if GFX11 adds any new instructions which makes this easier, but this approach should have full functionality. Signed-off-by: Rhys Perry Reviewed-by: Timur Kristóf Part-of: --- src/amd/compiler/aco_assembler.cpp | 1 + src/amd/compiler/aco_builder_h.py | 9 ++ src/amd/compiler/aco_instruction_selection.cpp | 2 +- src/amd/compiler/aco_ir.cpp | 2 +- src/amd/compiler/aco_lower_to_hw_instr.cpp | 162 ++++++++++++++++++++++--- src/amd/compiler/aco_validate.cpp | 3 +- src/amd/compiler/tests/test_sdwa.cpp | 2 +- 7 files changed, 157 insertions(+), 24 deletions(-) diff --git a/src/amd/compiler/aco_assembler.cpp b/src/amd/compiler/aco_assembler.cpp index a370d85..05754540 100644 --- a/src/amd/compiler/aco_assembler.cpp +++ b/src/amd/compiler/aco_assembler.cpp @@ -705,6 +705,7 @@ emit_instruction(asm_context& ctx, std::vector& out, Instruction* inst out.push_back(encoding); return; } else if (instr->isSDWA()) { + assert(ctx.gfx_level >= GFX8 && ctx.gfx_level < GFX11); SDWA_instruction& sdwa = instr->sdwa(); /* first emit the instruction without the SDWA operand */ diff --git a/src/amd/compiler/aco_builder_h.py b/src/amd/compiler/aco_builder_h.py index 5801d3c..8177c36 100644 --- a/src/amd/compiler/aco_builder_h.py +++ b/src/amd/compiler/aco_builder_h.py @@ -110,6 +110,15 @@ sendmsg_gs_done(bool cut, bool emit, unsigned stream) return (sendmsg)((unsigned)_sendmsg_gs_done | (cut << 4) | (emit << 5) | (stream << 8)); } +enum bperm_swiz { + bperm_b1_sign = 8, + bperm_b3_sign = 9, + bperm_b5_sign = 10, + bperm_b7_sign = 11, + bperm_0 = 12, + bperm_255 = 13, +}; + class Builder { public: struct Result { diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 51285cf..63ac939 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -11711,7 +11711,7 @@ calc_nontrivial_instance_id(Builder& bld, const struct radv_shader_args* args, u Operand fetch_index(tmp_vgpr0, v1); Operand div_info(tmp_sgpr, s1); - if (bld.program->gfx_level >= GFX8) { + if (bld.program->gfx_level >= GFX8 && bld.program->gfx_level < GFX11) { /* use SDWA */ if (bld.program->gfx_level < GFX9) { bld.vop1(aco_opcode::v_mov_b32, Definition(tmp_vgpr1, v1), div_info); diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp index 0e0f76c..88b028d 100644 --- a/src/amd/compiler/aco_ir.cpp +++ b/src/amd/compiler/aco_ir.cpp @@ -193,7 +193,7 @@ can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr& instr, bool pr if (!instr->isVALU()) return false; - if (gfx_level < GFX8 || instr->isDPP() || instr->isVOP3P()) + if (gfx_level < GFX8 || gfx_level >= GFX11 || instr->isDPP() || instr->isVOP3P()) return false; if (instr->isSDWA()) diff --git a/src/amd/compiler/aco_lower_to_hw_instr.cpp b/src/amd/compiler/aco_lower_to_hw_instr.cpp index 9c76913..2dc32f6 100644 --- a/src/amd/compiler/aco_lower_to_hw_instr.cpp +++ b/src/amd/compiler/aco_lower_to_hw_instr.cpp @@ -511,7 +511,7 @@ emit_reduction(lower_context* ctx, aco_opcode op, ReduceOp reduce_op, unsigned c } if (src.regClass() == v1b) { - if (ctx->program->gfx_level >= GFX8) { + if (ctx->program->gfx_level >= GFX8 && ctx->program->gfx_level < GFX11) { aco_ptr sdwa{create_instruction( aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; sdwa->operands[0] = Operand(PhysReg{tmp}, v1); @@ -532,9 +532,9 @@ emit_reduction(lower_context* ctx, aco_opcode op, ReduceOp reduce_op, unsigned c Operand::c32(8u)); } } else if (src.regClass() == v2b) { - if (ctx->program->gfx_level >= GFX10 && - (reduce_op == iadd16 || reduce_op == imax16 || reduce_op == imin16 || - reduce_op == umin16 || reduce_op == umax16)) { + bool is_add_cmp = reduce_op == iadd16 || reduce_op == imax16 || reduce_op == imin16 || + reduce_op == umin16 || reduce_op == umax16; + if (ctx->program->gfx_level >= GFX10 && ctx->program->gfx_level < GFX11 && is_add_cmp) { aco_ptr sdwa{create_instruction( aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; sdwa->operands[0] = Operand(PhysReg{tmp}, v1); @@ -543,7 +543,8 @@ emit_reduction(lower_context* ctx, aco_opcode op, ReduceOp reduce_op, unsigned c sdwa->sel[0] = SubdwordSel(2, 0, sext); sdwa->dst_sel = SubdwordSel::dword; bld.insert(std::move(sdwa)); - } else if (ctx->program->gfx_level == GFX6 || ctx->program->gfx_level == GFX7) { + } else if (ctx->program->gfx_level <= GFX7 || + (ctx->program->gfx_level >= GFX11 && is_add_cmp)) { aco_opcode opcode; if (reduce_op == imin16 || reduce_op == imax16 || reduce_op == iadd16) @@ -1019,6 +1020,23 @@ get_intersection_mask(int a_start, int a_size, int b_start, int b_size) } void +create_bperm(Builder& bld, uint8_t swiz[4], Definition dst, Operand src0, + Operand src1 = Operand(v1)) +{ + uint32_t swiz_packed = + swiz[0] | ((uint32_t)swiz[1] << 8) | ((uint32_t)swiz[2] << 16) | ((uint32_t)swiz[3] << 24); + + dst = Definition(PhysReg(dst.physReg().reg()), v1); + if (!src0.isConstant()) + src0 = Operand(PhysReg(src0.physReg().reg()), v1); + if (src1.isUndefined()) + src1 = Operand(dst.physReg(), v1); + else if (!src1.isConstant()) + src1 = Operand(PhysReg(src1.physReg().reg()), v1); + bld.vop3(aco_opcode::v_perm_b32, dst, src0, src1, Operand::c32(swiz_packed)); +} + +void copy_constant(lower_context* ctx, Builder& bld, Definition dst, Operand op) { assert(op.bytes() == dst.bytes()); @@ -1066,7 +1084,8 @@ copy_constant(lower_context* ctx, Builder& bld, Definition dst, Operand op) } else { assert(dst.regClass() == v1b || dst.regClass() == v2b); - if (dst.regClass() == v1b && ctx->program->gfx_level >= GFX9) { + bool use_sdwa = ctx->program->gfx_level >= GFX9 && ctx->program->gfx_level < GFX11; + if (dst.regClass() == v1b && use_sdwa) { uint8_t val = op.constantValue(); Operand op32 = Operand::c32((uint32_t)val | (val & 0x80u ? 0xffffff00u : 0u)); if (op32.isLiteral()) { @@ -1078,7 +1097,7 @@ copy_constant(lower_context* ctx, Builder& bld, Definition dst, Operand op) } else { bld.vop1_sdwa(aco_opcode::v_mov_b32, dst, op32); } - } else if (dst.regClass() == v2b && ctx->program->gfx_level >= GFX9 && !op.isLiteral()) { + } else if (dst.regClass() == v2b && use_sdwa && !op.isLiteral()) { if (op.constantValue() >= 0xfff0 || op.constantValue() <= 64) { /* use v_mov_b32 to avoid possible issues with denormal flushing or * NaN. v_add_f16 is still needed for float constants. */ @@ -1161,6 +1180,19 @@ swap_linear_vgpr(Builder& bld, Definition def, Operand op, bool preserve_scc, Ph Operand::zero()); } +void +addsub_subdword_gfx11(Builder& bld, Definition dst, Operand src0, Operand src1, bool sub) +{ + Instruction* instr = + bld.vop3(sub ? aco_opcode::v_sub_u16_e64 : aco_opcode::v_add_u16_e64, dst, src0, src1).instr; + if (src0.physReg().byte() == 2) + instr->vop3().opsel |= 0x1; + if (src1.physReg().byte() == 2) + instr->vop3().opsel |= 0x2; + if (dst.physReg().byte() == 2) + instr->vop3().opsel |= 0x8; +} + bool do_copy(lower_context* ctx, Builder& bld, const copy_operation& copy, bool* preserve_scc, PhysReg scratch_sgpr) @@ -1228,6 +1260,12 @@ do_copy(lower_context* ctx, Builder& bld, const copy_operation& copy, bool* pres } else { bld.vop1(aco_opcode::v_mov_b32, def, op); } + } else if (def.regClass() == v1b && ctx->program->gfx_level >= GFX11) { + uint8_t swiz[] = {4, 5, 6, 7}; + swiz[def.physReg().byte()] = op.physReg().byte(); + create_bperm(bld, swiz, def, op); + } else if (def.regClass() == v2b && ctx->program->gfx_level >= GFX11) { + addsub_subdword_gfx11(bld, def, op, Operand::zero(), false); } else if (def.regClass().is_subdword()) { bld.vop1_sdwa(aco_opcode::v_mov_b32, def, op); } else { @@ -1241,6 +1279,40 @@ do_copy(lower_context* ctx, Builder& bld, const copy_operation& copy, bool* pres } void +swap_subdword_gfx11(Builder& bld, Definition def, Operand op) +{ + if (def.physReg().reg() == op.physReg().reg()) { + assert(def.bytes() != 2); /* handled by caller */ + uint8_t swiz[] = {4, 5, 6, 7}; + std::swap(swiz[def.physReg().byte()], swiz[op.physReg().byte()]); + create_bperm(bld, swiz, def, Operand::zero()); + return; + } + + if (def.bytes() == 2) { + Operand def_as_op = Operand(def.physReg(), def.regClass()); + Definition op_as_def = Definition(op.physReg(), op.regClass()); + addsub_subdword_gfx11(bld, def, def_as_op, op, false); + addsub_subdword_gfx11(bld, op_as_def, def_as_op, op, true); + addsub_subdword_gfx11(bld, def, def_as_op, op, true); + } else { + PhysReg op_half = op.physReg(); + op_half.reg_b &= ~1; + + PhysReg def_other_half = def.physReg(); + def_other_half.reg_b &= ~1; + def_other_half.reg_b ^= 2; + + /* We can only swap individual bytes within a single VGPR, so temporarily move both bytes + * into the same VGPR. + */ + swap_subdword_gfx11(bld, Definition(def_other_half, v2b), Operand(op_half, v2b)); + swap_subdword_gfx11(bld, def, Operand(def_other_half.advance(op.physReg().byte() & 1), v1b)); + swap_subdword_gfx11(bld, Definition(def_other_half, v2b), Operand(op_half, v2b)); + } +} + +void do_swap(lower_context* ctx, Builder& bld, const copy_operation& copy, bool preserve_scc, Pseudo_instruction* pi) { @@ -1325,9 +1397,13 @@ do_swap(lower_context* ctx, Builder& bld, const copy_operation& copy, bool prese Operand::c32(2u)); } else { assert(def.regClass().is_subdword()); - bld.vop2_sdwa(aco_opcode::v_xor_b32, op_as_def, op, def_as_op); - bld.vop2_sdwa(aco_opcode::v_xor_b32, def, op, def_as_op); - bld.vop2_sdwa(aco_opcode::v_xor_b32, op_as_def, op, def_as_op); + if (ctx->program->gfx_level >= GFX11) { + swap_subdword_gfx11(bld, def, op); + } else { + bld.vop2_sdwa(aco_opcode::v_xor_b32, op_as_def, op, def_as_op); + bld.vop2_sdwa(aco_opcode::v_xor_b32, def, op, def_as_op); + bld.vop2_sdwa(aco_opcode::v_xor_b32, op_as_def, op, def_as_op); + } } offset += def.bytes(); @@ -1415,8 +1491,14 @@ do_pack_2x16(lower_context* ctx, Builder& bld, Definition def, Operand lo, Opera op.setFixed(reg); } - if (ctx->program->gfx_level >= GFX8) { - /* either hi or lo are already placed correctly */ + /* either hi or lo are already placed correctly */ + if (ctx->program->gfx_level >= GFX11) { + if (lo.physReg().reg() == def.physReg().reg()) + addsub_subdword_gfx11(bld, def_hi, hi, Operand::zero(), false); + else + addsub_subdword_gfx11(bld, def_lo, lo, Operand::zero(), false); + return; + } else if (ctx->program->gfx_level >= GFX8) { if (lo.physReg().reg() == def.physReg().reg()) bld.vop1_sdwa(aco_opcode::v_mov_b32, def_hi, hi); else @@ -2142,9 +2224,41 @@ lower_to_hw_instr(Program* program) } else { assert(dst.regClass() == v2b || dst.regClass() == v1b || op.regClass() == v2b || op.regClass() == v1b); - SDWA_instruction& sdwa = - bld.vop1_sdwa(aco_opcode::v_mov_b32, dst, op).instr->sdwa(); - sdwa.sel[0] = SubdwordSel(bits / 8, offset / 8, signext); + if (ctx.program->gfx_level >= GFX11) { + unsigned op_vgpr_byte = op.physReg().byte() + offset / 8; + unsigned sign_byte = op_vgpr_byte + bits / 8 - 1; + + uint8_t swiz[4] = {4, 5, 6, 7}; + swiz[dst.physReg().byte()] = op_vgpr_byte; + if (bits == 16) + swiz[dst.physReg().byte() + 1] = op_vgpr_byte + 1; + for (unsigned i = bits / 8; i < dst.bytes(); i++) { + uint8_t ext = bperm_0; + if (signext) { + if (sign_byte == 1) + ext = bperm_b1_sign; + else if (sign_byte == 3) + ext = bperm_b3_sign; + else /* replicate so sign-extension can be done later */ + ext = sign_byte; + } + swiz[dst.physReg().byte() + i] = ext; + } + create_bperm(bld, swiz, dst, op); + + if (signext && sign_byte != 3 && sign_byte != 1) { + assert(bits == 8); + assert(dst.regClass() == v2b || dst.regClass() == v1); + uint8_t ext_swiz[4] = {4, 5, 6, 7}; + uint8_t ext = dst.physReg().byte() == 2 ? bperm_b7_sign : bperm_b5_sign; + memset(ext_swiz + dst.physReg().byte() + 1, ext, dst.bytes() - 1); + create_bperm(bld, ext_swiz, dst, Operand::zero()); + } + } else { + SDWA_instruction& sdwa = + bld.vop1_sdwa(aco_opcode::v_mov_b32, dst, op).instr->sdwa(); + sdwa.sel[0] = SubdwordSel(bits / 8, offset / 8, signext); + } } break; } @@ -2159,6 +2273,7 @@ lower_to_hw_instr(Program* program) unsigned index = instr->operands[1].constantValue(); unsigned offset = index * bits; + bool has_sdwa = program->gfx_level >= GFX8 && program->gfx_level < GFX11; if (dst.regClass() == s1) { if (offset == (32 - bits)) { bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), op, @@ -2172,15 +2287,22 @@ lower_to_hw_instr(Program* program) bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), Operand(dst.physReg(), s1), Operand::c32(offset)); } - } else if (dst.regClass() == v1 || ctx.program->gfx_level <= GFX7) { - if (offset == (dst.bytes() * 8u - bits)) { + } else if (dst.regClass() == v1 || !has_sdwa) { + if (offset == (dst.bytes() * 8u - bits) && + (dst.regClass() == v1 || program->gfx_level <= GFX7)) { bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand::c32(offset), op); - } else if (offset == 0) { + } else if (offset == 0 && (dst.regClass() == v1 || program->gfx_level <= GFX7)) { bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand::zero(), Operand::c32(bits)); - } else if (program->gfx_level >= GFX9 || - (op.regClass() != s1 && program->gfx_level >= GFX8)) { + } else if (has_sdwa && (op.regClass() != s1 || program->gfx_level >= GFX9)) { bld.vop1_sdwa(aco_opcode::v_mov_b32, dst, op).instr->sdwa().dst_sel = SubdwordSel(bits / 8, offset / 8, false); + } else if (program->gfx_level >= GFX11) { + uint8_t swiz[] = {4, 5, 6, 7}; + for (unsigned i = 0; i < dst.bytes(); i++) + swiz[dst.physReg().byte() + i] = bperm_0; + for (unsigned i = 0; i < bits / 8; i++) + swiz[dst.physReg().byte() + i + offset / 8] = op.physReg().byte() + i; + create_bperm(bld, swiz, dst, op); } else { bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand::zero(), Operand::c32(bits)); bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand::c32(offset), diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index 87bb217..f0da221 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -153,7 +153,8 @@ validate_ir(Program* program) base_format == Format::VOPC, "Format cannot have SDWA applied", instr.get()); - check(program->gfx_level >= GFX8, "SDWA is GFX8+ only", instr.get()); + check(program->gfx_level >= GFX8, "SDWA is GFX8 to GFX10.3 only", instr.get()); + check(program->gfx_level < GFX11, "SDWA is GFX8 to GFX10.3 only", instr.get()); SDWA_instruction& sdwa = instr->sdwa(); check(sdwa.omod == 0 || program->gfx_level >= GFX9, "SDWA omod only supported on GFX9+", diff --git a/src/amd/compiler/tests/test_sdwa.cpp b/src/amd/compiler/tests/test_sdwa.cpp index 8ed5bbb..83d5206 100644 --- a/src/amd/compiler/tests/test_sdwa.cpp +++ b/src/amd/compiler/tests/test_sdwa.cpp @@ -54,7 +54,7 @@ BEGIN_TEST(validate.sdwa.support) continue; //>> Validation results: - //~gfx7! SDWA is GFX8+ only: v1: %t0 = v_mul_f32 %a, %b dst_sel:dword src0_sel:dword src1_sel:dword + //~gfx7! SDWA is GFX8 to GFX10.3 only: v1: %t0 = v_mul_f32 %a, %b dst_sel:dword src0_sel:dword src1_sel:dword //~gfx7! Validation failed //~gfx([89]|10)! Validation passed bld.vop2_sdwa(aco_opcode::v_mul_f32, bld.def(v1), inputs[0], inputs[1]); -- 2.7.4