label_dpp16 = 1ull << 35,
label_dpp8 = 1ull << 36,
label_f2f32 = 1ull << 37,
+ label_f2f16 = 1ull << 38,
};
static constexpr uint64_t instr_usedef_labels =
label_uniform_bitwise | label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 |
label_dpp8 | label_f2f32;
static constexpr uint64_t instr_mod_labels =
- label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert;
+ label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels;
static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f |
bool is_clamp() { return label & label_clamp; }
+ void set_f2f16(Instruction* conv)
+ {
+ add_label(label_f2f16);
+ instr = conv;
+ }
+
+ bool is_f2f16() { return label & label_f2f16; }
+
void set_undefined() { add_label(label_undefined); }
bool is_undefined() { return label & label_undefined; }
ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
break;
}
+ case aco_opcode::v_cvt_f16_f32: {
+ if (instr->operands[0].isTemp())
+ ctx.info[instr->operands[0].tempId()].set_f2f16(instr.get());
+ break;
+ }
case aco_opcode::v_cvt_f32_f16: {
if (instr->operands[0].isTemp())
ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
}
instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
- ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert;
+ ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
ctx.uses[def_info.instr->definitions[0].tempId()]--;
return true;
vop3p->clamp = instr->isVOP3() && instr->vop3().clamp;
instr = std::move(vop3p);
- ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_mul;
+ ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
ctx.info[instr->definitions[0].tempId()].instr = instr.get();
}
+bool
+combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
+{
+ ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
+ if (!def_info.is_f2f16())
+ return false;
+ Instruction* conv = def_info.instr;
+
+ if (!can_use_mad_mix(ctx, instr) || ctx.uses[instr->definitions[0].tempId()] != 1)
+ return false;
+
+ if (!ctx.uses[conv->definitions[0].tempId()])
+ return false;
+
+ if (conv->usesModifiers())
+ return false;
+
+ if (!instr->isVOP3P())
+ to_mad_mix(ctx, instr);
+
+ instr->opcode = aco_opcode::v_fma_mixlo_f16;
+ instr->definitions[0].swapTemp(conv->definitions[0]);
+ if (conv->definitions[0].isPrecise())
+ instr->definitions[0].setPrecise(true);
+ ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
+ ctx.uses[conv->definitions[0].tempId()]--;
+
+ return true;
+}
+
void
combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
if (can_apply_sgprs(ctx, instr))
apply_sgprs(ctx, instr);
combine_mad_mix(ctx, instr);
- while (apply_omod_clamp(ctx, instr))
+ while (apply_omod_clamp(ctx, instr) | combine_output_conversion(ctx, instr))
;
apply_insert(ctx, instr);
}