From b69ec8bde3397551d901a76764af71483967480d Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Tue, 21 Feb 2023 11:45:09 +0100 Subject: [PATCH] radv/rt: Refactor rq_load lowering This just gets rid of all the bcsel emissions. Part-of: --- src/amd/vulkan/radv_nir_lower_ray_queries.c | 92 +++++++++++++---------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/src/amd/vulkan/radv_nir_lower_ray_queries.c b/src/amd/vulkan/radv_nir_lower_ray_queries.c index 831ab4a..1bf4fde 100644 --- a/src/amd/vulkan/radv_nir_lower_ray_queries.c +++ b/src/amd/vulkan/radv_nir_lower_ray_queries.c @@ -428,32 +428,31 @@ lower_rq_initialize(nir_builder *b, nir_ssa_def *index, nir_intrinsic_instr *ins } static nir_ssa_def * -lower_rq_load(nir_builder *b, nir_ssa_def *index, struct ray_query_vars *vars, - nir_ssa_def *committed, nir_ray_query_value value, unsigned column) +lower_rq_load(nir_builder *b, nir_ssa_def *index, nir_intrinsic_instr *instr, + struct ray_query_vars *vars) { + assert(nir_src_is_const(instr->src[1])); + bool closest = nir_src_as_bool(instr->src[1]); + struct ray_query_intersection_vars *intersection = closest ? &vars->closest : &vars->candidate; + + uint32_t column = nir_intrinsic_column(instr); + + nir_ray_query_value value = nir_intrinsic_ray_query_value(instr); switch (value) { case nir_ray_query_value_flags: return rq_load_var(b, index, vars->flags); case nir_ray_query_value_intersection_barycentrics: - return nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.barycentrics), - rq_load_var(b, index, vars->candidate.barycentrics)); + return rq_load_var(b, index, intersection->barycentrics); case nir_ray_query_value_intersection_candidate_aabb_opaque: return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque), nir_ieq_imm(b, rq_load_var(b, index, vars->candidate.intersection_type), intersection_type_aabb)); case nir_ray_query_value_intersection_front_face: - return nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.frontface), - rq_load_var(b, index, vars->candidate.frontface)); + return rq_load_var(b, index, intersection->frontface); case nir_ray_query_value_intersection_geometry_index: - return nir_iand_imm( - b, - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.geometry_id_and_flags), - rq_load_var(b, index, vars->candidate.geometry_id_and_flags)), - 0xFFFFFF); + return nir_iand_imm(b, rq_load_var(b, index, intersection->geometry_id_and_flags), 0xFFFFFF); case nir_ray_query_value_intersection_instance_custom_index: { - nir_ssa_def *instance_node_addr = - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.instance_addr), - rq_load_var(b, index, vars->candidate.instance_addr)); + nir_ssa_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr); return nir_iand_imm(b, nir_build_load_global(b, 1, 32, nir_iadd_imm(b, instance_node_addr, @@ -462,39 +461,27 @@ lower_rq_load(nir_builder *b, nir_ssa_def *index, struct ray_query_vars *vars, 0xFFFFFF); } case nir_ray_query_value_intersection_instance_id: { - nir_ssa_def *instance_node_addr = - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.instance_addr), - rq_load_var(b, index, vars->candidate.instance_addr)); + nir_ssa_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr); return nir_build_load_global( b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id))); } case nir_ray_query_value_intersection_instance_sbt_index: - return nir_iand_imm( - b, - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.sbt_offset_and_flags), - rq_load_var(b, index, vars->candidate.sbt_offset_and_flags)), - 0xFFFFFF); + return nir_iand_imm(b, rq_load_var(b, index, intersection->sbt_offset_and_flags), 0xFFFFFF); case nir_ray_query_value_intersection_object_ray_direction: { - nir_ssa_def *instance_node_addr = - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.instance_addr), - rq_load_var(b, index, vars->candidate.instance_addr)); + nir_ssa_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr); nir_ssa_def *wto_matrix[3]; nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix); return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->direction), wto_matrix, false); } case nir_ray_query_value_intersection_object_ray_origin: { - nir_ssa_def *instance_node_addr = - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.instance_addr), - rq_load_var(b, index, vars->candidate.instance_addr)); + nir_ssa_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr); nir_ssa_def *wto_matrix[3]; nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix); return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->origin), wto_matrix, true); } case nir_ray_query_value_intersection_object_to_world: { - nir_ssa_def *instance_node_addr = - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.instance_addr), - rq_load_var(b, index, vars->candidate.instance_addr)); + nir_ssa_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr); nir_ssa_def *rows[3]; for (unsigned r = 0; r < 3; ++r) rows[r] = nir_build_load_global( @@ -506,19 +493,18 @@ lower_rq_load(nir_builder *b, nir_ssa_def *index, struct ray_query_vars *vars, nir_channel(b, rows[2], column)); } case nir_ray_query_value_intersection_primitive_index: - return nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.primitive_id), - rq_load_var(b, index, vars->candidate.primitive_id)); + return rq_load_var(b, index, intersection->primitive_id); case nir_ray_query_value_intersection_t: - return nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.t), - rq_load_var(b, index, vars->candidate.t)); - case nir_ray_query_value_intersection_type: - return nir_bcsel( - b, committed, rq_load_var(b, index, vars->closest.intersection_type), - nir_iadd_imm(b, rq_load_var(b, index, vars->candidate.intersection_type), -1)); + return rq_load_var(b, index, intersection->t); + case nir_ray_query_value_intersection_type: { + nir_ssa_def *intersection_type = rq_load_var(b, index, intersection->intersection_type); + if (!closest) + intersection_type = nir_iadd_imm(b, intersection_type, -1); + + return intersection_type; + } case nir_ray_query_value_intersection_world_to_object: { - nir_ssa_def *instance_node_addr = - nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.instance_addr), - rq_load_var(b, index, vars->candidate.instance_addr)); + nir_ssa_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr); nir_ssa_def *wto_matrix[3]; nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix); @@ -691,15 +677,19 @@ lower_rq_terminate(nir_builder *b, nir_ssa_def *index, nir_intrinsic_instr *inst bool radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device) { - bool contains_ray_query = false; + bool progress = false; struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL); + /* Run constant folding to collapse expressions that are required to be constant by the spec. */ + NIR_PASS(progress, shader, nir_opt_constant_folding); + nir_foreach_variable_in_list (var, &shader->variables) { if (!var->data.ray_query) continue; lower_ray_query(shader, var, query_ht, device->physical_device->max_shared_size); - contains_ray_query = true; + + progress = true; } nir_foreach_function (function, shader) { @@ -714,11 +704,9 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device continue; lower_ray_query(shader, var, query_ht, device->physical_device->max_shared_size); - contains_ray_query = true; - } - if (!contains_ray_query) - continue; + progress = true; + } nir_foreach_block (block, function->impl) { nir_foreach_instr_safe (instr, block) { @@ -760,9 +748,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device lower_rq_initialize(&builder, index, intrinsic, vars); break; case nir_intrinsic_rq_load: - new_dest = lower_rq_load(&builder, index, vars, intrinsic->src[1].ssa, - nir_intrinsic_ray_query_value(intrinsic), - nir_intrinsic_column(intrinsic)); + new_dest = lower_rq_load(&builder, index, intrinsic, vars); break; case nir_intrinsic_rq_proceed: new_dest = lower_rq_proceed(&builder, index, vars, device); @@ -779,6 +765,8 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device nir_instr_remove(instr); nir_instr_free(instr); + + progress = true; } } @@ -787,5 +775,5 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device ralloc_free(query_ht); - return contains_ray_query; + return progress; } -- 2.7.4