aco: swap operands if necessary to create v_madak/v_fmaak
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 13 May 2021 12:34:52 +0000 (13:34 +0100)
committerMarge Bot <emma+marge@anholt.net>
Mon, 13 Dec 2021 11:22:33 +0000 (11:22 +0000)
Also rewrite the check_literal logic to be more straightforward.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/9805>

src/amd/compiler/aco_optimizer.cpp

index 548b9e0..b2569a0 100644 (file)
@@ -3731,33 +3731,46 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
             return;
 
-         bool sgpr_used = false;
          uint32_t literal_idx = 0;
          uint32_t literal_uses = UINT32_MAX;
-         for (unsigned i = 0; i < instr->operands.size(); i++) {
-            if (instr->operands[i].isConstant() && i > 0) {
-               literal_uses = UINT32_MAX;
-               break;
+
+         /* Try using v_madak/v_fmaak */
+         if (instr->operands[2].isTemp() &&
+             ctx.info[instr->operands[2].tempId()].is_literal(get_operand_size(instr, 2))) {
+            bool has_sgpr = false;
+            bool has_vgpr = false;
+            for (unsigned i = 0; i < 2; i++) {
+               if (!instr->operands[i].isTemp())
+                  continue;
+               has_sgpr |= instr->operands[i].getTemp().type() == RegType::sgpr;
+               has_vgpr |= instr->operands[i].getTemp().type() == RegType::vgpr;
             }
-            if (!instr->operands[i].isTemp())
-               continue;
-            unsigned bits = get_operand_size(instr, i);
-            /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10
-             * or operands other than the 1st */
-            if (instr->operands[i].getTemp().type() == RegType::sgpr &&
-                (i > 0 || ctx.program->chip_class < GFX10)) {
-               if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal(bits)) {
-                  literal_uses = ctx.uses[instr->operands[i].tempId()];
+            /* Encoding limitations requires a VGPR operand. The constant bus limitations before
+             * GFX10 disallows SGPRs.
+             */
+            if ((!has_sgpr || ctx.program->chip_class >= GFX10) && has_vgpr) {
+               literal_idx = 2;
+               literal_uses = ctx.uses[instr->operands[2].tempId()];
+            }
+         }
+
+         /* Try using v_madmk/v_fmamk */
+         /* Encoding limitations requires a VGPR operand. */
+         if (instr->operands[2].isTemp() && instr->operands[2].getTemp().type() == RegType::vgpr) {
+            for (unsigned i = 0; i < 2; i++) {
+               if (!instr->operands[i].isTemp())
+                  continue;
+
+               /* The constant bus limitations before GFX10 disallows SGPRs. */
+               if (ctx.program->chip_class < GFX10 && instr->operands[!i].isTemp() &&
+                   instr->operands[!i].getTemp().type() == RegType::sgpr)
+                  continue;
+
+               if (ctx.info[instr->operands[i].tempId()].is_literal(get_operand_size(instr, i)) &&
+                   ctx.uses[instr->operands[i].tempId()] < literal_uses) {
                   literal_idx = i;
-               } else {
-                  literal_uses = UINT32_MAX;
+                  literal_uses = ctx.uses[instr->operands[i].tempId()];
                }
-               sgpr_used = true;
-               /* don't break because we still need to check constants */
-            } else if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal(bits) &&
-                       ctx.uses[instr->operands[i].tempId()] < literal_uses) {
-               literal_uses = ctx.uses[instr->operands[i].tempId()];
-               literal_idx = i;
             }
          }
 
@@ -3953,6 +3966,9 @@ apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
          if (info->literal_idx == 2) { /* add literal -> madak */
             new_mad->operands[0] = instr->operands[0];
             new_mad->operands[1] = instr->operands[1];
+            if (!new_mad->operands[1].isTemp() ||
+                new_mad->operands[1].getTemp().type() == RegType::sgpr)
+               std::swap(new_mad->operands[0], new_mad->operands[1]);
          } else { /* mul literal -> madmk */
             new_mad->operands[0] = instr->operands[1 - info->literal_idx];
             new_mad->operands[1] = instr->operands[2];