aco: propagate swizzles when optimizing packed clamp & fma
authorDaniel Schürmann <daniel@schuermann.dev>
Thu, 7 Jan 2021 14:07:09 +0000 (15:07 +0100)
committerMarge Bot <eric+marge@anholt.net>
Wed, 13 Jan 2021 17:46:56 +0000 (17:46 +0000)
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6680>

src/amd/compiler/aco_optimizer.cpp

index 6ce2567..646f06e 100644 (file)
@@ -2726,6 +2726,28 @@ bool combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    return false;
 }
 
+void propagate_swizzles(VOP3P_instruction* instr, uint8_t opsel_lo, uint8_t opsel_hi)
+{
+   /* propagate swizzles which apply to a result down to the instruction's operands:
+    * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
+   assert((opsel_lo & 1) == opsel_lo);
+   assert((opsel_hi & 1) == opsel_hi);
+   uint8_t tmp_lo = instr->opsel_lo;
+   uint8_t tmp_hi = instr->opsel_hi;
+   bool neg_lo[3] = { instr->neg_lo[0], instr->neg_lo[1], instr->neg_lo[2] };
+   bool neg_hi[3] = { instr->neg_hi[0], instr->neg_hi[1], instr->neg_hi[2] };
+   if (opsel_lo == 1) {
+      instr->opsel_lo = tmp_hi;
+      for (unsigned i = 0; i < 3; i++)
+         instr->neg_lo[i] = neg_hi[i];
+   }
+   if (opsel_hi == 0) {
+      instr->opsel_hi = tmp_lo;
+      for (unsigned i = 0; i < 3; i++)
+         instr->neg_hi[i] = neg_lo[i];
+   }
+}
+
 void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    VOP3P_instruction* vop3p = static_cast<VOP3P_instruction*>(instr.get());
@@ -2734,15 +2756,14 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    if (instr->opcode == aco_opcode::v_pk_mul_f16 &&
        instr->operands[1].constantEquals(0x3C00) &&
        vop3p->clamp &&
-       vop3p->opsel_lo == 0x0 &&
-       vop3p->opsel_hi == 0x1 &&
        instr->operands[0].isTemp() &&
        ctx.uses[instr->operands[0].tempId()] == 1) {
 
       ssa_info& info = ctx.info[instr->operands[0].tempId()];
       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
-         Instruction* candidate = ctx.info[instr->operands[0].tempId()].instr;
-         static_cast<VOP3P_instruction*>(candidate)->clamp = true;
+         VOP3P_instruction* candidate = static_cast<VOP3P_instruction*>(ctx.info[instr->operands[0].tempId()].instr);
+         candidate->clamp = true;
+         propagate_swizzles(candidate, vop3p->opsel_lo, vop3p->opsel_hi);
          std::swap(instr->definitions[0], candidate->definitions[0]);
          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
          ctx.uses[instr->definitions[0].tempId()]--;
@@ -2794,6 +2815,7 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 
       Instruction* mul_instr = nullptr;
       unsigned add_op_idx = 0;
+      uint8_t opsel_lo = 0, opsel_hi = 0;
       uint32_t uses = UINT32_MAX;
 
       /* find the 'best' mul instruction to combine with the add */
@@ -2809,16 +2831,14 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
              !check_vop3_operands(ctx, 3, op))
             continue;
 
-         /* opsel of mul needs to be .xy */
-         if (static_cast<VOP3P_instruction*>(instr.get())->opsel_lo & (1 << i) ||
-             !(static_cast<VOP3P_instruction*>(instr.get())->opsel_hi & (1 << i)))
-            continue;
          /* no clamp allowed between mul and add */
          if (static_cast<VOP3P_instruction*>(info.instr)->clamp)
             continue;
 
          mul_instr = info.instr;
          add_op_idx = 1 - i;
+         opsel_lo = (vop3p->opsel_lo >> i) & 1;
+         opsel_hi = (vop3p->opsel_hi >> i) & 1;
          uses = ctx.uses[instr->operands[i].tempId()];
       }
 
@@ -2845,11 +2865,14 @@ void combine_vop3p(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          fma->neg_hi[i] = mul->neg_hi[i];
       }
       fma->operands[2] = op[2];
+      fma->clamp = vop3p->clamp;
+      fma->opsel_lo = mul->opsel_lo;
+      fma->opsel_hi = mul->opsel_hi;
+      propagate_swizzles(fma.get(), opsel_lo, opsel_hi);
+      fma->opsel_lo |= (vop3p->opsel_lo << (2 - add_op_idx)) & 0x4;
+      fma->opsel_hi |= (vop3p->opsel_hi << (2 - add_op_idx)) & 0x4;
       fma->neg_lo[2] = vop3p->neg_lo[add_op_idx];
       fma->neg_hi[2] = vop3p->neg_hi[add_op_idx];
-      fma->clamp = vop3p->clamp;
-      fma->opsel_lo = mul->opsel_lo | ((vop3p->opsel_lo << (2 - add_op_idx)) & 0x4);
-      fma->opsel_hi = mul->opsel_hi | ((vop3p->opsel_hi << (2 - add_op_idx)) & 0x4);
       fma->neg_lo[1] = fma->neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
       fma->neg_hi[1] = fma->neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
       fma->definitions[0] = instr->definitions[0];