aco: add VINTERP instruction format
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 17 Jun 2022 12:53:08 +0000 (13:53 +0100)
committerMarge Bot <emma+marge@anholt.net>
Mon, 26 Sep 2022 14:49:56 +0000 (14:49 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17333>

12 files changed:
src/amd/compiler/aco_assembler.cpp
src/amd/compiler/aco_builder_h.py
src/amd/compiler/aco_insert_waitcnt.cpp
src/amd/compiler/aco_ir.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_lower_to_hw_instr.cpp
src/amd/compiler/aco_opcodes.py
src/amd/compiler/aco_opt_value_numbering.cpp
src/amd/compiler/aco_print_ir.cpp
src/amd/compiler/aco_register_allocation.cpp
src/amd/compiler/aco_validate.cpp
src/amd/compiler/tests/test_regalloc.cpp

index 58c2471..f19f731 100644 (file)
@@ -374,6 +374,24 @@ emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction* inst
       }
       break;
    }
+   case Format::VINTERP_INREG: {
+      VINTERP_inreg_instruction& interp = instr->vinterp_inreg();
+      uint32_t encoding = (0b11001101 << 24);
+      encoding |= reg(ctx, instr->definitions[0], 8);
+      encoding |= (uint32_t)interp.wait_exp << 8;
+      encoding |= (uint32_t)interp.opsel << 11;
+      encoding |= (uint32_t)interp.clamp << 15;
+      encoding |= opcode << 16;
+      out.push_back(encoding);
+
+      encoding = 0;
+      for (unsigned i = 0; i < instr->operands.size(); i++)
+         encoding |= reg(ctx, instr->operands[i]) << (i * 9);
+      for (unsigned i = 0; i < 3; i++)
+         encoding |= interp.neg[i] << (29 + i);
+      out.push_back(encoding);
+      break;
+   }
    case Format::DS: {
       DS_instruction& ds = instr->ds();
       uint32_t encoding = (0b110110 << 26);
index b837fb8..86f34b7 100644 (file)
@@ -531,6 +531,7 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod
            ("vopc_sdwa", [Format.VOPC, Format.SDWA], 'SDWA_instruction', itertools.product([1, 2], [2])),
            ("vop3", [Format.VOP3], 'VOP3_instruction', [(1, 3), (1, 2), (1, 1), (2, 2)]),
            ("vop3p", [Format.VOP3P], 'VOP3P_instruction', [(1, 2), (1, 3)]),
+           ("vinterp_inreg", [Format.VINTERP_INREG], 'VINTERP_inreg_instruction', [(1, 3)]),
            ("vintrp", [Format.VINTRP], 'VINTRP_instruction', [(1, 2), (1, 3)]),
            ("vop1_dpp", [Format.VOP1, Format.DPP16], 'DPP16_instruction', [(1, 1)]),
            ("vop2_dpp", [Format.VOP2, Format.DPP16], 'DPP16_instruction', itertools.product([1, 2], [2, 3])),
index eafea3a..6a7b65c 100644 (file)
@@ -758,6 +758,11 @@ handle_block(Program* program, Block& block, wait_ctx& ctx)
       gen(instr.get(), ctx);
 
       if (instr->format != Format::PSEUDO_BARRIER && !is_wait) {
+         if (instr->isVINTERP_INREG() && queued_imm.exp != wait_imm::unset_counter) {
+            instr->vinterp_inreg().wait_exp = MIN2(instr->vinterp_inreg().wait_exp, queued_imm.exp);
+            queued_imm.exp = wait_imm::unset_counter;
+         }
+
          if (!queued_imm.empty())
             emit_waitcnt(ctx, new_instructions, queued_imm);
 
index 96f5a3a..038d6dc 100644 (file)
@@ -441,6 +441,11 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
    case aco_opcode::v_mad_i32_i16: return idx >= 0 && idx < 2;
    case aco_opcode::v_dot2_f16_f16:
    case aco_opcode::v_dot2_bf16_bf16: return idx == -1 || idx == 2;
+   // TODO: This matches what LLVM allows. We should see if this matches what the hardware allows.
+   case aco_opcode::v_interp_p10_f16_f32_inreg:
+   case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2;
+   case aco_opcode::v_interp_p2_f16_f32_inreg:
+   case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0;
    default: return false;
    }
 }
@@ -448,6 +453,8 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
 bool
 instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op)
 {
+   // TODO: VINTERP (v_interp_p2_f16_f32, v_interp_p2_rtz_f16_f32)
+
    /* partial register writes are GFX9+, only */
    if (gfx_level < GFX9)
       return false;
index 8ae5557..eeee5ac 100644 (file)
@@ -96,6 +96,7 @@ enum class Format : std::uint16_t {
 
    /* Vector ALU Formats */
    VOP3P = 20,
+   VINTERP_INREG = 21,
    VOP1 = 1 << 8,
    VOP2 = 1 << 9,
    VOPC = 1 << 10,
@@ -1010,6 +1011,7 @@ struct Pseudo_branch_instruction;
 struct Pseudo_barrier_instruction;
 struct Pseudo_reduction_instruction;
 struct VOP3P_instruction;
+struct VINTERP_inreg_instruction;
 struct VOP1_instruction;
 struct VOP2_instruction;
 struct VOPC_instruction;
@@ -1258,6 +1260,17 @@ struct Instruction {
       return *(VOP3P_instruction*)this;
    }
    constexpr bool isVOP3P() const noexcept { return format == Format::VOP3P; }
+   VINTERP_inreg_instruction& vinterp_inreg() noexcept
+   {
+      assert(isVINTERP_INREG());
+      return *(VINTERP_inreg_instruction*)this;
+   }
+   const VINTERP_inreg_instruction& vinterp_inreg() const noexcept
+   {
+      assert(isVINTERP_INREG());
+      return *(VINTERP_inreg_instruction*)this;
+   }
+   constexpr bool isVINTERP_INREG() const noexcept { return format == Format::VINTERP_INREG; }
    VOP1_instruction& vop1() noexcept
    {
       assert(isVOP1());
@@ -1446,6 +1459,14 @@ struct VOP3P_instruction : public Instruction {
 };
 static_assert(sizeof(VOP3P_instruction) == sizeof(Instruction) + 8, "Unexpected padding");
 
+struct VINTERP_inreg_instruction : public Instruction {
+   uint8_t wait_exp : 3;
+   bool clamp : 1;
+   uint8_t opsel : 4;
+   bool neg[3];
+};
+static_assert(sizeof(VINTERP_inreg_instruction) == sizeof(Instruction) + 4, "Unexpected padding");
+
 /**
  * Data Parallel Primitives Format:
  * This format can be used for VOP1, VOP2 or VOPC instructions.
index d930b1a..409f9a1 100644 (file)
@@ -2414,7 +2414,7 @@ lower_to_hw_instr(Program* program)
                         can_remove = false;
                   } else if (inst->isSALU()) {
                      num_scalar++;
-                  } else if (inst->isVALU() || inst->isVINTRP()) {
+                  } else if (inst->isVALU() || inst->isVINTRP() || instr->isVINTERP_INREG()) {
                      num_vector++;
                      /* VALU which writes SGPRs are always executed on GFX10+ */
                      if (ctx.program->gfx_level >= GFX10) {
index cf1aaa4..49f2af7 100644 (file)
@@ -70,6 +70,7 @@ class Format(Enum):
    PSEUDO_BARRIER = 18
    PSEUDO_REDUCTION = 19
    VOP3P = 20
+   VINTERP_INREG = 21
    VOP1 = 1 << 8
    VOP2 = 1 << 9
    VOPC = 1 << 10
@@ -163,6 +164,9 @@ class Format(Enum):
       elif self == Format.VOP3P:
          return [('uint8_t', 'opsel_lo', None),
                  ('uint8_t', 'opsel_hi', None)]
+      elif self == Format.VINTERP_INREG:
+         return [('unsigned', 'wait_exp', 7),
+                 ('uint8_t', 'opsel', 0)]
       elif self in [Format.FLAT, Format.GLOBAL, Format.SCRATCH]:
          return [('int16_t', 'offset', 0),
                  ('memory_sync_info', 'sync', 'memory_sync_info()'),
@@ -999,7 +1003,7 @@ opcode("v_dot2_f32_f16", -1, 0x23, 0x13, 0x13, Format.VOP3P, InstrClass.Valu32)
 opcode("v_dot2_f32_bf16", -1, -1, -1, 0x1a, Format.VOP3P, InstrClass.Valu32)
 
 
-# VINTERP instructions:
+# VINTRP (GFX6 - GFX10.3) instructions:
 VINTRP = {
    (0x00, "v_interp_p1_f32"),
    (0x01, "v_interp_p2_f32"),
@@ -1009,6 +1013,20 @@ VINTRP = {
 for (code, name) in VINTRP:
    opcode(name, code, code, code, -1, Format.VINTRP, InstrClass.Valu32)
 
+
+# VINTERP (GFX11+) instructions:
+VINTERP = {
+   (0x00, "v_interp_p10_f32_inreg"),
+   (0x01, "v_interp_p2_f32_inreg"),
+   (0x02, "v_interp_p10_f16_f32_inreg"),
+   (0x03, "v_interp_p2_f16_f32_inreg"),
+   (0x04, "v_interp_p10_rtz_f16_f32_inreg"),
+   (0x05, "v_interp_p2_rtz_f16_f32_inreg"),
+}
+for (code, name) in VINTERP:
+   opcode(name, -1, -1, -1, code, Format.VINTERP_INREG, InstrClass.Valu32)
+
+
 # VOP3 instructions: 3 inputs, 1 output
 # VOP3b instructions: have a unique scalar output, e.g. VOP2 with vcc out
 VOP3 = {
index 892bdca..e154144 100644 (file)
@@ -99,6 +99,7 @@ struct InstrHash {
       switch (instr->format) {
       case Format::SMEM: return hash_murmur_32<SMEM_instruction>(instr);
       case Format::VINTRP: return hash_murmur_32<VINTRP_instruction>(instr);
+      case Format::VINTERP_INREG: return hash_murmur_32<VINTERP_inreg_instruction>(instr);
       case Format::DS: return hash_murmur_32<DS_instruction>(instr);
       case Format::SOPP: return hash_murmur_32<SOPP_instruction>(instr);
       case Format::SOPK: return hash_murmur_32<SOPK_instruction>(instr);
@@ -235,6 +236,12 @@ struct InstrPred {
          return a3P.opsel_lo == b3P.opsel_lo && a3P.opsel_hi == b3P.opsel_hi &&
                 a3P.clamp == b3P.clamp;
       }
+      case Format::VINTERP_INREG: {
+         VINTERP_inreg_instruction& aI = a->vinterp_inreg();
+         VINTERP_inreg_instruction& bI = b->vinterp_inreg();
+         return aI.wait_exp == bI.wait_exp && aI.clamp == bI.clamp && aI.opsel == bI.opsel &&
+                aI.neg[0] == bI.neg[0] && aI.neg[1] == bI.neg[1] && aI.neg[2] == bI.neg[2];
+      }
       case Format::PSEUDO_REDUCTION: {
          Pseudo_reduction_instruction& aR = a->reduction();
          Pseudo_reduction_instruction& bR = b->reduction();
index 76e6f02..7b2dece 100644 (file)
@@ -347,6 +347,12 @@ print_instr_format_specific(const Instruction* instr, FILE* output)
       print_sync(smem.sync, output);
       break;
    }
+   case Format::VINTERP_INREG: {
+      const VINTERP_inreg_instruction& vinterp = instr->vinterp_inreg();
+      if (vinterp.wait_exp != 7)
+         fprintf(output, " wait_exp:%u", vinterp.wait_exp);
+      break;
+   }
    case Format::VINTRP: {
       const VINTRP_instruction& vintrp = instr->vintrp();
       fprintf(output, " attr%d.%c", vintrp.attribute, "xyzw"[vintrp.component]);
@@ -655,6 +661,12 @@ print_instr_format_specific(const Instruction* instr, FILE* output)
          default: break;
          }
       }
+   } else if (instr->isVINTERP_INREG()) {
+      const VINTERP_inreg_instruction& vinterp = instr->vinterp_inreg();
+      if (vinterp.clamp)
+         fprintf(output, " clamp");
+      if (vinterp.opsel & (1 << 3))
+         fprintf(output, " opsel_hi");
    }
 }
 
@@ -714,6 +726,12 @@ aco_print_instr(const Instruction* instr, FILE* output, unsigned flags)
             f2f32[i] = vop3p.opsel_hi & (1 << i);
             opsel[i] = f2f32[i] && (vop3p.opsel_lo & (1 << i));
          }
+      } else if (instr->isVINTERP_INREG()) {
+         const VINTERP_inreg_instruction& vinterp = instr->vinterp_inreg();
+         for (unsigned i = 0; i < MIN2(num_operands, 3); ++i) {
+            neg[i] = vinterp.neg[i];
+            opsel[i] = vinterp.opsel & (1 << i);
+         }
       }
       for (unsigned i = 0; i < num_operands; ++i) {
          if (i)
index 84a070b..2484607 100644 (file)
@@ -503,7 +503,7 @@ get_subdword_operand_stride(amd_gfx_level gfx_level, const aco_ptr<Instruction>&
    }
 
    assert(rc.bytes() <= 2);
-   if (instr->isVALU()) {
+   if (instr->isVALU() || instr->isVINTERP_INREG()) {
       if (can_use_SDWA(gfx_level, instr, false))
          return rc.bytes();
       if (can_use_opsel(gfx_level, instr->opcode, idx))
@@ -538,13 +538,18 @@ add_subdword_operand(ra_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, uns
       return;
 
    assert(rc.bytes() <= 2);
-   if (instr->isVALU()) {
+   if (instr->isVALU() || instr->isVINTERP_INREG()) {
       /* check if we can use opsel */
       if (instr->format == Format::VOP3) {
          assert(byte == 2);
          instr->vop3().opsel |= 1 << idx;
          return;
       }
+      if (instr->isVINTERP_INREG()) {
+         assert(byte == 2);
+         instr->vinterp_inreg().opsel |= 1 << idx;
+         return;
+      }
       if (instr->isVOP3P()) {
          assert(byte == 2 && !(instr->vop3p().opsel_lo & (1 << idx)));
          instr->vop3p().opsel_lo |= 1 << idx;
@@ -608,7 +613,7 @@ get_subdword_definition_info(Program* program, const aco_ptr<Instruction>& instr
          return std::make_pair(4, rc.size() * 4u);
    }
 
-   if (instr->isVALU() || instr->isVINTRP()) {
+   if (instr->isVALU() || instr->isVINTRP() || instr->isVINTERP_INREG()) {
       assert(rc.bytes() <= 2);
 
       if (can_use_SDWA(gfx_level, instr, false))
@@ -676,7 +681,7 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
    if (instr->isPseudo())
       return;
 
-   if (instr->isVALU()) {
+   if (instr->isVALU() || instr->isVINTERP_INREG()) {
       amd_gfx_level gfx_level = program->gfx_level;
       assert(instr->definitions[0].bytes() <= 2);
 
@@ -689,6 +694,11 @@ add_subdword_definition(Program* program, aco_ptr<Instruction>& instr, PhysReg r
          assert(can_use_opsel(gfx_level, instr->opcode, -1));
          instr->vop3().opsel |= (1 << 3); /* dst in high half */
          return;
+      } else if (instr->isVINTERP_INREG()) {
+         assert(reg.byte() == 2);
+         assert(can_use_opsel(gfx_level, instr->opcode, -1));
+         instr->vinterp_inreg().opsel |= (1 << 3); /* dst in high half */
+         return;
       }
 
       if (instr->opcode == aco_opcode::v_fma_mixlo_f16) {
index 52fbc4e..abfb0b2 100644 (file)
@@ -281,7 +281,7 @@ validate_ir(Program* program)
                      instr.get());
          }
 
-         if (instr->isSALU() || instr->isVALU()) {
+         if (instr->isSALU() || instr->isVALU() || instr->isVINTERP_INREG()) {
             /* check literals */
             Operand literal(s1);
             for (unsigned i = 0; i < instr->operands.size(); i++) {
@@ -303,7 +303,7 @@ validate_ir(Program* program)
             }
 
             /* check num sgprs for VALU */
-            if (instr->isVALU()) {
+            if (instr->isVALU() || instr->isVINTERP_INREG()) {
                bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
                                  instr->opcode == aco_opcode::v_lshrrev_b64 ||
                                  instr->opcode == aco_opcode::v_ashrrev_i64;
@@ -311,7 +311,8 @@ validate_ir(Program* program)
                if (program->gfx_level >= GFX10 && !is_shift64)
                   const_bus_limit = 2;
 
-               uint32_t scalar_mask = instr->isVOP3() || instr->isVOP3P() ? 0x7 : 0x5;
+               uint32_t scalar_mask =
+                  instr->isVOP3() || instr->isVOP3P() || instr->isVINTERP_INREG() ? 0x7 : 0x5;
                if (instr->isSDWA())
                   scalar_mask = program->gfx_level >= GFX9 ? 0x7 : 0x4;
                else if (instr->isDPP())
@@ -898,7 +899,7 @@ get_subdword_bytes_written(Program* program, const aco_ptr<Instruction>& instr,
 
    if (instr->isPseudo())
       return gfx_level >= GFX8 ? def.bytes() : def.size() * 4u;
-   if (instr->isVALU()) {
+   if (instr->isVALU() || instr->isVINTERP_INREG()) {
       assert(def.bytes() <= 2);
       if (instr->isSDWA())
          return instr->sdwa().dst_sel.size();
index 2886120..78f5db0 100644 (file)
@@ -379,3 +379,28 @@ BEGIN_TEST(regalloc.branch_def_phis_at_branch_block)
 
    finish_ra_test(ra_test_policy());
 END_TEST
+
+BEGIN_TEST(regalloc.vinterp_fp16)
+   //>> v1: %in0:v[0], v1: %in1:v[1], v1: %in2:v[2] = p_startpgm
+   if (!setup_cs("v1 v1 v1", GFX11))
+      return;
+
+   //! v2b: %lo:v[3][0:16], v2b: %hi:v[3][16:32] = p_split_vector %in0:v[0]
+   Temp lo = bld.tmp(v2b);
+   Temp hi = bld.tmp(v2b);
+   bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), inputs[0]);
+
+   //! v1: %tmp0:v[1] = v_interp_p10_f16_f32_inreg %lo:v[3][0:16], %in1:v[1], hi(%hi:v[3][16:32])
+   //! p_unit_test %tmp0:v[1]
+   Temp tmp0 = bld.vinterp_inreg(aco_opcode::v_interp_p10_f16_f32_inreg, bld.def(v1), lo, inputs[1], hi);
+   bld.pseudo(aco_opcode::p_unit_test, tmp0);
+
+   //! v2b: %tmp1:v[0][16:32] = v_interp_p2_f16_f32_inreg %in0:v[0], %in2:v[2], %tmp0:v[1] opsel_hi
+   //! v1: %tmp2:v[0] = p_create_vector 0, %tmp1:v[0][16:32]
+   //! p_unit_test %tmp2:v[0]
+   Temp tmp1 = bld.vinterp_inreg(aco_opcode::v_interp_p2_f16_f32_inreg, bld.def(v2b), inputs[0], inputs[2], tmp0);
+   Temp tmp2 = bld.pseudo(aco_opcode::p_create_vector, bld.def(v1), Operand::zero(2), tmp1);
+   bld.pseudo(aco_opcode::p_unit_test, tmp2);
+
+   finish_ra_test(ra_test_policy());
+END_TEST