microsoft/compiler: Better and simpler bitcast reduction
authorJesse Natalie <jenatali@microsoft.com>
Tue, 16 May 2023 19:25:12 +0000 (12:25 -0700)
committerMarge Bot <emma+marge@anholt.net>
Fri, 19 May 2023 22:19:38 +0000 (22:19 +0000)
Using nir_gather_ssa_types works much better. There's 2 differences
compared to what I was doing before:
1. Multiple passes to allow data to propagate forward and backward
   through the whole shader.
2. Allowing a value to have indeterminate types due to having both
   int and float usages.

So this deletes some code and gets better results. Wish I'd known
this existed last week.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23062>

src/microsoft/compiler/dxil_nir.c
src/microsoft/compiler/dxil_nir.h
src/microsoft/compiler/nir_to_dxil.c

index 7c9e201..54b367a 100644 (file)
@@ -2399,25 +2399,10 @@ dxil_nir_forward_front_face(nir_shader *nir)
 }
 
 static bool
-split_phi_and_const_srcs(nir_builder *b, nir_instr *instr, void *data)
+move_consts(nir_builder *b, nir_instr *instr, void *data)
 {
    bool progress = false;
    switch (instr->type) {
-   case nir_instr_type_phi: {
-      /* Ensure each phi src is used only as a phi src and is not also a phi dest */
-      nir_phi_instr *phi = nir_instr_as_phi(instr);
-      nir_foreach_phi_src(src, phi) {
-         assert(src->src.is_ssa);
-         if (!list_is_singular(&src->src.use_link) ||
-             (src->src.is_ssa && src->src.parent_instr->type == nir_instr_type_phi)) {
-            b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
-            nir_ssa_def *new_phi_src = nir_mov(b, src->src.ssa);
-            nir_src_rewrite_ssa(&src->src, new_phi_src);
-            progress = true;
-         }
-      }
-      return progress;
-   }
    case nir_instr_type_load_const: {
       /* Sink load_const to their uses if there's multiple */
       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
@@ -2440,18 +2425,14 @@ split_phi_and_const_srcs(nir_builder *b, nir_instr *instr, void *data)
    }
 }
 
-/* If a value is used by a phi and another instruction (e.g. another phi),
- * copy the value with a mov and use that as the phi source. If the types
- * of the uses are compatible, then the two phi sources will use the same
- * DXIL SSA value, but if the types are not, then the mov provides an opportunity
- * to insert a bitcast. Similarly, sink all consts so that they have only have
- * a single use. The DXIL backend will already de-dupe the constants to the
+/* Sink all consts so that they have only have a single use.
+ * The DXIL backend will already de-dupe the constants to the
  * same dxil_value if they have the same type, but this allows a single constant
  * to have different types without bitcasts. */
 bool
-dxil_nir_split_phis_and_const_srcs(nir_shader *s)
+dxil_nir_move_consts(nir_shader *s)
 {
-   return nir_shader_instructions_pass(s, split_phi_and_const_srcs,
+   return nir_shader_instructions_pass(s, move_consts,
                                        nir_metadata_block_index | nir_metadata_dominance,
                                        NULL);
 }
index 5b38a87..a48c146 100644 (file)
@@ -84,7 +84,7 @@ bool dxil_nir_lower_num_subgroups(nir_shader *s);
 bool dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode modes);
 bool dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s);
 bool dxil_nir_forward_front_face(nir_shader *s);
-bool dxil_nir_split_phis_and_const_srcs(nir_shader *s);
+bool dxil_nir_move_consts(nir_shader *s);
 
 struct dxil_module;
 bool dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s);
index 136e088..3bdee60 100644 (file)
@@ -595,6 +595,9 @@ struct ntd_context {
    struct dxil_func_def *tess_ctrl_patch_constant_func_def;
    unsigned unnamed_ubo_count;
 
+   BITSET_WORD *float_types;
+   BITSET_WORD *int_types;
+
    const struct dxil_logger *logger;
 };
 
@@ -2078,24 +2081,40 @@ bitcast_to_float(struct ntd_context *ctx, unsigned bit_size,
    return dxil_emit_cast(&ctx->mod, DXIL_CAST_BITCAST, type, value);
 }
 
+static bool
+is_phi_src(nir_ssa_def *ssa)
+{
+   nir_foreach_use(src, ssa)
+      if (src->parent_instr->type == nir_instr_type_phi)
+         return true;
+   return false;
+}
+
 static void
 store_ssa_def(struct ntd_context *ctx, nir_ssa_def *ssa, unsigned chan,
-              const struct dxil_value *value, bool is_type_dummy)
+              const struct dxil_value *value)
 {
    assert(ssa->index < ctx->num_defs);
    assert(chan < ssa->num_components);
-   /* We pre-defined the dest value type, so bitcast while storing if the base type differs */
-   if (ctx->defs[ssa->index].chans[chan]) {
-      if (is_type_dummy)
-         return;
-      const struct dxil_type *expect_type = dxil_value_get_type(ctx->defs[ssa->index].chans[chan]);
-      const struct dxil_type *value_type = dxil_value_get_type(value);
-      if (dxil_type_to_nir_type(expect_type) != dxil_type_to_nir_type(value_type))
-         value = dxil_emit_cast(&ctx->mod, DXIL_CAST_BITCAST, expect_type, value);
-      if (expect_type == ctx->mod.int64_type)
-         ctx->mod.feats.int64_ops = true;
-      if (expect_type == ctx->mod.float64_type)
-         ctx->mod.feats.doubles = true;
+   /* Insert bitcasts for phi srcs in the parent block */
+   if (is_phi_src(ssa)) {
+      /* Prefer ints over floats if it could be both or if we have no type info */
+      nir_alu_type expect_type =
+         BITSET_TEST(ctx->int_types, ssa->index) ? nir_type_int :
+         (BITSET_TEST(ctx->float_types, ssa->index) ? nir_type_float :
+          nir_type_int);
+      assert(ssa->bit_size != 1 || expect_type == nir_type_int);
+      if (ssa->bit_size != 1 && expect_type != dxil_type_to_nir_type(dxil_value_get_type(value)))
+         value = dxil_emit_cast(&ctx->mod, DXIL_CAST_BITCAST,
+                                expect_type == nir_type_int ?
+                                 dxil_module_get_int_type(&ctx->mod, ssa->bit_size) :
+                                 dxil_module_get_float_type(&ctx->mod, ssa->bit_size), value);
+      if (ssa->bit_size == 64) {
+         if (expect_type == nir_type_int)
+            ctx->mod.feats.int64_ops = true;
+         if (expect_type == nir_type_float)
+            ctx->mod.feats.doubles = true;
+      }
    }
    ctx->defs[ssa->index].chans[chan] = value;
 }
@@ -2106,7 +2125,7 @@ store_dest_value(struct ntd_context *ctx, nir_dest *dest, unsigned chan,
 {
    assert(dest->is_ssa);
    assert(value);
-   store_ssa_def(ctx, &dest->ssa, chan, value, false);
+   store_ssa_def(ctx, &dest->ssa, chan, value);
 }
 
 static void
@@ -2437,26 +2456,13 @@ get_overload(nir_alu_type alu_type, unsigned bit_size)
 }
 
 static enum overload_type
-get_overload_from_dxil_value(struct dxil_module *mod, const struct dxil_value *value)
-{
-   const struct dxil_type *type = dxil_value_get_type(value);
-   if (type == mod->int1_type) return DXIL_I1;
-   if (type == mod->int16_type) return DXIL_I16;
-   if (type == mod->int32_type) return DXIL_I32;
-   if (type == mod->int64_type) return DXIL_I64;
-   if (type == mod->float16_type) return DXIL_F16;
-   if (type == mod->float32_type) return DXIL_F32;
-   if (type == mod->float64_type) return DXIL_F64;
-   return DXIL_NONE;
-}
-
-static enum overload_type
 get_ambiguous_overload(struct ntd_context *ctx, nir_intrinsic_instr *intr,
                        enum overload_type default_type)
 {
-   const struct dxil_value *dummy_value = ctx->defs[intr->dest.ssa.index].chans[0];
-   if (dummy_value)
-      return get_overload_from_dxil_value(&ctx->mod, dummy_value);
+   if (BITSET_TEST(ctx->int_types, intr->dest.ssa.index))
+      return get_overload(nir_type_int, intr->dest.ssa.bit_size);
+   if (BITSET_TEST(ctx->float_types, intr->dest.ssa.index))
+      return get_overload(nir_type_float, intr->dest.ssa.bit_size);
    return default_type;
 }
 
@@ -2464,10 +2470,7 @@ static enum overload_type
 get_ambiguous_overload_alu_type(struct ntd_context *ctx, nir_intrinsic_instr *intr,
                                 nir_alu_type alu_type)
 {
-   const struct dxil_value *dummy_value = ctx->defs[intr->dest.ssa.index].chans[0];
-   if (dummy_value)
-      return get_overload_from_dxil_value(&ctx->mod, dummy_value);
-   return get_overload(alu_type, nir_dest_bit_size(intr->dest));
+   return get_ambiguous_overload(ctx, intr, get_overload(alu_type, intr->dest.ssa.bit_size));
 }
 
 static bool
@@ -2801,7 +2804,7 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
    case nir_op_mov: {
          assert(nir_dest_num_components(alu->dest.dest) == 1);
          store_ssa_def(ctx, &alu->dest.dest.ssa, 0, get_src_ssa(ctx,
-                        alu->src->src.ssa, alu->src->swizzle[0]), false);
+                        alu->src->src.ssa, alu->src->swizzle[0]));
          return true;
       }
    case nir_op_pack_double_2x32_dxil:
@@ -3479,7 +3482,7 @@ emit_store_ssbo(struct ntd_context *ctx, nir_intrinsic_instr *intr)
 
    nir_alu_type type =
       dxil_type_to_nir_type(dxil_value_get_type(get_src_ssa(ctx, intr->src[0].ssa, 0)));
-   const struct dxil_value *value[4];
+   const struct dxil_value *value[4] = { 0 };
    for (unsigned i = 0; i < num_components; ++i) {
       value[i] = get_src(ctx, &intr->src[0], i, type);
       if (!value[i])
@@ -3495,7 +3498,7 @@ emit_store_ssbo(struct ntd_context *ctx, nir_intrinsic_instr *intr)
       int32_undef
    };
 
-   enum overload_type overload = get_overload_from_dxil_value(&ctx->mod, value[0]);
+   enum overload_type overload = get_overload(type, intr->src[0].ssa->bit_size);
    if (num_components < 4) {
       const struct dxil_value *value_undef = dxil_module_get_undef(&ctx->mod, dxil_value_get_type(value[0]));
       if (!value_undef)
@@ -5206,14 +5209,21 @@ get_value_for_const(struct dxil_module *mod, nir_const_value *c, const struct dx
    unreachable("Invalid type");
 }
 
+static const struct dxil_type *
+dxil_type_for_const(struct ntd_context *ctx, nir_ssa_def *def)
+{
+   if (BITSET_TEST(ctx->int_types, def->index) ||
+       !BITSET_TEST(ctx->float_types, def->index))
+      return dxil_module_get_int_type(&ctx->mod, def->bit_size);
+   return dxil_module_get_float_type(&ctx->mod, def->bit_size);
+}
+
 static bool
 emit_load_const(struct ntd_context *ctx, nir_load_const_instr *load_const)
 {
    for (uint32_t i = 0; i < load_const->def.num_components; ++i) {
-      const struct dxil_type *type = ctx->defs[load_const->def.index].chans[i] ?
-         dxil_value_get_type(ctx->defs[load_const->def.index].chans[i]) :
-         dxil_module_get_int_type(&ctx->mod, load_const->def.bit_size);
-      store_ssa_def(ctx, &load_const->def, i, get_value_for_const(&ctx->mod, &load_const->value[i], type), false);
+      const struct dxil_type *type = dxil_type_for_const(ctx, &load_const->def);
+      store_ssa_def(ctx, &load_const->def, i, get_value_for_const(&ctx->mod, &load_const->value[i], type));
    }
    return true;
 }
@@ -5875,7 +5885,7 @@ static bool
 emit_undefined(struct ntd_context *ctx, nir_ssa_undef_instr *undef)
 {
    for (unsigned i = 0; i < undef->def.num_components; ++i)
-      store_ssa_def(ctx, &undef->def, i, dxil_module_get_int32_const(&ctx->mod, 0), false);
+      store_ssa_def(ctx, &undef->def, i, dxil_module_get_int32_const(&ctx->mod, 0));
    return true;
 }
 
@@ -6053,108 +6063,6 @@ sort_uniforms_by_binding_and_remove_structs(nir_shader *s)
    exec_list_append(&s->variables, &new_list);
 }
 
-static const struct dxil_value *
-get_dummy_value_for_type(struct ntd_context *ctx, nir_alu_type type, unsigned bit_size)
-{
-   switch (nir_alu_type_get_base_type(type)) {
-   case nir_type_int:
-   case nir_type_uint:
-   case nir_type_bool:
-      return dxil_module_get_int_const(&ctx->mod, 0, bit_size);
-   case nir_type_float:
-      switch (bit_size) {
-      case 16: return dxil_module_get_float16_const(&ctx->mod, 0);
-      case 32: return dxil_module_get_float_const(&ctx->mod, 0);
-      case 64: return dxil_module_get_double_const(&ctx->mod, 0);
-      default: unreachable("Invalid float bit size");
-      }
-   default: unreachable("Invalid type");
-   }
-}
-
-/* NIR is untyped, but DXIL is typed. ALU ops in nir have implicit
- * types associated with them, and some intrinsics do too. Attempt
- * to gather type information based on how values are used and propagate
- * it backwards through the ops that have ambiguous type information. */
-static void
-prepare_types(struct ntd_context *ctx, nir_function_impl *impl)
-{
-   nir_foreach_block_reverse(block, impl) {
-      nir_foreach_instr_reverse(instr, block) {
-         switch (instr->type) {
-         case nir_instr_type_alu: {
-            nir_alu_instr *alu = nir_instr_as_alu(instr);
-            const nir_op_info *info = &nir_op_infos[alu->op];
-            for (uint32_t i = 0; i < info->num_inputs; ++i) {
-               const struct dxil_value *typed_value = NULL;
-               switch (alu->op) {
-               case nir_op_mov:
-               case nir_op_vec2:
-               case nir_op_vec3:
-               case nir_op_vec4:
-               case nir_op_vec8:
-               case nir_op_vec16:
-                  /* These are declared as uint types but they actually just copy from their sources,
-                   * so back-propagate if we have a type. */
-                  typed_value = ctx->defs[alu->dest.dest.ssa.index].chans[0];
-                  break;
-               case nir_op_bcsel:
-                  /* bcsel dest type matches srcs 1 and 2, but src0 is a bool */
-                  if (i > 0) {
-                     typed_value = ctx->defs[alu->dest.dest.ssa.index].chans[0];
-                     if (typed_value)
-                        break;
-                  }
-                  FALLTHROUGH;
-               default:
-                  typed_value = get_dummy_value_for_type(ctx, info->input_types[i], nir_src_bit_size(alu->src[i].src));
-                  break;
-               }
-               if (typed_value) {
-                  uint32_t num_comps = info->input_sizes[i] ? info->input_sizes[i] : alu->dest.dest.ssa.num_components;
-                  for (uint32_t comp = 0; comp < num_comps; ++comp)
-                     store_ssa_def(ctx, alu->src[i].src.ssa, alu->src[i].swizzle[comp], typed_value, true);
-               }
-            }
-            break;
-         }
-         case nir_instr_type_phi: {
-            nir_phi_instr *ir = nir_instr_as_phi(instr);
-            unsigned bitsize = nir_dest_bit_size(ir->dest);
-            /* Attempt to propagate a type from dest to sources */
-            const struct dxil_value *typed_value = ctx->defs[ir->dest.ssa.index].chans[0];
-            if (!typed_value)
-               typed_value = dxil_module_get_int_const(&ctx->mod, 0, bitsize);
-            nir_foreach_phi_src(src, ir) {
-               for (uint32_t i = 0; i < ir->dest.ssa.num_components; ++i)
-                  store_ssa_def(ctx, src->src.ssa, i, typed_value, true);
-            }
-            break;
-         }
-         case nir_instr_type_intrinsic: {
-            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
-            switch (intr->intrinsic) {
-            case nir_intrinsic_store_output:
-            case nir_intrinsic_store_per_vertex_output: {
-               /* Use the I/O info to set a type for the stored value source */
-               const struct dxil_value *typed_value = get_dummy_value_for_type(ctx, nir_intrinsic_src_type(intr),
-                                                                               nir_src_bit_size(intr->src[0]));
-               for (uint32_t i = 0; i < intr->num_components; ++i)
-                  store_ssa_def(ctx, intr->src[0].ssa, i, typed_value, true);
-               break;
-            }
-            default:
-               break;
-            }
-            break;
-         }
-         default:
-            break;
-         }
-      }
-   }
-}
-
 static bool
 emit_cbvs(struct ntd_context *ctx)
 {
@@ -6273,7 +6181,9 @@ emit_function(struct ntd_context *ctx, nir_function *func)
       ctx->tess_ctrl_patch_constant_func_def = func_def;
 
    ctx->defs = rzalloc_array(ctx->ralloc_ctx, struct dxil_def, impl->ssa_alloc);
-   if (!ctx->defs)
+   ctx->float_types = rzalloc_array(ctx->ralloc_ctx, BITSET_WORD, BITSET_WORDS(impl->ssa_alloc));
+   ctx->int_types = rzalloc_array(ctx->ralloc_ctx, BITSET_WORD, BITSET_WORDS(impl->ssa_alloc));
+   if (!ctx->defs || !ctx->float_types || !ctx->int_types)
       return false;
    ctx->num_defs = impl->ssa_alloc;
 
@@ -6281,7 +6191,7 @@ emit_function(struct ntd_context *ctx, nir_function *func)
    if (!ctx->phis)
       return false;
 
-   prepare_types(ctx, impl);
+   nir_gather_ssa_types(impl, ctx->float_types, ctx->int_types);
 
    if (!emit_scratch(ctx))
       return false;
@@ -6863,7 +6773,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
    NIR_PASS_V(s, nir_opt_dce);
 
    /* This needs to be after any copy prop is done to prevent these movs from being erased */
-   NIR_PASS_V(s, dxil_nir_split_phis_and_const_srcs);
+   NIR_PASS_V(s, dxil_nir_move_consts);
    NIR_PASS_V(s, nir_opt_dce);
 
    if (debug_dxil & DXIL_DEBUG_VERBOSE)