return ctx->allocated[def->index];
}
-Temp emit_mbcnt(isel_context *ctx, Definition dst,
- Operand mask_lo = Operand((uint32_t) -1), Operand mask_hi = Operand((uint32_t) -1))
+Temp emit_mbcnt(isel_context *ctx, Temp dst, Temp mask = Temp())
{
Builder bld(ctx->program, ctx->block);
- Definition lo_def = ctx->program->wave_size == 32 ? dst : bld.def(v1);
- Temp thread_id_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, lo_def, mask_lo, Operand(0u));
+ assert(mask.id() == 0 || mask.regClass() == bld.lm);
if (ctx->program->wave_size == 32) {
- return thread_id_lo;
- } else if (ctx->program->chip_class <= GFX7) {
- Temp thread_id_hi = bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, dst, mask_hi, thread_id_lo);
- return thread_id_hi;
- } else {
- Temp thread_id_hi = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32_e64, dst, mask_hi, thread_id_lo);
- return thread_id_hi;
+ Operand mask_lo = mask.id() ? Operand(mask) : Operand(-1u);
+ return bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, Definition(dst), mask_lo, Operand(0u));
+ }
+
+ Operand mask_lo(-1u);
+ Operand mask_hi(-1u);
+
+ if (mask.id()) {
+ Builder::Result mask_split = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), mask);
+ mask_lo = Operand(mask_split.def(0).getTemp());
+ mask_hi = Operand(mask_split.def(1).getTemp());
}
+
+ Temp mbcnt_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), mask_lo, Operand(0u));
+
+ if (ctx->program->chip_class <= GFX7)
+ return bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, Definition(dst), mask_hi, mbcnt_lo);
+ else
+ return bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32_e64, Definition(dst), mask_hi, mbcnt_lo);
}
Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_needs_wqm = false)
unsigned itemsize = ctx->stage == vertex_geometry_gs
? ctx->program->info->vs.es_info.esgs_itemsize
: ctx->program->info->tes.es_info.esgs_itemsize;
- Temp thread_id = emit_mbcnt(ctx, bld.def(v1));
+ Temp thread_id = emit_mbcnt(ctx, bld.tmp(v1));
Temp wave_idx = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), get_arg(ctx, ctx->args->merged_wave_info), Operand(4u << 16 | 24));
Temp vertex_idx = bld.vop2(aco_opcode::v_or_b32, bld.def(v1), thread_id,
bld.v_mul24_imm(bld.def(v1), as_vgpr(ctx, wave_idx), ctx->program->wave_size));
// return ((val & exec) >> cluster_offset) & cluster_mask != 0
//subgroupClusteredXor():
// return v_bnt_u32_b32(((val & exec) >> cluster_offset) & cluster_mask, 0) & 1 != 0
- Temp lane_id = emit_mbcnt(ctx, bld.def(v1));
+ Temp lane_id = emit_mbcnt(ctx, bld.tmp(v1));
Temp cluster_offset = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(~uint32_t(cluster_size - 1)), lane_id);
Temp tmp;
else
tmp = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm));
- Builder::Result lohi = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), tmp);
- Temp lo = lohi.def(0).getTemp();
- Temp hi = lohi.def(1).getTemp();
- Temp mbcnt = emit_mbcnt(ctx, bld.def(v1), Operand(lo), Operand(hi));
+ Temp mbcnt = emit_mbcnt(ctx, bld.tmp(v1), tmp);
Definition cmp_def = Definition();
if (op == nir_op_iand)
break;
}
case nir_intrinsic_load_local_invocation_index: {
- Temp id = emit_mbcnt(ctx, bld.def(v1));
+ Temp id = emit_mbcnt(ctx, bld.tmp(v1));
/* The tg_size bits [6:11] contain the subgroup id,
* we need this multiplied by the wave size, and then OR the thread id to it.
break;
}
case nir_intrinsic_load_subgroup_invocation: {
- emit_mbcnt(ctx, Definition(get_ssa_temp(ctx, &instr->dest.ssa)));
+ emit_mbcnt(ctx, get_ssa_temp(ctx, &instr->dest.ssa));
break;
}
case nir_intrinsic_load_num_subgroups: {
}
case nir_intrinsic_mbcnt_amd: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
- RegClass rc = RegClass(src.type(), 1);
- Temp mask_lo = bld.tmp(rc), mask_hi = bld.tmp(rc);
- bld.pseudo(aco_opcode::p_split_vector, Definition(mask_lo), Definition(mask_hi), src);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
- Temp wqm_tmp = emit_mbcnt(ctx, bld.def(v1), Operand(mask_lo), Operand(mask_hi));
+ Temp wqm_tmp = emit_mbcnt(ctx, bld.tmp(v1), src);
emit_wqm(ctx, wqm_tmp, dst);
break;
}
Temp so_vtx_count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
get_arg(ctx, ctx->args->streamout_config), Operand(0x70010u));
- Temp tid = emit_mbcnt(ctx, bld.def(v1));
+ Temp tid = emit_mbcnt(ctx, bld.tmp(v1));
Temp can_emit = bld.vopc(aco_opcode::v_cmp_gt_i32, bld.def(bld.lm), so_vtx_count, tid);
/* Calculate LDS address where the GS threads stored the primitive ID. */
Temp wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16)));
- Temp thread_id_in_wave = emit_mbcnt(ctx, bld.def(v1));
+ Temp thread_id_in_wave = emit_mbcnt(ctx, bld.tmp(v1));
Temp wave_id_mul = bld.v_mul24_imm(bld.def(v1), as_vgpr(ctx, wave_id_in_tg), ctx->program->wave_size);
Temp thread_id_in_tg = bld.vadd32(bld.def(v1), Operand(wave_id_mul), Operand(thread_id_in_wave));
Temp addr = bld.v_mul24_imm(bld.def(v1), thread_id_in_tg, 4u);