From: Daniel Schürmann Date: Mon, 6 Mar 2023 19:03:49 +0000 (+0100) Subject: radv/rt: implement radv_nir_lower_rt_abi to lower RT shaders for separate compilation X-Git-Tag: upstream/23.3.3~7383 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=99466ca18515e7f7c4576f16578e66d3dd6d6df8;p=platform%2Fupstream%2Fmesa.git radv/rt: implement radv_nir_lower_rt_abi to lower RT shaders for separate compilation Part-of: --- diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index e97ba2e..2fd4e02 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -26,6 +26,7 @@ #include "bvh/bvh.h" #include "meta/radv_meta.h" +#include "ac_nir.h" #include "radv_private.h" #include "radv_rt_common.h" #include "radv_shader.h" @@ -88,6 +89,8 @@ struct rt_variables { * the correct resume index upon returning. */ nir_variable *idx; + nir_variable *shader_va; + nir_variable *traversal_addr; /* scratch offset of the argument area relative to stack_ptr */ nir_variable *arg; @@ -129,6 +132,10 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR .create_info = create_info, }; vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx"); + vars.shader_va = + nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_va"); + vars.traversal_addr = + nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr"); vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg"); vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr"); vars.shader_record_ptr = @@ -177,6 +184,8 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, src->create_info = dst->create_info; _mesa_hash_table_insert(var_remap, src->idx, dst->idx); + _mesa_hash_table_insert(var_remap, src->shader_va, dst->shader_va); + _mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr); _mesa_hash_table_insert(var_remap, src->arg, dst->arg); _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr); _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr); @@ -1702,3 +1711,109 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf return b.shader; } + +void +radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + const struct radv_shader_args *args, const struct radv_pipeline_key *key, + uint32_t *stack_size) +{ + nir_builder b; + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + nir_builder_init(&b, impl); + + struct rt_variables vars = create_rt_variables(shader, pCreateInfo); + lower_rt_instructions(shader, &vars, 0); + + if (stack_size) { + vars.stack_size = MAX2(vars.stack_size, shader->scratch_size); + *stack_size = MAX2(*stack_size, vars.stack_size); + } + shader->scratch_size = 0; + + NIR_PASS(_, shader, nir_lower_returns); + + nir_cf_list list; + nir_cf_extract(&list, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body)); + + /* initialize variables */ + b.cursor = nir_before_cf_list(&impl->body); + + nir_ssa_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader); + nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1); + nir_ssa_def *shader_va = ac_nir_load_arg(&b, &args->ac, args->ac.rt.next_shader); + shader_va = nir_pack_64_2x32(&b, shader_va); + nir_store_var(&b, vars.shader_va, shader_va, 1); + nir_store_var(&b, vars.stack_ptr, + ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1); + nir_ssa_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record); + nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1); + nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1); + + nir_ssa_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct); + nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1); + nir_store_var(&b, vars.cull_mask_and_flags, + ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1); + nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1); + nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1); + nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 1); + nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7); + nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1); + nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction), + 0x7); + nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1); + + nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), + 1); + nir_ssa_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr); + nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1); + nir_store_var(&b, vars.geometry_id_and_flags, + ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1); + nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1); + + /* guard the shader, so that only the correct invocations execute it */ + nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc); + shader_pc = nir_pack_64_2x32(&b, shader_pc); + nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va); + nir_if *shader_guard = nir_push_if(&b, cond); + shader_guard->control = nir_selection_control_divergent_always_taken; + nir_cf_reinsert(&list, b.cursor); + nir_pop_if(&b, shader_guard); + + /* select next shader */ + // TODO: use a priority-based selection + b.cursor = nir_after_cf_list(&impl->body); + shader_va = nir_load_var(&b, vars.shader_va); + nir_ssa_def *next = nir_read_first_invocation(&b, shader_va); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next); + + /* store back all variables to registers */ + ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base, + nir_load_var(&b, vars.stack_ptr)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.next_shader, nir_load_var(&b, vars.shader_va)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_record, + nir_load_var(&b, vars.shader_record_ptr)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags, + nir_load_var(&b, vars.cull_mask_and_flags)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax)); + + ac_nir_store_arg(&b, &args->ac, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags, + nir_load_var(&b, vars.geometry_id_and_flags)); + ac_nir_store_arg(&b, &args->ac, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind)); + + /* cleanup passes */ + NIR_PASS_V(shader, nir_lower_global_vars_to_local); + NIR_PASS_V(shader, nir_lower_vars_to_ssa); + if (shader->info.stage == MESA_SHADER_CLOSEST_HIT || + shader->info.stage == MESA_SHADER_INTERSECTION) + NIR_PASS_V(shader, lower_hit_attribs, NULL, key->cs.compute_subgroup_size); +} diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index 956c7f3..2381036 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -583,6 +583,10 @@ nir_shader *radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo, const struct radv_pipeline_key *key); +void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + const struct radv_shader_args *args, const struct radv_pipeline_key *key, + uint32_t *stack_size); + struct radv_pipeline_stage; nir_shader *radv_shader_spirv_to_nir(struct radv_device *device,