aco: support v_fma_f32_dpp as fma_mix
authorGeorg Lehmann <dadschoorse@gmail.com>
Thu, 14 Sep 2023 11:15:39 +0000 (13:15 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 5 Oct 2023 20:02:53 +0000 (20:02 +0000)
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25220>

src/amd/compiler/aco_optimizer.cpp

index c58f41201c37f0b21feada93dd927fef13165573..5bb52baf1770e9a382d9e353f3730ae041926277 100644 (file)
@@ -3929,28 +3929,34 @@ can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
       return false;
 
+   if (instr->valu().omod)
+      return false;
+
    switch (instr->opcode) {
    case aco_opcode::v_add_f32:
    case aco_opcode::v_sub_f32:
    case aco_opcode::v_subrev_f32:
-   case aco_opcode::v_mul_f32:
-   case aco_opcode::v_fma_f32: break;
+   case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
+   case aco_opcode::v_fma_f32:
+      return ctx.program->dev.fused_mad_mix || !instr->definitions[0].isPrecise();
    case aco_opcode::v_fma_mix_f32:
    case aco_opcode::v_fma_mixlo_f16: return true;
    default: return false;
    }
-
-   if (instr->opcode == aco_opcode::v_fma_f32 && !ctx.program->dev.fused_mad_mix &&
-       instr->definitions[0].isPrecise())
-      return false;
-
-   return !instr->valu().omod && !instr->isSDWA() && !instr->isDPP();
 }
 
 void
 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 {
-   bool is_add = instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
+   ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
+
+   if (instr->opcode == aco_opcode::v_fma_f32) {
+      instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P));
+      instr->opcode = aco_opcode::v_fma_mix_f32;
+      return;
+   }
+
+   bool is_add = instr->opcode != aco_opcode::v_mul_f32;
 
    aco_ptr<VALU_instruction> vop3p{
       create_instruction<VALU_instruction>(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
@@ -3975,7 +3981,6 @@ to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    vop3p->pass_flags = instr->pass_flags;
    instr = std::move(vop3p);
 
-   ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
    if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
       ctx.info[instr->definitions[0].tempId()].instr = instr.get();
 }
@@ -4044,6 +4049,8 @@ combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       op[i] = conv->operands[0];
       if (!check_vop3_operands(ctx, instr->operands.size(), op))
          continue;
+      if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP())
+         continue;
 
       if (!instr->isVOP3P()) {
          bool is_add =