zink: infer types from load_const instrs to avoid more bitcasts
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Wed, 10 May 2023 12:51:10 +0000 (08:51 -0400)
committerMarge Bot <emma+marge@anholt.net>
Tue, 23 May 2023 01:02:56 +0000 (01:02 +0000)
this walks to uses list for the ssa def to infer a type from one of the
uses to reduce the need to bitcast

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22934>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 32c3017..2544584 100644 (file)
@@ -164,6 +164,129 @@ get_nir_alu_type(const struct glsl_type *type)
    return nir_alu_type_get_base_type(nir_get_nir_type_for_glsl_base_type(glsl_get_base_type(glsl_without_array_or_matrix(type))));
 }
 
+static nir_alu_type
+infer_nir_alu_type_from_uses_ssa(nir_ssa_def *ssa, unsigned depth);
+static nir_alu_type
+infer_nir_alu_type_from_uses_reg(nir_register *reg, unsigned depth);
+
+static nir_alu_type
+infer_nir_alu_type_from_use(nir_src *src, unsigned depth)
+{
+   nir_instr *instr = src->parent_instr;
+   nir_alu_type atype = nir_type_invalid;
+   switch (instr->type) {
+   case nir_instr_type_alu: {
+      nir_alu_instr *alu = nir_instr_as_alu(instr);
+      if (alu->op == nir_op_bcsel) {
+         if (nir_srcs_equal(alu->src[0].src, *src)) {
+            /* special case: the first src in bcsel is always bool */
+            return nir_type_bool;
+         }
+      }
+      /* ignore typeless ops */
+      if (alu_op_is_typeless(alu->op)) {
+         if (alu->dest.dest.is_ssa) {
+            atype = infer_nir_alu_type_from_uses_ssa(&alu->dest.dest.ssa, depth);
+         } else {
+            /* avoid infinite recursion */
+            if (depth > 10)
+               break;
+            if (!src->is_ssa && src->reg.reg == alu->dest.dest.reg.reg)
+               break;
+            atype = infer_nir_alu_type_from_uses_reg(alu->dest.dest.reg.reg, ++depth);
+         }
+         break;
+      }
+      for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
+         if (!nir_srcs_equal(alu->src[i].src, *src))
+            continue;
+         atype = nir_op_infos[alu->op].input_types[i];
+         break;
+      }
+      break;
+   }
+   case nir_instr_type_tex: {
+      nir_tex_instr *tex = nir_instr_as_tex(instr);
+      for (unsigned i = 0; i < tex->num_srcs; i++) {
+         if (!nir_srcs_equal(tex->src[i].src, *src))
+            continue;
+         switch (tex->src[i].src_type) {
+         case nir_tex_src_coord:
+         case nir_tex_src_lod:
+            if (tex->op == nir_texop_txf ||
+               tex->op == nir_texop_txf_ms ||
+               tex->op == nir_texop_txs)
+               atype = nir_type_int;
+            else
+               atype = nir_type_float;
+            break;
+         case nir_tex_src_projector:
+         case nir_tex_src_bias:
+         case nir_tex_src_min_lod:
+         case nir_tex_src_comparator:
+         case nir_tex_src_ddx:
+         case nir_tex_src_ddy:
+            atype = nir_type_float;
+            break;
+         case nir_tex_src_offset:
+         case nir_tex_src_ms_index:
+         case nir_tex_src_texture_offset:
+         case nir_tex_src_sampler_offset:
+         case nir_tex_src_sampler_handle:
+         case nir_tex_src_texture_handle:
+            atype = nir_type_int;
+            break;
+         default:
+            break;
+         }
+         break;
+      }
+      break;
+   }
+   case nir_instr_type_intrinsic: {
+      if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_deref) {
+         atype = get_nir_alu_type(nir_instr_as_deref(instr)->type);
+      } else if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_store_deref) {
+         atype = get_nir_alu_type(nir_src_as_deref(nir_instr_as_intrinsic(instr)->src[0])->type);
+      }
+      break;
+   }
+   default:
+      break;
+   }
+   return nir_alu_type_get_base_type(atype);
+}
+
+static nir_alu_type
+infer_nir_alu_type_from_uses_ssa(nir_ssa_def *ssa, unsigned depth)
+{
+   nir_alu_type atype = nir_type_invalid;
+   /* try to infer a type: if it's wrong then whatever, but at least we tried */
+   nir_foreach_use_including_if(src, ssa) {
+      if (src->is_if)
+         return nir_type_bool;
+      atype = infer_nir_alu_type_from_use(src, depth);
+      if (atype)
+         break;
+   }
+   return atype ? atype : nir_type_uint;
+}
+
+static nir_alu_type
+infer_nir_alu_type_from_uses_reg(nir_register *reg, unsigned depth)
+{
+   nir_alu_type atype = nir_type_invalid;
+   /* try to infer a type: if it's wrong then whatever, but at least we tried */
+   nir_foreach_use_including_if(src, reg) {
+      if (src->is_if)
+         return nir_type_bool;
+      atype = infer_nir_alu_type_from_use(src, depth);
+      if (atype)
+         break;
+   }
+   return atype ? atype : nir_type_uint;
+}
+
 static SpvId
 get_bvec_type(struct ntv_context *ctx, int num_components)
 {
@@ -2488,22 +2611,37 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
    SpvId components[NIR_MAX_VEC_COMPONENTS];
    nir_alu_type atype;
    if (bit_size == 1) {
+      atype = nir_type_bool;
       for (int i = 0; i < num_components; i++)
          components[i] = spirv_builder_const_bool(&ctx->builder,
                                                   load_const->value[i].b);
-      atype = nir_type_bool;
    } else {
+      atype = infer_nir_alu_type_from_uses_ssa(&load_const->def, 0);
       for (int i = 0; i < num_components; i++) {
-         uint64_t tmp = nir_const_value_as_uint(load_const->value[i],
-                                                bit_size);
-         components[i] = emit_uint_const(ctx, bit_size, tmp);
+         switch (atype) {
+         case nir_type_uint: {
+            uint64_t tmp = nir_const_value_as_uint(load_const->value[i], bit_size);
+            components[i] = emit_uint_const(ctx, bit_size, tmp);
+            break;
+         }
+         case nir_type_int: {
+            int64_t tmp = nir_const_value_as_int(load_const->value[i], bit_size);
+            components[i] = emit_int_const(ctx, bit_size, tmp);
+            break;
+         }
+         case nir_type_float: {
+            double tmp = nir_const_value_as_float(load_const->value[i], bit_size);
+            components[i] = emit_float_const(ctx, bit_size, tmp);
+            break;
+         }
+         default:
+            unreachable("this shouldn't happen!");
+         }
       }
-      atype = nir_type_uint;
    }
 
    if (num_components > 1) {
-      SpvId type = get_vec_from_bit_size(ctx, bit_size,
-                                         num_components);
+      SpvId type = get_alu_type(ctx, atype, num_components, bit_size);
       SpvId value = spirv_builder_const_composite(&ctx->builder,
                                                   type, components,
                                                   num_components);