radv/rt: move Ray Tracing shader creation into separate file
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 4 Nov 2022 12:23:45 +0000 (13:23 +0100)
committerDaniel Schürmann <daniel@schuermann.dev>
Wed, 16 Nov 2022 21:43:49 +0000 (22:43 +0100)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19525>

src/amd/vulkan/meson.build
src/amd/vulkan/radv_pipeline_rt.c
src/amd/vulkan/radv_rt_shader.c [new file with mode: 0644]
src/amd/vulkan/radv_shader.h

index ad75d68..e629aad 100644 (file)
@@ -85,6 +85,7 @@ libradv_files = files(
   'radv_radeon_winsys.h',
   'radv_rra.c',
   'radv_rt_common.c',
+  'radv_rt_shader.c',
   'radv_sdma_copy_image.c',
   'radv_shader.c',
   'radv_shader.h',
index 49ef06f..3db0f90 100644 (file)
  * IN THE SOFTWARE.
  */
 
-#include "radv_acceleration_structure.h"
+#include "nir/nir.h"
+
 #include "radv_debug.h"
-#include "radv_meta.h"
 #include "radv_private.h"
-#include "radv_rt_common.h"
 #include "radv_shader.h"
 
-#include "nir/nir.h"
-#include "nir/nir_builder.h"
-#include "nir/nir_builtin_builder.h"
-
-/* Traversal stack size. This stack is put in LDS and experimentally 16 entries results in best
- * performance. */
-#define MAX_STACK_ENTRY_COUNT 16
-
 static VkRayTracingPipelineCreateInfoKHR
 radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
 {
@@ -230,1331 +221,6 @@ fail:
    return VK_ERROR_OUT_OF_HOST_MEMORY;
 }
 
-/*
- * Global variables for an RT pipeline
- */
-struct rt_variables {
-   const VkRayTracingPipelineCreateInfoKHR *create_info;
-
-   /* idx of the next shader to run in the next iteration of the main loop.
-    * During traversal, idx is used to store the SBT index and will contain
-    * the correct resume index upon returning.
-    */
-   nir_variable *idx;
-
-   /* scratch offset of the argument area relative to stack_ptr */
-   nir_variable *arg;
-
-   nir_variable *stack_ptr;
-
-   /* global address of the SBT entry used for the shader */
-   nir_variable *shader_record_ptr;
-
-   /* trace_ray arguments */
-   nir_variable *accel_struct;
-   nir_variable *flags;
-   nir_variable *cull_mask;
-   nir_variable *sbt_offset;
-   nir_variable *sbt_stride;
-   nir_variable *miss_index;
-   nir_variable *origin;
-   nir_variable *tmin;
-   nir_variable *direction;
-   nir_variable *tmax;
-
-   /* Properties of the primitive currently being visited. */
-   nir_variable *primitive_id;
-   nir_variable *geometry_id_and_flags;
-   nir_variable *instance_addr;
-   nir_variable *hit_kind;
-   nir_variable *opaque;
-
-   /* Output variables for intersection & anyhit shaders. */
-   nir_variable *ahit_accept;
-   nir_variable *ahit_terminate;
-
-   /* Array of stack size struct for recording the max stack size for each group. */
-   struct radv_pipeline_shader_stack_size *stack_sizes;
-   unsigned stage_idx;
-};
-
-static void
-reserve_stack_size(struct rt_variables *vars, uint32_t size)
-{
-   for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
-      const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx;
-
-      if (vars->stage_idx == group->generalShader || vars->stage_idx == group->closestHitShader)
-         vars->stack_sizes[group_idx].recursive_size =
-            MAX2(vars->stack_sizes[group_idx].recursive_size, size);
-
-      if (vars->stage_idx == group->anyHitShader || vars->stage_idx == group->intersectionShader)
-         vars->stack_sizes[group_idx].non_recursive_size =
-            MAX2(vars->stack_sizes[group_idx].non_recursive_size, size);
-   }
-}
-
-static struct rt_variables
-create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
-                    struct radv_pipeline_shader_stack_size *stack_sizes)
-{
-   struct rt_variables vars = {
-      .create_info = create_info,
-   };
-   vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
-   vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
-   vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
-   vars.shader_record_ptr =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
-
-   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
-   vars.accel_struct =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
-   vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags");
-   vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
-   vars.sbt_offset =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
-   vars.sbt_stride =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
-   vars.miss_index =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
-   vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
-   vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
-   vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
-   vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
-
-   vars.primitive_id =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
-   vars.geometry_id_and_flags =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
-   vars.instance_addr =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
-   vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
-   vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
-
-   vars.ahit_accept =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
-   vars.ahit_terminate =
-      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
-
-   vars.stack_sizes = stack_sizes;
-   return vars;
-}
-
-/*
- * Remap all the variables between the two rt_variables struct for inlining.
- */
-static void
-map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
-                 const struct rt_variables *dst)
-{
-   src->create_info = dst->create_info;
-
-   _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
-   _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
-   _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
-   _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
-
-   _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
-   _mesa_hash_table_insert(var_remap, src->flags, dst->flags);
-   _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask);
-   _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
-   _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
-   _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
-   _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
-   _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
-   _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
-   _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
-
-   _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
-   _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
-   _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
-   _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
-   _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
-   _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
-   _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
-
-   src->stack_sizes = dst->stack_sizes;
-   src->stage_idx = dst->stage_idx;
-}
-
-/*
- * Create a copy of the global rt variables where the primitive/instance related variables are
- * independent.This is needed as we need to keep the old values of the global variables around
- * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
- * to the outer variables once we commit to a better hit.
- */
-static struct rt_variables
-create_inner_vars(nir_builder *b, const struct rt_variables *vars)
-{
-   struct rt_variables inner_vars = *vars;
-   inner_vars.idx =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
-   inner_vars.shader_record_ptr = nir_variable_create(
-      b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
-   inner_vars.primitive_id =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
-   inner_vars.geometry_id_and_flags = nir_variable_create(
-      b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
-   inner_vars.tmax =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
-   inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp,
-                                                  glsl_uint64_t_type(), "inner_instance_addr");
-   inner_vars.hit_kind =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
-
-   return inner_vars;
-}
-
-/* The hit attributes are stored on the stack. This is the offset compared to the current stack
- * pointer of where the hit attrib is stored. */
-const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
-
-static void
-insert_rt_return(nir_builder *b, const struct rt_variables *vars)
-{
-   nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
-   nir_store_var(b, vars->idx,
-                 nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1);
-}
-
-enum sbt_type {
-   SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
-   SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
-   SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
-   SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
-};
-
-static nir_ssa_def *
-get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
-{
-   nir_ssa_def *desc_base_addr = nir_load_sbt_base_amd(b);
-
-   nir_ssa_def *desc =
-      nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
-
-   nir_ssa_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
-   nir_ssa_def *stride =
-      nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, stride_offset));
-
-   return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride));
-}
-
-static void
-load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx,
-               enum sbt_type binding, unsigned offset)
-{
-   nir_ssa_def *addr = get_sbt_ptr(b, idx, binding);
-
-   nir_ssa_def *load_addr = nir_iadd_imm(b, addr, offset);
-   nir_ssa_def *v_idx = nir_build_load_global(b, 1, 32, load_addr);
-
-   nir_store_var(b, vars->idx, v_idx, 1);
-
-   nir_ssa_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE);
-   nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
-}
-
-/* This lowers all the RT instructions that we do not want to pass on to the combined shader and
- * that we can implement using the variables from the shader we are going to inline into. */
-static void
-lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)
-{
-   nir_builder b_shader;
-   nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader));
-
-   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
-      nir_foreach_instr_safe (instr, block) {
-         switch (instr->type) {
-         case nir_instr_type_intrinsic: {
-            b_shader.cursor = nir_before_instr(instr);
-            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
-            nir_ssa_def *ret = NULL;
-
-            switch (intr->intrinsic) {
-            case nir_intrinsic_rt_execute_callable: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
-               uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
-
-               nir_store_var(
-                  &b_shader, vars->stack_ptr,
-                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
-               nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
-                                 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
-
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
-                             1);
-               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0);
-
-               nir_store_var(&b_shader, vars->arg,
-                             nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1);
-
-               reserve_stack_size(vars, size + 16);
-               break;
-            }
-            case nir_intrinsic_rt_trace_ray: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
-               uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
-
-               nir_store_var(
-                  &b_shader, vars->stack_ptr,
-                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
-               nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
-                                 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
-
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
-                             1);
-
-               nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1);
-               nir_store_var(&b_shader, vars->arg,
-                             nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1);
-
-               reserve_stack_size(vars, size + 16);
-
-               /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
-               nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
-               nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1);
-               nir_store_var(&b_shader, vars->cull_mask,
-                             nir_iand_imm(&b_shader, intr->src[2].ssa, 0xff), 0x1);
-               nir_store_var(&b_shader, vars->sbt_offset,
-                             nir_iand_imm(&b_shader, intr->src[3].ssa, 0xf), 0x1);
-               nir_store_var(&b_shader, vars->sbt_stride,
-                             nir_iand_imm(&b_shader, intr->src[4].ssa, 0xf), 0x1);
-               nir_store_var(&b_shader, vars->miss_index,
-                             nir_iand_imm(&b_shader, intr->src[5].ssa, 0xffff), 0x1);
-               nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
-               nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
-               nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
-               nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
-               break;
-            }
-            case nir_intrinsic_rt_resume: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
-
-               nir_store_var(
-                  &b_shader, vars->stack_ptr,
-                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), -size), 1);
-               break;
-            }
-            case nir_intrinsic_rt_return_amd: {
-               if (shader->info.stage == MESA_SHADER_RAYGEN) {
-                  nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1);
-                  break;
-               }
-               insert_rt_return(&b_shader, vars);
-               break;
-            }
-            case nir_intrinsic_load_scratch: {
-               nir_instr_rewrite_src_ssa(
-                  instr, &intr->src[0],
-                  nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
-               continue;
-            }
-            case nir_intrinsic_store_scratch: {
-               nir_instr_rewrite_src_ssa(
-                  instr, &intr->src[1],
-                  nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
-               continue;
-            }
-            case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
-               ret = nir_load_var(&b_shader, vars->arg);
-               break;
-            }
-            case nir_intrinsic_load_shader_record_ptr: {
-               ret = nir_load_var(&b_shader, vars->shader_record_ptr);
-               break;
-            }
-            case nir_intrinsic_load_ray_launch_id: {
-               ret = nir_load_global_invocation_id(&b_shader, 32);
-               break;
-            }
-            case nir_intrinsic_load_ray_launch_size: {
-               nir_ssa_def *launch_size_addr =
-                  nir_load_ray_launch_size_addr_amd(&b_shader);
-
-               nir_ssa_def * xy = nir_build_load_smem_amd(
-                  &b_shader, 2, launch_size_addr, nir_imm_int(&b_shader, 0));
-               nir_ssa_def * z = nir_build_load_smem_amd(
-                  &b_shader, 1, launch_size_addr, nir_imm_int(&b_shader, 8));
-
-               nir_ssa_def *xyz[3] = {
-                  nir_channel(&b_shader, xy, 0),
-                  nir_channel(&b_shader, xy, 1),
-                  z,
-               };
-               ret = nir_vec(&b_shader, xyz, 3);
-               break;
-            }
-            case nir_intrinsic_load_ray_t_min: {
-               ret = nir_load_var(&b_shader, vars->tmin);
-               break;
-            }
-            case nir_intrinsic_load_ray_t_max: {
-               ret = nir_load_var(&b_shader, vars->tmax);
-               break;
-            }
-            case nir_intrinsic_load_ray_world_origin: {
-               ret = nir_load_var(&b_shader, vars->origin);
-               break;
-            }
-            case nir_intrinsic_load_ray_world_direction: {
-               ret = nir_load_var(&b_shader, vars->direction);
-               break;
-            }
-            case nir_intrinsic_load_ray_instance_custom_index: {
-               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_ssa_def *custom_instance_and_mask = nir_build_load_global(
-                  &b_shader, 1, 32,
-                  nir_iadd_imm(&b_shader, instance_node_addr,
-                               offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
-               ret = nir_iand_imm(&b_shader, custom_instance_and_mask, 0xFFFFFF);
-               break;
-            }
-            case nir_intrinsic_load_primitive_id: {
-               ret = nir_load_var(&b_shader, vars->primitive_id);
-               break;
-            }
-            case nir_intrinsic_load_ray_geometry_index: {
-               ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
-               ret = nir_iand_imm(&b_shader, ret, 0xFFFFFFF);
-               break;
-            }
-            case nir_intrinsic_load_instance_id: {
-               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               ret = nir_build_load_global(
-                  &b_shader, 1, 32,
-                  nir_iadd_imm(&b_shader, instance_node_addr,
-                               offsetof(struct radv_bvh_instance_node, instance_id)));
-               break;
-            }
-            case nir_intrinsic_load_ray_flags: {
-               ret = nir_load_var(&b_shader, vars->flags);
-               break;
-            }
-            case nir_intrinsic_load_ray_hit_kind: {
-               ret = nir_load_var(&b_shader, vars->hit_kind);
-               break;
-            }
-            case nir_intrinsic_load_ray_world_to_object: {
-               unsigned c = nir_intrinsic_column(intr);
-               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_ssa_def *wto_matrix[3];
-               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
-
-               nir_ssa_def *vals[3];
-               for (unsigned i = 0; i < 3; ++i)
-                  vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
-
-               ret = nir_vec(&b_shader, vals, 3);
-               break;
-            }
-            case nir_intrinsic_load_ray_object_to_world: {
-               unsigned c = nir_intrinsic_column(intr);
-               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_ssa_def *rows[3];
-               for (unsigned r = 0; r < 3; ++r)
-                  rows[r] = nir_build_load_global(
-                     &b_shader, 4, 32,
-                     nir_iadd_imm(&b_shader, instance_node_addr,
-                                  offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
-               ret =
-                  nir_vec3(&b_shader, nir_channel(&b_shader, rows[0], c),
-                           nir_channel(&b_shader, rows[1], c), nir_channel(&b_shader, rows[2], c));
-               break;
-            }
-            case nir_intrinsic_load_ray_object_origin: {
-               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_ssa_def *wto_matrix[3];
-               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
-               ret = nir_build_vec3_mat_mult(&b_shader, nir_load_var(&b_shader, vars->origin),
-                                             wto_matrix, true);
-               break;
-            }
-            case nir_intrinsic_load_ray_object_direction: {
-               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_ssa_def *wto_matrix[3];
-               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
-               ret = nir_build_vec3_mat_mult(
-                  &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false);
-               break;
-            }
-            case nir_intrinsic_load_intersection_opaque_amd: {
-               ret = nir_load_var(&b_shader, vars->opaque);
-               break;
-            }
-            case nir_intrinsic_load_cull_mask: {
-               ret = nir_load_var(&b_shader, vars->cull_mask);
-               break;
-            }
-            case nir_intrinsic_ignore_ray_intersection: {
-               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_false(&b_shader), 0x1);
-
-               /* The if is a workaround to avoid having to fix up control flow manually */
-               nir_push_if(&b_shader, nir_imm_true(&b_shader));
-               nir_jump(&b_shader, nir_jump_return);
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_terminate_ray: {
-               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
-               nir_store_var(&b_shader, vars->ahit_terminate, nir_imm_true(&b_shader), 0x1);
-
-               /* The if is a workaround to avoid having to fix up control flow manually */
-               nir_push_if(&b_shader, nir_imm_true(&b_shader));
-               nir_jump(&b_shader, nir_jump_return);
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_report_ray_intersection: {
-               nir_push_if(
-                  &b_shader,
-                  nir_iand(
-                     &b_shader,
-                     nir_fge(&b_shader, nir_load_var(&b_shader, vars->tmax), intr->src[0].ssa),
-                     nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
-               {
-                  nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
-                  nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
-                  nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
-               }
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_load_sbt_offset_amd: {
-               ret = nir_load_var(&b_shader, vars->sbt_offset);
-               break;
-            }
-            case nir_intrinsic_load_sbt_stride_amd: {
-               ret = nir_load_var(&b_shader, vars->sbt_stride);
-               break;
-            }
-            case nir_intrinsic_load_accel_struct_amd: {
-               ret = nir_load_var(&b_shader, vars->accel_struct);
-               break;
-            }
-            case nir_intrinsic_execute_closest_hit_amd: {
-               nir_store_var(&b_shader, vars->tmax, intr->src[1].ssa, 0x1);
-               nir_store_var(&b_shader, vars->primitive_id, intr->src[2].ssa, 0x1);
-               nir_store_var(&b_shader, vars->instance_addr, intr->src[3].ssa, 0x1);
-               nir_store_var(&b_shader, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
-               nir_store_var(&b_shader, vars->hit_kind, intr->src[5].ssa, 0x1);
-               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_HIT, 0);
-
-               nir_ssa_def *should_return =
-                  nir_ior(&b_shader,
-                          nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags),
-                                        SpvRayFlagsSkipClosestHitShaderKHRMask),
-                          nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
-
-               /* should_return is set if we had a hit but we won't be calling the closest hit
-                * shader and hence need to return immediately to the calling shader. */
-               nir_push_if(&b_shader, should_return);
-               insert_rt_return(&b_shader, vars);
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_execute_miss_amd: {
-               nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 0x1);
-               nir_ssa_def *undef = nir_ssa_undef(&b_shader, 1, 32);
-               nir_store_var(&b_shader, vars->primitive_id, undef, 0x1);
-               nir_store_var(&b_shader, vars->instance_addr, nir_ssa_undef(&b_shader, 1, 64), 0x1);
-               nir_store_var(&b_shader, vars->geometry_id_and_flags, undef, 0x1);
-               nir_store_var(&b_shader, vars->hit_kind, undef, 0x1);
-               nir_ssa_def *miss_index = nir_load_var(&b_shader, vars->miss_index);
-               load_sbt_entry(&b_shader, vars, miss_index, SBT_MISS, 0);
-               break;
-            }
-            default:
-               continue;
-            }
-
-            if (ret)
-               nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
-            nir_instr_remove(instr);
-            break;
-         }
-         case nir_instr_type_jump: {
-            nir_jump_instr *jump = nir_instr_as_jump(instr);
-            if (jump->type == nir_jump_halt) {
-               b_shader.cursor = nir_instr_remove(instr);
-               nir_jump(&b_shader, nir_jump_return);
-            }
-            break;
-         }
-         default:
-            break;
-         }
-      }
-   }
-
-   nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
-}
-
-static void
-insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
-               uint32_t call_idx_base, uint32_t call_idx)
-{
-   struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
-
-   nir_opt_dead_cf(shader);
-
-   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes);
-   map_rt_variables(var_remap, &src_vars, vars);
-
-   NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
-
-   NIR_PASS(_, shader, nir_opt_remove_phis);
-   NIR_PASS(_, shader, nir_lower_returns);
-   NIR_PASS(_, shader, nir_opt_dce);
-
-   reserve_stack_size(vars, shader->scratch_size);
-
-   nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
-   nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
-   nir_pop_if(b, NULL);
-
-   /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
-   ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
-
-   ralloc_free(var_remap);
-}
-
-static bool
-lower_rt_derefs(nir_shader *shader)
-{
-   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
-
-   bool progress = false;
-
-   nir_builder b;
-   nir_builder_init(&b, impl);
-
-   b.cursor = nir_before_cf_list(&impl->body);
-   nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
-
-   nir_foreach_block (block, impl) {
-      nir_foreach_instr_safe (instr, block) {
-         if (instr->type != nir_instr_type_deref)
-            continue;
-
-         nir_deref_instr *deref = nir_instr_as_deref(instr);
-         b.cursor = nir_before_instr(&deref->instr);
-
-         nir_deref_instr *replacement = NULL;
-         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
-            deref->modes = nir_var_function_temp;
-            progress = true;
-
-            if (deref->deref_type == nir_deref_type_var)
-               replacement =
-                  nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
-         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
-            deref->modes = nir_var_function_temp;
-            progress = true;
-
-            if (deref->deref_type == nir_deref_type_var)
-               replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
-                                                  nir_var_function_temp, deref->type, 0);
-         }
-
-         if (replacement != NULL) {
-            nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa);
-            nir_instr_remove(&deref->instr);
-         }
-      }
-   }
-
-   if (progress)
-      nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
-   else
-      nir_metadata_preserve(impl, nir_metadata_all);
-
-   return progress;
-}
-
-static nir_shader *
-parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo)
-{
-   struct radv_pipeline_key key;
-   memset(&key, 0, sizeof(key));
-
-   struct radv_pipeline_stage rt_stage;
-
-   radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage));
-
-   nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key);
-
-   if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
-       shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) {
-      nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader));
-      nir_builder b_inner;
-      nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader));
-      b_inner.cursor = nir_after_block(last_block);
-      nir_rt_return_amd(&b_inner);
-   }
-
-   NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
-            nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
-            glsl_get_natural_size_align_bytes);
-
-   NIR_PASS(_, shader, lower_rt_derefs);
-
-   NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp,
-            nir_address_format_32bit_offset);
-
-   return shader;
-}
-
-static nir_function_impl *
-lower_any_hit_for_intersection(nir_shader *any_hit)
-{
-   nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
-
-   /* Any-hit shaders need three parameters */
-   assert(impl->function->num_params == 0);
-   nir_parameter params[] = {
-      {
-         /* A pointer to a boolean value for whether or not the hit was
-          * accepted.
-          */
-         .num_components = 1,
-         .bit_size = 32,
-      },
-      {
-         /* The hit T value */
-         .num_components = 1,
-         .bit_size = 32,
-      },
-      {
-         /* The hit kind */
-         .num_components = 1,
-         .bit_size = 32,
-      },
-   };
-   impl->function->num_params = ARRAY_SIZE(params);
-   impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
-   memcpy(impl->function->params, params, sizeof(params));
-
-   nir_builder build;
-   nir_builder_init(&build, impl);
-   nir_builder *b = &build;
-
-   b->cursor = nir_before_cf_list(&impl->body);
-
-   nir_ssa_def *commit_ptr = nir_load_param(b, 0);
-   nir_ssa_def *hit_t = nir_load_param(b, 1);
-   nir_ssa_def *hit_kind = nir_load_param(b, 2);
-
-   nir_deref_instr *commit =
-      nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
-
-   nir_foreach_block_safe (block, impl) {
-      nir_foreach_instr_safe (instr, block) {
-         switch (instr->type) {
-         case nir_instr_type_intrinsic: {
-            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-            switch (intrin->intrinsic) {
-            case nir_intrinsic_ignore_ray_intersection:
-               b->cursor = nir_instr_remove(&intrin->instr);
-               /* We put the newly emitted code inside a dummy if because it's
-                * going to contain a jump instruction and we don't want to
-                * deal with that mess here.  It'll get dealt with by our
-                * control-flow optimization passes.
-                */
-               nir_store_deref(b, commit, nir_imm_false(b), 0x1);
-               nir_push_if(b, nir_imm_true(b));
-               nir_jump(b, nir_jump_return);
-               nir_pop_if(b, NULL);
-               break;
-
-            case nir_intrinsic_terminate_ray:
-               /* The "normal" handling of terminateRay works fine in
-                * intersection shaders.
-                */
-               break;
-
-            case nir_intrinsic_load_ray_t_max:
-               nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t);
-               nir_instr_remove(&intrin->instr);
-               break;
-
-            case nir_intrinsic_load_ray_hit_kind:
-               nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind);
-               nir_instr_remove(&intrin->instr);
-               break;
-
-            default:
-               break;
-            }
-            break;
-         }
-         case nir_instr_type_jump: {
-            nir_jump_instr *jump = nir_instr_as_jump(instr);
-            if (jump->type == nir_jump_halt) {
-               b->cursor = nir_instr_remove(instr);
-               nir_jump(b, nir_jump_return);
-            }
-            break;
-         }
-
-         default:
-            break;
-         }
-      }
-   }
-
-   nir_validate_shader(any_hit, "after initial any-hit lowering");
-
-   nir_lower_returns_impl(impl);
-
-   nir_validate_shader(any_hit, "after lowering returns");
-
-   return impl;
-}
-
-/* Inline the any_hit shader into the intersection shader so we don't have
- * to implement yet another shader call interface here. Neither do any recursion.
- */
-static void
-nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
-{
-   void *dead_ctx = ralloc_context(intersection);
-
-   nir_function_impl *any_hit_impl = NULL;
-   struct hash_table *any_hit_var_remap = NULL;
-   if (any_hit) {
-      any_hit = nir_shader_clone(dead_ctx, any_hit);
-      NIR_PASS(_, any_hit, nir_opt_dce);
-      any_hit_impl = lower_any_hit_for_intersection(any_hit);
-      any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
-   }
-
-   nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
-
-   nir_builder build;
-   nir_builder_init(&build, impl);
-   nir_builder *b = &build;
-
-   b->cursor = nir_before_cf_list(&impl->body);
-
-   nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
-   nir_store_var(b, commit, nir_imm_false(b), 0x1);
-
-   nir_foreach_block_safe (block, impl) {
-      nir_foreach_instr_safe (instr, block) {
-         if (instr->type != nir_instr_type_intrinsic)
-            continue;
-
-         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-         if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
-            continue;
-
-         b->cursor = nir_instr_remove(&intrin->instr);
-         nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1);
-         nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1);
-         nir_ssa_def *min_t = nir_load_ray_t_min(b);
-         nir_ssa_def *max_t = nir_load_ray_t_max(b);
-
-         /* bool commit_tmp = false; */
-         nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
-         nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
-
-         nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
-         {
-            /* Any-hit defaults to commit */
-            nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
-
-            if (any_hit_impl != NULL) {
-               nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
-               {
-                  nir_ssa_def *params[] = {
-                     &nir_build_deref_var(b, commit_tmp)->dest.ssa,
-                     hit_t,
-                     hit_kind,
-                  };
-                  nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
-               }
-               nir_pop_if(b, NULL);
-            }
-
-            nir_push_if(b, nir_load_var(b, commit_tmp));
-            {
-               nir_report_ray_intersection(b, 1, hit_t, hit_kind);
-            }
-            nir_pop_if(b, NULL);
-         }
-         nir_pop_if(b, NULL);
-
-         nir_ssa_def *accepted = nir_load_var(b, commit_tmp);
-         nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted);
-      }
-   }
-
-   /* We did some inlining; have to re-index SSA defs */
-   nir_index_ssa_defs(impl);
-
-   /* Eliminate the casts introduced for the commit return of the any-hit shader. */
-   NIR_PASS(_, intersection, nir_opt_deref);
-
-   ralloc_free(dead_ctx);
-}
-
-/* Variables only used internally to ray traversal. This is data that describes
- * the current state of the traversal vs. what we'd give to a shader.  e.g. what
- * is the instance we're currently visiting vs. what is the instance of the
- * closest hit. */
-struct rt_traversal_vars {
-   nir_variable *origin;
-   nir_variable *dir;
-   nir_variable *inv_dir;
-   nir_variable *sbt_offset_and_flags;
-   nir_variable *instance_addr;
-   nir_variable *hit;
-   nir_variable *bvh_base;
-   nir_variable *stack;
-   nir_variable *top_stack;
-   nir_variable *stack_base;
-   nir_variable *current_node;
-   nir_variable *previous_node;
-   nir_variable *instance_top_node;
-   nir_variable *instance_bottom_node;
-};
-
-static struct rt_traversal_vars
-init_traversal_vars(nir_builder *b)
-{
-   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
-   struct rt_traversal_vars ret;
-
-   ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
-   ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
-   ret.inv_dir =
-      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
-   ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
-                                                  "traversal_sbt_offset_and_flags");
-   ret.instance_addr =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
-   ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
-   ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(),
-                                      "traversal_bvh_base");
-   ret.stack =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
-   ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
-                                       "traversal_top_stack_ptr");
-   ret.stack_base =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_base");
-   ret.current_node =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "current_node;");
-   ret.previous_node =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "previous_node");
-   ret.instance_top_node =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_top_node");
-   ret.instance_bottom_node =
-      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_bottom_node");
-   return ret;
-}
-
-static void
-visit_any_hit_shaders(struct radv_device *device,
-                      const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
-                      struct rt_variables *vars)
-{
-   nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
-
-   nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
-   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
-      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
-      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
-
-      switch (group_info->type) {
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
-         shader_id = group_info->anyHitShader;
-         break;
-      default:
-         break;
-      }
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
-      nir_shader *nir_stage = parse_rt_stage(device, stage);
-
-      vars->stage_idx = shader_id;
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
-   }
-   nir_pop_if(b, NULL);
-}
-
-struct traversal_data {
-   struct radv_device *device;
-   const VkRayTracingPipelineCreateInfoKHR *createInfo;
-   struct rt_variables *vars;
-   struct rt_traversal_vars *trav_vars;
-};
-
-static void
-handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
-                          const struct radv_ray_traversal_args *args)
-{
-   struct traversal_data *data = args->data;
-
-   nir_ssa_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
-   nir_ssa_def *sbt_idx = nir_iadd(
-      b,
-      nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
-               nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
-      nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
-
-   nir_ssa_def *hit_kind =
-      nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
-
-   nir_ssa_def *barycentrics_addr =
-      nir_iadd_imm(b, nir_load_var(b, data->vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET);
-   nir_ssa_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_addr, .align_mul = 16);
-   nir_store_scratch(b, intersection->barycentrics, barycentrics_addr, .align_mul = 16);
-
-   nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
-   nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
-
-   nir_push_if(b, nir_inot(b, intersection->base.opaque));
-   {
-      struct rt_variables inner_vars = create_inner_vars(b, data->vars);
-
-      nir_store_var(b, inner_vars.primitive_id, intersection->base.primitive_id, 1);
-      nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->base.geometry_id_and_flags,
-                    1);
-      nir_store_var(b, inner_vars.tmax, intersection->t, 0x1);
-      nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr),
-                    0x1);
-      nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
-
-      load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
-
-      visit_any_hit_shaders(data->device, data->createInfo, b, &inner_vars);
-
-      nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
-      {
-         nir_store_scratch(b, prev_barycentrics, barycentrics_addr, .align_mul = 16);
-         nir_jump(b, nir_jump_continue);
-      }
-      nir_pop_if(b, NULL);
-   }
-   nir_pop_if(b, NULL);
-
-   nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1);
-   nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
-   nir_store_var(b, data->vars->tmax, intersection->t, 0x1);
-   nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr),
-                 0x1);
-   nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1);
-
-   nir_store_var(b, data->vars->idx, sbt_idx, 1);
-   nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
-
-   nir_ssa_def *terminate_on_first_hit =
-      nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask);
-   nir_ssa_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
-   nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
-   {
-      nir_jump(b, nir_jump_break);
-   }
-   nir_pop_if(b, NULL);
-}
-
-static void
-handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
-                      const struct radv_ray_traversal_args *args)
-{
-   struct traversal_data *data = args->data;
-
-   nir_ssa_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
-   nir_ssa_def *sbt_idx = nir_iadd(
-      b,
-      nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
-               nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
-      nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
-
-   struct rt_variables inner_vars = create_inner_vars(b, data->vars);
-
-   /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
-    * next closest hit candidate. */
-   inner_vars.hit_kind = data->vars->hit_kind;
-
-   nir_store_var(b, inner_vars.primitive_id, intersection->primitive_id, 1);
-   nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
-   nir_store_var(b, inner_vars.tmax, nir_load_var(b, data->vars->tmax), 0x1);
-   nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
-   nir_store_var(b, inner_vars.opaque, intersection->opaque, 1);
-
-   load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
-
-   nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
-   nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
-
-   nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
-   for (unsigned i = 0; i < data->createInfo->groupCount; ++i) {
-      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &data->createInfo->pGroups[i];
-      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
-      uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
-
-      switch (group_info->type) {
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
-         shader_id = group_info->intersectionShader;
-         any_hit_shader_id = group_info->anyHitShader;
-         break;
-      default:
-         break;
-      }
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id];
-      nir_shader *nir_stage = parse_rt_stage(data->device, stage);
-
-      nir_shader *any_hit_stage = NULL;
-      if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
-         stage = &data->createInfo->pStages[any_hit_shader_id];
-         any_hit_stage = parse_rt_stage(data->device, stage);
-
-         nir_lower_intersection_shader(nir_stage, any_hit_stage);
-         ralloc_free(any_hit_stage);
-      }
-
-      inner_vars.stage_idx = shader_id;
-      insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
-   }
-   nir_pop_if(b, NULL);
-
-   nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
-   {
-      nir_store_var(b, data->vars->primitive_id, intersection->primitive_id, 1);
-      nir_store_var(b, data->vars->geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
-      nir_store_var(b, data->vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
-      nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr),
-                    0x1);
-
-      nir_store_var(b, data->vars->idx, sbt_idx, 1);
-      nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
-
-      nir_ssa_def *terminate_on_first_hit =
-         nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask);
-      nir_ssa_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
-      nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
-      {
-         nir_jump(b, nir_jump_break);
-      }
-      nir_pop_if(b, NULL);
-   }
-   nir_pop_if(b, NULL);
-}
-
-static void
-store_stack_entry(nir_builder *b, nir_ssa_def *index, nir_ssa_def *value,
-                  const struct radv_ray_traversal_args *args)
-{
-   nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
-}
-
-static nir_ssa_def *
-load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_traversal_args *args)
-{
-   return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
-}
-
-static nir_shader *
-build_traversal_shader(struct radv_device *device,
-                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                       struct radv_pipeline_shader_stack_size *stack_sizes)
-{
-   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
-   b.shader->info.internal = false;
-   b.shader->info.workgroup_size[0] = 8;
-   b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
-   b.shader->info.shared_size =
-      device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
-   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
-
-   /* initialize trace_ray arguments */
-   nir_ssa_def *accel_struct = nir_load_accel_struct_amd(&b);
-   nir_store_var(&b, vars.flags, nir_load_ray_flags(&b), 0x1);
-   nir_store_var(&b, vars.cull_mask, nir_load_cull_mask(&b), 0x1);
-   nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
-   nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
-   nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
-   nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
-   nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
-   nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
-   nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
-   nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
-
-   struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
-
-   nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1);
-
-   nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0));
-   {
-      nir_ssa_def *bvh_offset = nir_build_load_global(
-         &b, 1, 32,
-         nir_iadd_imm(&b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
-         .access = ACCESS_NON_WRITEABLE);
-      nir_ssa_def *root_bvh_base = nir_iadd(&b, accel_struct, nir_u2u64(&b, bvh_offset));
-      root_bvh_base = build_addr_to_node(&b, root_bvh_base);
-
-      nir_store_var(&b, trav_vars.bvh_base, root_bvh_base, 1);
-
-      nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7);
-
-      nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
-      nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
-      nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)),
-                    7);
-      nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
-      nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
-
-      nir_store_var(&b, trav_vars.stack, nir_imul_imm(&b, nir_load_local_invocation_index(&b), sizeof(uint32_t)), 1);
-      nir_store_var(&b, trav_vars.stack_base, nir_load_var(&b, trav_vars.stack), 1);
-      nir_store_var(&b, trav_vars.current_node, nir_imm_int(&b, RADV_BVH_ROOT_NODE), 0x1);
-      nir_store_var(&b, trav_vars.previous_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
-      nir_store_var(&b, trav_vars.instance_top_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
-      nir_store_var(&b, trav_vars.instance_bottom_node, nir_imm_int(&b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
-
-      nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, -1), 1);
-
-      struct radv_ray_traversal_vars trav_vars_args = {
-         .tmax = nir_build_deref_var(&b, vars.tmax),
-         .origin = nir_build_deref_var(&b, trav_vars.origin),
-         .dir = nir_build_deref_var(&b, trav_vars.dir),
-         .inv_dir = nir_build_deref_var(&b, trav_vars.inv_dir),
-         .bvh_base = nir_build_deref_var(&b, trav_vars.bvh_base),
-         .stack = nir_build_deref_var(&b, trav_vars.stack),
-         .top_stack = nir_build_deref_var(&b, trav_vars.top_stack),
-         .stack_base = nir_build_deref_var(&b, trav_vars.stack_base),
-         .current_node = nir_build_deref_var(&b, trav_vars.current_node),
-         .previous_node = nir_build_deref_var(&b, trav_vars.previous_node),
-         .instance_top_node = nir_build_deref_var(&b, trav_vars.instance_top_node),
-         .instance_bottom_node = nir_build_deref_var(&b, trav_vars.instance_bottom_node),
-         .instance_addr = nir_build_deref_var(&b, trav_vars.instance_addr),
-         .sbt_offset_and_flags = nir_build_deref_var(&b, trav_vars.sbt_offset_and_flags),
-      };
-
-      struct traversal_data data = {
-         .device = device,
-         .createInfo = pCreateInfo,
-         .vars = &vars,
-         .trav_vars = &trav_vars,
-      };
-
-      struct radv_ray_traversal_args args = {
-         .root_bvh_base = root_bvh_base,
-         .flags = nir_load_var(&b, vars.flags),
-         .cull_mask = nir_load_var(&b, vars.cull_mask),
-         .origin = nir_load_var(&b, vars.origin),
-         .tmin = nir_load_var(&b, vars.tmin),
-         .dir = nir_load_var(&b, vars.direction),
-         .vars = trav_vars_args,
-         .stack_stride = device->physical_device->rt_wave_size * sizeof(uint32_t),
-         .stack_entries = MAX_STACK_ENTRY_COUNT,
-         .stack_store_cb = store_stack_entry,
-         .stack_load_cb = load_stack_entry,
-         .aabb_cb = (pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)
-                       ? NULL
-                       : handle_candidate_aabb,
-         .triangle_cb = (pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)
-                           ? NULL
-                           : handle_candidate_triangle,
-         .data = &data,
-      };
-
-      radv_build_ray_traversal(device, &b, &args);
-   }
-   nir_pop_if(&b, NULL);
-
-   /* Initialize follow-up shader. */
-   nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
-   {
-      nir_execute_closest_hit_amd(
-         &b, nir_load_var(&b, vars.idx), nir_load_var(&b, vars.tmax),
-         nir_load_var(&b, vars.primitive_id), nir_load_var(&b, vars.instance_addr),
-         nir_load_var(&b, vars.geometry_id_and_flags), nir_load_var(&b, vars.hit_kind));
-   }
-   nir_push_else(&b, NULL);
-   {
-      /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
-       * for miss shaders if none of the rays miss. */
-      nir_execute_miss_amd(&b, nir_load_var(&b, vars.tmax));
-   }
-   nir_pop_if(&b, NULL);
-
-   /* Deal with all the inline functions. */
-   nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
-   nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
-
-   /* Lower and cleanup variables */
-   NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
-   NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
-
-   return b.shader;
-}
-
-static unsigned
-compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                      const struct radv_pipeline_shader_stack_size *stack_sizes)
-{
-   unsigned raygen_size = 0;
-   unsigned callable_size = 0;
-   unsigned chit_size = 0;
-   unsigned miss_size = 0;
-   unsigned non_recursive_size = 0;
-
-   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
-      non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size);
-
-      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
-      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
-      unsigned size = stack_sizes[i].recursive_size;
-
-      switch (group_info->type) {
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
-         shader_id = group_info->generalShader;
-         break;
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
-      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
-         shader_id = group_info->closestHitShader;
-         break;
-      default:
-         break;
-      }
-      if (shader_id == VK_SHADER_UNUSED_KHR)
-         continue;
-
-      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
-      switch (stage->stage) {
-      case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
-         raygen_size = MAX2(raygen_size, size);
-         break;
-      case VK_SHADER_STAGE_MISS_BIT_KHR:
-         miss_size = MAX2(miss_size, size);
-         break;
-      case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
-         chit_size = MAX2(chit_size, size);
-         break;
-      case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
-         callable_size = MAX2(callable_size, size);
-         break;
-      default:
-         unreachable("Invalid stage type in RT shader");
-      }
-   }
-   return raygen_size +
-          MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
-             MAX2(MAX2(chit_size, miss_size), non_recursive_size) +
-          MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) *
-             MAX2(chit_size, miss_size) +
-          2 * callable_size;
-}
-
 bool
 radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
 {
@@ -1570,131 +236,6 @@ radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR
    return false;
 }
 
-static bool
-should_move_rt_instruction(nir_intrinsic_op intrinsic)
-{
-   switch (intrinsic) {
-   case nir_intrinsic_load_rt_arg_scratch_offset_amd:
-   case nir_intrinsic_load_ray_flags:
-   case nir_intrinsic_load_ray_object_origin:
-   case nir_intrinsic_load_ray_world_origin:
-   case nir_intrinsic_load_ray_t_min:
-   case nir_intrinsic_load_ray_object_direction:
-   case nir_intrinsic_load_ray_world_direction:
-   case nir_intrinsic_load_ray_t_max:
-      return true;
-   default:
-      return false;
-   }
-}
-
-static void
-move_rt_instructions(nir_shader *shader)
-{
-   nir_cursor target = nir_before_cf_list(&nir_shader_get_entrypoint(shader)->body);
-
-   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
-      nir_foreach_instr_safe (instr, block) {
-         if (instr->type != nir_instr_type_intrinsic)
-            continue;
-
-         nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
-
-         if (!should_move_rt_instruction(intrinsic->intrinsic))
-            continue;
-
-         nir_instr_move(target, instr);
-      }
-   }
-
-   nir_metadata_preserve(nir_shader_get_entrypoint(shader),
-                         nir_metadata_all & (~nir_metadata_instr_index));
-}
-
-static nir_shader *
-create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                 struct radv_pipeline_shader_stack_size *stack_sizes)
-{
-   struct radv_pipeline_key key;
-   memset(&key, 0, sizeof(key));
-
-   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
-   b.shader->info.internal = false;
-   b.shader->info.workgroup_size[0] = 8;
-   b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
-
-   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
-   load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0);
-   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
-      nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1);
-   else
-      nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
-
-   nir_loop *loop = nir_push_loop(&b);
-
-   nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
-   nir_jump(&b, nir_jump_break);
-   nir_pop_if(&b, NULL);
-
-   nir_ssa_def *idx = nir_load_var(&b, vars.idx);
-
-   /* Insert traversal shader */
-   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes);
-   assert(b.shader->info.shared_size == 0);
-   b.shader->info.shared_size = traversal->info.shared_size;
-   assert(b.shader->info.shared_size <= 32768);
-   insert_rt_case(&b, traversal, &vars, idx, 0, 1);
-
-   /* We do a trick with the indexing of the resume shaders so that the first
-    * shader of stage x always gets id x and the resume shader ids then come after
-    * stageCount. This makes the shadergroup handles independent of compilation. */
-   unsigned call_idx_base = pCreateInfo->stageCount + 1;
-   for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) {
-      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
-      gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
-      if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE &&
-          type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
-         continue;
-
-      nir_shader *nir_stage = parse_rt_stage(device, stage);
-
-      /* Move ray tracing system values to the top that are set by rt_trace_ray
-       * to prevent them from being overwritten by other rt_trace_ray calls.
-       */
-      NIR_PASS_V(nir_stage, move_rt_instructions);
-
-      const nir_lower_shader_calls_options opts = {
-         .address_format = nir_address_format_32bit_offset,
-         .stack_alignment = 16,
-         .localized_loads = true
-      };
-      uint32_t num_resume_shaders = 0;
-      nir_shader **resume_shaders = NULL;
-      nir_lower_shader_calls(nir_stage, &opts, &resume_shaders,
-                             &num_resume_shaders, nir_stage);
-
-      vars.stage_idx = i;
-      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
-      for (unsigned j = 0; j < num_resume_shaders; ++j) {
-         insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
-      }
-      call_idx_base += num_resume_shaders;
-   }
-
-   nir_pop_loop(&b, loop);
-
-   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
-      b.shader->scratch_size = 16; /* To enable scratch. */
-   else
-      b.shader->scratch_size += compute_rt_stack_size(pCreateInfo, stack_sizes);
-
-   /* Deal with all the inline functions. */
-   nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
-   nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
-
-   return b.shader;
-}
-
 static struct radv_pipeline_key
 radv_generate_rt_pipeline_key(const struct radv_ray_tracing_pipeline *pipeline,
                               VkPipelineCreateFlags flags)
diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c
new file mode 100644 (file)
index 0000000..150c272
--- /dev/null
@@ -0,0 +1,1484 @@
+/*
+ * Copyright © 2021 Google
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "nir/nir.h"
+#include "nir/nir_builder.h"
+
+#include "radv_acceleration_structure.h"
+#include "radv_meta.h"
+#include "radv_private.h"
+#include "radv_rt_common.h"
+#include "radv_shader.h"
+
+/* Traversal stack size. This stack is put in LDS and experimentally 16 entries results in best
+ * performance. */
+#define MAX_STACK_ENTRY_COUNT 16
+
+/* The hit attributes are stored on the stack. This is the offset compared to the current stack
+ * pointer of where the hit attrib is stored. */
+const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
+
+static bool
+lower_rt_derefs(nir_shader *shader)
+{
+   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+
+   bool progress = false;
+
+   nir_builder b;
+   nir_builder_init(&b, impl);
+
+   b.cursor = nir_before_cf_list(&impl->body);
+   nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
+
+   nir_foreach_block (block, impl) {
+      nir_foreach_instr_safe (instr, block) {
+         if (instr->type != nir_instr_type_deref)
+            continue;
+
+         nir_deref_instr *deref = nir_instr_as_deref(instr);
+         b.cursor = nir_before_instr(&deref->instr);
+
+         nir_deref_instr *replacement = NULL;
+         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
+            deref->modes = nir_var_function_temp;
+            progress = true;
+
+            if (deref->deref_type == nir_deref_type_var)
+               replacement =
+                  nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
+         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
+            deref->modes = nir_var_function_temp;
+            progress = true;
+
+            if (deref->deref_type == nir_deref_type_var)
+               replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
+                                                  nir_var_function_temp, deref->type, 0);
+         }
+
+         if (replacement != NULL) {
+            nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa);
+            nir_instr_remove(&deref->instr);
+         }
+      }
+   }
+
+   if (progress)
+      nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
+   else
+      nir_metadata_preserve(impl, nir_metadata_all);
+
+   return progress;
+}
+
+/*
+ * Global variables for an RT pipeline
+ */
+struct rt_variables {
+   const VkRayTracingPipelineCreateInfoKHR *create_info;
+
+   /* idx of the next shader to run in the next iteration of the main loop.
+    * During traversal, idx is used to store the SBT index and will contain
+    * the correct resume index upon returning.
+    */
+   nir_variable *idx;
+
+   /* scratch offset of the argument area relative to stack_ptr */
+   nir_variable *arg;
+
+   nir_variable *stack_ptr;
+
+   /* global address of the SBT entry used for the shader */
+   nir_variable *shader_record_ptr;
+
+   /* trace_ray arguments */
+   nir_variable *accel_struct;
+   nir_variable *flags;
+   nir_variable *cull_mask;
+   nir_variable *sbt_offset;
+   nir_variable *sbt_stride;
+   nir_variable *miss_index;
+   nir_variable *origin;
+   nir_variable *tmin;
+   nir_variable *direction;
+   nir_variable *tmax;
+
+   /* Properties of the primitive currently being visited. */
+   nir_variable *primitive_id;
+   nir_variable *geometry_id_and_flags;
+   nir_variable *instance_addr;
+   nir_variable *hit_kind;
+   nir_variable *opaque;
+
+   /* Output variables for intersection & anyhit shaders. */
+   nir_variable *ahit_accept;
+   nir_variable *ahit_terminate;
+
+   /* Array of stack size struct for recording the max stack size for each group. */
+   struct radv_pipeline_shader_stack_size *stack_sizes;
+   unsigned stage_idx;
+};
+
+static void
+reserve_stack_size(struct rt_variables *vars, uint32_t size)
+{
+   for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
+      const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx;
+
+      if (vars->stage_idx == group->generalShader || vars->stage_idx == group->closestHitShader)
+         vars->stack_sizes[group_idx].recursive_size =
+            MAX2(vars->stack_sizes[group_idx].recursive_size, size);
+
+      if (vars->stage_idx == group->anyHitShader || vars->stage_idx == group->intersectionShader)
+         vars->stack_sizes[group_idx].non_recursive_size =
+            MAX2(vars->stack_sizes[group_idx].non_recursive_size, size);
+   }
+}
+
+static struct rt_variables
+create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
+                    struct radv_pipeline_shader_stack_size *stack_sizes)
+{
+   struct rt_variables vars = {
+      .create_info = create_info,
+   };
+   vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
+   vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
+   vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
+   vars.shader_record_ptr =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
+
+   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
+   vars.accel_struct =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
+   vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags");
+   vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
+   vars.sbt_offset =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
+   vars.sbt_stride =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
+   vars.miss_index =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
+   vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
+   vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
+   vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
+   vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
+
+   vars.primitive_id =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
+   vars.geometry_id_and_flags =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
+   vars.instance_addr =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
+   vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
+   vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
+
+   vars.ahit_accept =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
+   vars.ahit_terminate =
+      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
+
+   vars.stack_sizes = stack_sizes;
+   return vars;
+}
+
+/*
+ * Remap all the variables between the two rt_variables struct for inlining.
+ */
+static void
+map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
+                 const struct rt_variables *dst)
+{
+   src->create_info = dst->create_info;
+
+   _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
+   _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
+   _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
+   _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
+
+   _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
+   _mesa_hash_table_insert(var_remap, src->flags, dst->flags);
+   _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask);
+   _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
+   _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
+   _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
+   _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
+   _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
+   _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
+   _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
+
+   _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
+   _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
+   _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
+   _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
+   _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
+   _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
+   _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
+
+   src->stack_sizes = dst->stack_sizes;
+   src->stage_idx = dst->stage_idx;
+}
+
+/*
+ * Create a copy of the global rt variables where the primitive/instance related variables are
+ * independent.This is needed as we need to keep the old values of the global variables around
+ * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
+ * to the outer variables once we commit to a better hit.
+ */
+static struct rt_variables
+create_inner_vars(nir_builder *b, const struct rt_variables *vars)
+{
+   struct rt_variables inner_vars = *vars;
+   inner_vars.idx =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
+   inner_vars.shader_record_ptr = nir_variable_create(
+      b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
+   inner_vars.primitive_id =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
+   inner_vars.geometry_id_and_flags = nir_variable_create(
+      b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
+   inner_vars.tmax =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
+   inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp,
+                                                  glsl_uint64_t_type(), "inner_instance_addr");
+   inner_vars.hit_kind =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
+
+   return inner_vars;
+}
+
+static void
+insert_rt_return(nir_builder *b, const struct rt_variables *vars)
+{
+   nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
+   nir_store_var(b, vars->idx,
+                 nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1);
+}
+
+enum sbt_type {
+   SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
+   SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
+   SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
+   SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
+};
+
+static nir_ssa_def *
+get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
+{
+   nir_ssa_def *desc_base_addr = nir_load_sbt_base_amd(b);
+
+   nir_ssa_def *desc =
+      nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
+
+   nir_ssa_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
+   nir_ssa_def *stride =
+      nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, stride_offset));
+
+   return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride));
+}
+
+static void
+load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx,
+               enum sbt_type binding, unsigned offset)
+{
+   nir_ssa_def *addr = get_sbt_ptr(b, idx, binding);
+
+   nir_ssa_def *load_addr = nir_iadd_imm(b, addr, offset);
+   nir_ssa_def *v_idx = nir_build_load_global(b, 1, 32, load_addr);
+
+   nir_store_var(b, vars->idx, v_idx, 1);
+
+   nir_ssa_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE);
+   nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
+}
+
+/* This lowers all the RT instructions that we do not want to pass on to the combined shader and
+ * that we can implement using the variables from the shader we are going to inline into. */
+static void
+lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)
+{
+   nir_builder b_shader;
+   nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader));
+
+   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
+      nir_foreach_instr_safe (instr, block) {
+         switch (instr->type) {
+         case nir_instr_type_intrinsic: {
+            b_shader.cursor = nir_before_instr(instr);
+            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+            nir_ssa_def *ret = NULL;
+
+            switch (intr->intrinsic) {
+            case nir_intrinsic_rt_execute_callable: {
+               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+               uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
+
+               nir_store_var(
+                  &b_shader, vars->stack_ptr,
+                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
+               nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
+                                 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
+
+               nir_store_var(&b_shader, vars->stack_ptr,
+                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
+                             1);
+               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0);
+
+               nir_store_var(&b_shader, vars->arg,
+                             nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1);
+
+               reserve_stack_size(vars, size + 16);
+               break;
+            }
+            case nir_intrinsic_rt_trace_ray: {
+               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+               uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
+
+               nir_store_var(
+                  &b_shader, vars->stack_ptr,
+                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
+               nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
+                                 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
+
+               nir_store_var(&b_shader, vars->stack_ptr,
+                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
+                             1);
+
+               nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1);
+               nir_store_var(&b_shader, vars->arg,
+                             nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1);
+
+               reserve_stack_size(vars, size + 16);
+
+               /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
+               nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
+               nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1);
+               nir_store_var(&b_shader, vars->cull_mask,
+                             nir_iand_imm(&b_shader, intr->src[2].ssa, 0xff), 0x1);
+               nir_store_var(&b_shader, vars->sbt_offset,
+                             nir_iand_imm(&b_shader, intr->src[3].ssa, 0xf), 0x1);
+               nir_store_var(&b_shader, vars->sbt_stride,
+                             nir_iand_imm(&b_shader, intr->src[4].ssa, 0xf), 0x1);
+               nir_store_var(&b_shader, vars->miss_index,
+                             nir_iand_imm(&b_shader, intr->src[5].ssa, 0xffff), 0x1);
+               nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
+               nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
+               nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
+               nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
+               break;
+            }
+            case nir_intrinsic_rt_resume: {
+               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+
+               nir_store_var(
+                  &b_shader, vars->stack_ptr,
+                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), -size), 1);
+               break;
+            }
+            case nir_intrinsic_rt_return_amd: {
+               if (shader->info.stage == MESA_SHADER_RAYGEN) {
+                  nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1);
+                  break;
+               }
+               insert_rt_return(&b_shader, vars);
+               break;
+            }
+            case nir_intrinsic_load_scratch: {
+               nir_instr_rewrite_src_ssa(
+                  instr, &intr->src[0],
+                  nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
+               continue;
+            }
+            case nir_intrinsic_store_scratch: {
+               nir_instr_rewrite_src_ssa(
+                  instr, &intr->src[1],
+                  nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
+               continue;
+            }
+            case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
+               ret = nir_load_var(&b_shader, vars->arg);
+               break;
+            }
+            case nir_intrinsic_load_shader_record_ptr: {
+               ret = nir_load_var(&b_shader, vars->shader_record_ptr);
+               break;
+            }
+            case nir_intrinsic_load_ray_launch_id: {
+               ret = nir_load_global_invocation_id(&b_shader, 32);
+               break;
+            }
+            case nir_intrinsic_load_ray_launch_size: {
+               nir_ssa_def *launch_size_addr = nir_load_ray_launch_size_addr_amd(&b_shader);
+
+               nir_ssa_def *xy = nir_build_load_smem_amd(&b_shader, 2, launch_size_addr,
+                                                         nir_imm_int(&b_shader, 0));
+               nir_ssa_def *z = nir_build_load_smem_amd(&b_shader, 1, launch_size_addr,
+                                                        nir_imm_int(&b_shader, 8));
+
+               nir_ssa_def *xyz[3] = {
+                  nir_channel(&b_shader, xy, 0),
+                  nir_channel(&b_shader, xy, 1),
+                  z,
+               };
+               ret = nir_vec(&b_shader, xyz, 3);
+               break;
+            }
+            case nir_intrinsic_load_ray_t_min: {
+               ret = nir_load_var(&b_shader, vars->tmin);
+               break;
+            }
+            case nir_intrinsic_load_ray_t_max: {
+               ret = nir_load_var(&b_shader, vars->tmax);
+               break;
+            }
+            case nir_intrinsic_load_ray_world_origin: {
+               ret = nir_load_var(&b_shader, vars->origin);
+               break;
+            }
+            case nir_intrinsic_load_ray_world_direction: {
+               ret = nir_load_var(&b_shader, vars->direction);
+               break;
+            }
+            case nir_intrinsic_load_ray_instance_custom_index: {
+               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
+               nir_ssa_def *custom_instance_and_mask = nir_build_load_global(
+                  &b_shader, 1, 32,
+                  nir_iadd_imm(&b_shader, instance_node_addr,
+                               offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
+               ret = nir_iand_imm(&b_shader, custom_instance_and_mask, 0xFFFFFF);
+               break;
+            }
+            case nir_intrinsic_load_primitive_id: {
+               ret = nir_load_var(&b_shader, vars->primitive_id);
+               break;
+            }
+            case nir_intrinsic_load_ray_geometry_index: {
+               ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
+               ret = nir_iand_imm(&b_shader, ret, 0xFFFFFFF);
+               break;
+            }
+            case nir_intrinsic_load_instance_id: {
+               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
+               ret = nir_build_load_global(
+                  &b_shader, 1, 32,
+                  nir_iadd_imm(&b_shader, instance_node_addr,
+                               offsetof(struct radv_bvh_instance_node, instance_id)));
+               break;
+            }
+            case nir_intrinsic_load_ray_flags: {
+               ret = nir_load_var(&b_shader, vars->flags);
+               break;
+            }
+            case nir_intrinsic_load_ray_hit_kind: {
+               ret = nir_load_var(&b_shader, vars->hit_kind);
+               break;
+            }
+            case nir_intrinsic_load_ray_world_to_object: {
+               unsigned c = nir_intrinsic_column(intr);
+               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
+               nir_ssa_def *wto_matrix[3];
+               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
+
+               nir_ssa_def *vals[3];
+               for (unsigned i = 0; i < 3; ++i)
+                  vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
+
+               ret = nir_vec(&b_shader, vals, 3);
+               break;
+            }
+            case nir_intrinsic_load_ray_object_to_world: {
+               unsigned c = nir_intrinsic_column(intr);
+               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
+               nir_ssa_def *rows[3];
+               for (unsigned r = 0; r < 3; ++r)
+                  rows[r] = nir_build_load_global(
+                     &b_shader, 4, 32,
+                     nir_iadd_imm(&b_shader, instance_node_addr,
+                                  offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
+               ret =
+                  nir_vec3(&b_shader, nir_channel(&b_shader, rows[0], c),
+                           nir_channel(&b_shader, rows[1], c), nir_channel(&b_shader, rows[2], c));
+               break;
+            }
+            case nir_intrinsic_load_ray_object_origin: {
+               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
+               nir_ssa_def *wto_matrix[3];
+               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
+               ret = nir_build_vec3_mat_mult(&b_shader, nir_load_var(&b_shader, vars->origin),
+                                             wto_matrix, true);
+               break;
+            }
+            case nir_intrinsic_load_ray_object_direction: {
+               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
+               nir_ssa_def *wto_matrix[3];
+               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
+               ret = nir_build_vec3_mat_mult(&b_shader, nir_load_var(&b_shader, vars->direction),
+                                             wto_matrix, false);
+               break;
+            }
+            case nir_intrinsic_load_intersection_opaque_amd: {
+               ret = nir_load_var(&b_shader, vars->opaque);
+               break;
+            }
+            case nir_intrinsic_load_cull_mask: {
+               ret = nir_load_var(&b_shader, vars->cull_mask);
+               break;
+            }
+            case nir_intrinsic_ignore_ray_intersection: {
+               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_false(&b_shader), 0x1);
+
+               /* The if is a workaround to avoid having to fix up control flow manually */
+               nir_push_if(&b_shader, nir_imm_true(&b_shader));
+               nir_jump(&b_shader, nir_jump_return);
+               nir_pop_if(&b_shader, NULL);
+               break;
+            }
+            case nir_intrinsic_terminate_ray: {
+               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
+               nir_store_var(&b_shader, vars->ahit_terminate, nir_imm_true(&b_shader), 0x1);
+
+               /* The if is a workaround to avoid having to fix up control flow manually */
+               nir_push_if(&b_shader, nir_imm_true(&b_shader));
+               nir_jump(&b_shader, nir_jump_return);
+               nir_pop_if(&b_shader, NULL);
+               break;
+            }
+            case nir_intrinsic_report_ray_intersection: {
+               nir_push_if(
+                  &b_shader,
+                  nir_iand(
+                     &b_shader,
+                     nir_fge(&b_shader, nir_load_var(&b_shader, vars->tmax), intr->src[0].ssa),
+                     nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
+               {
+                  nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
+                  nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
+                  nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
+               }
+               nir_pop_if(&b_shader, NULL);
+               break;
+            }
+            case nir_intrinsic_load_sbt_offset_amd: {
+               ret = nir_load_var(&b_shader, vars->sbt_offset);
+               break;
+            }
+            case nir_intrinsic_load_sbt_stride_amd: {
+               ret = nir_load_var(&b_shader, vars->sbt_stride);
+               break;
+            }
+            case nir_intrinsic_load_accel_struct_amd: {
+               ret = nir_load_var(&b_shader, vars->accel_struct);
+               break;
+            }
+            case nir_intrinsic_execute_closest_hit_amd: {
+               nir_store_var(&b_shader, vars->tmax, intr->src[1].ssa, 0x1);
+               nir_store_var(&b_shader, vars->primitive_id, intr->src[2].ssa, 0x1);
+               nir_store_var(&b_shader, vars->instance_addr, intr->src[3].ssa, 0x1);
+               nir_store_var(&b_shader, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
+               nir_store_var(&b_shader, vars->hit_kind, intr->src[5].ssa, 0x1);
+               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_HIT, 0);
+
+               nir_ssa_def *should_return =
+                  nir_ior(&b_shader,
+                          nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags),
+                                        SpvRayFlagsSkipClosestHitShaderKHRMask),
+                          nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
+
+               /* should_return is set if we had a hit but we won't be calling the closest hit
+                * shader and hence need to return immediately to the calling shader. */
+               nir_push_if(&b_shader, should_return);
+               insert_rt_return(&b_shader, vars);
+               nir_pop_if(&b_shader, NULL);
+               break;
+            }
+            case nir_intrinsic_execute_miss_amd: {
+               nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 0x1);
+               nir_ssa_def *undef = nir_ssa_undef(&b_shader, 1, 32);
+               nir_store_var(&b_shader, vars->primitive_id, undef, 0x1);
+               nir_store_var(&b_shader, vars->instance_addr, nir_ssa_undef(&b_shader, 1, 64), 0x1);
+               nir_store_var(&b_shader, vars->geometry_id_and_flags, undef, 0x1);
+               nir_store_var(&b_shader, vars->hit_kind, undef, 0x1);
+               nir_ssa_def *miss_index = nir_load_var(&b_shader, vars->miss_index);
+               load_sbt_entry(&b_shader, vars, miss_index, SBT_MISS, 0);
+               break;
+            }
+            default:
+               continue;
+            }
+
+            if (ret)
+               nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
+            nir_instr_remove(instr);
+            break;
+         }
+         case nir_instr_type_jump: {
+            nir_jump_instr *jump = nir_instr_as_jump(instr);
+            if (jump->type == nir_jump_halt) {
+               b_shader.cursor = nir_instr_remove(instr);
+               nir_jump(&b_shader, nir_jump_return);
+            }
+            break;
+         }
+         default:
+            break;
+         }
+      }
+   }
+
+   nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
+}
+
+static void
+insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
+               uint32_t call_idx_base, uint32_t call_idx)
+{
+   struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
+
+   nir_opt_dead_cf(shader);
+
+   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes);
+   map_rt_variables(var_remap, &src_vars, vars);
+
+   NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
+
+   NIR_PASS(_, shader, nir_opt_remove_phis);
+   NIR_PASS(_, shader, nir_lower_returns);
+   NIR_PASS(_, shader, nir_opt_dce);
+
+   reserve_stack_size(vars, shader->scratch_size);
+
+   nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
+   nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
+   nir_pop_if(b, NULL);
+
+   /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
+   ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
+
+   ralloc_free(var_remap);
+}
+
+static nir_shader *
+parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo)
+{
+   struct radv_pipeline_key key;
+   memset(&key, 0, sizeof(key));
+
+   struct radv_pipeline_stage rt_stage;
+
+   radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage));
+
+   nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key);
+
+   if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
+       shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) {
+      nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader));
+      nir_builder b_inner;
+      nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader));
+      b_inner.cursor = nir_after_block(last_block);
+      nir_rt_return_amd(&b_inner);
+   }
+
+   NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
+            nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
+            glsl_get_natural_size_align_bytes);
+
+   NIR_PASS(_, shader, lower_rt_derefs);
+
+   NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp,
+            nir_address_format_32bit_offset);
+
+   return shader;
+}
+
+static nir_function_impl *
+lower_any_hit_for_intersection(nir_shader *any_hit)
+{
+   nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
+
+   /* Any-hit shaders need three parameters */
+   assert(impl->function->num_params == 0);
+   nir_parameter params[] = {
+      {
+         /* A pointer to a boolean value for whether or not the hit was
+          * accepted.
+          */
+         .num_components = 1,
+         .bit_size = 32,
+      },
+      {
+         /* The hit T value */
+         .num_components = 1,
+         .bit_size = 32,
+      },
+      {
+         /* The hit kind */
+         .num_components = 1,
+         .bit_size = 32,
+      },
+   };
+   impl->function->num_params = ARRAY_SIZE(params);
+   impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
+   memcpy(impl->function->params, params, sizeof(params));
+
+   nir_builder build;
+   nir_builder_init(&build, impl);
+   nir_builder *b = &build;
+
+   b->cursor = nir_before_cf_list(&impl->body);
+
+   nir_ssa_def *commit_ptr = nir_load_param(b, 0);
+   nir_ssa_def *hit_t = nir_load_param(b, 1);
+   nir_ssa_def *hit_kind = nir_load_param(b, 2);
+
+   nir_deref_instr *commit =
+      nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
+
+   nir_foreach_block_safe (block, impl) {
+      nir_foreach_instr_safe (instr, block) {
+         switch (instr->type) {
+         case nir_instr_type_intrinsic: {
+            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+            switch (intrin->intrinsic) {
+            case nir_intrinsic_ignore_ray_intersection:
+               b->cursor = nir_instr_remove(&intrin->instr);
+               /* We put the newly emitted code inside a dummy if because it's
+                * going to contain a jump instruction and we don't want to
+                * deal with that mess here.  It'll get dealt with by our
+                * control-flow optimization passes.
+                */
+               nir_store_deref(b, commit, nir_imm_false(b), 0x1);
+               nir_push_if(b, nir_imm_true(b));
+               nir_jump(b, nir_jump_return);
+               nir_pop_if(b, NULL);
+               break;
+
+            case nir_intrinsic_terminate_ray:
+               /* The "normal" handling of terminateRay works fine in
+                * intersection shaders.
+                */
+               break;
+
+            case nir_intrinsic_load_ray_t_max:
+               nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t);
+               nir_instr_remove(&intrin->instr);
+               break;
+
+            case nir_intrinsic_load_ray_hit_kind:
+               nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind);
+               nir_instr_remove(&intrin->instr);
+               break;
+
+            default:
+               break;
+            }
+            break;
+         }
+         case nir_instr_type_jump: {
+            nir_jump_instr *jump = nir_instr_as_jump(instr);
+            if (jump->type == nir_jump_halt) {
+               b->cursor = nir_instr_remove(instr);
+               nir_jump(b, nir_jump_return);
+            }
+            break;
+         }
+
+         default:
+            break;
+         }
+      }
+   }
+
+   nir_validate_shader(any_hit, "after initial any-hit lowering");
+
+   nir_lower_returns_impl(impl);
+
+   nir_validate_shader(any_hit, "after lowering returns");
+
+   return impl;
+}
+
+/* Inline the any_hit shader into the intersection shader so we don't have
+ * to implement yet another shader call interface here. Neither do any recursion.
+ */
+static void
+nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
+{
+   void *dead_ctx = ralloc_context(intersection);
+
+   nir_function_impl *any_hit_impl = NULL;
+   struct hash_table *any_hit_var_remap = NULL;
+   if (any_hit) {
+      any_hit = nir_shader_clone(dead_ctx, any_hit);
+      NIR_PASS(_, any_hit, nir_opt_dce);
+      any_hit_impl = lower_any_hit_for_intersection(any_hit);
+      any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
+   }
+
+   nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
+
+   nir_builder build;
+   nir_builder_init(&build, impl);
+   nir_builder *b = &build;
+
+   b->cursor = nir_before_cf_list(&impl->body);
+
+   nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
+   nir_store_var(b, commit, nir_imm_false(b), 0x1);
+
+   nir_foreach_block_safe (block, impl) {
+      nir_foreach_instr_safe (instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+
+         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+         if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
+            continue;
+
+         b->cursor = nir_instr_remove(&intrin->instr);
+         nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1);
+         nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1);
+         nir_ssa_def *min_t = nir_load_ray_t_min(b);
+         nir_ssa_def *max_t = nir_load_ray_t_max(b);
+
+         /* bool commit_tmp = false; */
+         nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
+         nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
+
+         nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
+         {
+            /* Any-hit defaults to commit */
+            nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
+
+            if (any_hit_impl != NULL) {
+               nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
+               {
+                  nir_ssa_def *params[] = {
+                     &nir_build_deref_var(b, commit_tmp)->dest.ssa,
+                     hit_t,
+                     hit_kind,
+                  };
+                  nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
+               }
+               nir_pop_if(b, NULL);
+            }
+
+            nir_push_if(b, nir_load_var(b, commit_tmp));
+            {
+               nir_report_ray_intersection(b, 1, hit_t, hit_kind);
+            }
+            nir_pop_if(b, NULL);
+         }
+         nir_pop_if(b, NULL);
+
+         nir_ssa_def *accepted = nir_load_var(b, commit_tmp);
+         nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted);
+      }
+   }
+
+   /* We did some inlining; have to re-index SSA defs */
+   nir_index_ssa_defs(impl);
+
+   /* Eliminate the casts introduced for the commit return of the any-hit shader. */
+   NIR_PASS(_, intersection, nir_opt_deref);
+
+   ralloc_free(dead_ctx);
+}
+
+/* Variables only used internally to ray traversal. This is data that describes
+ * the current state of the traversal vs. what we'd give to a shader.  e.g. what
+ * is the instance we're currently visiting vs. what is the instance of the
+ * closest hit. */
+struct rt_traversal_vars {
+   nir_variable *origin;
+   nir_variable *dir;
+   nir_variable *inv_dir;
+   nir_variable *sbt_offset_and_flags;
+   nir_variable *instance_addr;
+   nir_variable *hit;
+   nir_variable *bvh_base;
+   nir_variable *stack;
+   nir_variable *top_stack;
+   nir_variable *stack_base;
+   nir_variable *current_node;
+   nir_variable *previous_node;
+   nir_variable *instance_top_node;
+   nir_variable *instance_bottom_node;
+};
+
+static struct rt_traversal_vars
+init_traversal_vars(nir_builder *b)
+{
+   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
+   struct rt_traversal_vars ret;
+
+   ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
+   ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
+   ret.inv_dir =
+      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
+   ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
+                                                  "traversal_sbt_offset_and_flags");
+   ret.instance_addr =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
+   ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
+   ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(),
+                                      "traversal_bvh_base");
+   ret.stack =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
+   ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
+                                       "traversal_top_stack_ptr");
+   ret.stack_base =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_base");
+   ret.current_node =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "current_node;");
+   ret.previous_node =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "previous_node");
+   ret.instance_top_node =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_top_node");
+   ret.instance_bottom_node =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "instance_bottom_node");
+   return ret;
+}
+
+static void
+visit_any_hit_shaders(struct radv_device *device,
+                      const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
+                      struct rt_variables *vars)
+{
+   nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
+
+   nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
+   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
+      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
+      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
+
+      switch (group_info->type) {
+      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
+         shader_id = group_info->anyHitShader;
+         break;
+      default:
+         break;
+      }
+      if (shader_id == VK_SHADER_UNUSED_KHR)
+         continue;
+
+      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
+      nir_shader *nir_stage = parse_rt_stage(device, stage);
+
+      vars->stage_idx = shader_id;
+      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
+   }
+   nir_pop_if(b, NULL);
+}
+
+struct traversal_data {
+   struct radv_device *device;
+   const VkRayTracingPipelineCreateInfoKHR *createInfo;
+   struct rt_variables *vars;
+   struct rt_traversal_vars *trav_vars;
+};
+
+static void
+handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
+                          const struct radv_ray_traversal_args *args)
+{
+   struct traversal_data *data = args->data;
+
+   nir_ssa_def *geometry_id = nir_iand_imm(b, intersection->base.geometry_id_and_flags, 0xfffffff);
+   nir_ssa_def *sbt_idx = nir_iadd(
+      b,
+      nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
+               nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
+      nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
+
+   nir_ssa_def *hit_kind =
+      nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
+
+   nir_ssa_def *barycentrics_addr =
+      nir_iadd_imm(b, nir_load_var(b, data->vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET);
+   nir_ssa_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_addr, .align_mul = 16);
+   nir_store_scratch(b, intersection->barycentrics, barycentrics_addr, .align_mul = 16);
+
+   nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
+   nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
+
+   nir_push_if(b, nir_inot(b, intersection->base.opaque));
+   {
+      struct rt_variables inner_vars = create_inner_vars(b, data->vars);
+
+      nir_store_var(b, inner_vars.primitive_id, intersection->base.primitive_id, 1);
+      nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->base.geometry_id_and_flags,
+                    1);
+      nir_store_var(b, inner_vars.tmax, intersection->t, 0x1);
+      nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr),
+                    0x1);
+      nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
+
+      load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
+
+      visit_any_hit_shaders(data->device, data->createInfo, b, &inner_vars);
+
+      nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
+      {
+         nir_store_scratch(b, prev_barycentrics, barycentrics_addr, .align_mul = 16);
+         nir_jump(b, nir_jump_continue);
+      }
+      nir_pop_if(b, NULL);
+   }
+   nir_pop_if(b, NULL);
+
+   nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1);
+   nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
+   nir_store_var(b, data->vars->tmax, intersection->t, 0x1);
+   nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr),
+                 0x1);
+   nir_store_var(b, data->vars->hit_kind, hit_kind, 0x1);
+
+   nir_store_var(b, data->vars->idx, sbt_idx, 1);
+   nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
+
+   nir_ssa_def *terminate_on_first_hit =
+      nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask);
+   nir_ssa_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
+   nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
+   {
+      nir_jump(b, nir_jump_break);
+   }
+   nir_pop_if(b, NULL);
+}
+
+static void
+handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
+                      const struct radv_ray_traversal_args *args)
+{
+   struct traversal_data *data = args->data;
+
+   nir_ssa_def *geometry_id = nir_iand_imm(b, intersection->geometry_id_and_flags, 0xfffffff);
+   nir_ssa_def *sbt_idx = nir_iadd(
+      b,
+      nir_iadd(b, nir_load_var(b, data->vars->sbt_offset),
+               nir_iand_imm(b, nir_load_var(b, data->trav_vars->sbt_offset_and_flags), 0xffffff)),
+      nir_imul(b, nir_load_var(b, data->vars->sbt_stride), geometry_id));
+
+   struct rt_variables inner_vars = create_inner_vars(b, data->vars);
+
+   /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
+    * next closest hit candidate. */
+   inner_vars.hit_kind = data->vars->hit_kind;
+
+   nir_store_var(b, inner_vars.primitive_id, intersection->primitive_id, 1);
+   nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
+   nir_store_var(b, inner_vars.tmax, nir_load_var(b, data->vars->tmax), 0x1);
+   nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, data->trav_vars->instance_addr), 0x1);
+   nir_store_var(b, inner_vars.opaque, intersection->opaque, 1);
+
+   load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
+
+   nir_store_var(b, data->vars->ahit_accept, nir_imm_false(b), 0x1);
+   nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
+
+   nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
+   for (unsigned i = 0; i < data->createInfo->groupCount; ++i) {
+      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &data->createInfo->pGroups[i];
+      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
+      uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
+
+      switch (group_info->type) {
+      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
+         shader_id = group_info->intersectionShader;
+         any_hit_shader_id = group_info->anyHitShader;
+         break;
+      default:
+         break;
+      }
+      if (shader_id == VK_SHADER_UNUSED_KHR)
+         continue;
+
+      const VkPipelineShaderStageCreateInfo *stage = &data->createInfo->pStages[shader_id];
+      nir_shader *nir_stage = parse_rt_stage(data->device, stage);
+
+      nir_shader *any_hit_stage = NULL;
+      if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
+         stage = &data->createInfo->pStages[any_hit_shader_id];
+         any_hit_stage = parse_rt_stage(data->device, stage);
+
+         nir_lower_intersection_shader(nir_stage, any_hit_stage);
+         ralloc_free(any_hit_stage);
+      }
+
+      inner_vars.stage_idx = shader_id;
+      insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
+   }
+   nir_pop_if(b, NULL);
+
+   nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
+   {
+      nir_store_var(b, data->vars->primitive_id, intersection->primitive_id, 1);
+      nir_store_var(b, data->vars->geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
+      nir_store_var(b, data->vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
+      nir_store_var(b, data->vars->instance_addr, nir_load_var(b, data->trav_vars->instance_addr),
+                    0x1);
+
+      nir_store_var(b, data->vars->idx, sbt_idx, 1);
+      nir_store_var(b, data->trav_vars->hit, nir_imm_true(b), 1);
+
+      nir_ssa_def *terminate_on_first_hit =
+         nir_test_mask(b, args->flags, SpvRayFlagsTerminateOnFirstHitKHRMask);
+      nir_ssa_def *ray_terminated = nir_load_var(b, data->vars->ahit_terminate);
+      nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
+      {
+         nir_jump(b, nir_jump_break);
+      }
+      nir_pop_if(b, NULL);
+   }
+   nir_pop_if(b, NULL);
+}
+
+static void
+store_stack_entry(nir_builder *b, nir_ssa_def *index, nir_ssa_def *value,
+                  const struct radv_ray_traversal_args *args)
+{
+   nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
+}
+
+static nir_ssa_def *
+load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_traversal_args *args)
+{
+   return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
+}
+
+static nir_shader *
+build_traversal_shader(struct radv_device *device,
+                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                       struct radv_pipeline_shader_stack_size *stack_sizes)
+{
+   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
+   b.shader->info.internal = false;
+   b.shader->info.workgroup_size[0] = 8;
+   b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
+   b.shader->info.shared_size =
+      device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
+
+   /* initialize trace_ray arguments */
+   nir_ssa_def *accel_struct = nir_load_accel_struct_amd(&b);
+   nir_store_var(&b, vars.flags, nir_load_ray_flags(&b), 0x1);
+   nir_store_var(&b, vars.cull_mask, nir_load_cull_mask(&b), 0x1);
+   nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
+   nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
+   nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
+   nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
+   nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
+   nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
+   nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1);
+   nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
+
+   struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
+
+   nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1);
+
+   nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0));
+   {
+      nir_ssa_def *bvh_offset = nir_build_load_global(
+         &b, 1, 32,
+         nir_iadd_imm(&b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
+         .access = ACCESS_NON_WRITEABLE);
+      nir_ssa_def *root_bvh_base = nir_iadd(&b, accel_struct, nir_u2u64(&b, bvh_offset));
+      root_bvh_base = build_addr_to_node(&b, root_bvh_base);
+
+      nir_store_var(&b, trav_vars.bvh_base, root_bvh_base, 1);
+
+      nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7);
+
+      nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
+      nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
+      nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)),
+                    7);
+      nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
+      nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
+
+      nir_store_var(&b, trav_vars.stack,
+                    nir_imul_imm(&b, nir_load_local_invocation_index(&b), sizeof(uint32_t)), 1);
+      nir_store_var(&b, trav_vars.stack_base, nir_load_var(&b, trav_vars.stack), 1);
+      nir_store_var(&b, trav_vars.current_node, nir_imm_int(&b, RADV_BVH_ROOT_NODE), 0x1);
+      nir_store_var(&b, trav_vars.previous_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
+      nir_store_var(&b, trav_vars.instance_top_node, nir_imm_int(&b, RADV_BVH_INVALID_NODE), 0x1);
+      nir_store_var(&b, trav_vars.instance_bottom_node, nir_imm_int(&b, RADV_BVH_NO_INSTANCE_ROOT),
+                    0x1);
+
+      nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, -1), 1);
+
+      struct radv_ray_traversal_vars trav_vars_args = {
+         .tmax = nir_build_deref_var(&b, vars.tmax),
+         .origin = nir_build_deref_var(&b, trav_vars.origin),
+         .dir = nir_build_deref_var(&b, trav_vars.dir),
+         .inv_dir = nir_build_deref_var(&b, trav_vars.inv_dir),
+         .bvh_base = nir_build_deref_var(&b, trav_vars.bvh_base),
+         .stack = nir_build_deref_var(&b, trav_vars.stack),
+         .top_stack = nir_build_deref_var(&b, trav_vars.top_stack),
+         .stack_base = nir_build_deref_var(&b, trav_vars.stack_base),
+         .current_node = nir_build_deref_var(&b, trav_vars.current_node),
+         .previous_node = nir_build_deref_var(&b, trav_vars.previous_node),
+         .instance_top_node = nir_build_deref_var(&b, trav_vars.instance_top_node),
+         .instance_bottom_node = nir_build_deref_var(&b, trav_vars.instance_bottom_node),
+         .instance_addr = nir_build_deref_var(&b, trav_vars.instance_addr),
+         .sbt_offset_and_flags = nir_build_deref_var(&b, trav_vars.sbt_offset_and_flags),
+      };
+
+      struct traversal_data data = {
+         .device = device,
+         .createInfo = pCreateInfo,
+         .vars = &vars,
+         .trav_vars = &trav_vars,
+      };
+
+      struct radv_ray_traversal_args args = {
+         .root_bvh_base = root_bvh_base,
+         .flags = nir_load_var(&b, vars.flags),
+         .cull_mask = nir_load_var(&b, vars.cull_mask),
+         .origin = nir_load_var(&b, vars.origin),
+         .tmin = nir_load_var(&b, vars.tmin),
+         .dir = nir_load_var(&b, vars.direction),
+         .vars = trav_vars_args,
+         .stack_stride = device->physical_device->rt_wave_size * sizeof(uint32_t),
+         .stack_entries = MAX_STACK_ENTRY_COUNT,
+         .stack_store_cb = store_stack_entry,
+         .stack_load_cb = load_stack_entry,
+         .aabb_cb = (pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)
+                       ? NULL
+                       : handle_candidate_aabb,
+         .triangle_cb = (pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)
+                           ? NULL
+                           : handle_candidate_triangle,
+         .data = &data,
+      };
+
+      radv_build_ray_traversal(device, &b, &args);
+   }
+   nir_pop_if(&b, NULL);
+
+   /* Initialize follow-up shader. */
+   nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
+   {
+      nir_execute_closest_hit_amd(
+         &b, nir_load_var(&b, vars.idx), nir_load_var(&b, vars.tmax),
+         nir_load_var(&b, vars.primitive_id), nir_load_var(&b, vars.instance_addr),
+         nir_load_var(&b, vars.geometry_id_and_flags), nir_load_var(&b, vars.hit_kind));
+   }
+   nir_push_else(&b, NULL);
+   {
+      /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
+       * for miss shaders if none of the rays miss. */
+      nir_execute_miss_amd(&b, nir_load_var(&b, vars.tmax));
+   }
+   nir_pop_if(&b, NULL);
+
+   /* Deal with all the inline functions. */
+   nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
+   nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
+
+   /* Lower and cleanup variables */
+   NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
+   NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
+
+   return b.shader;
+}
+
+static unsigned
+compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                      const struct radv_pipeline_shader_stack_size *stack_sizes)
+{
+   unsigned raygen_size = 0;
+   unsigned callable_size = 0;
+   unsigned chit_size = 0;
+   unsigned miss_size = 0;
+   unsigned non_recursive_size = 0;
+
+   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
+      non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size);
+
+      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
+      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
+      unsigned size = stack_sizes[i].recursive_size;
+
+      switch (group_info->type) {
+      case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
+         shader_id = group_info->generalShader;
+         break;
+      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
+      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
+         shader_id = group_info->closestHitShader;
+         break;
+      default:
+         break;
+      }
+      if (shader_id == VK_SHADER_UNUSED_KHR)
+         continue;
+
+      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
+      switch (stage->stage) {
+      case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
+         raygen_size = MAX2(raygen_size, size);
+         break;
+      case VK_SHADER_STAGE_MISS_BIT_KHR:
+         miss_size = MAX2(miss_size, size);
+         break;
+      case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
+         chit_size = MAX2(chit_size, size);
+         break;
+      case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
+         callable_size = MAX2(callable_size, size);
+         break;
+      default:
+         unreachable("Invalid stage type in RT shader");
+      }
+   }
+   return raygen_size +
+          MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
+             MAX2(MAX2(chit_size, miss_size), non_recursive_size) +
+          MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) *
+             MAX2(chit_size, miss_size) +
+          2 * callable_size;
+}
+
+static bool
+should_move_rt_instruction(nir_intrinsic_op intrinsic)
+{
+   switch (intrinsic) {
+   case nir_intrinsic_load_rt_arg_scratch_offset_amd:
+   case nir_intrinsic_load_ray_flags:
+   case nir_intrinsic_load_ray_object_origin:
+   case nir_intrinsic_load_ray_world_origin:
+   case nir_intrinsic_load_ray_t_min:
+   case nir_intrinsic_load_ray_object_direction:
+   case nir_intrinsic_load_ray_world_direction:
+   case nir_intrinsic_load_ray_t_max:
+      return true;
+   default:
+      return false;
+   }
+}
+
+static void
+move_rt_instructions(nir_shader *shader)
+{
+   nir_cursor target = nir_before_cf_list(&nir_shader_get_entrypoint(shader)->body);
+
+   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
+      nir_foreach_instr_safe (instr, block) {
+         if (instr->type != nir_instr_type_intrinsic)
+            continue;
+
+         nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
+
+         if (!should_move_rt_instruction(intrinsic->intrinsic))
+            continue;
+
+         nir_instr_move(target, instr);
+      }
+   }
+
+   nir_metadata_preserve(nir_shader_get_entrypoint(shader),
+                         nir_metadata_all & (~nir_metadata_instr_index));
+}
+
+nir_shader *
+create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                 struct radv_pipeline_shader_stack_size *stack_sizes)
+{
+   struct radv_pipeline_key key;
+   memset(&key, 0, sizeof(key));
+
+   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
+   b.shader->info.internal = false;
+   b.shader->info.workgroup_size[0] = 8;
+   b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
+
+   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
+   load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0);
+   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
+      nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1);
+   else
+      nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
+
+   nir_loop *loop = nir_push_loop(&b);
+
+   nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
+   nir_jump(&b, nir_jump_break);
+   nir_pop_if(&b, NULL);
+
+   nir_ssa_def *idx = nir_load_var(&b, vars.idx);
+
+   /* Insert traversal shader */
+   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes);
+   assert(b.shader->info.shared_size == 0);
+   b.shader->info.shared_size = traversal->info.shared_size;
+   assert(b.shader->info.shared_size <= 32768);
+   insert_rt_case(&b, traversal, &vars, idx, 0, 1);
+
+   /* We do a trick with the indexing of the resume shaders so that the first
+    * shader of stage x always gets id x and the resume shader ids then come after
+    * stageCount. This makes the shadergroup handles independent of compilation. */
+   unsigned call_idx_base = pCreateInfo->stageCount + 1;
+   for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) {
+      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
+      gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
+      if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE &&
+          type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
+         continue;
+
+      nir_shader *nir_stage = parse_rt_stage(device, stage);
+
+      /* Move ray tracing system values to the top that are set by rt_trace_ray
+       * to prevent them from being overwritten by other rt_trace_ray calls.
+       */
+      NIR_PASS_V(nir_stage, move_rt_instructions);
+
+      const nir_lower_shader_calls_options opts = {
+         .address_format = nir_address_format_32bit_offset,
+         .stack_alignment = 16,
+         .localized_loads = true};
+      uint32_t num_resume_shaders = 0;
+      nir_shader **resume_shaders = NULL;
+      nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, &num_resume_shaders, nir_stage);
+
+      vars.stage_idx = i;
+      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
+      for (unsigned j = 0; j < num_resume_shaders; ++j) {
+         insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
+      }
+      call_idx_base += num_resume_shaders;
+   }
+
+   nir_pop_loop(&b, loop);
+
+   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
+      b.shader->scratch_size = 16; /* To enable scratch. */
+   else
+      b.shader->scratch_size += compute_rt_stack_size(pCreateInfo, stack_sizes);
+
+   /* Deal with all the inline functions. */
+   nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
+   nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
+
+   return b.shader;
+}
index ae27d0d..0efe7e3 100644 (file)
@@ -763,4 +763,8 @@ bool radv_force_primitive_shading_rate(nir_shader *nir, struct radv_device *devi
 bool radv_lower_fs_intrinsics(nir_shader *nir, const struct radv_pipeline_stage *fs_stage,
                               const struct radv_pipeline_key *key);
 
+nir_shader *create_rt_shader(struct radv_device *device,
+                             const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+                             struct radv_pipeline_shader_stack_size *stack_sizes);
+
 #endif