radv/rt: implement radv_nir_lower_rt_abi to lower RT shaders for separate compilation
authorDaniel Schürmann <daniel@schuermann.dev>
Mon, 6 Mar 2023 19:03:49 +0000 (20:03 +0100)
committerMarge Bot <emma+marge@anholt.net>
Thu, 8 Jun 2023 00:37:03 +0000 (00:37 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>

src/amd/vulkan/radv_rt_shader.c
src/amd/vulkan/radv_shader.h

index e97ba2e..2fd4e02 100644 (file)
@@ -26,6 +26,7 @@
 
 #include "bvh/bvh.h"
 #include "meta/radv_meta.h"
+#include "ac_nir.h"
 #include "radv_private.h"
 #include "radv_rt_common.h"
 #include "radv_shader.h"
@@ -88,6 +89,8 @@ struct rt_variables {
     * the correct resume index upon returning.
     */
    nir_variable *idx;
+   nir_variable *shader_va;
+   nir_variable *traversal_addr;
 
    /* scratch offset of the argument area relative to stack_ptr */
    nir_variable *arg;
@@ -129,6 +132,10 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR
       .create_info = create_info,
    };
    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
+   vars.shader_va =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_va");
+   vars.traversal_addr =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
    vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
    vars.shader_record_ptr =
@@ -177,6 +184,8 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
    src->create_info = dst->create_info;
 
    _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
+   _mesa_hash_table_insert(var_remap, src->shader_va, dst->shader_va);
+   _mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
    _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
@@ -1702,3 +1711,109 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
 
    return b.shader;
 }
+
+void
+radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                      const struct radv_shader_args *args, const struct radv_pipeline_key *key,
+                      uint32_t *stack_size)
+{
+   nir_builder b;
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+   nir_builder_init(&b, impl);
+
+   struct rt_variables vars = create_rt_variables(shader, pCreateInfo);
+   lower_rt_instructions(shader, &vars, 0);
+
+   if (stack_size) {
+      vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
+      *stack_size = MAX2(*stack_size, vars.stack_size);
+   }
+   shader->scratch_size = 0;
+
+   NIR_PASS(_, shader, nir_lower_returns);
+
+   nir_cf_list list;
+   nir_cf_extract(&list, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
+
+   /* initialize variables */
+   b.cursor = nir_before_cf_list(&impl->body);
+
+   nir_ssa_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader);
+   nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
+   nir_ssa_def *shader_va = ac_nir_load_arg(&b, &args->ac, args->ac.rt.next_shader);
+   shader_va = nir_pack_64_2x32(&b, shader_va);
+   nir_store_var(&b, vars.shader_va, shader_va, 1);
+   nir_store_var(&b, vars.stack_ptr,
+                 ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
+   nir_ssa_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
+   nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
+   nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
+
+   nir_ssa_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
+   nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
+   nir_store_var(&b, vars.cull_mask_and_flags,
+                 ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
+   nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
+   nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
+   nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 1);
+   nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
+   nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
+   nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction),
+                 0x7);
+   nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
+
+   nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id),
+                 1);
+   nir_ssa_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
+   nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
+   nir_store_var(&b, vars.geometry_id_and_flags,
+                 ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
+   nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
+
+   /* guard the shader, so that only the correct invocations execute it */
+   nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc);
+   shader_pc = nir_pack_64_2x32(&b, shader_pc);
+   nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va);
+   nir_if *shader_guard = nir_push_if(&b, cond);
+   shader_guard->control = nir_selection_control_divergent_always_taken;
+   nir_cf_reinsert(&list, b.cursor);
+   nir_pop_if(&b, shader_guard);
+
+   /* select next shader */
+   // TODO: use a priority-based selection
+   b.cursor = nir_after_cf_list(&impl->body);
+   shader_va = nir_load_var(&b, vars.shader_va);
+   nir_ssa_def *next = nir_read_first_invocation(&b, shader_va);
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next);
+
+   /* store back all variables to registers */
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base,
+                    nir_load_var(&b, vars.stack_ptr));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.next_shader, nir_load_var(&b, vars.shader_va));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_record,
+                    nir_load_var(&b, vars.shader_record_ptr));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags,
+                    nir_load_var(&b, vars.cull_mask_and_flags));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
+
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags,
+                    nir_load_var(&b, vars.geometry_id_and_flags));
+   ac_nir_store_arg(&b, &args->ac, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
+
+   /* cleanup passes */
+   NIR_PASS_V(shader, nir_lower_global_vars_to_local);
+   NIR_PASS_V(shader, nir_lower_vars_to_ssa);
+   if (shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
+       shader->info.stage == MESA_SHADER_INTERSECTION)
+      NIR_PASS_V(shader, lower_hit_attribs, NULL, key->cs.compute_subgroup_size);
+}
index 956c7f3..2381036 100644 (file)
@@ -583,6 +583,10 @@ nir_shader *radv_parse_rt_stage(struct radv_device *device,
                                 const VkPipelineShaderStageCreateInfo *sinfo,
                                 const struct radv_pipeline_key *key);
 
+void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                           const struct radv_shader_args *args, const struct radv_pipeline_key *key,
+                           uint32_t *stack_size);
+
 struct radv_pipeline_stage;
 
 nir_shader *radv_shader_spirv_to_nir(struct radv_device *device,