aco: use v_fma_mix_f32 for v_fma_f32 with 2 fp16 representable, different literals
authorGeorg Lehmann <dadschoorse@gmail.com>
Mon, 9 Jan 2023 11:32:56 +0000 (12:32 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 2 Mar 2023 10:59:05 +0000 (10:59 +0000)
We can pack two fp16 literals into one 32bit literal and use opsel to select
the correct value. Note that LLVM currently disassembles these instructions
incorrectly.

Foz-DB Navi21:
Totals from 13365 (9.91% of 134913) affected shaders:
VGPRs: 840880 -> 840016 (-0.10%); split: -0.11%, +0.01%
SpillSGPRs: 724 -> 722 (-0.28%)
CodeSize: 82439364 -> 82451336 (+0.01%); split: -0.06%, +0.08%
MaxWaves: 244858 -> 244980 (+0.05%)
Instrs: 15265976 -> 15247201 (-0.12%); split: -0.13%, +0.01%
Latency: 223316180 -> 223272495 (-0.02%); split: -0.03%, +0.02%
InvThroughput: 41981375 -> 41969917 (-0.03%); split: -0.04%, +0.01%
VClause: 266775 -> 266558 (-0.08%); split: -0.14%, +0.06%
SClause: 646602 -> 645996 (-0.09%); split: -0.16%, +0.07%
Copies: 794703 -> 776075 (-2.34%); split: -2.46%, +0.12%
Branches: 296317 -> 296316 (-0.00%)
PreSGPRs: 658796 -> 656479 (-0.35%); split: -0.35%, +0.00%
PreVGPRs: 744014 -> 743679 (-0.05%)

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20587>

src/amd/compiler/aco_optimizer.cpp

index d9ce6eb..e576c73 100644 (file)
@@ -76,9 +76,10 @@ struct mad_info {
    aco_ptr<Instruction> add_instr;
    uint32_t mul_temp_id;
    uint16_t literal_mask;
+   uint16_t fp16_mask;
 
    mad_info(aco_ptr<Instruction> instr, uint32_t id)
-       : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0)
+       : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0)
    {}
 };
 
@@ -4755,8 +4756,7 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
          mad_info = NULL;
       }
       /* check literals */
-      else if (!instr->usesModifiers() && !instr->isVOP3P() &&
-               instr->opcode != aco_opcode::v_fma_f64 &&
+      else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 &&
                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
                instr->opcode != aco_opcode::v_fma_legacy_f32) {
          /* FMA can only take literals on GFX10+ */
@@ -4770,6 +4770,7 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
             return;
 
          uint32_t literal_mask = 0;
+         uint32_t fp16_mask = 0;
          uint32_t sgpr_mask = 0;
          uint32_t vgpr_mask = 0;
          uint32_t literal_uses = UINT32_MAX;
@@ -4782,6 +4783,13 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
                continue;
             if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) {
                uint32_t new_literal = ctx.info[op.tempId()].val;
+               float value = uif(new_literal);
+               uint16_t fp16_val = _mesa_float_to_half(value);
+               bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
+               if (_mesa_half_to_float(fp16_val) == value &&
+                   (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
+                  fp16_mask |= 1 << i;
+
                if (!literal_mask || literal_value == new_literal) {
                   literal_value = new_literal;
                   literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]);
@@ -4805,6 +4813,24 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
          if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100))
             literal_mask = 0;
 
+         if (instr->usesModifiers())
+            literal_mask = 0;
+
+         /* We can't use three unique fp16 literals */
+         if (fp16_mask == 0b111)
+            fp16_mask = 0b11;
+
+         if ((instr->opcode == aco_opcode::v_fma_f32 ||
+              (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) &&
+             !instr->vop3().omod && ctx.program->gfx_level >= GFX10 &&
+             util_bitcount(fp16_mask) > std::max<uint32_t>(util_bitcount(literal_mask), 1)) {
+            assert(ctx.program->dev.fused_mad_mix);
+            u_foreach_bit (i, fp16_mask)
+               ctx.uses[instr->operands[i].tempId()]--;
+            mad_info->fp16_mask = fp16_mask;
+            return;
+         }
+
          /* Limit the number of literals to apply to not increase the code
           * size too much, but always apply literals for v_mad->v_madak
           * because both instructions are 64-bit and this doesn't increase
@@ -5159,8 +5185,41 @@ apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags];
       const bool madak = (info->literal_mask & 0b100);
       bool has_dead_literal = false;
-      u_foreach_bit (i, info->literal_mask)
+      u_foreach_bit (i, info->literal_mask | info->fp16_mask)
          has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0;
+
+      if (has_dead_literal && info->fp16_mask) {
+         aco_ptr<Instruction> fma_mix(
+            create_instruction<VOP3P_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1));
+
+         fma_mix->vop3p().clamp = instr->vop3().clamp;
+         std::copy(std::cbegin(instr->vop3().abs), std::cend(instr->vop3().abs),
+                   std::begin(fma_mix->vop3p().neg_hi));
+         std::copy(std::cbegin(instr->vop3().neg), std::cend(instr->vop3().neg),
+                   std::begin(fma_mix->vop3p().neg_lo));
+
+         uint32_t literal = 0;
+         bool second = false;
+         u_foreach_bit (i, info->fp16_mask) {
+            float value = uif(ctx.info[instr->operands[i].tempId()].val);
+            literal |= _mesa_float_to_half(value) << (second * 16);
+            fma_mix->vop3p().opsel_lo |= second << i;
+            fma_mix->vop3p().opsel_hi |= 1 << i;
+            second = true;
+         }
+
+         for (unsigned i = 0; i < 3; i++) {
+            if (info->fp16_mask & (1 << i))
+               fma_mix->operands[i] = Operand::literal32(literal);
+            else
+               fma_mix->operands[i] = instr->operands[i];
+         }
+
+         fma_mix->definitions[0] = instr->definitions[0];
+         ctx.instructions.emplace_back(std::move(fma_mix));
+         return;
+      }
+
       if (has_dead_literal || madak) {
          aco_ptr<Instruction> new_mad;