aco: optimize swizzled SALU 8/16-bit conversions
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 29 May 2020 13:26:50 +0000 (14:26 +0100)
committerMarge Bot <eric+marge@anholt.net>
Thu, 30 Jul 2020 17:34:51 +0000 (17:34 +0000)
We only need one s_bfe for a conversion with a swizzled source.

shader-db (parallel-rdp, Navi):
Totals from 487 (71.30% of 683) affected shaders:
SpillSGPRs: 3284 -> 3233 (-1.55%); split: -2.71%, +1.16%
SpillVGPRs: 2174 -> 2150 (-1.10%); split: -1.24%, +0.14%
CodeSize: 2497864 -> 2445544 (-2.09%); split: -2.11%, +0.01%
Instrs: 450613 -> 445104 (-1.22%); split: -1.27%, +0.05%

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5259>

src/amd/compiler/aco_instruction_selection.cpp

index f0dccd2..3e27759 100644 (file)
@@ -537,6 +537,104 @@ Temp bool_to_scalar_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s1
    return emit_wqm(ctx, tmp, dst);
 }
 
+Temp convert_int(isel_context *ctx, Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp())
+{
+   if (!dst.id()) {
+      if (dst_bits % 32 == 0 || src.type() == RegType::sgpr)
+         dst = bld.tmp(src.type(), DIV_ROUND_UP(dst_bits, 32u));
+      else
+         dst = bld.tmp(RegClass(RegType::vgpr, dst_bits / 8u).as_subdword());
+   }
+
+   if (dst.bytes() == src.bytes() && dst_bits < src_bits)
+      return bld.copy(Definition(dst), src);
+   else if (dst.bytes() < src.bytes())
+      return bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u));
+
+   Temp tmp = dst;
+   if (dst_bits == 64)
+      tmp = src_bits == 32 ? src : bld.tmp(src.type(), 1);
+
+   if (tmp == src) {
+   } else if (src.regClass() == s1) {
+      if (is_signed)
+         bld.sop1(src_bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, Definition(tmp), src);
+      else
+         bld.sop2(aco_opcode::s_and_b32, Definition(tmp), bld.def(s1, scc), Operand(src_bits == 8 ? 0xFFu : 0xFFFFu), src);
+   } else if (ctx->options->chip_class >= GFX8) {
+      assert(src_bits != 8 || src.regClass() == v1b);
+      assert(src_bits != 16 || src.regClass() == v2b);
+      aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+      sdwa->operands[0] = Operand(src);
+      sdwa->definitions[0] = Definition(tmp);
+      if (is_signed)
+         sdwa->sel[0] = src_bits == 8 ? sdwa_sbyte : sdwa_sword;
+      else
+         sdwa->sel[0] = src_bits == 8 ? sdwa_ubyte : sdwa_uword;
+      sdwa->dst_sel = tmp.bytes() == 2 ? sdwa_uword : sdwa_udword;
+      bld.insert(std::move(sdwa));
+   } else {
+      assert(ctx->options->chip_class == GFX6 || ctx->options->chip_class == GFX7);
+      aco_opcode opcode = is_signed ? aco_opcode::v_bfe_i32 : aco_opcode::v_bfe_u32;
+      bld.vop3(opcode, Definition(tmp), src, Operand(0u), Operand(src_bits == 8 ? 8u : 16u));
+   }
+
+   if (dst_bits == 64) {
+      if (is_signed && dst.regClass() == s2) {
+         Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), tmp, Operand(31u));
+         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high);
+      } else if (is_signed && dst.regClass() == v2) {
+         Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), tmp);
+         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high);
+      } else {
+         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, Operand(0u));
+      }
+   }
+
+   return dst;
+}
+
+enum sgpr_extract_mode {
+   sgpr_extract_sext,
+   sgpr_extract_zext,
+   sgpr_extract_undef,
+};
+
+Temp extract_8_16_bit_sgpr_element(isel_context *ctx, Temp dst, nir_alu_src *src, sgpr_extract_mode mode)
+{
+   Temp vec = get_ssa_temp(ctx, src->src.ssa);
+   unsigned src_size = src->src.ssa->bit_size;
+   unsigned swizzle = src->swizzle[0];
+
+   if (vec.size() > 1) {
+      assert(src_size == 16);
+      vec = emit_extract_vector(ctx, vec, swizzle / 2, s1);
+      swizzle = swizzle & 1;
+   }
+
+   Builder bld(ctx->program, ctx->block);
+   unsigned offset = src_size * swizzle;
+   Temp tmp = dst.regClass() == s2 ? bld.tmp(s1) : dst;
+
+   if (mode == sgpr_extract_undef && swizzle == 0) {
+      bld.copy(Definition(tmp), vec);
+   } else if (mode == sgpr_extract_undef || (offset == 24 && mode == sgpr_extract_zext)) {
+      bld.sop2(aco_opcode::s_lshr_b32, Definition(tmp), bld.def(s1, scc), vec, Operand(offset));
+   } else if (src_size == 8 && swizzle == 0 && mode == sgpr_extract_sext) {
+      bld.sop1(aco_opcode::s_sext_i32_i8, Definition(tmp), vec);
+   } else if (src_size == 16 && swizzle == 0 && mode == sgpr_extract_sext) {
+      bld.sop1(aco_opcode::s_sext_i32_i16, Definition(tmp), vec);
+   } else {
+      aco_opcode op = mode == sgpr_extract_zext ? aco_opcode::s_bfe_u32 : aco_opcode::s_bfe_i32;
+      bld.sop2(op, Definition(tmp), bld.def(s1, scc), vec, Operand((src_size << 16) | offset));
+   }
+
+   if (dst.regClass() == s2)
+      convert_int(ctx, bld, tmp, 32, 64, mode == sgpr_extract_sext, dst);
+
+   return dst;
+}
+
 Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1)
 {
    if (src.src.ssa->num_components == 1 && src.swizzle[0] == 0 && size == 1)
@@ -560,23 +658,8 @@ Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1)
    if (elem_size < 4 && vec.type() == RegType::sgpr) {
       assert(src.src.ssa->bit_size == 8 || src.src.ssa->bit_size == 16);
       assert(size == 1);
-      unsigned swizzle = src.swizzle[0];
-      if (vec.size() > 1) {
-         assert(src.src.ssa->bit_size == 16);
-         vec = emit_extract_vector(ctx, vec, swizzle / 2, s1);
-         swizzle = swizzle & 1;
-      }
-      if (swizzle == 0)
-         return vec;
-
-      Temp dst{ctx->program->allocateId(), s1};
-      aco_ptr<SOP2_instruction> bfe{create_instruction<SOP2_instruction>(aco_opcode::s_bfe_u32, Format::SOP2, 2, 2)};
-      bfe->operands[0] = Operand(vec);
-      bfe->operands[1] = Operand(uint32_t((src.src.ssa->bit_size << 16) | (src.src.ssa->bit_size * swizzle)));
-      bfe->definitions[0] = Definition(dst);
-      bfe->definitions[1] = Definition(ctx->program->allocateId(), scc, s1);
-      ctx->block->instructions.emplace_back(std::move(bfe));
-      return dst;
+      return extract_8_16_bit_sgpr_element(
+         ctx, Temp(ctx->program->allocateId(), s1), &src, sgpr_extract_undef);
    }
 
    RegClass elem_rc = elem_size < 4 ? RegClass(vec.type(), elem_size).as_subdword() : RegClass(vec.type(), elem_size / 4);
@@ -1042,62 +1125,6 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val)
    return add->definitions[0].getTemp();
 }
 
-Temp convert_int(isel_context *ctx, Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp()) {
-   if (!dst.id()) {
-      if (dst_bits % 32 == 0 || src.type() == RegType::sgpr)
-         dst = bld.tmp(src.type(), DIV_ROUND_UP(dst_bits, 32u));
-      else
-         dst = bld.tmp(RegClass(RegType::vgpr, dst_bits / 8u).as_subdword());
-   }
-
-   if (dst.bytes() == src.bytes() && dst_bits < src_bits)
-      return bld.copy(Definition(dst), src);
-   else if (dst.bytes() < src.bytes())
-      return bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u));
-
-   Temp tmp = dst;
-   if (dst_bits == 64)
-      tmp = src_bits == 32 ? src : bld.tmp(src.type(), 1);
-
-   if (tmp == src) {
-   } else if (src.regClass() == s1) {
-      if (is_signed)
-         bld.sop1(src_bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, Definition(tmp), src);
-      else
-         bld.sop2(aco_opcode::s_and_b32, Definition(tmp), bld.def(s1, scc), Operand(src_bits == 8 ? 0xFFu : 0xFFFFu), src);
-   } else if (ctx->options->chip_class >= GFX8) {
-      assert(src_bits != 8 || src.regClass() == v1b);
-      assert(src_bits != 16 || src.regClass() == v2b);
-      aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
-      sdwa->operands[0] = Operand(src);
-      sdwa->definitions[0] = Definition(tmp);
-      if (is_signed)
-         sdwa->sel[0] = src_bits == 8 ? sdwa_sbyte : sdwa_sword;
-      else
-         sdwa->sel[0] = src_bits == 8 ? sdwa_ubyte : sdwa_uword;
-      sdwa->dst_sel = tmp.bytes() == 2 ? sdwa_uword : sdwa_udword;
-      bld.insert(std::move(sdwa));
-   } else {
-      assert(ctx->options->chip_class == GFX6 || ctx->options->chip_class == GFX7);
-      aco_opcode opcode = is_signed ? aco_opcode::v_bfe_i32 : aco_opcode::v_bfe_u32;
-      bld.vop3(opcode, Definition(tmp), src, Operand(0u), Operand(src_bits == 8 ? 8u : 16u));
-   }
-
-   if (dst_bits == 64) {
-      if (is_signed && dst.regClass() == s2) {
-         Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), tmp, Operand(31u));
-         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high);
-      } else if (is_signed && dst.regClass() == v2) {
-         Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), tmp);
-         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high);
-      } else {
-         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, Operand(0u));
-      }
-   }
-
-   return dst;
-}
-
 void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
 {
    if (!instr->dest.dest.is_ssa) {
@@ -2683,16 +2710,30 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_i2i16:
    case nir_op_i2i32:
    case nir_op_i2i64: {
-      convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]),
-                  instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, true, dst);
+      if (dst.type() == RegType::sgpr && instr->src[0].src.ssa->bit_size < 32) {
+         /* no need to do the extract in get_alu_src() */
+         sgpr_extract_mode mode = instr->dest.dest.ssa.bit_size > instr->src[0].src.ssa->bit_size ?
+                                  sgpr_extract_sext : sgpr_extract_undef;
+         extract_8_16_bit_sgpr_element(ctx, dst, &instr->src[0], mode);
+      } else {
+         convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]),
+                     instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, true, dst);
+      }
       break;
    }
    case nir_op_u2u8:
    case nir_op_u2u16:
    case nir_op_u2u32:
    case nir_op_u2u64: {
-      convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]),
-                  instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, false, dst);
+      if (dst.type() == RegType::sgpr && instr->src[0].src.ssa->bit_size < 32) {
+         /* no need to do the extract in get_alu_src() */
+         sgpr_extract_mode mode = instr->dest.dest.ssa.bit_size > instr->src[0].src.ssa->bit_size ?
+                                  sgpr_extract_zext : sgpr_extract_undef;
+         extract_8_16_bit_sgpr_element(ctx, dst, &instr->src[0], mode);
+      } else {
+         convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]),
+                     instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, false, dst);
+      }
       break;
    }
    case nir_op_b2b32: