aco/optimizer: simplify using VALU instruction
authorGeorg Lehmann <dadschoorse@gmail.com>
Tue, 31 Jan 2023 17:03:01 +0000 (18:03 +0100)
committerMarge Bot <emma+marge@anholt.net>
Tue, 7 Mar 2023 11:53:23 +0000 (11:53 +0000)
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21023>

src/amd/compiler/aco_optimizer.cpp

index e576c73..eb976b9 100644 (file)
@@ -1423,24 +1423,15 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
             if (!instr->isDPP() && !instr->isSDWA())
                to_VOP3(ctx, instr);
             instr->operands[i].setTemp(info.temp);
-            if (instr->isDPP16() && !instr->dpp16().abs[i])
-               instr->dpp16().neg[i] = true;
-            else if (instr->isSDWA() && !instr->sdwa().abs[i])
-               instr->sdwa().neg[i] = true;
-            else if (instr->isVOP3() && !instr->vop3().abs[i])
-               instr->vop3().neg[i] = true;
+            if (!instr->valu().abs[i])
+               instr->valu().neg[i] = true;
          }
          if (info.is_abs() && can_use_mod && mod_bitsize_compat &&
              can_eliminate_fcanonicalize(ctx, instr, info.temp)) {
             if (!instr->isDPP() && !instr->isSDWA())
                to_VOP3(ctx, instr);
             instr->operands[i] = Operand(info.temp);
-            if (instr->isDPP16())
-               instr->dpp16().abs[i] = true;
-            else if (instr->isSDWA())
-               instr->sdwa().abs[i] = true;
-            else
-               instr->vop3().abs[i] = true;
+            instr->valu().abs[i] = true;
             continue;
          }
 
@@ -3511,25 +3502,6 @@ apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    }
 }
 
-template <typename T>
-bool
-apply_omod_clamp_helper(opt_ctx& ctx, T* instr, ssa_info& def_info)
-{
-   if (!def_info.is_clamp() && (instr->clamp || instr->omod))
-      return false;
-
-   if (def_info.is_omod2())
-      instr->omod = 1;
-   else if (def_info.is_omod4())
-      instr->omod = 2;
-   else if (def_info.is_omod5())
-      instr->omod = 3;
-   else if (def_info.is_clamp())
-      instr->clamp = true;
-
-   return true;
-}
-
 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
 bool
 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
@@ -3569,17 +3541,20 @@ apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    /* MADs/FMAs are created later, so we don't have to update the original add */
    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
 
-   if (instr->isSDWA()) {
-      if (!apply_omod_clamp_helper(ctx, &instr->sdwa(), def_info))
-         return false;
-   } else if (instr->isVOP3P()) {
-      assert(def_info.is_clamp());
-      instr->vop3p().clamp = true;
-   } else {
+   if (!instr->isSDWA() && !instr->isVOP3P())
       to_VOP3(ctx, instr);
-      if (!apply_omod_clamp_helper(ctx, &instr->vop3(), def_info))
-         return false;
-   }
+
+   if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
+      return false;
+
+   if (def_info.is_omod2())
+      instr->valu().omod = 1;
+   else if (def_info.is_omod4())
+      instr->valu().omod = 2;
+   else if (def_info.is_omod5())
+      instr->valu().omod = 3;
+   else if (def_info.is_clamp())
+      instr->valu().clamp = true;
 
    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
@@ -3998,10 +3973,8 @@ to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    vop3p->opsel_hi = 0x0;
    for (unsigned i = 0; i < instr->operands.size(); i++) {
       vop3p->operands[is_add + i] = instr->operands[i];
-      vop3p->neg_lo[is_add + i] = instr->isVOP3() && instr->vop3().neg[i];
-      vop3p->neg_lo[is_add + i] |= instr->isSDWA() && instr->sdwa().neg[i];
-      vop3p->neg_hi[is_add + i] = instr->isVOP3() && instr->vop3().abs[i];
-      vop3p->neg_hi[is_add + i] |= instr->isSDWA() && instr->sdwa().abs[i];
+      vop3p->neg_lo[is_add + i] = instr->valu().neg[i];
+      vop3p->neg_hi[is_add + i] = instr->valu().abs[i];
       vop3p->opsel_lo |= (instr->isSDWA() && instr->sdwa().sel[i].offset()) << (is_add + i);
    }
    if (instr->opcode == aco_opcode::v_mul_f32) {
@@ -4105,11 +4078,11 @@ combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       instr->vop3p().opsel_hi ^= 1u << i;
       if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
          instr->vop3p().opsel_lo |= 1u << i;
-      bool neg = (conv->isVOP3() && conv->vop3().neg[0]) || (conv->isSDWA() && conv->sdwa().neg[0]);
-      bool abs = (conv->isVOP3() && conv->vop3().abs[0]) || (conv->isSDWA() && conv->sdwa().abs[0]);
-      if (!instr->vop3p().neg_hi[i]) {
-         instr->vop3p().neg_lo[i] ^= neg;
-         instr->vop3p().neg_hi[i] = abs;
+      bool neg = conv->valu().neg[0];
+      bool abs = conv->valu().abs[0];
+      if (!instr->vop3p().abs[i]) {
+         instr->vop3p().neg[i] ^= neg;
+         instr->vop3p().abs[i] = abs;
       }
    }
 }
@@ -4375,55 +4348,32 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
          bool clamp = false;
          uint8_t opsel_lo = 0;
          uint8_t opsel_hi = 0;
-
-         if (mul_instr->isVOP3()) {
-            VOP3_instruction& vop3 = mul_instr->vop3();
-            neg[0] = vop3.neg[0];
-            neg[1] = vop3.neg[1];
-            abs[0] = vop3.abs[0];
-            abs[1] = vop3.abs[1];
-         } else if (mul_instr->isVOP3P()) {
-            VOP3P_instruction& vop3p = mul_instr->vop3p();
-            neg[0] = vop3p.neg_lo[0];
-            neg[1] = vop3p.neg_lo[1];
-            abs[0] = vop3p.neg_hi[0];
-            abs[1] = vop3p.neg_hi[1];
-            opsel_lo = vop3p.opsel_lo & 0x3;
-            opsel_hi = vop3p.opsel_hi & 0x3;
-         }
-
-         if (instr->isVOP3()) {
-            VOP3_instruction& vop3 = instr->vop3();
-            neg[2] = vop3.neg[add_op_idx];
-            abs[2] = vop3.abs[add_op_idx];
-            omod = vop3.omod;
-            clamp = vop3.clamp;
-            /* abs of the multiplication result */
-            if (vop3.abs[1 - add_op_idx]) {
-               neg[0] = false;
-               neg[1] = false;
-               abs[0] = true;
-               abs[1] = true;
-            }
-            /* neg of the multiplication result */
-            neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
-         } else if (instr->isVOP3P()) {
-            VOP3P_instruction& vop3p = instr->vop3p();
-            neg[2] = vop3p.neg_lo[add_op_idx];
-            abs[2] = vop3p.neg_hi[add_op_idx];
-            opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
-            opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
-            clamp = vop3p.clamp;
-            /* abs of the multiplication result */
-            if (vop3p.neg_hi[3 - add_op_idx]) {
-               neg[0] = false;
-               neg[1] = false;
-               abs[0] = true;
-               abs[1] = true;
-            }
-            /* neg of the multiplication result */
-            neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
+         unsigned mul_op_idx = (instr->isVOP3P() ? 3 : 1) - add_op_idx;
+
+         VALU_instruction& valu_mul = mul_instr->valu();
+         neg[0] = valu_mul.neg[0];
+         neg[1] = valu_mul.neg[1];
+         abs[0] = valu_mul.abs[0];
+         abs[1] = valu_mul.abs[1];
+         opsel_lo = valu_mul.opsel_lo & 0x3;
+         opsel_hi = valu_mul.opsel_hi & 0x3;
+
+         VALU_instruction& valu = instr->valu();
+         neg[2] = valu.neg[add_op_idx];
+         abs[2] = valu.abs[add_op_idx];
+         opsel_lo |= valu.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
+         opsel_hi |= valu.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
+         omod = valu.omod;
+         clamp = valu.clamp;
+         /* abs of the multiplication result */
+         if (valu.abs[mul_op_idx]) {
+            neg[0] = false;
+            neg[1] = false;
+            abs[0] = true;
+            abs[1] = true;
          }
+         /* neg of the multiplication result */
+         neg[1] ^= valu.neg[mul_op_idx];
 
          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
@@ -4432,24 +4382,17 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
 
          aco_ptr<Instruction> add_instr = std::move(instr);
+         aco_ptr<VALU_instruction> mad;
          if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
             assert(!omod);
 
             aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
                                                                        : aco_opcode::v_fma_mix_f32;
-            aco_ptr<VOP3P_instruction> mad{
-               create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1)};
-            for (unsigned i = 0; i < 3; i++) {
-               mad->operands[i] = op[i];
-               mad->neg_lo[i] = neg[i];
-               mad->neg_hi[i] = abs[i];
-            }
-            mad->clamp = clamp;
-            mad->opsel_lo = opsel_lo;
-            mad->opsel_hi = opsel_hi;
-
-            instr = std::move(mad);
+            mad.reset(create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 1));
          } else {
+            assert(!opsel_lo);
+            assert(!opsel_hi);
+
             aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
             if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
                assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
@@ -4463,21 +4406,23 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
                mad_op = aco_opcode::v_fma_f64;
             }
 
-            aco_ptr<VOP3_instruction> mad{
-               create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
-            for (unsigned i = 0; i < 3; i++) {
-               mad->operands[i] = op[i];
-               mad->neg[i] = neg[i];
-               mad->abs[i] = abs[i];
-            }
-            mad->omod = omod;
-            mad->clamp = clamp;
+            mad.reset(create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1));
+         }
 
-            instr = std::move(mad);
+         for (unsigned i = 0; i < 3; i++) {
+            mad->operands[i] = op[i];
+            mad->neg[i] = neg[i];
+            mad->abs[i] = abs[i];
          }
-         instr->definitions[0] = add_instr->definitions[0];
-         instr->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
-                                          mul_instr->definitions[0].isPrecise());
+         mad->omod = omod;
+         mad->clamp = clamp;
+         mad->opsel_lo = opsel_lo;
+         mad->opsel_hi = opsel_hi;
+         mad->definitions[0] = add_instr->definitions[0];
+         mad->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
+                                        mul_instr->definitions[0].isPrecise());
+
+         instr = std::move(mad);
 
          /* mark this ssa_def to be re-checked for profitability and literals */
          ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());