radv: Always create ray query vars as shader temp
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Wed, 20 Jul 2022 17:19:16 +0000 (19:19 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 11 Nov 2022 19:00:17 +0000 (19:00 +0000)
Avoid the whole "is this function or shader scope" code and fix some
memory leaks in the process.

Signed-off-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17663>

src/amd/vulkan/radv_nir_lower_ray_queries.c

index e7fe2d5..e6786b0 100644 (file)
@@ -42,21 +42,17 @@ typedef struct {
 } rq_variable;
 
 static rq_variable *
-rq_variable_create(nir_shader *shader, nir_function_impl *impl, unsigned array_length,
+rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length,
                    const struct glsl_type *type, const char *name)
 {
-   rq_variable *result = ralloc(shader ? (void *)shader : (void *)impl, rq_variable);
+   rq_variable *result = ralloc(ctx, rq_variable);
    result->array_length = array_length;
 
    const struct glsl_type *variable_type = type;
    if (array_length != 1)
       variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type));
 
-   if (shader) {
-      result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
-   } else {
-      result->variable = nir_local_variable_create(impl, variable_type, name);
-   }
+   result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
 
    return result;
 }
@@ -183,42 +179,42 @@ struct ray_query_vars {
 };
 
 #define VAR_NAME(name)                                                                             \
-   strcat(strcpy(ralloc_size(impl, strlen(base_name) + strlen(name) + 1), base_name), name)
+   strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name)
 
 static struct ray_query_traversal_vars
-init_ray_query_traversal_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_length,
+init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length,
                               const char *base_name)
 {
    struct ray_query_traversal_vars result;
 
    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
 
-   result.origin = rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_origin"));
+   result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin"));
    result.direction =
-      rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_direction"));
+      rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction"));
 
-   result.inv_dir = rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_inv_dir"));
+   result.inv_dir = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_inv_dir"));
    result.bvh_base =
-      rq_variable_create(shader, impl, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
+      rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
    result.stack =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_stack"));
+      rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack"));
    result.top_stack =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_top_stack"));
+      rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_top_stack"));
    result.stack_base =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_stack_base"));
+      rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_base"));
    result.current_node =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
+      rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
    result.previous_node =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_previous_node"));
-   result.instance_top_node = rq_variable_create(shader, impl, array_length, glsl_uint_type(),
+      rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_previous_node"));
+   result.instance_top_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
                                                  VAR_NAME("_instance_top_node"));
-   result.instance_bottom_node = rq_variable_create(shader, impl, array_length, glsl_uint_type(),
+   result.instance_bottom_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
                                                     VAR_NAME("_instance_bottom_node"));
    return result;
 }
 
 static struct ray_query_intersection_vars
-init_ray_query_intersection_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_length,
+init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length,
                                  const char *base_name)
 {
    struct ray_query_intersection_vars result;
@@ -226,54 +222,53 @@ init_ray_query_intersection_vars(nir_shader *shader, nir_function_impl *impl, un
    const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2);
 
    result.primitive_id =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
-   result.geometry_id_and_flags = rq_variable_create(shader, impl, array_length, glsl_uint_type(),
+      rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
+   result.geometry_id_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
                                                      VAR_NAME("_geometry_id_and_flags"));
-   result.instance_addr = rq_variable_create(shader, impl, array_length, glsl_uint64_t_type(),
+   result.instance_addr = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(),
                                              VAR_NAME("_instance_addr"));
-   result.intersection_type = rq_variable_create(shader, impl, array_length, glsl_uint_type(),
+   result.intersection_type = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
                                                  VAR_NAME("_intersection_type"));
    result.opaque =
-      rq_variable_create(shader, impl, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
+      rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
    result.frontface =
-      rq_variable_create(shader, impl, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
-   result.sbt_offset_and_flags = rq_variable_create(shader, impl, array_length, glsl_uint_type(),
+      rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
+   result.sbt_offset_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(),
                                                     VAR_NAME("_sbt_offset_and_flags"));
    result.barycentrics =
-      rq_variable_create(shader, impl, array_length, vec2_type, VAR_NAME("_barycentrics"));
-   result.t = rq_variable_create(shader, impl, array_length, glsl_float_type(), VAR_NAME("_t"));
+      rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics"));
+   result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t"));
 
    return result;
 }
 
 static void
-init_ray_query_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_length,
-                    struct ray_query_vars *dst, const char *base_name)
+init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst,
+                    const char *base_name)
 {
+   void *ctx = dst;
    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
 
-   dst->root_bvh_base = rq_variable_create(shader, impl, array_length, glsl_uint64_t_type(),
+   dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(),
                                            VAR_NAME("_root_bvh_base"));
-   dst->flags =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_flags"));
+   dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags"));
    dst->cull_mask =
-      rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
-   dst->origin = rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_origin"));
-   dst->tmin = rq_variable_create(shader, impl, array_length, glsl_float_type(), VAR_NAME("_tmin"));
+      rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
+   dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin"));
+   dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin"));
    dst->direction =
-      rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_direction"));
+      rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction"));
 
    dst->incomplete =
-      rq_variable_create(shader, impl, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
+      rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
 
-   dst->closest =
-      init_ray_query_intersection_vars(shader, impl, array_length, VAR_NAME("_closest"));
+   dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest"));
    dst->candidate =
-      init_ray_query_intersection_vars(shader, impl, array_length, VAR_NAME("_candidate"));
+      init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate"));
 
-   dst->trav = init_ray_query_traversal_vars(shader, impl, array_length, VAR_NAME("_top"));
+   dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top"));
 
-   dst->stack = rq_variable_create(shader, impl, array_length,
+   dst->stack = rq_variable_create(dst, shader, array_length,
                                    glsl_array_type(glsl_uint_type(), MAX_STACK_ENTRY_COUNT,
                                                    glsl_get_explicit_stride(glsl_uint_type())),
                                    VAR_NAME("_stack"));
@@ -282,17 +277,15 @@ init_ray_query_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_
 #undef VAR_NAME
 
 static void
-lower_ray_query(nir_shader *shader, nir_function_impl *impl, nir_variable *ray_query,
-                struct hash_table *ht)
+lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht)
 {
-   struct ray_query_vars *vars = ralloc(impl, struct ray_query_vars);
+   struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
 
    unsigned array_length = 1;
    if (glsl_type_is_array(ray_query->type))
       array_length = glsl_get_length(ray_query->type);
 
-   init_ray_query_vars(shader, impl, array_length, vars,
-                       ray_query->name == NULL ? "" : ray_query->name);
+   init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name);
 
    _mesa_hash_table_insert(ht, ray_query, vars);
 }
@@ -702,7 +695,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device
       if (!var->data.ray_query)
          continue;
 
-      lower_ray_query(shader, NULL, var, query_ht);
+      lower_ray_query(shader, var, query_ht);
       contains_ray_query = true;
    }
 
@@ -717,7 +710,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device
          if (!var->data.ray_query)
             continue;
 
-         lower_ray_query(NULL, function->impl, var, query_ht);
+         lower_ray_query(shader, var, query_ht);
          contains_ray_query = true;
       }