From 25c7181f1b81711150c695bb86b3826991f61199 Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Tue, 16 May 2023 12:25:12 -0700 Subject: [PATCH] microsoft/compiler: Better and simpler bitcast reduction Using nir_gather_ssa_types works much better. There's 2 differences compared to what I was doing before: 1. Multiple passes to allow data to propagate forward and backward through the whole shader. 2. Allowing a value to have indeterminate types due to having both int and float usages. So this deletes some code and gets better results. Wish I'd known this existed last week. Part-of: --- src/microsoft/compiler/dxil_nir.c | 29 +---- src/microsoft/compiler/dxil_nir.h | 2 +- src/microsoft/compiler/nir_to_dxil.c | 206 ++++++++++------------------------- 3 files changed, 64 insertions(+), 173 deletions(-) diff --git a/src/microsoft/compiler/dxil_nir.c b/src/microsoft/compiler/dxil_nir.c index 7c9e201..54b367a 100644 --- a/src/microsoft/compiler/dxil_nir.c +++ b/src/microsoft/compiler/dxil_nir.c @@ -2399,25 +2399,10 @@ dxil_nir_forward_front_face(nir_shader *nir) } static bool -split_phi_and_const_srcs(nir_builder *b, nir_instr *instr, void *data) +move_consts(nir_builder *b, nir_instr *instr, void *data) { bool progress = false; switch (instr->type) { - case nir_instr_type_phi: { - /* Ensure each phi src is used only as a phi src and is not also a phi dest */ - nir_phi_instr *phi = nir_instr_as_phi(instr); - nir_foreach_phi_src(src, phi) { - assert(src->src.is_ssa); - if (!list_is_singular(&src->src.use_link) || - (src->src.is_ssa && src->src.parent_instr->type == nir_instr_type_phi)) { - b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr); - nir_ssa_def *new_phi_src = nir_mov(b, src->src.ssa); - nir_src_rewrite_ssa(&src->src, new_phi_src); - progress = true; - } - } - return progress; - } case nir_instr_type_load_const: { /* Sink load_const to their uses if there's multiple */ nir_load_const_instr *load_const = nir_instr_as_load_const(instr); @@ -2440,18 +2425,14 @@ split_phi_and_const_srcs(nir_builder *b, nir_instr *instr, void *data) } } -/* If a value is used by a phi and another instruction (e.g. another phi), - * copy the value with a mov and use that as the phi source. If the types - * of the uses are compatible, then the two phi sources will use the same - * DXIL SSA value, but if the types are not, then the mov provides an opportunity - * to insert a bitcast. Similarly, sink all consts so that they have only have - * a single use. The DXIL backend will already de-dupe the constants to the +/* Sink all consts so that they have only have a single use. + * The DXIL backend will already de-dupe the constants to the * same dxil_value if they have the same type, but this allows a single constant * to have different types without bitcasts. */ bool -dxil_nir_split_phis_and_const_srcs(nir_shader *s) +dxil_nir_move_consts(nir_shader *s) { - return nir_shader_instructions_pass(s, split_phi_and_const_srcs, + return nir_shader_instructions_pass(s, move_consts, nir_metadata_block_index | nir_metadata_dominance, NULL); } diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index 5b38a87..a48c146 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -84,7 +84,7 @@ bool dxil_nir_lower_num_subgroups(nir_shader *s); bool dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes); bool dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s); bool dxil_nir_forward_front_face(nir_shader *s); -bool dxil_nir_split_phis_and_const_srcs(nir_shader *s); +bool dxil_nir_move_consts(nir_shader *s); struct dxil_module; bool dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s); diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index 136e088..3bdee60 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -595,6 +595,9 @@ struct ntd_context { struct dxil_func_def *tess_ctrl_patch_constant_func_def; unsigned unnamed_ubo_count; + BITSET_WORD *float_types; + BITSET_WORD *int_types; + const struct dxil_logger *logger; }; @@ -2078,24 +2081,40 @@ bitcast_to_float(struct ntd_context *ctx, unsigned bit_size, return dxil_emit_cast(&ctx->mod, DXIL_CAST_BITCAST, type, value); } +static bool +is_phi_src(nir_ssa_def *ssa) +{ + nir_foreach_use(src, ssa) + if (src->parent_instr->type == nir_instr_type_phi) + return true; + return false; +} + static void store_ssa_def(struct ntd_context *ctx, nir_ssa_def *ssa, unsigned chan, - const struct dxil_value *value, bool is_type_dummy) + const struct dxil_value *value) { assert(ssa->index < ctx->num_defs); assert(chan < ssa->num_components); - /* We pre-defined the dest value type, so bitcast while storing if the base type differs */ - if (ctx->defs[ssa->index].chans[chan]) { - if (is_type_dummy) - return; - const struct dxil_type *expect_type = dxil_value_get_type(ctx->defs[ssa->index].chans[chan]); - const struct dxil_type *value_type = dxil_value_get_type(value); - if (dxil_type_to_nir_type(expect_type) != dxil_type_to_nir_type(value_type)) - value = dxil_emit_cast(&ctx->mod, DXIL_CAST_BITCAST, expect_type, value); - if (expect_type == ctx->mod.int64_type) - ctx->mod.feats.int64_ops = true; - if (expect_type == ctx->mod.float64_type) - ctx->mod.feats.doubles = true; + /* Insert bitcasts for phi srcs in the parent block */ + if (is_phi_src(ssa)) { + /* Prefer ints over floats if it could be both or if we have no type info */ + nir_alu_type expect_type = + BITSET_TEST(ctx->int_types, ssa->index) ? nir_type_int : + (BITSET_TEST(ctx->float_types, ssa->index) ? nir_type_float : + nir_type_int); + assert(ssa->bit_size != 1 || expect_type == nir_type_int); + if (ssa->bit_size != 1 && expect_type != dxil_type_to_nir_type(dxil_value_get_type(value))) + value = dxil_emit_cast(&ctx->mod, DXIL_CAST_BITCAST, + expect_type == nir_type_int ? + dxil_module_get_int_type(&ctx->mod, ssa->bit_size) : + dxil_module_get_float_type(&ctx->mod, ssa->bit_size), value); + if (ssa->bit_size == 64) { + if (expect_type == nir_type_int) + ctx->mod.feats.int64_ops = true; + if (expect_type == nir_type_float) + ctx->mod.feats.doubles = true; + } } ctx->defs[ssa->index].chans[chan] = value; } @@ -2106,7 +2125,7 @@ store_dest_value(struct ntd_context *ctx, nir_dest *dest, unsigned chan, { assert(dest->is_ssa); assert(value); - store_ssa_def(ctx, &dest->ssa, chan, value, false); + store_ssa_def(ctx, &dest->ssa, chan, value); } static void @@ -2437,26 +2456,13 @@ get_overload(nir_alu_type alu_type, unsigned bit_size) } static enum overload_type -get_overload_from_dxil_value(struct dxil_module *mod, const struct dxil_value *value) -{ - const struct dxil_type *type = dxil_value_get_type(value); - if (type == mod->int1_type) return DXIL_I1; - if (type == mod->int16_type) return DXIL_I16; - if (type == mod->int32_type) return DXIL_I32; - if (type == mod->int64_type) return DXIL_I64; - if (type == mod->float16_type) return DXIL_F16; - if (type == mod->float32_type) return DXIL_F32; - if (type == mod->float64_type) return DXIL_F64; - return DXIL_NONE; -} - -static enum overload_type get_ambiguous_overload(struct ntd_context *ctx, nir_intrinsic_instr *intr, enum overload_type default_type) { - const struct dxil_value *dummy_value = ctx->defs[intr->dest.ssa.index].chans[0]; - if (dummy_value) - return get_overload_from_dxil_value(&ctx->mod, dummy_value); + if (BITSET_TEST(ctx->int_types, intr->dest.ssa.index)) + return get_overload(nir_type_int, intr->dest.ssa.bit_size); + if (BITSET_TEST(ctx->float_types, intr->dest.ssa.index)) + return get_overload(nir_type_float, intr->dest.ssa.bit_size); return default_type; } @@ -2464,10 +2470,7 @@ static enum overload_type get_ambiguous_overload_alu_type(struct ntd_context *ctx, nir_intrinsic_instr *intr, nir_alu_type alu_type) { - const struct dxil_value *dummy_value = ctx->defs[intr->dest.ssa.index].chans[0]; - if (dummy_value) - return get_overload_from_dxil_value(&ctx->mod, dummy_value); - return get_overload(alu_type, nir_dest_bit_size(intr->dest)); + return get_ambiguous_overload(ctx, intr, get_overload(alu_type, intr->dest.ssa.bit_size)); } static bool @@ -2801,7 +2804,7 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu) case nir_op_mov: { assert(nir_dest_num_components(alu->dest.dest) == 1); store_ssa_def(ctx, &alu->dest.dest.ssa, 0, get_src_ssa(ctx, - alu->src->src.ssa, alu->src->swizzle[0]), false); + alu->src->src.ssa, alu->src->swizzle[0])); return true; } case nir_op_pack_double_2x32_dxil: @@ -3479,7 +3482,7 @@ emit_store_ssbo(struct ntd_context *ctx, nir_intrinsic_instr *intr) nir_alu_type type = dxil_type_to_nir_type(dxil_value_get_type(get_src_ssa(ctx, intr->src[0].ssa, 0))); - const struct dxil_value *value[4]; + const struct dxil_value *value[4] = { 0 }; for (unsigned i = 0; i < num_components; ++i) { value[i] = get_src(ctx, &intr->src[0], i, type); if (!value[i]) @@ -3495,7 +3498,7 @@ emit_store_ssbo(struct ntd_context *ctx, nir_intrinsic_instr *intr) int32_undef }; - enum overload_type overload = get_overload_from_dxil_value(&ctx->mod, value[0]); + enum overload_type overload = get_overload(type, intr->src[0].ssa->bit_size); if (num_components < 4) { const struct dxil_value *value_undef = dxil_module_get_undef(&ctx->mod, dxil_value_get_type(value[0])); if (!value_undef) @@ -5206,14 +5209,21 @@ get_value_for_const(struct dxil_module *mod, nir_const_value *c, const struct dx unreachable("Invalid type"); } +static const struct dxil_type * +dxil_type_for_const(struct ntd_context *ctx, nir_ssa_def *def) +{ + if (BITSET_TEST(ctx->int_types, def->index) || + !BITSET_TEST(ctx->float_types, def->index)) + return dxil_module_get_int_type(&ctx->mod, def->bit_size); + return dxil_module_get_float_type(&ctx->mod, def->bit_size); +} + static bool emit_load_const(struct ntd_context *ctx, nir_load_const_instr *load_const) { for (uint32_t i = 0; i < load_const->def.num_components; ++i) { - const struct dxil_type *type = ctx->defs[load_const->def.index].chans[i] ? - dxil_value_get_type(ctx->defs[load_const->def.index].chans[i]) : - dxil_module_get_int_type(&ctx->mod, load_const->def.bit_size); - store_ssa_def(ctx, &load_const->def, i, get_value_for_const(&ctx->mod, &load_const->value[i], type), false); + const struct dxil_type *type = dxil_type_for_const(ctx, &load_const->def); + store_ssa_def(ctx, &load_const->def, i, get_value_for_const(&ctx->mod, &load_const->value[i], type)); } return true; } @@ -5875,7 +5885,7 @@ static bool emit_undefined(struct ntd_context *ctx, nir_ssa_undef_instr *undef) { for (unsigned i = 0; i < undef->def.num_components; ++i) - store_ssa_def(ctx, &undef->def, i, dxil_module_get_int32_const(&ctx->mod, 0), false); + store_ssa_def(ctx, &undef->def, i, dxil_module_get_int32_const(&ctx->mod, 0)); return true; } @@ -6053,108 +6063,6 @@ sort_uniforms_by_binding_and_remove_structs(nir_shader *s) exec_list_append(&s->variables, &new_list); } -static const struct dxil_value * -get_dummy_value_for_type(struct ntd_context *ctx, nir_alu_type type, unsigned bit_size) -{ - switch (nir_alu_type_get_base_type(type)) { - case nir_type_int: - case nir_type_uint: - case nir_type_bool: - return dxil_module_get_int_const(&ctx->mod, 0, bit_size); - case nir_type_float: - switch (bit_size) { - case 16: return dxil_module_get_float16_const(&ctx->mod, 0); - case 32: return dxil_module_get_float_const(&ctx->mod, 0); - case 64: return dxil_module_get_double_const(&ctx->mod, 0); - default: unreachable("Invalid float bit size"); - } - default: unreachable("Invalid type"); - } -} - -/* NIR is untyped, but DXIL is typed. ALU ops in nir have implicit - * types associated with them, and some intrinsics do too. Attempt - * to gather type information based on how values are used and propagate - * it backwards through the ops that have ambiguous type information. */ -static void -prepare_types(struct ntd_context *ctx, nir_function_impl *impl) -{ - nir_foreach_block_reverse(block, impl) { - nir_foreach_instr_reverse(instr, block) { - switch (instr->type) { - case nir_instr_type_alu: { - nir_alu_instr *alu = nir_instr_as_alu(instr); - const nir_op_info *info = &nir_op_infos[alu->op]; - for (uint32_t i = 0; i < info->num_inputs; ++i) { - const struct dxil_value *typed_value = NULL; - switch (alu->op) { - case nir_op_mov: - case nir_op_vec2: - case nir_op_vec3: - case nir_op_vec4: - case nir_op_vec8: - case nir_op_vec16: - /* These are declared as uint types but they actually just copy from their sources, - * so back-propagate if we have a type. */ - typed_value = ctx->defs[alu->dest.dest.ssa.index].chans[0]; - break; - case nir_op_bcsel: - /* bcsel dest type matches srcs 1 and 2, but src0 is a bool */ - if (i > 0) { - typed_value = ctx->defs[alu->dest.dest.ssa.index].chans[0]; - if (typed_value) - break; - } - FALLTHROUGH; - default: - typed_value = get_dummy_value_for_type(ctx, info->input_types[i], nir_src_bit_size(alu->src[i].src)); - break; - } - if (typed_value) { - uint32_t num_comps = info->input_sizes[i] ? info->input_sizes[i] : alu->dest.dest.ssa.num_components; - for (uint32_t comp = 0; comp < num_comps; ++comp) - store_ssa_def(ctx, alu->src[i].src.ssa, alu->src[i].swizzle[comp], typed_value, true); - } - } - break; - } - case nir_instr_type_phi: { - nir_phi_instr *ir = nir_instr_as_phi(instr); - unsigned bitsize = nir_dest_bit_size(ir->dest); - /* Attempt to propagate a type from dest to sources */ - const struct dxil_value *typed_value = ctx->defs[ir->dest.ssa.index].chans[0]; - if (!typed_value) - typed_value = dxil_module_get_int_const(&ctx->mod, 0, bitsize); - nir_foreach_phi_src(src, ir) { - for (uint32_t i = 0; i < ir->dest.ssa.num_components; ++i) - store_ssa_def(ctx, src->src.ssa, i, typed_value, true); - } - break; - } - case nir_instr_type_intrinsic: { - nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); - switch (intr->intrinsic) { - case nir_intrinsic_store_output: - case nir_intrinsic_store_per_vertex_output: { - /* Use the I/O info to set a type for the stored value source */ - const struct dxil_value *typed_value = get_dummy_value_for_type(ctx, nir_intrinsic_src_type(intr), - nir_src_bit_size(intr->src[0])); - for (uint32_t i = 0; i < intr->num_components; ++i) - store_ssa_def(ctx, intr->src[0].ssa, i, typed_value, true); - break; - } - default: - break; - } - break; - } - default: - break; - } - } - } -} - static bool emit_cbvs(struct ntd_context *ctx) { @@ -6273,7 +6181,9 @@ emit_function(struct ntd_context *ctx, nir_function *func) ctx->tess_ctrl_patch_constant_func_def = func_def; ctx->defs = rzalloc_array(ctx->ralloc_ctx, struct dxil_def, impl->ssa_alloc); - if (!ctx->defs) + ctx->float_types = rzalloc_array(ctx->ralloc_ctx, BITSET_WORD, BITSET_WORDS(impl->ssa_alloc)); + ctx->int_types = rzalloc_array(ctx->ralloc_ctx, BITSET_WORD, BITSET_WORDS(impl->ssa_alloc)); + if (!ctx->defs || !ctx->float_types || !ctx->int_types) return false; ctx->num_defs = impl->ssa_alloc; @@ -6281,7 +6191,7 @@ emit_function(struct ntd_context *ctx, nir_function *func) if (!ctx->phis) return false; - prepare_types(ctx, impl); + nir_gather_ssa_types(impl, ctx->float_types, ctx->int_types); if (!emit_scratch(ctx)) return false; @@ -6863,7 +6773,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts, NIR_PASS_V(s, nir_opt_dce); /* This needs to be after any copy prop is done to prevent these movs from being erased */ - NIR_PASS_V(s, dxil_nir_split_phis_and_const_srcs); + NIR_PASS_V(s, dxil_nir_move_consts); NIR_PASS_V(s, nir_opt_dce); if (debug_dxil & DXIL_DEBUG_VERBOSE) -- 2.7.4