aco: Extract thread_id_in_threadgroup to a separate function.
authorTimur Kristóf <timur.kristof@gmail.com>
Fri, 4 Sep 2020 11:35:47 +0000 (13:35 +0200)
committerTimur Kristóf <timur.kristof@gmail.com>
Fri, 9 Oct 2020 13:26:14 +0000 (15:26 +0200)
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6964>

src/amd/compiler/aco_instruction_selection.cpp

index 6635734..721cabb 100644 (file)
@@ -3871,6 +3871,26 @@ void load_vmem_mubuf(isel_context *ctx, Temp dst, Temp descriptor, Temp voffset,
    emit_load(ctx, bld, info, mubuf_load_params);
 }
 
+Temp wave_id_in_threadgroup(isel_context *ctx)
+{
+   Builder bld(ctx->program, ctx->block);
+   return bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                   get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16)));
+}
+
+Temp thread_id_in_threadgroup(isel_context *ctx)
+{
+   /* tid_in_tg = wave_id * wave_size + tid_in_wave */
+
+   Builder bld(ctx->program, ctx->block);
+
+   Temp wave_id_in_tg = wave_id_in_threadgroup(ctx);
+   Temp tid_in_wave = emit_mbcnt(ctx, bld.tmp(v1));
+   Temp num_pre_threads = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), wave_id_in_tg,
+                                   Operand(ctx->program->wave_size == 64 ? 6u : 5u));
+   return bld.vadd32(bld.def(v1), Operand(num_pre_threads), Operand(tid_in_wave));
+}
+
 std::pair<Temp, unsigned> offset_add_from_nir(isel_context *ctx, const std::pair<Temp, unsigned> &base_offset, nir_src *off_src, unsigned stride = 1u)
 {
    Builder bld(ctx->program, ctx->block);
@@ -10845,11 +10865,7 @@ void ngg_emit_nogs_output(isel_context *ctx)
          create_workgroup_barrier(bld);
 
          /* Calculate LDS address where the GS threads stored the primitive ID. */
-         Temp wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
-                                       get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16)));
-         Temp thread_id_in_wave = emit_mbcnt(ctx, bld.tmp(v1));
-         Temp wave_id_mul = bld.v_mul24_imm(bld.def(v1), as_vgpr(ctx, wave_id_in_tg), ctx->program->wave_size);
-         Temp thread_id_in_tg = bld.vadd32(bld.def(v1), Operand(wave_id_mul), Operand(thread_id_in_wave));
+         Temp thread_id_in_tg = thread_id_in_threadgroup(ctx);
          Temp addr = bld.v_mul24_imm(bld.def(v1), thread_id_in_tg, 4u);
 
          /* Load primitive ID from LDS. */