aco: implement select_rt_prolog()
authorDaniel Schürmann <daniel@schuermann.dev>
Thu, 26 Jan 2023 14:58:01 +0000 (15:58 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 16 Mar 2023 01:40:30 +0000 (01:40 +0000)
Co-authored-by: Friedrich Vock <friedrich.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21780>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_ir.h

index ac615ec..8bab001 100644 (file)
@@ -11632,17 +11632,22 @@ select_trap_handler_shader(Program* program, struct nir_shader* shader, ac_shade
    cleanup_cfg(program);
 }
 
-Operand
-get_arg_fixed(const struct ac_shader_args* args, struct ac_arg arg)
+PhysReg
+get_arg_reg(const struct ac_shader_args* args, struct ac_arg arg)
 {
    assert(arg.used);
-
    enum ac_arg_regfile file = args->args[arg.arg_index].file;
-   unsigned size = args->args[arg.arg_index].size;
    unsigned reg = args->args[arg.arg_index].offset;
+   return PhysReg(file == AC_ARG_SGPR ? reg : reg + 256);
+}
 
-   return Operand(PhysReg(file == AC_ARG_SGPR ? reg : reg + 256),
-                  RegClass(file == AC_ARG_SGPR ? RegType::sgpr : RegType::vgpr, size));
+Operand
+get_arg_fixed(const struct ac_shader_args* args, struct ac_arg arg)
+{
+   enum ac_arg_regfile file = args->args[arg.arg_index].file;
+   unsigned size = args->args[arg.arg_index].size;
+   RegClass rc = RegClass(file == AC_ARG_SGPR ? RegType::sgpr : RegType::vgpr, size);
+   return Operand(get_arg_reg(args, arg), rc);
 }
 
 unsigned
@@ -11737,6 +11742,136 @@ calc_nontrivial_instance_id(Builder& bld, const struct ac_shader_args* args,
 }
 
 void
+select_rt_prolog(Program* program, ac_shader_config* config,
+                 const struct aco_compiler_options* options, const struct aco_shader_info* info,
+                 const struct ac_shader_args* in_args, const struct ac_shader_args* out_args)
+{
+   init_program(program, compute_cs, info, options->gfx_level, options->family, options->wgp_mode,
+                config);
+   Block* block = program->create_and_insert_block();
+   block->kind = block_kind_top_level;
+   program->workgroup_size = info->workgroup_size;
+   program->wave_size = info->workgroup_size;
+   calc_min_waves(program);
+   Builder bld(program, block);
+   block->instructions.reserve(32);
+   unsigned num_sgprs = MAX2(in_args->num_sgprs_used, out_args->num_sgprs_used);
+   unsigned num_vgprs = MAX2(in_args->num_vgprs_used, out_args->num_vgprs_used);
+
+   /* Inputs:
+    * Ring offsets:                s[0-1]
+    * Indirect descriptor sets:    s[2]
+    * Push constants pointer:      s[3]
+    * SBT descriptors:             s[4-5]
+    * Ray launch size address:     s[6-7]
+    * Traversal shader address:    s[8-9]
+    * Dynamic callable stack base: s[10]
+    * Workgroup IDs (xyz):         s[11], s[12], s[13]
+    * Scratch offset:              s[14]
+    * Local invocation IDs:        v[0-2]
+    */
+   PhysReg in_ring_offsets = get_arg_reg(in_args, in_args->ring_offsets);
+   PhysReg in_launch_size_addr = get_arg_reg(in_args, in_args->ray_launch_size_addr);
+   PhysReg in_shader_addr = get_arg_reg(in_args, in_args->rt_traversal_shader_addr);
+   PhysReg in_stack_base = get_arg_reg(in_args, in_args->rt_dynamic_callable_stack_base);
+   PhysReg in_wg_id_x = get_arg_reg(in_args, in_args->workgroup_ids[0]);
+   PhysReg in_wg_id_y = get_arg_reg(in_args, in_args->workgroup_ids[1]);
+   PhysReg in_wg_id_z = get_arg_reg(in_args, in_args->workgroup_ids[2]);
+   PhysReg in_scratch_offset = get_arg_reg(in_args, in_args->scratch_offset);
+   PhysReg in_local_ids[2] = {
+      get_arg_reg(in_args, in_args->local_invocation_ids),
+      get_arg_reg(in_args, in_args->local_invocation_ids).advance(4),
+   };
+
+   /* Outputs:
+    * Callee shader PC:            s[0-1]
+    * Indirect descriptor sets:    s[2]
+    * Push constants pointer:      s[3]
+    * SBT descriptors:             s[4-5]
+    * Ray launch sizes (xyz):      s[6], s[7], s[8]
+    * Scratch offset (<GFX9 only): s[9]
+    * Ring offsets (<GFX9 only):   s[10-11]
+    * Ray launch IDs:              v[0-2]
+    * Stack pointer:               v[3]
+    */
+   PhysReg out_shader_pc = get_arg_reg(out_args, out_args->rt_shader_pc);
+   PhysReg out_launch_size_x = get_arg_reg(out_args, out_args->ray_launch_size);
+   PhysReg out_launch_size_z = out_launch_size_x.advance(8);
+   PhysReg out_launch_ids[3];
+   for (unsigned i = 0; i < 3; i++)
+      out_launch_ids[i] = get_arg_reg(out_args, out_args->ray_launch_id).advance(i * 4);
+   PhysReg out_stack_ptr = get_arg_reg(out_args, out_args->rt_dynamic_callable_stack_base);
+
+   /* Temporaries: */
+   num_sgprs = align(num_sgprs, 2) + 2;
+   PhysReg tmp_ring_offsets = PhysReg{num_sgprs - 2};
+
+   /* Confirm some assumptions about register aliasing */
+   assert(in_ring_offsets == out_shader_pc);
+   assert(get_arg_reg(in_args, in_args->push_constants) ==
+          get_arg_reg(out_args, out_args->push_constants));
+   assert(get_arg_reg(in_args, in_args->sbt_descriptors) ==
+          get_arg_reg(out_args, out_args->sbt_descriptors));
+   assert(in_launch_size_addr == out_launch_size_x);
+   assert(in_shader_addr == out_launch_size_z);
+   assert(in_local_ids[0] == out_launch_ids[0]);
+
+   /* init scratch */
+   if (options->gfx_level >= GFX9) {
+      hw_init_scratch(bld, Definition(in_ring_offsets, s1), Operand(in_ring_offsets, s2),
+                      Operand(in_scratch_offset, s1));
+   } else {
+      /* copy ring offsets to temporary location*/
+      bld.sop1(aco_opcode::s_mov_b64, Definition(tmp_ring_offsets, s2),
+               Operand(in_ring_offsets, s2));
+   }
+
+   /* set stack ptr */
+   bld.vop1(aco_opcode::v_mov_b32, Definition(out_stack_ptr, v1), Operand(in_stack_base, s1));
+
+   /* load RT shader address */
+   /* TODO: load this from the SBT, will be possible with separate shader compilation */
+   bld.sop1(aco_opcode::s_mov_b64, Definition(out_shader_pc, s2), Operand(in_shader_addr, s2));
+
+   /* load ray launch sizes */
+   bld.smem(aco_opcode::s_load_dword, Definition(out_launch_size_z, s1),
+            Operand(in_launch_size_addr, s2), Operand::c32(8u));
+   bld.smem(aco_opcode::s_load_dwordx2, Definition(out_launch_size_x, s2),
+            Operand(in_launch_size_addr, s2), Operand::c32(0u));
+
+   /* calculate ray launch ids */
+   if (options->gfx_level >= GFX11) {
+      /* Thread IDs are packed in VGPR0, 10 bits per component. */
+      bld.vop3(aco_opcode::v_bfe_u32, Definition(in_local_ids[1], v1), Operand(in_local_ids[0], v1),
+               Operand::c32(10u), Operand::c32(3u));
+      bld.vop2(aco_opcode::v_and_b32, Definition(in_local_ids[0], v1), Operand(in_local_ids[0], v1),
+               Operand::c32(0x7));
+   }
+   /* Do this backwards to reduce some RAW hazards on GFX11+ */
+   bld.vop1(aco_opcode::v_mov_b32, Definition(out_launch_ids[2], v1), Operand(in_wg_id_z, s1));
+   bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[1], v1), Operand(in_wg_id_y, s1),
+            Operand::c32(program->workgroup_size == 32 ? 4 : 8), Operand(in_local_ids[1], v1));
+   bld.vop3(aco_opcode::v_mad_u32_u24, Definition(out_launch_ids[0], v1), Operand(in_wg_id_x, s1),
+            Operand::c32(8), Operand(in_local_ids[0], v1));
+
+   if (options->gfx_level < GFX9) {
+      /* write scratch/ring offsets to outputs, if needed */
+      bld.sop1(aco_opcode::s_mov_b32,
+               Definition(get_arg_reg(out_args, out_args->scratch_offset), s1),
+               Operand(in_scratch_offset, s1));
+      bld.sop1(aco_opcode::s_mov_b64, Definition(get_arg_reg(out_args, out_args->ring_offsets), s2),
+               Operand(tmp_ring_offsets, s2));
+   }
+
+   /* jump to raygen */
+   bld.sop1(aco_opcode::s_setpc_b64, Operand(out_shader_pc, s2));
+
+   program->config->float_mode = program->blocks[0].fp_mode.val;
+   program->config->num_vgprs = get_vgpr_alloc(program, num_sgprs);
+   program->config->num_sgprs = get_sgpr_alloc(program, num_vgprs);
+}
+
+void
 select_vs_prolog(Program* program, const struct aco_vs_prolog_info* pinfo, ac_shader_config* config,
                  const struct aco_compiler_options* options, const struct aco_shader_info* info,
                  const struct ac_shader_args* args, unsigned* num_preserved_sgprs)
index 5045fde..6f42f09 100644 (file)
@@ -2188,6 +2188,10 @@ void select_trap_handler_shader(Program* program, struct nir_shader* shader,
                                 const struct aco_compiler_options* options,
                                 const struct aco_shader_info* info,
                                 const struct ac_shader_args* args);
+void select_rt_prolog(Program* program, ac_shader_config* config,
+                      const struct aco_compiler_options* options,
+                      const struct aco_shader_info* info, const struct ac_shader_args* in_args,
+                      const struct ac_shader_args* out_args);
 void select_vs_prolog(Program* program, const struct aco_vs_prolog_info* pinfo,
                       ac_shader_config* config, const struct aco_compiler_options* options,
                       const struct aco_shader_info* info, const struct ac_shader_args* args,