From 3a50dcb3f478488cb588948a78325c6c5cdca139 Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Wed, 20 Jul 2022 19:19:16 +0200 Subject: [PATCH] radv: Always create ray query vars as shader temp Avoid the whole "is this function or shader scope" code and fix some memory leaks in the process. Signed-off-by: Konstantin Seurer Reviewed-by: Bas Nieuwenhuizen Part-of: --- src/amd/vulkan/radv_nir_lower_ray_queries.c | 97 +++++++++++++---------------- 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/src/amd/vulkan/radv_nir_lower_ray_queries.c b/src/amd/vulkan/radv_nir_lower_ray_queries.c index e7fe2d5..e6786b0 100644 --- a/src/amd/vulkan/radv_nir_lower_ray_queries.c +++ b/src/amd/vulkan/radv_nir_lower_ray_queries.c @@ -42,21 +42,17 @@ typedef struct { } rq_variable; static rq_variable * -rq_variable_create(nir_shader *shader, nir_function_impl *impl, unsigned array_length, +rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length, const struct glsl_type *type, const char *name) { - rq_variable *result = ralloc(shader ? (void *)shader : (void *)impl, rq_variable); + rq_variable *result = ralloc(ctx, rq_variable); result->array_length = array_length; const struct glsl_type *variable_type = type; if (array_length != 1) variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type)); - if (shader) { - result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name); - } else { - result->variable = nir_local_variable_create(impl, variable_type, name); - } + result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name); return result; } @@ -183,42 +179,42 @@ struct ray_query_vars { }; #define VAR_NAME(name) \ - strcat(strcpy(ralloc_size(impl, strlen(base_name) + strlen(name) + 1), base_name), name) + strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name) static struct ray_query_traversal_vars -init_ray_query_traversal_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_length, +init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name) { struct ray_query_traversal_vars result; const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); - result.origin = rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_origin")); + result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin")); result.direction = - rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_direction")); + rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction")); - result.inv_dir = rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_inv_dir")); + result.inv_dir = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_inv_dir")); result.bvh_base = - rq_variable_create(shader, impl, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base")); + rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base")); result.stack = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_stack")); + rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack")); result.top_stack = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_top_stack")); + rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_top_stack")); result.stack_base = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_stack_base")); + rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_base")); result.current_node = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_current_node")); + rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node")); result.previous_node = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_previous_node")); - result.instance_top_node = rq_variable_create(shader, impl, array_length, glsl_uint_type(), + rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_previous_node")); + result.instance_top_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_top_node")); - result.instance_bottom_node = rq_variable_create(shader, impl, array_length, glsl_uint_type(), + result.instance_bottom_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_bottom_node")); return result; } static struct ray_query_intersection_vars -init_ray_query_intersection_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_length, +init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name) { struct ray_query_intersection_vars result; @@ -226,54 +222,53 @@ init_ray_query_intersection_vars(nir_shader *shader, nir_function_impl *impl, un const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2); result.primitive_id = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_primitive_id")); - result.geometry_id_and_flags = rq_variable_create(shader, impl, array_length, glsl_uint_type(), + rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id")); + result.geometry_id_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_geometry_id_and_flags")); - result.instance_addr = rq_variable_create(shader, impl, array_length, glsl_uint64_t_type(), + result.instance_addr = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_instance_addr")); - result.intersection_type = rq_variable_create(shader, impl, array_length, glsl_uint_type(), + result.intersection_type = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_intersection_type")); result.opaque = - rq_variable_create(shader, impl, array_length, glsl_bool_type(), VAR_NAME("_opaque")); + rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque")); result.frontface = - rq_variable_create(shader, impl, array_length, glsl_bool_type(), VAR_NAME("_frontface")); - result.sbt_offset_and_flags = rq_variable_create(shader, impl, array_length, glsl_uint_type(), + rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface")); + result.sbt_offset_and_flags = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_sbt_offset_and_flags")); result.barycentrics = - rq_variable_create(shader, impl, array_length, vec2_type, VAR_NAME("_barycentrics")); - result.t = rq_variable_create(shader, impl, array_length, glsl_float_type(), VAR_NAME("_t")); + rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics")); + result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t")); return result; } static void -init_ray_query_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_length, - struct ray_query_vars *dst, const char *base_name) +init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst, + const char *base_name) { + void *ctx = dst; const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); - dst->root_bvh_base = rq_variable_create(shader, impl, array_length, glsl_uint64_t_type(), + dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_root_bvh_base")); - dst->flags = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_flags")); + dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags")); dst->cull_mask = - rq_variable_create(shader, impl, array_length, glsl_uint_type(), VAR_NAME("_cull_mask")); - dst->origin = rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_origin")); - dst->tmin = rq_variable_create(shader, impl, array_length, glsl_float_type(), VAR_NAME("_tmin")); + rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask")); + dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin")); + dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin")); dst->direction = - rq_variable_create(shader, impl, array_length, vec3_type, VAR_NAME("_direction")); + rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction")); dst->incomplete = - rq_variable_create(shader, impl, array_length, glsl_bool_type(), VAR_NAME("_incomplete")); + rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete")); - dst->closest = - init_ray_query_intersection_vars(shader, impl, array_length, VAR_NAME("_closest")); + dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest")); dst->candidate = - init_ray_query_intersection_vars(shader, impl, array_length, VAR_NAME("_candidate")); + init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate")); - dst->trav = init_ray_query_traversal_vars(shader, impl, array_length, VAR_NAME("_top")); + dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top")); - dst->stack = rq_variable_create(shader, impl, array_length, + dst->stack = rq_variable_create(dst, shader, array_length, glsl_array_type(glsl_uint_type(), MAX_STACK_ENTRY_COUNT, glsl_get_explicit_stride(glsl_uint_type())), VAR_NAME("_stack")); @@ -282,17 +277,15 @@ init_ray_query_vars(nir_shader *shader, nir_function_impl *impl, unsigned array_ #undef VAR_NAME static void -lower_ray_query(nir_shader *shader, nir_function_impl *impl, nir_variable *ray_query, - struct hash_table *ht) +lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht) { - struct ray_query_vars *vars = ralloc(impl, struct ray_query_vars); + struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars); unsigned array_length = 1; if (glsl_type_is_array(ray_query->type)) array_length = glsl_get_length(ray_query->type); - init_ray_query_vars(shader, impl, array_length, vars, - ray_query->name == NULL ? "" : ray_query->name); + init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name); _mesa_hash_table_insert(ht, ray_query, vars); } @@ -702,7 +695,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device if (!var->data.ray_query) continue; - lower_ray_query(shader, NULL, var, query_ht); + lower_ray_query(shader, var, query_ht); contains_ray_query = true; } @@ -717,7 +710,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device if (!var->data.ray_query) continue; - lower_ray_query(NULL, function->impl, var, query_ht); + lower_ray_query(shader, var, query_ht); contains_ray_query = true; } -- 2.7.4