Instruction* mul_instr = nullptr;
unsigned add_op_idx = 0;
- bool opsel_lo = false, opsel_hi = false;
+ bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
uint32_t uses = UINT32_MAX;
/* find the 'best' mul instruction to combine with the add */
for (unsigned i = 0; i < 2; i++) {
- if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
+ Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
+ if (!op_instr)
continue;
- ssa_info& info = ctx.info[instr->operands[i].tempId()];
- if (fadd) {
- if (info.instr->opcode != aco_opcode::v_pk_mul_f16 ||
- info.instr->definitions[0].isPrecise())
+
+ if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
+ if (fadd) {
+ if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
+ op_instr->definitions[0].isPrecise())
+ continue;
+ } else {
+ if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
+ continue;
+ }
+
+ Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
+ if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
continue;
- } else {
- if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16)
+
+ /* no clamp allowed between mul and add */
+ if (op_instr->valu().clamp)
continue;
- }
- Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
- if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
- continue;
+ mul_instr = op_instr;
+ add_op_idx = 1 - i;
+ uses = ctx.uses[instr->operands[i].tempId()];
+ mul_neg_lo = mul_instr->valu().neg_lo;
+ mul_neg_hi = mul_instr->valu().neg_hi;
+ mul_opsel_lo = mul_instr->valu().opsel_lo;
+ mul_opsel_hi = mul_instr->valu().opsel_hi;
+ } else if (instr->operands[i].bytes() == 2) {
+ if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
+ op_instr->definitions[0].isPrecise())) ||
+ (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
+ op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
+ continue;
- /* no clamp allowed between mul and add */
- if (info.instr->valu().clamp)
- continue;
+ if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
+ continue;
- mul_instr = info.instr;
- add_op_idx = 1 - i;
- opsel_lo = vop3p->opsel_lo[i];
- opsel_hi = vop3p->opsel_hi[i];
- uses = ctx.uses[instr->operands[i].tempId()];
+ if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
+ op_instr->sdwa().sel[1].size() < 2)))
+ continue;
+
+ Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
+ if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
+ continue;
+
+ mul_instr = op_instr;
+ add_op_idx = 1 - i;
+ uses = ctx.uses[instr->operands[i].tempId()];
+ mul_neg_lo = mul_instr->valu().neg;
+ mul_neg_hi = mul_instr->valu().neg;
+ if (mul_instr->isSDWA()) {
+ for (unsigned j = 0; j < 2; j++)
+ mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
+ } else {
+ mul_opsel_lo = mul_instr->valu().opsel;
+ }
+ mul_opsel_hi = mul_opsel_lo;
+ }
}
if (!mul_instr)
return;
- /* turn packed mul+add into v_pk_fma_f16 */
- assert(mul_instr->isVOP3P());
+ /* turn mul + packed add into v_pk_fma_f16 */
aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
aco_ptr<VALU_instruction> fma{create_instruction<VALU_instruction>(mad, Format::VOP3P, 3, 1)};
- VALU_instruction* mul = &mul_instr->valu();
- for (unsigned i = 0; i < 2; i++) {
- fma->operands[i] = copy_operand(ctx, mul_instr->operands[i]);
- fma->neg_lo[i] = mul->neg_lo[i];
- fma->neg_hi[i] = mul->neg_hi[i];
- }
+ fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
+ fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
fma->operands[2] = instr->operands[add_op_idx];
fma->clamp = vop3p->clamp;
- fma->opsel_lo = mul->opsel_lo;
- fma->opsel_hi = mul->opsel_hi;
- propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
+ fma->neg_lo = mul_neg_lo;
+ fma->neg_hi = mul_neg_hi;
+ fma->opsel_lo = mul_opsel_lo;
+ fma->opsel_hi = mul_opsel_hi;
+ propagate_swizzles(fma.get(), vop3p->opsel_lo[1 - add_op_idx],
+ vop3p->opsel_hi[1 - add_op_idx]);
fma->opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
fma->opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];