aco: implement nir_intrinsic_load_resume_shader_address_amd
authorDaniel Schürmann <daniel@schuermann.dev>
Thu, 12 May 2022 21:29:37 +0000 (23:29 +0200)
committerMarge Bot <emma+marge@anholt.net>
Thu, 8 Jun 2023 00:37:03 +0000 (00:37 +0000)
Similar to p_constaddr but targeting BBs.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>

src/amd/compiler/aco_assembler.cpp
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_lower_to_hw_instr.cpp
src/amd/compiler/aco_opcodes.py

index c3be64f..eda5e69 100644 (file)
@@ -46,6 +46,7 @@ struct asm_context {
    enum amd_gfx_level gfx_level;
    std::vector<std::pair<int, SOPP_instruction*>> branches;
    std::map<unsigned, constaddr_info> constaddrs;
+   std::map<unsigned, constaddr_info> resumeaddrs;
    std::vector<struct aco_symbol>* symbols;
    const int16_t* opcode;
    // TODO: keep track of branch instructions referring blocks
@@ -138,6 +139,19 @@ emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction* inst
       assert(instr->operands[1].isConstant());
       /* in case it's an inline constant, make it a literal */
       instr->operands[1] = Operand::literal32(instr->operands[1].constantValue());
+   } else if (instr->opcode == aco_opcode::p_resumeaddr_getpc) {
+      ctx.resumeaddrs[instr->operands[0].constantValue()].getpc_end = out.size() + 1;
+
+      instr->opcode = aco_opcode::s_getpc_b64;
+      instr->operands.pop_back();
+   } else if (instr->opcode == aco_opcode::p_resumeaddr_addlo) {
+      ctx.resumeaddrs[instr->operands[2].constantValue()].add_literal = out.size() + 1;
+
+      instr->opcode = aco_opcode::s_add_u32;
+      instr->operands.pop_back();
+      assert(instr->operands[1].isConstant());
+      /* in case it's an inline constant, make it a literal */
+      instr->operands[1] = Operand::literal32(instr->operands[1].constantValue());
    } else if (instr->opcode == aco_opcode::p_load_symbol) {
       assert(instr->operands[0].isConstant());
       assert(ctx.symbols);
@@ -1049,6 +1063,13 @@ insert_code(asm_context& ctx, std::vector<uint32_t>& out, unsigned insert_before
       if (info.add_literal >= insert_before)
          info.add_literal += insert_count;
    }
+   for (auto& constaddr : ctx.resumeaddrs) {
+      constaddr_info& info = constaddr.second;
+      if (info.getpc_end >= insert_before)
+         info.getpc_end += insert_count;
+      if (info.add_literal >= insert_before)
+         info.add_literal += insert_count;
+   }
 
    if (ctx.symbols) {
       for (auto& symbol : *ctx.symbols) {
@@ -1188,6 +1209,12 @@ fix_constaddrs(asm_context& ctx, std::vector<uint32_t>& out)
       constaddr_info& info = constaddr.second;
       out[info.add_literal] += (out.size() - info.getpc_end) * 4u;
    }
+   for (auto& addr : ctx.resumeaddrs) {
+      constaddr_info& info = addr.second;
+      const Block& block = ctx.program->blocks[out[info.add_literal]];
+      assert(block.kind & block_kind_resume);
+      out[info.add_literal] = (block.offset - info.getpc_end) * 4u;
+   }
 }
 
 unsigned
index a2bb460..de6d82d 100644 (file)
@@ -9004,6 +9004,12 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
       bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)),
                get_arg(ctx, ctx->args->rt.dynamic_callable_stack_base));
       break;
+   case nir_intrinsic_load_resume_shader_address_amd: {
+      bld.pseudo(aco_opcode::p_resume_shader_address,
+                 Definition(get_ssa_temp(ctx, &instr->dest.ssa)), bld.def(s1, scc),
+                 Operand::c32(nir_intrinsic_call_idx(instr)));
+      break;
+   }
    case nir_intrinsic_overwrite_vs_arguments_amd: {
       ctx->arg_temps[ctx->args->vertex_id.arg_index] = get_ssa_temp(ctx, instr->src[0].ssa);
       ctx->arg_temps[ctx->args->instance_id.arg_index] = get_ssa_temp(ctx, instr->src[1].ssa);
index eb58a6d..6a7100b 100644 (file)
@@ -1873,6 +1873,7 @@ enum block_kind {
    block_kind_discard_early_exit = 1 << 11,
    block_kind_uses_discard = 1 << 12,
    block_kind_needs_lowering = 1 << 13,
+   block_kind_resume = 1 << 14,
    block_kind_export_end = 1 << 15,
 };
 
@@ -1960,12 +1961,12 @@ struct Block {
    std::vector<unsigned> logical_succs;
    std::vector<unsigned> linear_succs;
    RegisterDemand register_demand = RegisterDemand();
+   uint32_t kind = 0;
+   int32_t logical_idom = -1;
+   int32_t linear_idom = -1;
    uint16_t loop_nest_depth = 0;
    uint16_t divergent_if_logical_depth = 0;
    uint16_t uniform_if_depth = 0;
-   uint16_t kind = 0;
-   int logical_idom = -1;
-   int linear_idom = -1;
 
    /* this information is needed for predecessors to blocks with phis when
     * moving out of ssa */
index 6da27e1..61fe7eb 100644 (file)
@@ -2423,6 +2423,28 @@ lower_to_hw_instr(Program* program)
                /* s_addc_u32 not needed because the program is in a 32-bit VA range */
                break;
             }
+            case aco_opcode::p_resume_shader_address: {
+               /* Find index of resume block. */
+               unsigned resume_idx = instr->operands[0].constantValue();
+               unsigned resume_block_idx = 0;
+               for (Block& resume_block : program->blocks) {
+                  if (resume_block.kind & block_kind_resume) {
+                     if (resume_idx == 0) {
+                        resume_block_idx = resume_block.index;
+                        break;
+                     }
+                     resume_idx--;
+                  }
+               }
+               assert(resume_block_idx != 0);
+               unsigned id = instr->definitions[0].tempId();
+               PhysReg reg = instr->definitions[0].physReg();
+               bld.sop1(aco_opcode::p_resumeaddr_getpc, instr->definitions[0], Operand::c32(id));
+               bld.sop2(aco_opcode::p_resumeaddr_addlo, Definition(reg, s1), bld.def(s1, scc),
+                        Operand(reg, s1), Operand::c32(resume_block_idx), Operand::c32(id));
+               /* s_addc_u32 not needed because the program is in a 32-bit VA range */
+               break;
+            }
             case aco_opcode::p_extract: {
                assert(instr->operands[1].isConstant());
                assert(instr->operands[2].isConstant());
index 5fab419..2f19268 100644 (file)
@@ -334,6 +334,7 @@ opcode("p_bpermute_gfx11w64")
 opcode("p_elect")
 
 opcode("p_constaddr")
+opcode("p_resume_shader_address")
 
 # These don't have to be pseudo-ops, but it makes optimization easier to only
 # have to consider two instructions.
@@ -414,6 +415,7 @@ SOP2 = {
    (  -1,   -1,   -1, 0x2d, 0x36, 0x2e, "s_mul_hi_i32"),
    # actually a pseudo-instruction. it's lowered to SALU during assembly though, so it's useful to identify it as a SOP2.
    (  -1,   -1,   -1,   -1,   -1,   -1, "p_constaddr_addlo"),
+   (  -1,   -1,   -1,   -1,   -1,   -1, "p_resumeaddr_addlo"),
 }
 for (gfx6, gfx7, gfx8, gfx9, gfx10, gfx11, name, cls) in default_class(SOP2, InstrClass.Salu):
     opcode(name, gfx7, gfx9, gfx10, gfx11, Format.SOP2, cls)
@@ -530,6 +532,7 @@ SOP1 = {
    (  -1,   -1,   -1,   -1,   -1, 0x4d, "s_sendmsg_rtn_b64"),
    # actually a pseudo-instruction. it's lowered to SALU during assembly though, so it's useful to identify it as a SOP1.
    (  -1,   -1,   -1,   -1,   -1,   -1, "p_constaddr_getpc"),
+   (  -1,   -1,   -1,   -1,   -1,   -1, "p_resumeaddr_getpc"),
    (  -1,   -1,   -1,   -1,   -1,   -1, "p_load_symbol"),
 }
 for (gfx6, gfx7, gfx8, gfx9, gfx10, gfx11, name, cls) in default_class(SOP1, InstrClass.Salu):