aco: Add s_delay_alu support for GFX11+
authorBas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Sun, 13 Nov 2022 18:15:28 +0000 (18:15 +0000)
committerMarge Bot <emma+marge@anholt.net>
Wed, 7 Dec 2022 22:05:25 +0000 (22:05 +0000)
Roughly copied from LLVM. This facilitates better ALU usage by
switching between waves when there is an ALU stall, which isn't
automatic anymore on GFX11.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19743>

src/amd/compiler/aco_insert_waitcnt.cpp

index bb3e362..c3ff61a 100644 (file)
@@ -69,7 +69,10 @@ enum wait_event : uint16_t {
    event_vmem_gpr_lock = 1 << 10,
    event_sendmsg = 1 << 11,
    event_ldsdir = 1 << 12,
-   num_events = 13,
+   event_valu = 1 << 13,
+   event_trans = 1 << 14,
+   event_salu = 1 << 15,
+   num_events = 16,
 };
 
 enum counter_type : uint8_t {
@@ -77,7 +80,8 @@ enum counter_type : uint8_t {
    counter_lgkm = 1 << 1,
    counter_vm = 1 << 2,
    counter_vs = 1 << 3,
-   num_counters = 4,
+   counter_alu = 1 << 4,
+   num_counters = 5,
 };
 
 enum vmem_type : uint8_t {
@@ -93,6 +97,91 @@ static const uint16_t lgkm_events = event_smem | event_lds | event_gds | event_f
 static const uint16_t vm_events = event_vmem | event_flat;
 static const uint16_t vs_events = event_vmem_store;
 
+/* On GFX11+ the SIMD frontend doesn't switch to issuing instructions from a different
+ * wave if there is an ALU stall. Hence we have an instruction (s_delay_alu) to signal
+ * that we should switch to a different wave and contains info on dependencies as to
+ * when we can switch back.
+ *
+ * This seems to apply only for ALU->ALU dependencies as other instructions have better
+ * integration with the frontend.
+ *
+ * Note that if we do not emit s_delay_alu things will still be correct, but the wave
+ * will stall in the ALU (and the ALU will be doing nothing else). We'll use this as
+ * I'm pretty sure our cycle info is wrong at times (necessarily so, e.g. wave64 VALU
+ * instructions can take a different number of cycles based on the exec mask)
+ */
+struct alu_delay_info {
+   /* These are the values directly above the max representable value, i.e. the wait
+    * would turn into a no-op when we try to wait for something further back than
+    * this.
+    */
+   static constexpr int8_t valu_nop = 5;
+   static constexpr int8_t trans_nop = 4;
+
+   /* How many VALU instructions ago this value was written */
+   int8_t valu_instrs = valu_nop;
+   /* Cycles until the writing VALU instruction is finished */
+   int8_t valu_cycles = 0;
+
+   /* How many Transcedent instructions ago this value was written */
+   int8_t trans_instrs = trans_nop;
+   /* Cycles until the writing Transcendent instruction is finished */
+   int8_t trans_cycles = 0;
+
+   /* Cycles until the writing SALU instruction is finished*/
+   int8_t salu_cycles = 0;
+
+   bool combine(const alu_delay_info& other)
+   {
+      bool changed = other.valu_instrs < valu_instrs || other.trans_instrs < trans_instrs ||
+                     other.salu_cycles > salu_cycles || other.valu_cycles > valu_cycles ||
+                     other.trans_cycles > trans_cycles;
+      valu_instrs = std::min(valu_instrs, other.valu_instrs);
+      trans_instrs = std::min(trans_instrs, other.trans_instrs);
+      salu_cycles = std::max(salu_cycles, other.salu_cycles);
+      valu_cycles = std::max(valu_cycles, other.valu_cycles);
+      trans_cycles = std::max(trans_cycles, other.trans_cycles);
+      return changed;
+   }
+
+   /* Needs to be called after any change to keep the data consistent. */
+   void fixup()
+   {
+      if (valu_instrs >= valu_nop || valu_cycles <= 0) {
+         valu_instrs = valu_nop;
+         valu_cycles = 0;
+      }
+
+      if (trans_instrs >= trans_nop || trans_cycles <= 0) {
+         trans_instrs = trans_nop;
+         trans_cycles = 0;
+      }
+
+      salu_cycles = std::max<int8_t>(salu_cycles, 0);
+   }
+
+   /* Returns true if a wait would be a no-op */
+   bool empty() const
+   {
+      return valu_instrs == valu_nop && trans_instrs == trans_nop && salu_cycles == 0;
+   }
+};
+
+enum class alu_delay_wait {
+   NO_DEP,
+   VALU_DEP_1,
+   VALU_DEP_2,
+   VALU_DEP_3,
+   VALU_DEP_4,
+   TRANS32_DEP_1,
+   TRANS32_DEP_2,
+   TRANS32_DEP_3,
+   FMA_ACCUM_CYCLE_1,
+   SALU_CYCLE_1,
+   SALU_CYCLE_2,
+   SALU_CYCLE_3
+};
+
 uint8_t
 get_counters_for_event(wait_event ev)
 {
@@ -110,20 +199,25 @@ get_counters_for_event(wait_event ev)
    case event_gds_gpr_lock:
    case event_vmem_gpr_lock:
    case event_ldsdir: return counter_exp;
+   case event_valu:
+   case event_trans:
+   case event_salu: return counter_alu;
    default: return 0;
    }
 }
 
 struct wait_entry {
    wait_imm imm;
+   alu_delay_info delay;
    uint16_t events;  /* use wait_event notion */
    uint8_t counters; /* use counter_type notion */
    bool wait_on_read : 1;
    bool logical : 1;
    uint8_t vmem_types : 4;
 
-   wait_entry(wait_event event_, wait_imm imm_, bool logical_, bool wait_on_read_)
-       : imm(imm_), events(event_), counters(get_counters_for_event(event_)),
+   wait_entry(wait_event event_, wait_imm imm_, alu_delay_info delay_, bool logical_,
+              bool wait_on_read_)
+       : imm(imm_), delay(delay_), events(event_), counters(get_counters_for_event(event_)),
          wait_on_read(wait_on_read_), logical(logical_), vmem_types(0)
    {}
 
@@ -134,6 +228,7 @@ struct wait_entry {
       events |= other.events;
       counters |= other.counters;
       changed |= imm.combine(other.imm);
+      changed |= delay.combine(other.delay);
       wait_on_read |= other.wait_on_read;
       vmem_types |= other.vmem_types;
       assert(logical == other.logical);
@@ -167,6 +262,11 @@ struct wait_entry {
 
       if (!(counters & counter_lgkm) && !(counters & counter_vm))
          events &= ~event_flat;
+
+      if (counter == counter_alu) {
+         delay = alu_delay_info();
+         events &= ~(event_valu | event_trans | event_salu);
+      }
    }
 };
 
@@ -258,7 +358,7 @@ get_vmem_type(Instruction* instr)
 }
 
 void
-check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr)
+check_instr(wait_ctx& ctx, wait_imm& wait, alu_delay_info& delay, Instruction* instr)
 {
    for (const Operand op : instr->operands) {
       if (op.isConstant() || op.isUndefined())
@@ -272,6 +372,8 @@ check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr)
             continue;
 
          wait.combine(it->second.imm);
+         if (instr->isVALU() || instr->isSALU() || instr->isVINTERP_INREG())
+            delay.combine(it->second.delay);
       }
    }
 
@@ -314,6 +416,25 @@ parse_wait_instr(wait_ctx& ctx, wait_imm& imm, Instruction* instr)
    return false;
 }
 
+bool
+parse_delay_alu(wait_ctx& ctx, alu_delay_info& delay, Instruction* instr)
+{
+   if (instr->opcode != aco_opcode::s_delay_alu)
+      return false;
+
+   unsigned imm[2] = {instr->sopp().imm & 0xf, (instr->sopp().imm >> 7) & 0xf};
+   for (unsigned i = 0; i < 2; ++i) {
+      alu_delay_wait wait = (alu_delay_wait)imm[i];
+      if (wait >= alu_delay_wait::VALU_DEP_1 && wait <= alu_delay_wait::VALU_DEP_4)
+         delay.valu_instrs = imm[i] - (uint32_t)alu_delay_wait::VALU_DEP_1 + 1;
+      else if (wait >= alu_delay_wait::TRANS32_DEP_1 && wait <= alu_delay_wait::TRANS32_DEP_3)
+         delay.trans_instrs = imm[i] - (uint32_t)alu_delay_wait::TRANS32_DEP_1 + 1;
+      else if (wait >= alu_delay_wait::SALU_CYCLE_1)
+         delay.salu_cycles = imm[i] - (uint32_t)alu_delay_wait::SALU_CYCLE_1 + 1;
+   }
+   return true;
+}
+
 void
 perform_barrier(wait_ctx& ctx, wait_imm& imm, memory_sync_info sync, unsigned semantics)
 {
@@ -359,7 +480,28 @@ force_waitcnt(wait_ctx& ctx, wait_imm& imm)
 }
 
 void
-kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_info)
+update_alu(wait_ctx& ctx, bool is_valu, bool is_trans, bool clear, int cycles)
+{
+   for (std::pair<const PhysReg, wait_entry>& e : ctx.gpr_map) {
+      wait_entry& entry = e.second;
+
+      if (clear) {
+         entry.delay = alu_delay_info();
+      } else {
+         entry.delay.valu_instrs += is_valu ? 1 : 0;
+         entry.delay.trans_instrs += is_trans ? 1 : 0;
+         entry.delay.salu_cycles -= cycles;
+         entry.delay.valu_cycles -= cycles;
+         entry.delay.trans_cycles -= cycles;
+
+         entry.delay.fixup();
+      }
+   }
+}
+
+void
+kill(wait_imm& imm, alu_delay_info& delay, Instruction* instr, wait_ctx& ctx,
+     memory_sync_info sync_info)
 {
    if (instr->opcode == aco_opcode::s_setpc_b64 || (debug_flags & DEBUG_FORCE_WAITCNT)) {
       /* Force emitting waitcnt states right after the instruction if there is
@@ -369,8 +511,7 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
       force_waitcnt(ctx, imm);
    }
 
-   if (ctx.exp_cnt || ctx.vm_cnt || ctx.lgkm_cnt)
-      check_instr(ctx, imm, instr);
+   check_instr(ctx, imm, delay, instr);
 
    /* It's required to wait for scalar stores before "writing back" data.
     * It shouldn't cost anything anyways since we're about to do s_endpgm.
@@ -418,7 +559,7 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
    else
       perform_barrier(ctx, imm, sync_info, semantic_release);
 
-   if (!imm.empty()) {
+   if (!imm.empty() || !delay.empty()) {
       if (ctx.pending_flat_vm && imm.vm != wait_imm::unset_counter)
          imm.vm = 0;
       if (ctx.pending_flat_lgkm && imm.lgkm != wait_imm::unset_counter)
@@ -454,6 +595,10 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
             bar_ev &= ~event_flat;
       }
 
+      if (ctx.program->gfx_level >= GFX11) {
+         update_alu(ctx, false, false, false, MAX3(delay.salu_cycles, delay.valu_cycles, delay.trans_cycles));
+      }
+
       /* remove all gprs with higher counter from map */
       std::map<PhysReg, wait_entry>::iterator it = ctx.gpr_map.begin();
       while (it != ctx.gpr_map.end()) {
@@ -465,6 +610,13 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
             ctx.wait_and_remove_from_entry(it->first, it->second, counter_lgkm);
          if (imm.vs != wait_imm::unset_counter && imm.vs <= it->second.imm.vs)
             ctx.wait_and_remove_from_entry(it->first, it->second, counter_vs);
+         if (delay.valu_instrs <= it->second.delay.valu_instrs)
+            it->second.delay.valu_instrs = alu_delay_info::valu_nop;
+         if (delay.trans_instrs <= it->second.delay.trans_instrs)
+            it->second.delay.trans_instrs = alu_delay_info::trans_nop;
+         it->second.delay.fixup();
+         if (it->second.delay.empty())
+            ctx.wait_and_remove_from_entry(it->first, it->second, counter_alu);
          if (!it->second.counters)
             it = ctx.gpr_map.erase(it);
          else
@@ -587,7 +739,7 @@ update_counters_for_flat_load(wait_ctx& ctx, memory_sync_info sync = memory_sync
 
 void
 insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read,
-                  uint8_t vmem_types = 0)
+                  uint8_t vmem_types = 0, unsigned cycles = 0)
 {
    uint16_t counters = get_counters_for_event(event);
    wait_imm imm;
@@ -600,7 +752,18 @@ insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, boo
    if (counters & counter_vs)
       imm.vs = 0;
 
-   wait_entry new_entry(event, imm, !rc.is_linear(), wait_on_read);
+   alu_delay_info delay;
+   if (event == event_valu) {
+      delay.valu_instrs = 0;
+      delay.valu_cycles = cycles;
+   } else if (event == event_trans) {
+      delay.trans_instrs = 0;
+      delay.trans_cycles = cycles;
+   } else if (event == event_salu) {
+      delay.salu_cycles = cycles;
+   }
+
+   wait_entry new_entry(event, imm, delay, !rc.is_linear(), wait_on_read);
    new_entry.vmem_types |= vmem_types;
 
    for (unsigned i = 0; i < rc.size(); i++) {
@@ -614,13 +777,38 @@ void
 insert_wait_entry(wait_ctx& ctx, Operand op, wait_event event, uint8_t vmem_types = 0)
 {
    if (!op.isConstant() && !op.isUndefined())
-      insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false, vmem_types);
+      insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false, vmem_types, 0);
 }
 
 void
-insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, uint8_t vmem_types = 0)
+insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, uint8_t vmem_types = 0,
+                  unsigned cycles = 0)
 {
-   insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, vmem_types);
+   insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, vmem_types, cycles);
+}
+
+void
+gen_alu(Instruction* instr, wait_ctx& ctx)
+{
+   Instruction_cycle_info cycle_info = get_cycle_info(*ctx.program, *instr);
+   bool is_valu = instr->isVALU() || instr->isVINTERP_INREG();
+   bool is_trans = instr->isTrans();
+   bool clear = instr->isEXP() || instr->isDS() || instr->isMIMG() || instr->isFlatLike() ||
+                instr->isMUBUF() || instr->isMTBUF();
+
+   wait_event event = (wait_event)0;
+   if (is_trans)
+      event = event_trans;
+   else if (is_valu)
+      event = event_valu;
+   else if (instr->isSALU())
+      event = event_salu;
+
+   if (event != (wait_event)0) {
+      for (const Definition& def : instr->definitions)
+         insert_wait_entry(ctx, def, event, 0, cycle_info.latency);
+   }
+   update_alu(ctx, is_valu, is_trans, clear, cycle_info.issue_cycles);
 }
 
 void
@@ -756,21 +944,54 @@ emit_waitcnt(wait_ctx& ctx, std::vector<aco_ptr<Instruction>>& instructions, wai
 }
 
 void
+emit_delay_alu(wait_ctx& ctx, std::vector<aco_ptr<Instruction>>& instructions,
+               alu_delay_info& delay)
+{
+   uint32_t imm = 0;
+   if (delay.trans_instrs != delay.trans_nop) {
+      imm |= (uint32_t)alu_delay_wait::TRANS32_DEP_1 + delay.trans_instrs - 1;
+   }
+
+   if (delay.valu_instrs != delay.valu_nop) {
+      imm |= ((uint32_t)alu_delay_wait::VALU_DEP_1 + delay.valu_instrs - 1) << (imm ? 7 : 0);
+   }
+
+   /* Note that we can only put 2 wait conditions in the instruction, so if we have all 3 we just
+    * drop the SALU one. Here we use that this doesn't really affect correctness so occasionally
+    * getting this wrong isn't an issue. */
+   if (delay.salu_cycles && imm <= 0xf) {
+      unsigned cycles = std::min<uint8_t>(3, delay.salu_cycles);
+      imm |= ((uint32_t)alu_delay_wait::SALU_CYCLE_1 + cycles - 1) << (imm ? 7 : 0);
+   }
+
+   SOPP_instruction* inst =
+      create_instruction<SOPP_instruction>(aco_opcode::s_delay_alu, Format::SOPP, 0, 0);
+   inst->imm = imm;
+   inst->block = -1;
+   instructions.emplace_back(inst);
+   delay = alu_delay_info();
+}
+
+void
 handle_block(Program* program, Block& block, wait_ctx& ctx)
 {
    std::vector<aco_ptr<Instruction>> new_instructions;
 
    wait_imm queued_imm;
+   alu_delay_info queued_delay;
 
    for (aco_ptr<Instruction>& instr : block.instructions) {
       bool is_wait = parse_wait_instr(ctx, queued_imm, instr.get());
+      bool is_delay_alu = parse_delay_alu(ctx, queued_delay, instr.get());
 
       memory_sync_info sync_info = get_sync_info(instr.get());
-      kill(queued_imm, instr.get(), ctx, sync_info);
+      kill(queued_imm, queued_delay, instr.get(), ctx, sync_info);
 
       gen(instr.get(), ctx);
+      if (program->gfx_level >= GFX11)
+         gen_alu(instr.get(), ctx);
 
-      if (instr->format != Format::PSEUDO_BARRIER && !is_wait) {
+      if (instr->format != Format::PSEUDO_BARRIER && !is_wait && !is_delay_alu) {
          if (instr->isVINTERP_INREG() && queued_imm.exp != wait_imm::unset_counter) {
             instr->vinterp_inreg().wait_exp = MIN2(instr->vinterp_inreg().wait_exp, queued_imm.exp);
             queued_imm.exp = wait_imm::unset_counter;
@@ -778,6 +999,8 @@ handle_block(Program* program, Block& block, wait_ctx& ctx)
 
          if (!queued_imm.empty())
             emit_waitcnt(ctx, new_instructions, queued_imm);
+         if (!queued_delay.empty())
+            emit_delay_alu(ctx, new_instructions, queued_delay);
 
          bool is_ordered_count_acquire =
             instr->opcode == aco_opcode::ds_ordered_count &&
@@ -793,6 +1016,8 @@ handle_block(Program* program, Block& block, wait_ctx& ctx)
 
    if (!queued_imm.empty())
       emit_waitcnt(ctx, new_instructions, queued_imm);
+   if (!queued_delay.empty())
+      emit_delay_alu(ctx, new_instructions, queued_delay);
 
    block.instructions.swap(new_instructions);
 }