aco: handle SDWA in the optimizer
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 5 Dec 2019 14:12:39 +0000 (14:12 +0000)
committerMarge Bot <eric+marge@anholt.net>
Thu, 29 Oct 2020 18:08:31 +0000 (18:08 +0000)
Apply SGPRs/modifiers when possible and try not to break when SDWA
instructions are encountered.

No shader-db changes.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7349>

src/amd/compiler/aco_optimizer.cpp

index 48cf5e3..9a8d6d9 100644 (file)
@@ -619,8 +619,10 @@ bool can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
           instr->opcode != aco_opcode::v_readfirstlane_b32;
 }
 
-bool can_apply_sgprs(aco_ptr<Instruction>& instr)
+bool can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 {
+   if (instr->isSDWA() && ctx.program->chip_class < GFX9)
+      return false;
    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
           instr->opcode != aco_opcode::v_readlane_b32 &&
           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
@@ -891,7 +893,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
             info = ctx.info[info.temp.id()];
          }
          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
-         if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(instr) && instr->operands.size() == 1) {
+         if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) && instr->operands.size() == 1) {
             instr->operands[i].setTemp(info.temp);
             info = ctx.info[info.temp.id()];
          }
@@ -900,12 +902,19 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          unsigned can_use_mod = instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
          can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
 
-         if (info.is_abs() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) {
-            if (!instr->isDPP())
+         if (instr->isSDWA())
+            can_use_mod = can_use_mod && (static_cast<SDWA_instruction*>(instr.get())->sel[i] & sdwa_asuint) == sdwa_udword;
+         else
+            can_use_mod = can_use_mod && (instr->isDPP() || can_use_VOP3(ctx, instr));
+
+         if (info.is_abs() && can_use_mod) {
+            if (!instr->isDPP() && !instr->isSDWA())
                to_VOP3(ctx, instr);
             instr->operands[i] = Operand(info.temp);
             if (instr->isDPP())
                static_cast<DPP_instruction*>(instr.get())->abs[i] = true;
+            else if (instr->isSDWA())
+               static_cast<SDWA_instruction*>(instr.get())->abs[i] = true;
             else
                static_cast<VOP3A_instruction*>(instr.get())->abs[i] = true;
          }
@@ -917,12 +926,14 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
             instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
             instr->operands[i].setTemp(info.temp);
             continue;
-         } else if (info.is_neg() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) {
-            if (!instr->isDPP())
+         } else if (info.is_neg() && can_use_mod) {
+            if (!instr->isDPP() && !instr->isSDWA())
                to_VOP3(ctx, instr);
             instr->operands[i].setTemp(info.temp);
             if (instr->isDPP())
                static_cast<DPP_instruction*>(instr.get())->neg[i] = true;
+            else if (instr->isSDWA())
+               static_cast<SDWA_instruction*>(instr.get())->neg[i] = true;
             else
                static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
             continue;
@@ -932,7 +943,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
              (!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
             Operand op = get_constant_op(ctx, info, bits);
             perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
-            if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
+            if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
+                instr->opcode == aco_opcode::v_writelane_b32) {
                instr->operands[i] = op;
                continue;
             } else if (!instr->isVOP3() && can_swap_operands(instr)) {
@@ -1641,6 +1653,8 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          neg[i] = vop3->neg[0];
          abs[i] = vop3->abs[0];
          opsel |= (vop3->opsel & 1) << i;
+      } else if (op_instr[i]->isSDWA()) {
+         return false;
       }
 
       Temp op0 = op_instr[i]->operands[0].getTemp();
@@ -1715,6 +1729,8 @@ bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    Instruction *cmp = follow_operand(ctx, instr->operands[1], true);
    if (!nan_test || !cmp)
       return false;
+   if (nan_test->isSDWA() || cmp->isSDWA())
+      return false;
 
    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
       std::swap(nan_test, cmp);
@@ -1785,6 +1801,8 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
 
    if (!nan_test || !cmp)
       return false;
+   if (nan_test->isSDWA() || cmp->isSDWA())
+      return false;
 
    aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
    if (get_f32_cmp(cmp->opcode) == expected_nan_test)
@@ -1906,6 +1924,18 @@ bool combine_inverse_comparison(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       new_vop3->omod = cmp_vop3->omod;
       new_vop3->opsel = cmp_vop3->opsel;
       new_instr = new_vop3;
+   } else if (cmp->isSDWA()) {
+      SDWA_instruction *new_sdwa = create_instruction<SDWA_instruction>(
+         new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
+      SDWA_instruction *cmp_sdwa = static_cast<SDWA_instruction*>(cmp);
+      memcpy(new_sdwa->abs, cmp_sdwa->abs, sizeof(new_sdwa->abs));
+      memcpy(new_sdwa->sel, cmp_sdwa->sel, sizeof(new_sdwa->sel));
+      memcpy(new_sdwa->neg, cmp_sdwa->neg, sizeof(new_sdwa->neg));
+      new_sdwa->dst_sel = cmp_sdwa->dst_sel;
+      new_sdwa->dst_preserve = cmp_sdwa->dst_preserve;
+      new_sdwa->clamp = cmp_sdwa->clamp;
+      new_sdwa->omod = cmp_sdwa->omod;
+      new_instr = new_sdwa;
    } else {
       new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
    }
@@ -1942,6 +1972,9 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2,
    VOP3A_instruction *op1_vop3 = op1_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op1_instr) : NULL;
    VOP3A_instruction *op2_vop3 = op2_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op2_instr) : NULL;
 
+   if (op1_instr->isSDWA() || op2_instr->isSDWA())
+      return false;
+
    /* don't support inbetween clamp/omod */
    if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
       return false;
@@ -2431,7 +2464,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       /* Applying two sgprs require making it VOP3, so don't do it unless it's
        * definitively beneficial.
        * TODO: this is too conservative because later the use count could be reduced to 1 */
-      if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3())
+      if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() && !instr->isSDWA())
          break;
 
       Temp sgpr = ctx.info[sgpr_info_id].temp;
@@ -2439,7 +2472,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       if (new_sgpr && num_sgprs >= max_sgprs)
          continue;
 
-      if (sgpr_idx == 0 || instr->isVOP3()) {
+      if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA()) {
          instr->operands[sgpr_idx] = Operand(sgpr);
       } else if (can_swap_operands(instr)) {
          instr->operands[sgpr_idx] = instr->operands[0];
@@ -2461,22 +2494,20 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    }
 }
 
-bool apply_omod_clamp_helper(opt_ctx &ctx, aco_ptr<Instruction>& instr, ssa_info& def_info)
+template <typename T>
+bool apply_omod_clamp_helper(opt_ctx &ctx, T *instr, ssa_info& def_info)
 {
-   to_VOP3(ctx, instr);
-
-   if (!def_info.is_clamp() && (static_cast<VOP3A_instruction*>(instr.get())->clamp ||
-                                static_cast<VOP3A_instruction*>(instr.get())->omod))
+   if (!def_info.is_clamp() && (instr->clamp || instr->omod))
       return false;
 
    if (def_info.is_omod2())
-      static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
+      instr->omod = 1;
    else if (def_info.is_omod4())
-      static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
+      instr->omod = 2;
    else if (def_info.is_omod5())
-      static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
+      instr->omod = 3;
    else if (def_info.is_clamp())
-      static_cast<VOP3A_instruction*>(instr.get())->clamp = true;
+      instr->clamp = true;
 
    return true;
 }
@@ -2488,11 +2519,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
        !instr_info.can_use_output_modifiers[(int)instr->opcode])
       return false;
 
-   if (!can_use_VOP3(ctx, instr))
+   bool can_vop3 = can_use_VOP3(ctx, instr);
+   if (!instr->isSDWA() && !can_vop3)
       return false;
 
    /* omod has no effect if denormals are enabled */
    bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0;
+   can_use_omod = can_use_omod && (can_vop3 || ctx.program->chip_class >= GFX9); /* SDWA omod is GFX9+ */
+
    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
 
    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
@@ -2506,8 +2540,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    /* MADs/FMAs are created later, so we don't have to update the original add */
    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
 
-   if (!apply_omod_clamp_helper(ctx, instr, def_info))
-      return false;
+   if (instr->isSDWA()) {
+      if (!apply_omod_clamp_helper(ctx, static_cast<SDWA_instruction *>(instr.get()), def_info))
+         return false;
+   } else {
+      to_VOP3(ctx, instr);
+      if (!apply_omod_clamp_helper(ctx, static_cast<VOP3A_instruction *>(instr.get()), def_info))
+         return false;
+   }
 
    std::swap(instr->definitions[0], def_info.instr->definitions[0]);
    ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
@@ -2525,7 +2565,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       return;
 
    if (instr->isVALU()) {
-      if (can_apply_sgprs(instr))
+      if (can_apply_sgprs(ctx, instr))
          apply_sgprs(ctx, instr);
       while (apply_omod_clamp(ctx, block, instr)) ;
    }
@@ -2534,6 +2574,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       instr->definitions[0].setHint(vcc);
    }
 
+   if (instr->isSDWA())
+      return;
+
    /* TODO: There are still some peephole optimizations that could be done:
     * - abs(a - b) -> s_absdiff_i32
     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
@@ -2557,6 +2600,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          return;
       if (mul_instr->isVOP3() && static_cast<VOP3A_instruction*>(mul_instr)->clamp)
          return;
+      if (mul_instr->isSDWA())
+         return;
 
       /* convert to mul(neg(a), b) */
       ctx.uses[mul_instr->definitions[0].tempId()]--;
@@ -2639,6 +2684,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          // TODO: would be better to check this before selecting a mul instr?
          if (!check_vop3_operands(ctx, 3, op))
             return;
+         if (mul_instr->isSDWA())
+            return;
 
          if (mul_instr->isVOP3()) {
             VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);