From 208c31b25fb963ea40977c708837fd8464285255 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Wed, 10 May 2023 08:51:10 -0400 Subject: [PATCH] zink: infer types from load_const instrs to avoid more bitcasts this walks to uses list for the ssa def to infer a type from one of the uses to reduce the need to bitcast Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 152 ++++++++++++++++++++- 1 file changed, 145 insertions(+), 7 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 32c3017..2544584 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -164,6 +164,129 @@ get_nir_alu_type(const struct glsl_type *type) return nir_alu_type_get_base_type(nir_get_nir_type_for_glsl_base_type(glsl_get_base_type(glsl_without_array_or_matrix(type)))); } +static nir_alu_type +infer_nir_alu_type_from_uses_ssa(nir_ssa_def *ssa, unsigned depth); +static nir_alu_type +infer_nir_alu_type_from_uses_reg(nir_register *reg, unsigned depth); + +static nir_alu_type +infer_nir_alu_type_from_use(nir_src *src, unsigned depth) +{ + nir_instr *instr = src->parent_instr; + nir_alu_type atype = nir_type_invalid; + switch (instr->type) { + case nir_instr_type_alu: { + nir_alu_instr *alu = nir_instr_as_alu(instr); + if (alu->op == nir_op_bcsel) { + if (nir_srcs_equal(alu->src[0].src, *src)) { + /* special case: the first src in bcsel is always bool */ + return nir_type_bool; + } + } + /* ignore typeless ops */ + if (alu_op_is_typeless(alu->op)) { + if (alu->dest.dest.is_ssa) { + atype = infer_nir_alu_type_from_uses_ssa(&alu->dest.dest.ssa, depth); + } else { + /* avoid infinite recursion */ + if (depth > 10) + break; + if (!src->is_ssa && src->reg.reg == alu->dest.dest.reg.reg) + break; + atype = infer_nir_alu_type_from_uses_reg(alu->dest.dest.reg.reg, ++depth); + } + break; + } + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { + if (!nir_srcs_equal(alu->src[i].src, *src)) + continue; + atype = nir_op_infos[alu->op].input_types[i]; + break; + } + break; + } + case nir_instr_type_tex: { + nir_tex_instr *tex = nir_instr_as_tex(instr); + for (unsigned i = 0; i < tex->num_srcs; i++) { + if (!nir_srcs_equal(tex->src[i].src, *src)) + continue; + switch (tex->src[i].src_type) { + case nir_tex_src_coord: + case nir_tex_src_lod: + if (tex->op == nir_texop_txf || + tex->op == nir_texop_txf_ms || + tex->op == nir_texop_txs) + atype = nir_type_int; + else + atype = nir_type_float; + break; + case nir_tex_src_projector: + case nir_tex_src_bias: + case nir_tex_src_min_lod: + case nir_tex_src_comparator: + case nir_tex_src_ddx: + case nir_tex_src_ddy: + atype = nir_type_float; + break; + case nir_tex_src_offset: + case nir_tex_src_ms_index: + case nir_tex_src_texture_offset: + case nir_tex_src_sampler_offset: + case nir_tex_src_sampler_handle: + case nir_tex_src_texture_handle: + atype = nir_type_int; + break; + default: + break; + } + break; + } + break; + } + case nir_instr_type_intrinsic: { + if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_deref) { + atype = get_nir_alu_type(nir_instr_as_deref(instr)->type); + } else if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_store_deref) { + atype = get_nir_alu_type(nir_src_as_deref(nir_instr_as_intrinsic(instr)->src[0])->type); + } + break; + } + default: + break; + } + return nir_alu_type_get_base_type(atype); +} + +static nir_alu_type +infer_nir_alu_type_from_uses_ssa(nir_ssa_def *ssa, unsigned depth) +{ + nir_alu_type atype = nir_type_invalid; + /* try to infer a type: if it's wrong then whatever, but at least we tried */ + nir_foreach_use_including_if(src, ssa) { + if (src->is_if) + return nir_type_bool; + atype = infer_nir_alu_type_from_use(src, depth); + if (atype) + break; + } + return atype ? atype : nir_type_uint; +} + +static nir_alu_type +infer_nir_alu_type_from_uses_reg(nir_register *reg, unsigned depth) +{ + nir_alu_type atype = nir_type_invalid; + /* try to infer a type: if it's wrong then whatever, but at least we tried */ + nir_foreach_use_including_if(src, reg) { + if (src->is_if) + return nir_type_bool; + atype = infer_nir_alu_type_from_use(src, depth); + if (atype) + break; + } + return atype ? atype : nir_type_uint; +} + static SpvId get_bvec_type(struct ntv_context *ctx, int num_components) { @@ -2488,22 +2611,37 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) SpvId components[NIR_MAX_VEC_COMPONENTS]; nir_alu_type atype; if (bit_size == 1) { + atype = nir_type_bool; for (int i = 0; i < num_components; i++) components[i] = spirv_builder_const_bool(&ctx->builder, load_const->value[i].b); - atype = nir_type_bool; } else { + atype = infer_nir_alu_type_from_uses_ssa(&load_const->def, 0); for (int i = 0; i < num_components; i++) { - uint64_t tmp = nir_const_value_as_uint(load_const->value[i], - bit_size); - components[i] = emit_uint_const(ctx, bit_size, tmp); + switch (atype) { + case nir_type_uint: { + uint64_t tmp = nir_const_value_as_uint(load_const->value[i], bit_size); + components[i] = emit_uint_const(ctx, bit_size, tmp); + break; + } + case nir_type_int: { + int64_t tmp = nir_const_value_as_int(load_const->value[i], bit_size); + components[i] = emit_int_const(ctx, bit_size, tmp); + break; + } + case nir_type_float: { + double tmp = nir_const_value_as_float(load_const->value[i], bit_size); + components[i] = emit_float_const(ctx, bit_size, tmp); + break; + } + default: + unreachable("this shouldn't happen!"); + } } - atype = nir_type_uint; } if (num_components > 1) { - SpvId type = get_vec_from_bit_size(ctx, bit_size, - num_components); + SpvId type = get_alu_type(ctx, atype, num_components, bit_size); SpvId value = spirv_builder_const_composite(&ctx->builder, type, components, num_components); -- 2.7.4