aco: combine scalar mul+pk_add to pk_fma
authorGeorg Lehmann <dadschoorse@gmail.com>
Thu, 9 Mar 2023 13:51:50 +0000 (14:51 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 20 Jun 2023 14:29:24 +0000 (14:29 +0000)
Foz-DB Navi21:
Totals from 12 (0.01% of 134913) affected shaders:
CodeSize: 37860 -> 37668 (-0.51%)
Instrs: 6757 -> 6733 (-0.36%)
Latency: 25632 -> 25589 (-0.17%)
InvThroughput: 2637 -> 2622 (-0.57%)

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21819>

src/amd/compiler/aco_optimizer.cpp

index 411c63f..ad7dc02 100644 (file)
@@ -3821,56 +3821,89 @@ combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 
       Instruction* mul_instr = nullptr;
       unsigned add_op_idx = 0;
-      bool opsel_lo = false, opsel_hi = false;
+      bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
       uint32_t uses = UINT32_MAX;
 
       /* find the 'best' mul instruction to combine with the add */
       for (unsigned i = 0; i < 2; i++) {
-         if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_vop3p())
+         Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
+         if (!op_instr)
             continue;
-         ssa_info& info = ctx.info[instr->operands[i].tempId()];
-         if (fadd) {
-            if (info.instr->opcode != aco_opcode::v_pk_mul_f16 ||
-                info.instr->definitions[0].isPrecise())
+
+         if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
+            if (fadd) {
+               if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
+                   op_instr->definitions[0].isPrecise())
+                  continue;
+            } else {
+               if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
+                  continue;
+            }
+
+            Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
+            if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
                continue;
-         } else {
-            if (info.instr->opcode != aco_opcode::v_pk_mul_lo_u16)
+
+            /* no clamp allowed between mul and add */
+            if (op_instr->valu().clamp)
                continue;
-         }
 
-         Operand op[3] = {info.instr->operands[0], info.instr->operands[1], instr->operands[1 - i]};
-         if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
-            continue;
+            mul_instr = op_instr;
+            add_op_idx = 1 - i;
+            uses = ctx.uses[instr->operands[i].tempId()];
+            mul_neg_lo = mul_instr->valu().neg_lo;
+            mul_neg_hi = mul_instr->valu().neg_hi;
+            mul_opsel_lo = mul_instr->valu().opsel_lo;
+            mul_opsel_hi = mul_instr->valu().opsel_hi;
+         } else if (instr->operands[i].bytes() == 2) {
+            if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
+                          op_instr->definitions[0].isPrecise())) ||
+                (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
+                 op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
+               continue;
 
-         /* no clamp allowed between mul and add */
-         if (info.instr->valu().clamp)
-            continue;
+            if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
+               continue;
 
-         mul_instr = info.instr;
-         add_op_idx = 1 - i;
-         opsel_lo = vop3p->opsel_lo[i];
-         opsel_hi = vop3p->opsel_hi[i];
-         uses = ctx.uses[instr->operands[i].tempId()];
+            if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
+                                                             op_instr->sdwa().sel[1].size() < 2)))
+               continue;
+
+            Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
+            if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
+               continue;
+
+            mul_instr = op_instr;
+            add_op_idx = 1 - i;
+            uses = ctx.uses[instr->operands[i].tempId()];
+            mul_neg_lo = mul_instr->valu().neg;
+            mul_neg_hi = mul_instr->valu().neg;
+            if (mul_instr->isSDWA()) {
+               for (unsigned j = 0; j < 2; j++)
+                  mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
+            } else {
+               mul_opsel_lo = mul_instr->valu().opsel;
+            }
+            mul_opsel_hi = mul_opsel_lo;
+         }
       }
 
       if (!mul_instr)
          return;
 
-      /* turn packed mul+add into v_pk_fma_f16 */
-      assert(mul_instr->isVOP3P());
+      /* turn mul + packed add into v_pk_fma_f16 */
       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
       aco_ptr<VALU_instruction> fma{create_instruction<VALU_instruction>(mad, Format::VOP3P, 3, 1)};
-      VALU_instruction* mul = &mul_instr->valu();
-      for (unsigned i = 0; i < 2; i++) {
-         fma->operands[i] = copy_operand(ctx, mul_instr->operands[i]);
-         fma->neg_lo[i] = mul->neg_lo[i];
-         fma->neg_hi[i] = mul->neg_hi[i];
-      }
+      fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
+      fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
       fma->operands[2] = instr->operands[add_op_idx];
       fma->clamp = vop3p->clamp;
-      fma->opsel_lo = mul->opsel_lo;
-      fma->opsel_hi = mul->opsel_hi;
-      propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
+      fma->neg_lo = mul_neg_lo;
+      fma->neg_hi = mul_neg_hi;
+      fma->opsel_lo = mul_opsel_lo;
+      fma->opsel_hi = mul_opsel_hi;
+      propagate_swizzles(fma.get(), vop3p->opsel_lo[1 - add_op_idx],
+                         vop3p->opsel_hi[1 - add_op_idx]);
       fma->opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
       fma->opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
       fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];