nir_split_struct_vars: Support more modes and constant initializers
authorJesse Natalie <jenatali@microsoft.com>
Thu, 18 May 2023 23:16:50 +0000 (16:16 -0700)
committerMarge Bot <emma+marge@anholt.net>
Tue, 13 Jun 2023 00:43:36 +0000 (00:43 +0000)
Idiomatic DXIL has constants contained within global variables rather
than a big blob of data. Doing this allows us to have 16-bit and 64-bit
data as well, where normally bitcasts would be disallowed on variable
GEP chains.

Unfortunately, DXIL validation requires SOA to be turned into AOS,
which means we need to split structs. We want to be able to run this
on nir_var_mem_constant variables which have constant initializers,
so add a bit of logic to handle that case, and relax the mode validation.
There's nothing special about the modes it was set up to handle.

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23173>

src/compiler/nir/nir_split_vars.c
src/compiler/nir/tests/vars_tests.cpp

index 343d9d8..f1800f6 100644 (file)
@@ -75,6 +75,9 @@ struct field {
    unsigned num_fields;
    struct field *fields;
 
+   /* The field currently being recursed */
+   unsigned current_index;
+
    nir_variable *var;
 };
 
@@ -95,6 +98,33 @@ num_array_levels_in_array_of_vector_type(const struct glsl_type *type)
    }
 }
 
+static nir_constant *
+gather_constant_initializers(nir_constant *src,
+                             nir_variable *var,
+                             const struct glsl_type *type,
+                             struct field *field,
+                             struct split_var_state *state)
+{
+   if (!src)
+      return NULL;
+   if (glsl_type_is_array(type)) {
+      const struct glsl_type *element = glsl_get_array_element(type);
+      assert(src->num_elements == glsl_get_length(type));
+      nir_constant *dst = rzalloc(var, nir_constant);
+      dst->num_elements = src->num_elements;
+      dst->elements = rzalloc_array(var, nir_constant *, src->num_elements);
+      for (unsigned i = 0; i < src->num_elements; ++i) {
+         dst->elements[i] = gather_constant_initializers(src->elements[i], var, element, field, state);
+      }
+      return dst;
+   } else if (glsl_type_is_struct(type)) {
+      const struct glsl_type *element = glsl_get_struct_field(type, field->current_index);
+      return gather_constant_initializers(src->elements[field->current_index], var, element, &field->fields[field->current_index], state);
+   } else {
+      return nir_constant_clone(src, var);
+   }
+}
+
 static void
 init_field_for_type(struct field *field, struct field *parent,
                     const struct glsl_type *type,
@@ -121,14 +151,18 @@ init_field_for_type(struct field *field, struct field *parent,
                                          glsl_get_type_name(struct_type),
                                          glsl_get_struct_elem_name(struct_type, i));
          }
+         field->current_index = i;
          init_field_for_type(&field->fields[i], field,
                              glsl_get_struct_field(struct_type, i),
                              field_name, state);
       }
    } else {
       const struct glsl_type *var_type = type;
-      for (struct field *f = field->parent; f; f = f->parent)
+      struct field *root = field;
+      for (struct field *f = field->parent; f; f = f->parent) {
          var_type = glsl_type_wrap_in_arrays(var_type, f->type);
+         root = f;
+      }
 
       nir_variable_mode mode = state->base_var->data.mode;
       if (mode == nir_var_function_temp) {
@@ -137,6 +171,9 @@ init_field_for_type(struct field *field, struct field *parent,
          field->var = nir_variable_create(state->shader, mode, var_type, name);
       }
       field->var->data.ray_query = state->base_var->data.ray_query;
+      field->var->constant_initializer = gather_constant_initializers(state->base_var->constant_initializer,
+                                                                      field->var, state->base_var->type,
+                                                                      root, state);
    }
 }
 
@@ -299,10 +336,8 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
       _mesa_pointer_hash_table_create(mem_ctx);
    struct set *complex_vars = NULL;
 
-   assert((modes & (nir_var_shader_temp | nir_var_ray_hit_attrib | nir_var_function_temp)) == modes);
-
    bool has_global_splits = false;
-   nir_variable_mode global_modes = modes & (nir_var_shader_temp | nir_var_ray_hit_attrib);
+   nir_variable_mode global_modes = modes & ~nir_var_function_temp;
    if (global_modes) {
       has_global_splits = split_var_list_structs(shader, NULL,
                                                  &shader->variables,
index 0a6c6a9..205933a 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "nir.h"
 #include "nir_builder.h"
+#include "nir_deref.h"
 
 namespace {
 
@@ -2408,6 +2409,76 @@ TEST_F(nir_split_vars_test, split_wildcard_copy)
    ASSERT_EQ(count_intrinsics(nir_intrinsic_copy_deref), 4);
 }
 
+TEST_F(nir_split_vars_test, split_nested_struct_const_init)
+{
+   const struct glsl_struct_field inner_struct_types[] = {
+      { glsl_int_type(), "a"},
+      { glsl_int_type(), "b"},
+   };
+   const struct glsl_type *inner_struct = glsl_struct_type(inner_struct_types, 2, "inner", false);
+   const struct glsl_struct_field outer_struct_types[] = {
+      { glsl_array_type(inner_struct, 2, 0), "as" },
+      { glsl_array_type(inner_struct, 2, 0), "bs" },
+   };
+   const struct glsl_type *outer_struct = glsl_struct_type(outer_struct_types, 2, "outer", false);
+   nir_variable *var = create_var(nir_var_mem_constant, glsl_array_type(outer_struct, 2, 0), "consts");
+
+   uint32_t literal_val = 0;
+   auto get_inner_struct_val = [&]() {
+      nir_constant ret = {};
+      ret.values[0].u32 = literal_val++;
+      return ret;
+   };
+   auto get_nested_constant = [&](auto &get_inner_val) {
+      nir_constant *arr = ralloc_array(b->shader, nir_constant, 2);
+      arr[0] = get_inner_val();
+      arr[1] = get_inner_val();
+      nir_constant **arr2 = ralloc_array(b->shader, nir_constant *, 2);
+      arr2[0] = &arr[0];
+      arr2[1] = &arr[1];
+      nir_constant ret = {};
+      ret.num_elements = 2;
+      ret.elements = arr2;
+      return ret;
+   };
+   auto get_inner_struct_constant = [&]() { return get_nested_constant(get_inner_struct_val); };
+   auto get_inner_array_constant = [&]() { return get_nested_constant(get_inner_struct_constant); };
+   auto get_outer_struct_constant = [&]() { return get_nested_constant(get_inner_array_constant); };
+   auto get_outer_array_constant = [&]() { return get_nested_constant(get_outer_struct_constant); };
+   nir_constant var_constant = get_outer_array_constant();
+   var->constant_initializer = &var_constant;
+
+   nir_variable *out = create_int(nir_var_shader_out, "out");
+   nir_store_var(b, out,
+      nir_load_deref(b,
+         nir_build_deref_struct(b,
+            nir_build_deref_array_imm(b,
+               nir_build_deref_struct(b,
+                  nir_build_deref_array_imm(b, nir_build_deref_var(b, var), 1),
+                                      0),
+                                      1),
+                                1)
+                     ),
+                 0xff);
+
+   nir_validate_shader(b->shader, NULL);
+
+   bool progress = nir_split_struct_vars(b->shader, nir_var_mem_constant);
+   EXPECT_TRUE(progress);
+
+   nir_validate_shader(b->shader, NULL);
+   
+   unsigned count = 0;
+   nir_foreach_variable_with_modes(var, b->shader, nir_var_mem_constant) {
+      EXPECT_EQ(glsl_get_aoa_size(var->type), 4);
+      EXPECT_EQ(glsl_get_length(var->type), 2);
+      EXPECT_EQ(glsl_without_array(var->type), glsl_int_type());
+      count++;
+   }
+
+   ASSERT_EQ(count, 4);
+}
+
 TEST_F(nir_remove_dead_variables_test, pointer_initializer_used)
 {
    nir_variable *x = create_int(nir_var_shader_temp, "x");