aco: Eliminate SALU comparison when SCC can be used instead.
authorTimur Kristóf <timur.kristof@gmail.com>
Thu, 15 Apr 2021 18:00:21 +0000 (20:00 +0200)
committerMarge Bot <eric+marge@anholt.net>
Fri, 28 May 2021 12:14:53 +0000 (12:14 +0000)
For example:

s0, scc = s_and_u32 ...
scc = s_cmp_eq_u32 s0, 0
p_cbranch_sccz

is turned into:

s0, scc = s_and_u32 ...
p_cbranch_sccnz

Fossil DB results on Sienna Cichlid:

Totals from 85267 (56.91% of 149839) affected shaders:
CodeSize: 202539256 -> 202237268 (-0.15%)
Instrs: 38964493 -> 38888996 (-0.19%)
Latency: 750062328 -> 749913450 (-0.02%); split: -0.02%, +0.00%
InvThroughput: 167408952 -> 167405157 (-0.00%)

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7779>

src/amd/compiler/aco_optimizer_postRA.cpp
src/amd/compiler/tests/test_optimizer_postRA.cpp

index 12a143c..64b8f7e 100644 (file)
@@ -162,12 +162,150 @@ void try_apply_branch_vcc(pr_opt_ctx &ctx, aco_ptr<Instruction> &instr)
    instr->operands[0] = op0_instr->operands[0];
 }
 
+void try_optimize_scc_nocompare(pr_opt_ctx &ctx, aco_ptr<Instruction> &instr)
+{
+   /* We are looking for the following pattern:
+    *
+    * s_bfe_u32 s0, s3, 0x40018  ; outputs SGPR and SCC if the SGPR != 0
+    * s_cmp_eq_i32 s0, 0         ; comparison between the SGPR and 0
+    * s_cbranch_scc0 BB3         ; use the result of the comparison, eg. branch or cselect
+    *
+    * If possible, the above is optimized into:
+    *
+    * s_bfe_u32 s0, s3, 0x40018  ; original instruction
+    * s_cbranch_scc1 BB3         ; modified to use SCC directly rather than the SGPR with comparison
+    *
+    */
+
+   if (!instr->isSALU() && !instr->isBranch())
+      return;
+
+   if (instr->isSOPC() &&
+       (instr->opcode == aco_opcode::s_cmp_eq_u32 || instr->opcode == aco_opcode::s_cmp_eq_i32 ||
+        instr->opcode == aco_opcode::s_cmp_lg_u32 || instr->opcode == aco_opcode::s_cmp_lg_i32 ||
+        instr->opcode == aco_opcode::s_cmp_eq_u64 ||
+        instr->opcode == aco_opcode::s_cmp_lg_u64) &&
+       (instr->operands[0].constantEquals(0) || instr->operands[1].constantEquals(0)) &&
+       (instr->operands[0].isTemp() || instr->operands[1].isTemp())) {
+      /* Make sure the constant is always in operand 1 */
+      if (instr->operands[0].isConstant())
+         std::swap(instr->operands[0], instr->operands[1]);
+
+      if (ctx.uses[instr->operands[0].tempId()] > 1)
+         return;
+
+      /* Make sure both SCC and Operand 0 are written by the same instruction. */
+      int wr_idx = last_writer_idx(ctx, instr->operands[0]);
+      int sccwr_idx = last_writer_idx(ctx, scc, s1);
+      if (wr_idx < 0 || wr_idx != sccwr_idx)
+         return;
+
+      aco_ptr<Instruction> &wr_instr = ctx.current_block->instructions[wr_idx];
+      if (!wr_instr->isSALU() || wr_instr->definitions.size() < 2 || wr_instr->definitions[1].physReg() != scc)
+         return;
+
+      /* Look for instructions which set SCC := (D != 0) */
+      switch (wr_instr->opcode) {
+      case aco_opcode::s_bfe_i32:
+      case aco_opcode::s_bfe_i64:
+      case aco_opcode::s_bfe_u32:
+      case aco_opcode::s_bfe_u64:
+      case aco_opcode::s_and_b32:
+      case aco_opcode::s_and_b64:
+      case aco_opcode::s_andn2_b32:
+      case aco_opcode::s_andn2_b64:
+      case aco_opcode::s_or_b32:
+      case aco_opcode::s_or_b64:
+      case aco_opcode::s_orn2_b32:
+      case aco_opcode::s_orn2_b64:
+      case aco_opcode::s_xor_b32:
+      case aco_opcode::s_xor_b64:
+      case aco_opcode::s_not_b32:
+      case aco_opcode::s_not_b64:
+      case aco_opcode::s_nor_b32:
+      case aco_opcode::s_nor_b64:
+      case aco_opcode::s_xnor_b32:
+      case aco_opcode::s_xnor_b64:
+      case aco_opcode::s_nand_b32:
+      case aco_opcode::s_nand_b64:
+      case aco_opcode::s_lshl_b32:
+      case aco_opcode::s_lshl_b64:
+      case aco_opcode::s_lshr_b32:
+      case aco_opcode::s_lshr_b64:
+      case aco_opcode::s_ashr_i32:
+      case aco_opcode::s_ashr_i64:
+      case aco_opcode::s_abs_i32:
+      case aco_opcode::s_absdiff_i32:
+         break;
+      default:
+         return;
+      }
+
+      /* Use the SCC def from wr_instr */
+      ctx.uses[instr->operands[0].tempId()]--;
+      instr->operands[0] = Operand(wr_instr->definitions[1].getTemp(), scc);
+      ctx.uses[instr->operands[0].tempId()]++;
+
+      /* Set the opcode and operand to 32-bit */
+      instr->operands[1] = Operand(0u);
+      instr->opcode = (instr->opcode == aco_opcode::s_cmp_eq_u32 ||
+                       instr->opcode == aco_opcode::s_cmp_eq_i32 ||
+                       instr->opcode == aco_opcode::s_cmp_eq_u64)
+                      ? aco_opcode::s_cmp_eq_u32
+                      : aco_opcode::s_cmp_lg_u32;
+   } else if ((instr->format == Format::PSEUDO_BRANCH &&
+               instr->operands.size() == 1 &&
+               instr->operands[0].physReg() == scc) ||
+              instr->opcode == aco_opcode::s_cselect_b32) {
+
+      /* For cselect, operand 2 is the SCC condition */
+      unsigned scc_op_idx = 0;
+      if (instr->opcode == aco_opcode::s_cselect_b32) {
+         scc_op_idx = 2;
+      }
+
+      int wr_idx = last_writer_idx(ctx, instr->operands[scc_op_idx]);
+      if (wr_idx < 0)
+         return;
+
+      aco_ptr<Instruction> &wr_instr = ctx.current_block->instructions[wr_idx];
+
+      /* Check if we found the pattern above. */
+      if (wr_instr->opcode != aco_opcode::s_cmp_eq_u32 && wr_instr->opcode != aco_opcode::s_cmp_lg_u32)
+         return;
+      if (wr_instr->operands[0].physReg() != scc)
+         return;
+      if (!wr_instr->operands[1].constantEquals(0))
+         return;
+
+      /* The optimization can be unsafe when there are other users. */
+      if (ctx.uses[instr->operands[scc_op_idx].tempId()] > 1)
+         return;
+
+      if (wr_instr->opcode == aco_opcode::s_cmp_eq_u32) {
+         /* Flip the meaning of the instruction to correctly use the SCC. */
+         if (instr->format == Format::PSEUDO_BRANCH)
+            instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz : aco_opcode::p_cbranch_z;
+         else if (instr->opcode == aco_opcode::s_cselect_b32)
+            std::swap(instr->operands[0], instr->operands[1]);
+         else
+            unreachable("scc_nocompare optimization is only implemented for p_cbranch and s_cselect");
+      }
+
+      /* Use the SCC def from the original instruction, not the comparison */
+      ctx.uses[instr->operands[scc_op_idx].tempId()]--;
+      instr->operands[scc_op_idx] = wr_instr->operands[0];
+   }
+}
+
 void process_instruction(pr_opt_ctx &ctx, aco_ptr<Instruction> &instr)
 {
    ctx.current_instr_idx++;
 
    try_apply_branch_vcc(ctx, instr);
 
+   try_optimize_scc_nocompare(ctx, instr);
+
    if (instr)
       save_reg_writes(ctx, instr);
 }
index 16a427a..dec5e49 100644 (file)
@@ -122,3 +122,134 @@ BEGIN_TEST(optimizer_postRA.vcmp)
 
     finish_optimizer_postRA_test();
 END_TEST
+
+BEGIN_TEST(optimizer_postRA.scc_nocmp_opt)
+    //>> s1: %a, s2: %y, s1: %z = p_startpgm
+    ASSERTED bool setup_ok = setup_cs("s1 s2 s1", GFX6);
+    assert(setup_ok);
+
+    PhysReg reg_s0{0};
+    PhysReg reg_s1{1};
+    PhysReg reg_s2{2};
+    PhysReg reg_s3{3};
+    PhysReg reg_s4{4};
+    PhysReg reg_s6{6};
+
+    Temp in_0 = inputs[0];
+    Temp in_1 = inputs[1];
+    Temp in_2 = inputs[2];
+    Operand op_in_0(in_0);
+    op_in_0.setFixed(reg_s0);
+    Operand op_in_1(in_1);
+    op_in_1.setFixed(reg_s4);
+    Operand op_in_2(in_2);
+    op_in_2.setFixed(reg_s6);
+
+    {
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s2: %f:vcc = p_cbranch_nz %e:scc
+        //! p_unit_test 0, %f:vcc
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.branch(aco_opcode::p_cbranch_z, bld.def(s2, vcc), bld.scc(scmp));
+        writeout(0, Operand(br, vcc));
+    }
+
+    //; del d, e, f
+
+    {
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s2: %f:vcc = p_cbranch_z %e:scc
+        //! p_unit_test 1, %f:vcc
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_lg_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.branch(aco_opcode::p_cbranch_z, bld.def(s2, vcc), bld.scc(scmp));
+        writeout(1, Operand(br, vcc));
+    }
+
+    //; del d, e, f
+
+    {
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s2: %f:vcc = p_cbranch_z %e:scc
+        //! p_unit_test 2, %f:vcc
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.branch(aco_opcode::p_cbranch_nz, bld.def(s2, vcc), bld.scc(scmp));
+        writeout(2, Operand(br, vcc));
+    }
+
+    //; del d, e, f
+
+    {
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s2: %f:vcc = p_cbranch_nz %e:scc
+        //! p_unit_test 3, %f:vcc
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_lg_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.branch(aco_opcode::p_cbranch_nz, bld.def(s2, vcc), bld.scc(scmp));
+        writeout(3, Operand(br, vcc));
+    }
+
+    //; del d, e, f
+
+    {
+        //! s2: %d:s[2-3], s1: %e:scc = s_and_b64 %y:s[4-5], 0x12345
+        //! s2: %f:vcc = p_cbranch_z %e:scc
+        //! p_unit_test 4, %f:vcc
+        auto salu = bld.sop2(aco_opcode::s_and_b64, bld.def(s2, reg_s2), bld.def(s1, scc), op_in_1, Operand(0x12345u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_eq_u64, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0UL));
+        auto br = bld.branch(aco_opcode::p_cbranch_nz, bld.def(s2, vcc), bld.scc(scmp));
+        writeout(4, Operand(br, vcc));
+    }
+
+    //; del d, e, f
+
+    {
+        /* SCC is overwritten in between, don't optimize */
+
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s1: %h:s[3], s1: %x:scc = s_add_u32 %a:s[0], 1
+        //! s1: %g:scc = s_cmp_eq_u32 %d:s[2], 0
+        //! s2: %f:vcc = p_cbranch_z %g:scc
+        //! p_unit_test 5, %f:vcc, %h:s[3]
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto ovrw = bld.sop2(aco_opcode::s_add_u32, bld.def(s1, reg_s3), bld.def(s1, scc), op_in_0, Operand(1u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.branch(aco_opcode::p_cbranch_z, bld.def(s2, vcc), bld.scc(scmp));
+        writeout(5, Operand(br, vcc), Operand(ovrw, reg_s3));
+    }
+
+    //; del d, e, f, g, h, x
+
+    {
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s1: %f:s[4] = s_cselect_b32 %z:s[6], %a:s[0], %e:scc
+        //! p_unit_test 6, %f:s[4]
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1, reg_s4), Operand(op_in_0), Operand(op_in_2), bld.scc(scmp));
+        writeout(6, Operand(br, reg_s4));
+    }
+
+    //; del d, e, f
+
+    {
+        /* SCC is overwritten in between, don't optimize */
+
+        //! s1: %d:s[2], s1: %e:scc = s_bfe_u32 %a:s[0], 0x40018
+        //! s1: %h:s[3], s1: %x:scc = s_add_u32 %a:s[0], 1
+        //! s1: %g:scc = s_cmp_eq_u32 %d:s[2], 0
+        //! s1: %f:s[4] = s_cselect_b32 %a:s[0], %z:s[6], %g:scc
+        //! p_unit_test 7, %f:s[4], %h:s[3]
+        auto salu = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1, reg_s2), bld.def(s1, scc), op_in_0, Operand(0x40018u));
+        auto ovrw = bld.sop2(aco_opcode::s_add_u32, bld.def(s1, reg_s3), bld.def(s1, scc), op_in_0, Operand(1u));
+        auto scmp = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), Operand(salu, reg_s2), Operand(0u));
+        auto br = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1, reg_s4), Operand(op_in_0), Operand(op_in_2), bld.scc(scmp));
+        writeout(7, Operand(br, reg_s4), Operand(ovrw, reg_s3));
+    }
+
+    //; del d, e, f, g, h, x
+
+    finish_optimizer_postRA_test();
+END_TEST