return ctx->allocated[def->index];
}
-Temp emit_mbcnt(isel_context *ctx, Temp dst, Temp mask = Temp(), Operand base = Operand(0u))
+Temp emit_mbcnt(isel_context *ctx, Temp dst, Operand mask = Operand(), Operand base = Operand(0u))
{
Builder bld(ctx->program, ctx->block);
- assert(mask.id() == 0 || mask.regClass() == bld.lm);
+ assert(mask.isUndefined() || mask.isTemp() || (mask.isFixed() && mask.physReg() == exec));
+ assert(mask.isUndefined() || mask.regClass() == bld.lm);
if (ctx->program->wave_size == 32) {
- Operand mask_lo = mask.id() ? Operand(mask) : Operand(-1u);
+ Operand mask_lo = mask.isUndefined() ? Operand(-1u) : mask;
return bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, Definition(dst), mask_lo, base);
}
Operand mask_lo(-1u);
Operand mask_hi(-1u);
- if (mask.id()) {
+ if (mask.isTemp()) {
Builder::Result mask_split = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), mask);
mask_lo = Operand(mask_split.def(0).getTemp());
mask_hi = Operand(mask_split.def(1).getTemp());
+ } else if (mask.physReg() == exec) {
+ mask_lo = Operand(exec_lo, s1);
+ mask_hi = Operand(exec_hi, s1);
}
Temp mbcnt_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), mask_lo, base);
else
tmp = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm));
- Temp mbcnt = emit_mbcnt(ctx, bld.tmp(v1), tmp);
+ Temp mbcnt = emit_mbcnt(ctx, bld.tmp(v1), Operand(tmp));
Definition cmp_def = Definition();
if (op == nir_op_iand)
return Temp();
}
+ReduceOp get_reduce_op(nir_op op, unsigned bit_size)
+{
+ switch (op) {
+ #define CASEI(name) case nir_op_##name: return (bit_size == 32) ? name##32 : (bit_size == 16) ? name##16 : (bit_size == 8) ? name##8 : name##64;
+ #define CASEF(name) case nir_op_##name: return (bit_size == 32) ? name##32 : (bit_size == 16) ? name##16 : name##64;
+ CASEI(iadd)
+ CASEI(imul)
+ CASEI(imin)
+ CASEI(umin)
+ CASEI(imax)
+ CASEI(umax)
+ CASEI(iand)
+ CASEI(ior)
+ CASEI(ixor)
+ CASEF(fadd)
+ CASEF(fmul)
+ CASEF(fmin)
+ CASEF(fmax)
+ default:
+ unreachable("unknown reduction op");
+ #undef CASEI
+ #undef CASEF
+ }
+}
+
void emit_uniform_subgroup(isel_context *ctx, nir_intrinsic_instr *instr, Temp src)
{
Builder bld(ctx->program, ctx->block);
Definition dst(get_ssa_temp(ctx, &instr->dest.ssa));
+ assert(dst.regClass().type() != RegType::vgpr);
if (src.regClass().type() == RegType::vgpr) {
bld.pseudo(aco_opcode::p_as_uniform, dst, src);
} else if (src.regClass() == s1) {
}
}
+void emit_addition_uniform_reduce(isel_context *ctx, nir_op op, Definition dst, nir_src src, Temp count)
+{
+ Builder bld(ctx->program, ctx->block);
+ Temp src_tmp = get_ssa_temp(ctx, src.ssa);
+
+ if (op == nir_op_fadd) {
+ src_tmp = as_vgpr(ctx, src_tmp);
+ Temp tmp = dst.regClass() == s1 ? bld.tmp(src_tmp.regClass()) : dst.getTemp();
+
+ if (src.ssa->bit_size == 16) {
+ count = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v2b), count);
+ bld.vop2(aco_opcode::v_mul_f16, Definition(tmp), count, src_tmp);
+ } else {
+ assert(src.ssa->bit_size == 32);
+ count = bld.vop1(aco_opcode::v_cvt_f32_u32, bld.def(v1), count);
+ bld.vop2(aco_opcode::v_mul_f32, Definition(tmp), count, src_tmp);
+ }
+
+ if (tmp != dst.getTemp())
+ bld.pseudo(aco_opcode::p_as_uniform, dst, tmp);
+
+ return;
+ }
+
+ if (dst.regClass() == s1)
+ src_tmp = bld.as_uniform(src_tmp);
+
+ if (op == nir_op_ixor && count.type() == RegType::sgpr)
+ count = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc),
+ count, Operand(1u));
+ else if (op == nir_op_ixor)
+ count = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(1u), count);
+
+ assert(dst.getTemp().type() == count.type());
+
+ if (nir_src_is_const(src)) {
+ if (nir_src_as_uint(src) == 1 && dst.bytes() <= 2)
+ bld.pseudo(aco_opcode::p_extract_vector, dst, count, Operand(0u));
+ else if (nir_src_as_uint(src) == 1)
+ bld.copy(dst, count);
+ else if (nir_src_as_uint(src) == 0 && dst.bytes() <= 2)
+ bld.vop1(aco_opcode::v_mov_b32, dst, Operand(0u)); /* RA will use SDWA if possible */
+ else if (nir_src_as_uint(src) == 0)
+ bld.copy(dst, Operand(0u));
+ else if (count.type() == RegType::vgpr)
+ bld.v_mul_imm(dst, count, nir_src_as_uint(src));
+ else
+ bld.sop2(aco_opcode::s_mul_i32, dst, src_tmp, count);
+ } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX10) {
+ bld.vop3(aco_opcode::v_mul_lo_u16_e64, dst, src_tmp, count);
+ } else if (dst.bytes() <= 2 && ctx->program->chip_class >= GFX8) {
+ bld.vop2(aco_opcode::v_mul_lo_u16, dst, src_tmp, count);
+ } else if (dst.getTemp().type() == RegType::vgpr) {
+ bld.vop3(aco_opcode::v_mul_lo_u32, dst, src_tmp, count);
+ } else {
+ bld.sop2(aco_opcode::s_mul_i32, dst, src_tmp, count);
+ }
+}
+
+bool emit_uniform_reduce(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+ nir_op op = (nir_op)nir_intrinsic_reduction_op(instr);
+ if (op == nir_op_imul || op == nir_op_fmul)
+ return false;
+
+ if (op == nir_op_iadd || op == nir_op_ixor || op == nir_op_fadd) {
+ Builder bld(ctx->program, ctx->block);
+ Definition dst(get_ssa_temp(ctx, &instr->dest.ssa));
+ unsigned bit_size = instr->src[0].ssa->bit_size;
+ if (bit_size > 32)
+ return false;
+
+ Temp thread_count = bld.sop1(
+ Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), Operand(exec, bld.lm));
+
+ emit_addition_uniform_reduce(ctx, op, dst, instr->src[0], thread_count);
+ } else {
+ emit_uniform_subgroup(ctx, instr, get_ssa_temp(ctx, instr->src[0].ssa));
+ }
+
+ return true;
+}
+
+bool emit_uniform_scan(isel_context *ctx, nir_intrinsic_instr *instr)
+{
+ Builder bld(ctx->program, ctx->block);
+ Definition dst(get_ssa_temp(ctx, &instr->dest.ssa));
+ nir_op op = (nir_op)nir_intrinsic_reduction_op(instr);
+ bool inc = instr->intrinsic == nir_intrinsic_inclusive_scan;
+
+ if (op == nir_op_imul || op == nir_op_fmul)
+ return false;
+
+ if (op == nir_op_iadd || op == nir_op_ixor || op == nir_op_fadd) {
+ if (instr->src[0].ssa->bit_size > 32)
+ return false;
+
+ Temp packed_tid;
+ if (inc)
+ packed_tid = emit_mbcnt(ctx, bld.tmp(v1), Operand(exec, bld.lm), Operand(1u));
+ else
+ packed_tid = emit_mbcnt(ctx, bld.tmp(v1), Operand(exec, bld.lm));
+
+ emit_addition_uniform_reduce(ctx, op, dst, instr->src[0], packed_tid);
+ return true;
+ }
+
+ assert(op == nir_op_imin || op == nir_op_umin ||
+ op == nir_op_imax || op == nir_op_umax ||
+ op == nir_op_iand || op == nir_op_ior ||
+ op == nir_op_fmin || op == nir_op_fmax);
+
+ if (inc) {
+ emit_uniform_subgroup(ctx, instr, get_ssa_temp(ctx, instr->src[0].ssa));
+ return true;
+ }
+
+ /* Copy the source and write the reduction operation identity to the first
+ * lane. */
+ Temp lane = bld.sop1(Builder::s_ff1_i32, bld.def(s1), Operand(exec, bld.lm));
+ Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
+ ReduceOp reduce_op = get_reduce_op(op, instr->src[0].ssa->bit_size);
+ if (dst.bytes() == 8) {
+ Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
+ bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
+ uint32_t identity_lo = get_reduction_identity(reduce_op, 0);
+ uint32_t identity_hi = get_reduction_identity(reduce_op, 1);
+
+ lo = bld.writelane(bld.def(v1), bld.copy(bld.hint_m0(s1), Operand(identity_lo)), lane, lo);
+ hi = bld.writelane(bld.def(v1), bld.copy(bld.hint_m0(s1), Operand(identity_hi)), lane, hi);
+ bld.pseudo(aco_opcode::p_create_vector, dst, lo, hi);
+ } else {
+ uint32_t identity = get_reduction_identity(reduce_op, 0);
+ bld.writelane(dst, bld.copy(bld.hint_m0(s1), Operand(identity)), lane, as_vgpr(ctx, src));
+ }
+
+ return true;
+}
+
Pseudo_reduction_instruction *create_reduction_instr(isel_context *ctx, aco_opcode aco_op, ReduceOp op, Definition dst, Temp src)
{
assert(src.bytes() <= 8);
nir_intrinsic_cluster_size(instr) : 0;
cluster_size = util_next_power_of_two(MIN2(cluster_size ? cluster_size : ctx->program->wave_size, ctx->program->wave_size));
- if (!nir_src_is_divergent(instr->src[0]) && (op == nir_op_ior || op == nir_op_iand)) {
- emit_uniform_subgroup(ctx, instr, src);
- } else if (instr->dest.ssa.bit_size == 1) {
+ if (!nir_src_is_divergent(instr->src[0]) &&
+ cluster_size == ctx->program->wave_size && instr->dest.ssa.bit_size != 1) {
+ /* We use divergence analysis to assign the regclass, so check if it's
+ * working as expected */
+ ASSERTED bool expected_divergent = instr->intrinsic == nir_intrinsic_exclusive_scan;
+ if (instr->intrinsic == nir_intrinsic_inclusive_scan)
+ expected_divergent = op == nir_op_iadd || op == nir_op_fadd || op == nir_op_ixor;
+ assert(nir_dest_is_divergent(instr->dest) == expected_divergent);
+
+ if (instr->intrinsic == nir_intrinsic_reduce) {
+ if (emit_uniform_reduce(ctx, instr))
+ break;
+ } else if (emit_uniform_scan(ctx, instr)) {
+ break;
+ }
+ }
+
+ if (instr->dest.ssa.bit_size == 1) {
if (op == nir_op_imul || op == nir_op_umin || op == nir_op_imin)
op = nir_op_iand;
else if (op == nir_op_iadd)
src = emit_extract_vector(ctx, src, 0, RegClass::get(RegType::vgpr, bit_size / 8));
- ReduceOp reduce_op;
- switch (op) {
- #define CASEI(name) case nir_op_##name: reduce_op = (bit_size == 32) ? name##32 : (bit_size == 16) ? name##16 : (bit_size == 8) ? name##8 : name##64; break;
- #define CASEF(name) case nir_op_##name: reduce_op = (bit_size == 32) ? name##32 : (bit_size == 16) ? name##16 : name##64; break;
- CASEI(iadd)
- CASEI(imul)
- CASEI(imin)
- CASEI(umin)
- CASEI(imax)
- CASEI(umax)
- CASEI(iand)
- CASEI(ior)
- CASEI(ixor)
- CASEF(fadd)
- CASEF(fmul)
- CASEF(fmin)
- CASEF(fmax)
- default:
- unreachable("unknown reduction op");
- #undef CASEI
- #undef CASEF
- }
+ ReduceOp reduce_op = get_reduce_op(op, bit_size);
aco_opcode aco_op;
switch (instr->intrinsic) {
case nir_intrinsic_mbcnt_amd: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
- Temp wqm_tmp = emit_mbcnt(ctx, bld.tmp(v1), src);
+ Temp wqm_tmp = emit_mbcnt(ctx, bld.tmp(v1), Operand(src));
emit_wqm(ctx, wqm_tmp, dst);
break;
}
/* Subgroup reduction and exclusive scan on the per-lane boolean. */
Temp sg_reduction = bld.sop1(Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), src_mask);
- Temp sg_excl = emit_mbcnt(ctx, bld.tmp(v1), src_mask);
+ Temp sg_excl = emit_mbcnt(ctx, bld.tmp(v1), Operand(src_mask));
if (ctx->program->workgroup_size <= ctx->program->wave_size)
return std::make_pair(sg_reduction, sg_excl);