From 71d30bcede5df1450bf4a1c018d241a94482ba16 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Thu, 9 Mar 2023 14:51:50 +0100 Subject: [PATCH] aco: combine scalar mul+pk_add to pk_fma MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Foz-DB Navi21: Totals from 12 (0.01% of 134913) affected shaders: CodeSize: 37860 -> 37668 (-0.51%) Instrs: 6757 -> 6733 (-0.36%) Latency: 25632 -> 25589 (-0.17%) InvThroughput: 2637 -> 2622 (-0.57%) Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_optimizer.cpp | 95 +++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 31 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 411c63f..ad7dc02 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -3821,56 +3821,89 @@ combine_vop3p(opt_ctx& ctx, aco_ptr& instr) Instruction* mul_instr = nullptr; unsigned add_op_idx = 0; - bool opsel_lo = false, opsel_hi = false; + bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0; uint32_t uses = UINT32_MAX; /* find the 'best' mul instruction to combine with the add */ for (unsigned i = 0; i < 2; i++) { - if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p()) + Instruction* op_instr = follow_operand(ctx, instr->operands[i], true); + if (!op_instr) continue; - ssa_info& info = ctx.info[instr->operands[i].tempId()]; - if (fadd) { - if (info.instr->opcode != aco_opcode::v_pk_mul_f16 || - info.instr->definitions[0].isPrecise()) + + if (ctx.info[instr->operands[i].tempId()].is_vop3p()) { + if (fadd) { + if (op_instr->opcode != aco_opcode::v_pk_mul_f16 || + op_instr->definitions[0].isPrecise()) + continue; + } else { + if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16) + continue; + } + + Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]}; + if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op)) continue; - } else { - if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16) + + /* no clamp allowed between mul and add */ + if (op_instr->valu().clamp) continue; - } - Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]}; - if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op)) - continue; + mul_instr = op_instr; + add_op_idx = 1 - i; + uses = ctx.uses[instr->operands[i].tempId()]; + mul_neg_lo = mul_instr->valu().neg_lo; + mul_neg_hi = mul_instr->valu().neg_hi; + mul_opsel_lo = mul_instr->valu().opsel_lo; + mul_opsel_hi = mul_instr->valu().opsel_hi; + } else if (instr->operands[i].bytes() == 2) { + if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 || + op_instr->definitions[0].isPrecise())) || + (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 && + op_instr->opcode != aco_opcode::v_mul_lo_u16_e64)) + continue; - /* no clamp allowed between mul and add */ - if (info.instr->valu().clamp) - continue; + if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs) + continue; - mul_instr = info.instr; - add_op_idx = 1 - i; - opsel_lo = vop3p->opsel_lo[i]; - opsel_hi = vop3p->opsel_hi[i]; - uses = ctx.uses[instr->operands[i].tempId()]; + if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 || + op_instr->sdwa().sel[1].size() < 2))) + continue; + + Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]}; + if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op)) + continue; + + mul_instr = op_instr; + add_op_idx = 1 - i; + uses = ctx.uses[instr->operands[i].tempId()]; + mul_neg_lo = mul_instr->valu().neg; + mul_neg_hi = mul_instr->valu().neg; + if (mul_instr->isSDWA()) { + for (unsigned j = 0; j < 2; j++) + mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset(); + } else { + mul_opsel_lo = mul_instr->valu().opsel; + } + mul_opsel_hi = mul_opsel_lo; + } } if (!mul_instr) return; - /* turn packed mul+add into v_pk_fma_f16 */ - assert(mul_instr->isVOP3P()); + /* turn mul + packed add into v_pk_fma_f16 */ aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16; aco_ptr fma{create_instruction(mad, Format::VOP3P, 3, 1)}; - VALU_instruction* mul = &mul_instr->valu(); - for (unsigned i = 0; i < 2; i++) { - fma->operands[i] = copy_operand(ctx, mul_instr->operands[i]); - fma->neg_lo[i] = mul->neg_lo[i]; - fma->neg_hi[i] = mul->neg_hi[i]; - } + fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]); + fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]); fma->operands[2] = instr->operands[add_op_idx]; fma->clamp = vop3p->clamp; - fma->opsel_lo = mul->opsel_lo; - fma->opsel_hi = mul->opsel_hi; - propagate_swizzles(fma.get(), opsel_lo, opsel_hi); + fma->neg_lo = mul_neg_lo; + fma->neg_hi = mul_neg_hi; + fma->opsel_lo = mul_opsel_lo; + fma->opsel_hi = mul_opsel_hi; + propagate_swizzles(fma.get(), vop3p->opsel_lo[1 - add_op_idx], + vop3p->opsel_hi[1 - add_op_idx]); fma->opsel_lo[2] = vop3p->opsel_lo[add_op_idx]; fma->opsel_hi[2] = vop3p->opsel_hi[add_op_idx]; fma->neg_lo[2] = vop3p->neg_lo[add_op_idx]; -- 2.7.4