aco: implement udot_4x8/sdot_4x8/udot_2x16/sdot_2x16 opcodes
authorRhys Perry <pendingchaos02@gmail.com>
Mon, 2 Aug 2021 18:42:44 +0000 (19:42 +0100)
committerMarge Bot <eric+marge@anholt.net>
Fri, 3 Sep 2021 13:21:28 +0000 (13:21 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12617>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/compiler/aco_opcodes.py
src/amd/compiler/aco_register_allocation.cpp

index 5b17f449df254f3a6553ead17bcd80252affdf5a..04c6a7216d3f8f0d7d3c5d4348f64f48410fe9ce 100644 (file)
@@ -960,6 +960,24 @@ emit_vop3p_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, T
    return res;
 }
 
+void
+emit_idot_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst, bool clamp)
+{
+   Temp src[3] = {Temp(0, v1), Temp(0, v1), Temp(0, v1)};
+   bool has_sgpr = false;
+   for (unsigned i = 0; i < 3; i++) {
+      src[i] = get_alu_src(ctx, instr->src[i]);
+      if (has_sgpr)
+         src[i] = as_vgpr(ctx, src[i]);
+      else
+         has_sgpr = src[i].type() == RegType::sgpr;
+   }
+
+   Builder bld(ctx->program, ctx->block);
+   bld.is_precise = instr->exact;
+   bld.vop3p(op, Definition(dst), src[0], src[1], src[2], 0x0, 0x7).instr->vop3p().clamp = clamp;
+}
+
 void
 emit_vop1_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst)
 {
@@ -2112,6 +2130,38 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
       }
       break;
    }
+   case nir_op_sdot_4x8_iadd: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_i32_i8, dst, false);
+      break;
+   }
+   case nir_op_sdot_4x8_iadd_sat: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_i32_i8, dst, true);
+      break;
+   }
+   case nir_op_udot_4x8_uadd: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_u32_u8, dst, false);
+      break;
+   }
+   case nir_op_udot_4x8_uadd_sat: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot4_u32_u8, dst, true);
+      break;
+   }
+   case nir_op_sdot_2x16_iadd: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_i32_i16, dst, false);
+      break;
+   }
+   case nir_op_sdot_2x16_iadd_sat: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_i32_i16, dst, true);
+      break;
+   }
+   case nir_op_udot_2x16_uadd: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_u32_u16, dst, false);
+      break;
+   }
+   case nir_op_udot_2x16_uadd_sat: {
+      emit_idot_instruction(ctx, instr, aco_opcode::v_dot2_u32_u16, dst, true);
+      break;
+   }
    case nir_op_cube_face_coord_amd: {
       Temp in = get_alu_src(ctx, instr->src[0], 3);
       Temp src[3] = {emit_extract_vector(ctx, in, 0, v1), emit_extract_vector(ctx, in, 1, v1),
index 5371628d7acd87a9d389696e0fdb40e5f19242c9..e9b333110d70321de340bf00a021fac151d97729 100644 (file)
@@ -594,7 +594,15 @@ init_context(isel_context* ctx, nir_shader* shader)
                case nir_op_cube_face_index_amd:
                case nir_op_cube_face_coord_amd:
                case nir_op_sad_u8x4:
-               case nir_op_iadd_sat: type = RegType::vgpr; break;
+               case nir_op_iadd_sat:
+               case nir_op_udot_4x8_uadd:
+               case nir_op_sdot_4x8_iadd:
+               case nir_op_udot_4x8_uadd_sat:
+               case nir_op_sdot_4x8_iadd_sat:
+               case nir_op_udot_2x16_uadd:
+               case nir_op_sdot_2x16_iadd:
+               case nir_op_udot_2x16_uadd_sat:
+               case nir_op_sdot_2x16_iadd_sat: type = RegType::vgpr; break;
                case nir_op_f2i16:
                case nir_op_f2u16:
                case nir_op_f2i32:
index 426c97c206901a5377bf29f9ca28f2cba84fc5e5..bb0271804567acdb5f8ded2ff2945ac860ef366f 100644 (file)
@@ -680,6 +680,7 @@ VOP2 = {
    (0x0a, 0x0a, 0x07, 0x07, 0x0a, "v_mul_hi_i32_i24", False),
    (0x0b, 0x0b, 0x08, 0x08, 0x0b, "v_mul_u32_u24", False),
    (0x0c, 0x0c, 0x09, 0x09, 0x0c, "v_mul_hi_u32_u24", False),
+   (  -1,   -1,   -1, 0x39, 0x0d, "v_dot4c_i32_i8", False),
    (0x0d, 0x0d,   -1,   -1,   -1, "v_min_legacy_f32", True),
    (0x0e, 0x0e,   -1,   -1,   -1, "v_max_legacy_f32", True),
    (0x0f, 0x0f, 0x0a, 0x0a, 0x0f, "v_min_f32", True),
@@ -963,6 +964,10 @@ VOPP = {
 # (gfx6, gfx7, gfx8, gfx9, gfx10, name) = (-1, -1, -1, code, code, name)
 for (code, name, modifiers) in VOPP:
    opcode(name, -1, code, code, Format.VOP3P, InstrClass.Valu32, modifiers, modifiers)
+opcode("v_dot2_i32_i16", -1, 0x26, 0x14, Format.VOP3P, InstrClass.Valu32)
+opcode("v_dot2_u32_u16", -1, 0x27, 0x15, Format.VOP3P, InstrClass.Valu32)
+opcode("v_dot4_i32_i8", -1, 0x28, 0x16, Format.VOP3P, InstrClass.Valu32)
+opcode("v_dot4_u32_u8", -1, 0x29, 0x17, Format.VOP3P, InstrClass.Valu32)
 
 
 # VINTERP instructions: 
index 8c0c5aabcd1e77d953d12a82ec7feeee389296c6..bcb70eeab48dda38f5bdb2b12b9061a97ab62ef2 100644 (file)
@@ -2477,7 +2477,8 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
               instr->opcode == aco_opcode::v_mad_f16 ||
               instr->opcode == aco_opcode::v_mad_legacy_f16 ||
               (instr->opcode == aco_opcode::v_fma_f16 && program->chip_class >= GFX10) ||
-              (instr->opcode == aco_opcode::v_pk_fma_f16 && program->chip_class >= GFX10)) &&
+              (instr->opcode == aco_opcode::v_pk_fma_f16 && program->chip_class >= GFX10) ||
+              (instr->opcode == aco_opcode::v_dot4_i32_i8 && program->family != CHIP_VEGA20)) &&
              instr->operands[2].isTemp() && instr->operands[2].isKillBeforeDef() &&
              instr->operands[2].getTemp().type() == RegType::vgpr && instr->operands[1].isTemp() &&
              instr->operands[1].getTemp().type() == RegType::vgpr && !instr->usesModifiers() &&
@@ -2496,6 +2497,7 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
                case aco_opcode::v_mad_legacy_f16: instr->opcode = aco_opcode::v_mac_f16; break;
                case aco_opcode::v_fma_f16: instr->opcode = aco_opcode::v_fmac_f16; break;
                case aco_opcode::v_pk_fma_f16: instr->opcode = aco_opcode::v_pk_fmac_f16; break;
+               case aco_opcode::v_dot4_i32_i8: instr->opcode = aco_opcode::v_dot4c_i32_i8; break;
                default: break;
                }
             }
@@ -2507,7 +2509,8 @@ register_allocation(Program* program, std::vector<IDSet>& live_out_per_block, ra
              instr->opcode == aco_opcode::v_mac_f16 || instr->opcode == aco_opcode::v_fmac_f16 ||
              instr->opcode == aco_opcode::v_pk_fmac_f16 ||
              instr->opcode == aco_opcode::v_writelane_b32 ||
-             instr->opcode == aco_opcode::v_writelane_b32_e64) {
+             instr->opcode == aco_opcode::v_writelane_b32_e64 ||
+             instr->opcode == aco_opcode::v_dot4c_i32_i8) {
             instr->definitions[0].setFixed(instr->operands[2].physReg());
          } else if (instr->opcode == aco_opcode::s_addk_i32 ||
                     instr->opcode == aco_opcode::s_mulk_i32) {