nir: Add a pass to lower mediump temps and shared mem.
authorEmma Anholt <emma@anholt.net>
Tue, 23 Aug 2022 21:54:37 +0000 (14:54 -0700)
committerMarge Bot <emma+marge@anholt.net>
Thu, 1 Sep 2022 22:39:39 +0000 (22:39 +0000)
SPIRV and GLSL are reasonable at converting ALU ops to mediump, but
variable storage would be wrapped in a 2f32/2mp on store/load, and if
nir_vars_to_ssa doesn't make that storage go away then you'd have extra
conversions.  For compute shader shared mem, you'd waste memory too.

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Matt Turner <mattst88@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18259>

src/compiler/nir/nir.h
src/compiler/nir/nir_lower_mediump.c
src/compiler/nir_types.cpp
src/compiler/nir_types.h

index fe0322f..02db779 100644 (file)
@@ -5345,6 +5345,7 @@ bool nir_lower_doubles(nir_shader *shader, const nir_shader *softfp64,
 bool nir_lower_pack(nir_shader *shader);
 
 bool nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes);
+bool nir_lower_mediump_vars(nir_shader *nir, nir_variable_mode modes);
 bool nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
                           uint64_t varying_mask, bool use_16bit_slots);
 bool nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
index bdbb934..2e8572c 100644 (file)
@@ -381,6 +381,189 @@ nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
 }
 
 static bool
+is_mediump_or_lowp(unsigned precision)
+{
+   return precision == GLSL_PRECISION_LOW || precision == GLSL_PRECISION_MEDIUM;
+}
+
+static bool
+try_lower_mediump_var(nir_variable *var, nir_variable_mode modes)
+{
+   if (!(var->data.mode & modes) || !is_mediump_or_lowp(var->data.precision))
+      return false;
+
+   const struct glsl_type *new_type = glsl_type_to_16bit(var->type);
+   if (var->type == new_type)
+      return false;
+
+   var->type = new_type;
+   return true;
+}
+
+static bool
+nir_lower_mediump_vars_impl(nir_function_impl *impl, nir_variable_mode modes,
+                            bool any_lowered)
+{
+   bool progress = false;
+
+   if (modes & nir_var_function_temp) {
+      nir_foreach_function_temp_variable(var, impl) {
+         any_lowered = try_lower_mediump_var(var, modes) || any_lowered;
+      }
+   }
+   if (!any_lowered)
+      return false;
+
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         switch (instr->type) {
+         case nir_instr_type_deref: {
+            nir_deref_instr *deref = nir_instr_as_deref(instr);
+
+            if (deref->modes & modes) {
+               switch (deref->deref_type) {
+               case nir_deref_type_var:
+                  deref->type = deref->var->type;
+                  break;
+               case nir_deref_type_array:
+               case nir_deref_type_array_wildcard:
+                  deref->type = glsl_get_array_element(nir_deref_instr_parent(deref)->type);
+                  break;
+               case nir_deref_type_struct:
+                  deref->type = glsl_get_struct_field(nir_deref_instr_parent(deref)->type, deref->strct.index);
+                  break;
+               default:
+                  nir_print_instr(instr, stderr);
+                  unreachable("unsupported deref type");
+               }
+            }
+
+            break;
+         }
+
+         case nir_instr_type_intrinsic: {
+            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+            switch (intrin->intrinsic) {
+            case nir_intrinsic_load_deref: {
+
+               if (intrin->dest.ssa.bit_size != 32)
+                  break;
+
+               nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
+               nir_ssa_def *replace = NULL;
+
+               b.cursor = nir_after_instr(&intrin->instr);
+               switch (glsl_get_base_type(deref->type)) {
+               case GLSL_TYPE_FLOAT16:
+                  replace = nir_f2f32(&b, &intrin->dest.ssa);
+                  break;
+               case GLSL_TYPE_INT16:
+                  replace = nir_i2i32(&b, &intrin->dest.ssa);
+                  break;
+               case GLSL_TYPE_UINT16:
+                  replace = nir_u2u32(&b, &intrin->dest.ssa);
+                  break;
+               default:
+                  break;
+               }
+               if (!replace)
+                  break;
+
+               intrin->dest.ssa.bit_size = 16;
+               nir_ssa_def_rewrite_uses_after(&intrin->dest.ssa,
+                                              replace,
+                                              replace->parent_instr);
+               progress = true;
+               break;
+            }
+
+            case nir_intrinsic_store_deref: {
+               nir_ssa_def *data = intrin->src[1].ssa;
+               if (data->bit_size != 32)
+                  break;
+
+               b.cursor = nir_before_instr(&intrin->instr);
+               nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
+               nir_ssa_def *replace = NULL;
+               switch (glsl_get_base_type(deref->type)) {
+               case GLSL_TYPE_FLOAT16:
+                  replace = nir_f2fmp(&b, data);
+                  break;
+               case GLSL_TYPE_INT16:
+               case GLSL_TYPE_UINT16:
+                  replace = nir_i2imp(&b, data);
+                  break;
+               default:
+                  break;
+               }
+               if (!replace)
+                  break;
+
+               nir_instr_rewrite_src(&intrin->instr, &intrin->src[1],
+                                     nir_src_for_ssa(replace));
+               progress = true;
+               break;
+            }
+
+            case nir_intrinsic_copy_deref: {
+               nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
+               nir_deref_instr *src = nir_src_as_deref(intrin->src[0]);
+               /* If we convert once side of a copy and not the other, that
+                * would be very bad.
+                */
+               if (nir_deref_mode_may_be(dst, modes) ||
+                   nir_deref_mode_may_be(src, modes)) {
+                  assert(nir_deref_mode_must_be(dst, modes));
+                  assert(nir_deref_mode_must_be(src, modes));
+               }
+               break;
+            }
+
+            default:
+               break;
+            }
+            break;
+         }
+
+         default:
+            break;
+         }
+      }
+   }
+
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
+   }
+
+   return progress;
+}
+
+bool
+nir_lower_mediump_vars(nir_shader *shader, nir_variable_mode modes)
+{
+   bool progress = false;
+
+   if (modes & ~nir_var_function_temp) {
+      nir_foreach_variable_in_shader(var, shader) {
+         progress = try_lower_mediump_var(var, modes) || progress;
+      }
+   }
+
+   nir_foreach_function(function, shader) {
+      if (function->impl && nir_lower_mediump_vars_impl(function->impl, modes, progress))
+         progress = true;
+   }
+
+   return progress;
+}
+
+static bool
 is_n_to_m_conversion(nir_instr *instr, unsigned n, nir_op m)
 {
    if (instr->type != nir_instr_type_alu)
index 1f86a38..1391665 100644 (file)
@@ -778,6 +778,31 @@ glsl_uint16_type(const struct glsl_type *type)
    return type->get_uint16_type();
 }
 
+const struct glsl_type *
+glsl_type_to_16bit(const struct glsl_type *old_type)
+{
+   if (glsl_type_is_array(old_type)) {
+      return glsl_array_type(glsl_type_to_16bit(glsl_get_array_element(old_type)),
+                             glsl_get_length(old_type),
+                             glsl_get_explicit_stride(old_type));
+   }
+
+   if (glsl_type_is_vector_or_scalar(old_type)) {
+      switch (glsl_get_base_type(old_type)) {
+      case GLSL_TYPE_FLOAT:
+         return glsl_float16_type(old_type);
+      case GLSL_TYPE_UINT:
+         return glsl_uint16_type(old_type);
+      case GLSL_TYPE_INT:
+         return glsl_int16_type(old_type);
+      default:
+         break;
+      }
+   }
+
+   return old_type;
+}
+
 static void
 glsl_size_align_handle_array_and_structs(const struct glsl_type *type,
                                          glsl_type_size_align_func size_align,
index ddbb472..eefd8c3 100644 (file)
@@ -241,6 +241,7 @@ const struct glsl_type *glsl_channel_type(const struct glsl_type *type);
 const struct glsl_type *glsl_float16_type(const struct glsl_type *type);
 const struct glsl_type *glsl_int16_type(const struct glsl_type *type);
 const struct glsl_type *glsl_uint16_type(const struct glsl_type *type);
+const struct glsl_type *glsl_type_to_16bit(const struct glsl_type *old_type);
 
 void glsl_get_natural_size_align_bytes(const struct glsl_type *type,
                                        unsigned *size, unsigned *align);