aco: implement 64-bit VGPR {u,i}find_msb
authorRhys Perry <pendingchaos02@gmail.com>
Mon, 15 Mar 2021 13:28:17 +0000 (13:28 +0000)
committerRhys Perry <pendingchaos02@gmail.com>
Wed, 17 Mar 2021 15:33:22 +0000 (15:33 +0000)
This can be created by subgroupBallotFindMSB().

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/4458
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/9598>

src/amd/compiler/aco_instruction_selection.cpp

index 5540c12..6f7d9b3 100644 (file)
@@ -1228,6 +1228,18 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val)
    return add->definitions[0].getTemp();
 }
 
+Temp uadd32_sat(Builder& bld, Definition dst, Temp src0, Temp src1)
+{
+   if (bld.program->chip_class >= GFX9) {
+      Builder::Result add = bld.vop2_e64(aco_opcode::v_add_u32, dst, src0, src1);
+      add.instr->vop3().clamp = 1;
+   } else {
+      Builder::Result add = bld.vadd32(bld.def(v1), src0, src1, true);
+      bld.vop2_e64(aco_opcode::v_cndmask_b32, dst, add.def(0).getTemp(), Operand((uint32_t) -1), add.def(1).getTemp());
+   }
+   return dst.getTemp();
+}
+
 void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
 {
    if (!instr->dest.dest.is_ssa) {
@@ -1614,6 +1626,22 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          Temp msb = bld.tmp(v1);
          Temp carry = bld.vsub32(Definition(msb), Operand(31u), Operand(msb_rev), true).def(1).getTemp();
          bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), msb, Operand((uint32_t)-1), carry);
+      } else if (src.regClass() == v2) {
+         aco_opcode op = instr->op == nir_op_ufind_msb ? aco_opcode::v_ffbh_u32 : aco_opcode::v_ffbh_i32;
+
+         Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
+
+         lo = uadd32_sat(bld, bld.def(v1), bld.copy(bld.def(s1), Operand(32u)),
+                                           bld.vop1(op, bld.def(v1), lo));
+         hi = bld.vop1(op, bld.def(v1), hi);
+         Temp found_hi = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(bld.lm), Operand((uint32_t)-1), hi);
+
+         Temp msb_rev = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), lo, hi, found_hi);
+
+         Temp msb = bld.tmp(v1);
+         Temp carry = bld.vsub32(Definition(msb), Operand(63u), Operand(msb_rev), true).def(1).getTemp();
+         bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), msb, Operand((uint32_t)-1), carry);
       } else {
          isel_err(&instr->instr, "Unimplemented NIR instr bit size");
       }