radv/rt: Use nir_shader_instructions_pass for lower_rt_instructions
authorKonstantin Seurer <konstantin.seurer@gmail.com>
Tue, 12 Sep 2023 18:43:34 +0000 (20:43 +0200)
committerMarge Bot <emma+marge@anholt.net>
Wed, 18 Oct 2023 08:18:50 +0000 (08:18 +0000)
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25187>

src/amd/vulkan/radv_rt_shader.c

index 9ef5a02..3dea00d 100644 (file)
@@ -270,316 +270,312 @@ load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_def *idx, en
    nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
 }
 
-/* This lowers all the RT instructions that we do not want to pass on to the combined shader and
- * that we can implement using the variables from the shader we are going to inline into. */
-static void
-lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool apply_stack_ptr)
+struct radv_lower_rt_instruction_data {
+   struct rt_variables *vars;
+   bool apply_stack_ptr;
+};
+
+static bool
+radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
 {
-   nir_builder b_shader = nir_builder_create(nir_shader_get_entrypoint(shader));
+   if (instr->type == nir_instr_type_jump) {
+      nir_jump_instr *jump = nir_instr_as_jump(instr);
+      if (jump->type == nir_jump_halt) {
+         jump->type = nir_jump_return;
+         return true;
+      }
+      return false;
+   } else if (instr->type != nir_instr_type_intrinsic) {
+      return false;
+   }
 
-   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
-      nir_foreach_instr_safe (instr, block) {
-         switch (instr->type) {
-         case nir_instr_type_intrinsic: {
-            b_shader.cursor = nir_before_instr(instr);
-            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
-            nir_def *ret = NULL;
+   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
 
-            switch (intr->intrinsic) {
-            case nir_intrinsic_rt_execute_callable: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
-               nir_def *ret_ptr = nir_load_resume_shader_address_amd(&b_shader, nir_intrinsic_call_idx(intr));
-               ret_ptr = nir_ior_imm(&b_shader, ret_ptr, radv_get_rt_priority(shader->info.stage));
+   struct radv_lower_rt_instruction_data *data = _data;
+   struct rt_variables *vars = data->vars;
+   bool apply_stack_ptr = data->apply_stack_ptr;
 
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
-               nir_store_scratch(&b_shader, ret_ptr, nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
+   b->cursor = nir_before_instr(&intr->instr);
 
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16), 1);
-               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
+   nir_def *ret = NULL;
+   switch (intr->intrinsic) {
+   case nir_intrinsic_rt_execute_callable: {
+      uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
+      nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
+      ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
 
-               nir_store_var(&b_shader, vars->arg, nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1);
+      nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
+      nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
 
-               vars->stack_size = MAX2(vars->stack_size, size + 16);
-               break;
-            }
-            case nir_intrinsic_rt_trace_ray: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
-               nir_def *ret_ptr = nir_load_resume_shader_address_amd(&b_shader, nir_intrinsic_call_idx(intr));
-               ret_ptr = nir_ior_imm(&b_shader, ret_ptr, radv_get_rt_priority(shader->info.stage));
-
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
-               nir_store_scratch(&b_shader, ret_ptr, nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
-
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16), 1);
-
-               nir_store_var(&b_shader, vars->shader_addr, nir_load_var(&b_shader, vars->traversal_addr), 1);
-               nir_store_var(&b_shader, vars->arg, nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1);
-
-               vars->stack_size = MAX2(vars->stack_size, size + 16);
-
-               /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
-               nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
-               nir_store_var(&b_shader, vars->cull_mask_and_flags,
-                             nir_ior(&b_shader, nir_ishl_imm(&b_shader, intr->src[2].ssa, 24), intr->src[1].ssa), 0x1);
-               nir_store_var(&b_shader, vars->sbt_offset, nir_iand_imm(&b_shader, intr->src[3].ssa, 0xf), 0x1);
-               nir_store_var(&b_shader, vars->sbt_stride, nir_iand_imm(&b_shader, intr->src[4].ssa, 0xf), 0x1);
-               nir_store_var(&b_shader, vars->miss_index, nir_iand_imm(&b_shader, intr->src[5].ssa, 0xffff), 0x1);
-               nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
-               nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
-               nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
-               nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
-               break;
-            }
-            case nir_intrinsic_rt_resume: {
-               uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
+      nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
+      load_sbt_entry(b, vars, intr->src[0].ssa, SBT_CALLABLE, SBT_RECURSIVE_PTR);
 
-               nir_store_var(&b_shader, vars->stack_ptr,
-                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), -size), 1);
-               break;
-            }
-            case nir_intrinsic_rt_return_amd: {
-               if (shader->info.stage == MESA_SHADER_RAYGEN) {
-                  nir_terminate(&b_shader);
-                  break;
-               }
-               insert_rt_return(&b_shader, vars);
-               break;
-            }
-            case nir_intrinsic_load_scratch: {
-               if (apply_stack_ptr)
-                  nir_src_rewrite(&intr->src[0],
-                                  nir_iadd_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
-               continue;
-            }
-            case nir_intrinsic_store_scratch: {
-               if (apply_stack_ptr)
-                  nir_src_rewrite(&intr->src[1],
-                                  nir_iadd_nuw(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
-               continue;
-            }
-            case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
-               ret = nir_load_var(&b_shader, vars->arg);
-               break;
-            }
-            case nir_intrinsic_load_shader_record_ptr: {
-               ret = nir_load_var(&b_shader, vars->shader_record_ptr);
-               break;
-            }
-            case nir_intrinsic_load_ray_t_min: {
-               ret = nir_load_var(&b_shader, vars->tmin);
-               break;
-            }
-            case nir_intrinsic_load_ray_t_max: {
-               ret = nir_load_var(&b_shader, vars->tmax);
-               break;
-            }
-            case nir_intrinsic_load_ray_world_origin: {
-               ret = nir_load_var(&b_shader, vars->origin);
-               break;
-            }
-            case nir_intrinsic_load_ray_world_direction: {
-               ret = nir_load_var(&b_shader, vars->direction);
-               break;
-            }
-            case nir_intrinsic_load_ray_instance_custom_index: {
-               nir_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_def *custom_instance_and_mask = nir_build_load_global(
-                  &b_shader, 1, 32,
-                  nir_iadd_imm(&b_shader, instance_node_addr,
-                               offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
-               ret = nir_iand_imm(&b_shader, custom_instance_and_mask, 0xFFFFFF);
-               break;
-            }
-            case nir_intrinsic_load_primitive_id: {
-               ret = nir_load_var(&b_shader, vars->primitive_id);
-               break;
-            }
-            case nir_intrinsic_load_ray_geometry_index: {
-               ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
-               ret = nir_iand_imm(&b_shader, ret, 0xFFFFFFF);
-               break;
-            }
-            case nir_intrinsic_load_instance_id: {
-               nir_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               ret = nir_build_load_global(
-                  &b_shader, 1, 32,
-                  nir_iadd_imm(&b_shader, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
-               break;
-            }
-            case nir_intrinsic_load_ray_flags: {
-               ret = nir_iand_imm(&b_shader, nir_load_var(&b_shader, vars->cull_mask_and_flags), 0xFFFFFF);
-               break;
-            }
-            case nir_intrinsic_load_ray_hit_kind: {
-               ret = nir_load_var(&b_shader, vars->hit_kind);
-               break;
-            }
-            case nir_intrinsic_load_ray_world_to_object: {
-               unsigned c = nir_intrinsic_column(intr);
-               nir_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_def *wto_matrix[3];
-               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
+      nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[1].ssa, -size - 16), 1);
 
-               nir_def *vals[3];
-               for (unsigned i = 0; i < 3; ++i)
-                  vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
+      vars->stack_size = MAX2(vars->stack_size, size + 16);
+      break;
+   }
+   case nir_intrinsic_rt_trace_ray: {
+      uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
+      nir_def *ret_ptr = nir_load_resume_shader_address_amd(b, nir_intrinsic_call_idx(intr));
+      ret_ptr = nir_ior_imm(b, ret_ptr, radv_get_rt_priority(b->shader->info.stage));
 
-               ret = nir_vec(&b_shader, vals, 3);
-               break;
-            }
-            case nir_intrinsic_load_ray_object_to_world: {
-               unsigned c = nir_intrinsic_column(intr);
-               nir_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_def *rows[3];
-               for (unsigned r = 0; r < 3; ++r)
-                  rows[r] =
-                     nir_build_load_global(&b_shader, 4, 32,
-                                           nir_iadd_imm(&b_shader, instance_node_addr,
-                                                        offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
-               ret = nir_vec3(&b_shader, nir_channel(&b_shader, rows[0], c), nir_channel(&b_shader, rows[1], c),
-                              nir_channel(&b_shader, rows[2], c));
-               break;
-            }
-            case nir_intrinsic_load_ray_object_origin: {
-               nir_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_def *wto_matrix[3];
-               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
-               ret = nir_build_vec3_mat_mult(&b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix, true);
-               break;
-            }
-            case nir_intrinsic_load_ray_object_direction: {
-               nir_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
-               nir_def *wto_matrix[3];
-               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
-               ret = nir_build_vec3_mat_mult(&b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false);
-               break;
-            }
-            case nir_intrinsic_load_intersection_opaque_amd: {
-               ret = nir_load_var(&b_shader, vars->opaque);
-               break;
-            }
-            case nir_intrinsic_load_cull_mask: {
-               ret = nir_ushr_imm(&b_shader, nir_load_var(&b_shader, vars->cull_mask_and_flags), 24);
-               break;
-            }
-            case nir_intrinsic_ignore_ray_intersection: {
-               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_false(&b_shader), 0x1);
+      nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), size), 1);
+      nir_store_scratch(b, ret_ptr, nir_load_var(b, vars->stack_ptr), .align_mul = 16);
 
-               /* The if is a workaround to avoid having to fix up control flow manually */
-               nir_push_if(&b_shader, nir_imm_true(&b_shader));
-               nir_jump(&b_shader, nir_jump_return);
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_terminate_ray: {
-               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
-               nir_store_var(&b_shader, vars->ahit_terminate, nir_imm_true(&b_shader), 0x1);
-
-               /* The if is a workaround to avoid having to fix up control flow manually */
-               nir_push_if(&b_shader, nir_imm_true(&b_shader));
-               nir_jump(&b_shader, nir_jump_return);
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_report_ray_intersection: {
-               nir_push_if(
-                  &b_shader,
-                  nir_iand(&b_shader, nir_fge(&b_shader, nir_load_var(&b_shader, vars->tmax), intr->src[0].ssa),
-                           nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
-               {
-                  nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
-                  nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
-                  nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
-               }
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_load_sbt_offset_amd: {
-               ret = nir_load_var(&b_shader, vars->sbt_offset);
-               break;
-            }
-            case nir_intrinsic_load_sbt_stride_amd: {
-               ret = nir_load_var(&b_shader, vars->sbt_stride);
-               break;
-            }
-            case nir_intrinsic_load_accel_struct_amd: {
-               ret = nir_load_var(&b_shader, vars->accel_struct);
-               break;
-            }
-            case nir_intrinsic_load_cull_mask_and_flags_amd: {
-               ret = nir_load_var(&b_shader, vars->cull_mask_and_flags);
-               break;
-            }
-            case nir_intrinsic_execute_closest_hit_amd: {
-               nir_store_var(&b_shader, vars->tmax, intr->src[1].ssa, 0x1);
-               nir_store_var(&b_shader, vars->primitive_id, intr->src[2].ssa, 0x1);
-               nir_store_var(&b_shader, vars->instance_addr, intr->src[3].ssa, 0x1);
-               nir_store_var(&b_shader, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
-               nir_store_var(&b_shader, vars->hit_kind, intr->src[5].ssa, 0x1);
-               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
-
-               nir_def *should_return = nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->cull_mask_and_flags),
-                                                      SpvRayFlagsSkipClosestHitShaderKHRMask);
-
-               if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
-                  should_return = nir_ior(&b_shader, should_return,
-                                          nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->shader_addr), 0));
-               }
+      nir_store_var(b, vars->stack_ptr, nir_iadd_imm_nuw(b, nir_load_var(b, vars->stack_ptr), 16), 1);
 
-               /* should_return is set if we had a hit but we won't be calling the closest hit
-                * shader and hence need to return immediately to the calling shader. */
-               nir_push_if(&b_shader, should_return);
-               insert_rt_return(&b_shader, vars);
-               nir_pop_if(&b_shader, NULL);
-               break;
-            }
-            case nir_intrinsic_execute_miss_amd: {
-               nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 0x1);
-               nir_def *undef = nir_undef(&b_shader, 1, 32);
-               nir_store_var(&b_shader, vars->primitive_id, undef, 0x1);
-               nir_store_var(&b_shader, vars->instance_addr, nir_undef(&b_shader, 1, 64), 0x1);
-               nir_store_var(&b_shader, vars->geometry_id_and_flags, undef, 0x1);
-               nir_store_var(&b_shader, vars->hit_kind, undef, 0x1);
-               nir_def *miss_index = nir_load_var(&b_shader, vars->miss_index);
-               load_sbt_entry(&b_shader, vars, miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
-
-               if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
-                  /* In case of a NULL miss shader, do nothing and just return. */
-                  nir_push_if(&b_shader, nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->shader_addr), 0));
-                  insert_rt_return(&b_shader, vars);
-                  nir_pop_if(&b_shader, NULL);
-               }
+      nir_store_var(b, vars->shader_addr, nir_load_var(b, vars->traversal_addr), 1);
+      nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -size - 16), 1);
 
-               break;
-            }
-            default:
-               continue;
-            }
+      vars->stack_size = MAX2(vars->stack_size, size + 16);
 
-            if (ret)
-               nir_def_rewrite_uses(&intr->def, ret);
-            nir_instr_remove(instr);
-            break;
-         }
-         case nir_instr_type_jump: {
-            nir_jump_instr *jump = nir_instr_as_jump(instr);
-            if (jump->type == nir_jump_halt) {
-               b_shader.cursor = nir_instr_remove(instr);
-               nir_jump(&b_shader, nir_jump_return);
-            }
-            break;
-         }
-         default:
-            break;
-         }
+      /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
+      nir_store_var(b, vars->accel_struct, intr->src[0].ssa, 0x1);
+      nir_store_var(b, vars->cull_mask_and_flags, nir_ior(b, nir_ishl_imm(b, intr->src[2].ssa, 24), intr->src[1].ssa),
+                    0x1);
+      nir_store_var(b, vars->sbt_offset, nir_iand_imm(b, intr->src[3].ssa, 0xf), 0x1);
+      nir_store_var(b, vars->sbt_stride, nir_iand_imm(b, intr->src[4].ssa, 0xf), 0x1);
+      nir_store_var(b, vars->miss_index, nir_iand_imm(b, intr->src[5].ssa, 0xffff), 0x1);
+      nir_store_var(b, vars->origin, intr->src[6].ssa, 0x7);
+      nir_store_var(b, vars->tmin, intr->src[7].ssa, 0x1);
+      nir_store_var(b, vars->direction, intr->src[8].ssa, 0x7);
+      nir_store_var(b, vars->tmax, intr->src[9].ssa, 0x1);
+      break;
+   }
+   case nir_intrinsic_rt_resume: {
+      uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
+
+      nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -size), 1);
+      break;
+   }
+   case nir_intrinsic_rt_return_amd: {
+      if (b->shader->info.stage == MESA_SHADER_RAYGEN) {
+         nir_terminate(b);
+         break;
       }
+      insert_rt_return(b, vars);
+      break;
+   }
+   case nir_intrinsic_load_scratch: {
+      if (apply_stack_ptr)
+         nir_src_rewrite(&intr->src[0], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[0].ssa));
+      return true;
+   }
+   case nir_intrinsic_store_scratch: {
+      if (apply_stack_ptr)
+         nir_src_rewrite(&intr->src[1], nir_iadd_nuw(b, nir_load_var(b, vars->stack_ptr), intr->src[1].ssa));
+      return true;
+   }
+   case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
+      ret = nir_load_var(b, vars->arg);
+      break;
+   }
+   case nir_intrinsic_load_shader_record_ptr: {
+      ret = nir_load_var(b, vars->shader_record_ptr);
+      break;
+   }
+   case nir_intrinsic_load_ray_t_min: {
+      ret = nir_load_var(b, vars->tmin);
+      break;
+   }
+   case nir_intrinsic_load_ray_t_max: {
+      ret = nir_load_var(b, vars->tmax);
+      break;
+   }
+   case nir_intrinsic_load_ray_world_origin: {
+      ret = nir_load_var(b, vars->origin);
+      break;
+   }
+   case nir_intrinsic_load_ray_world_direction: {
+      ret = nir_load_var(b, vars->direction);
+      break;
+   }
+   case nir_intrinsic_load_ray_instance_custom_index: {
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      nir_def *custom_instance_and_mask = nir_build_load_global(
+         b, 1, 32,
+         nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask)));
+      ret = nir_iand_imm(b, custom_instance_and_mask, 0xFFFFFF);
+      break;
+   }
+   case nir_intrinsic_load_primitive_id: {
+      ret = nir_load_var(b, vars->primitive_id);
+      break;
+   }
+   case nir_intrinsic_load_ray_geometry_index: {
+      ret = nir_load_var(b, vars->geometry_id_and_flags);
+      ret = nir_iand_imm(b, ret, 0xFFFFFFF);
+      break;
+   }
+   case nir_intrinsic_load_instance_id: {
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      ret = nir_build_load_global(
+         b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
+      break;
+   }
+   case nir_intrinsic_load_ray_flags: {
+      ret = nir_iand_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 0xFFFFFF);
+      break;
    }
+   case nir_intrinsic_load_ray_hit_kind: {
+      ret = nir_load_var(b, vars->hit_kind);
+      break;
+   }
+   case nir_intrinsic_load_ray_world_to_object: {
+      unsigned c = nir_intrinsic_column(intr);
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      nir_def *wto_matrix[3];
+      nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
+
+      nir_def *vals[3];
+      for (unsigned i = 0; i < 3; ++i)
+         vals[i] = nir_channel(b, wto_matrix[i], c);
+
+      ret = nir_vec(b, vals, 3);
+      break;
+   }
+   case nir_intrinsic_load_ray_object_to_world: {
+      unsigned c = nir_intrinsic_column(intr);
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      nir_def *rows[3];
+      for (unsigned r = 0; r < 3; ++r)
+         rows[r] = nir_build_load_global(
+            b, 4, 32,
+            nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
+      ret = nir_vec3(b, nir_channel(b, rows[0], c), nir_channel(b, rows[1], c), nir_channel(b, rows[2], c));
+      break;
+   }
+   case nir_intrinsic_load_ray_object_origin: {
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      nir_def *wto_matrix[3];
+      nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
+      ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->origin), wto_matrix, true);
+      break;
+   }
+   case nir_intrinsic_load_ray_object_direction: {
+      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
+      nir_def *wto_matrix[3];
+      nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
+      ret = nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false);
+      break;
+   }
+   case nir_intrinsic_load_intersection_opaque_amd: {
+      ret = nir_load_var(b, vars->opaque);
+      break;
+   }
+   case nir_intrinsic_load_cull_mask: {
+      ret = nir_ushr_imm(b, nir_load_var(b, vars->cull_mask_and_flags), 24);
+      break;
+   }
+   case nir_intrinsic_ignore_ray_intersection: {
+      nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1);
+
+      /* The if is a workaround to avoid having to fix up control flow manually */
+      nir_push_if(b, nir_imm_true(b));
+      nir_jump(b, nir_jump_return);
+      nir_pop_if(b, NULL);
+      break;
+   }
+   case nir_intrinsic_terminate_ray: {
+      nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
+      nir_store_var(b, vars->ahit_terminate, nir_imm_true(b), 0x1);
+
+      /* The if is a workaround to avoid having to fix up control flow manually */
+      nir_push_if(b, nir_imm_true(b));
+      nir_jump(b, nir_jump_return);
+      nir_pop_if(b, NULL);
+      break;
+   }
+   case nir_intrinsic_report_ray_intersection: {
+      nir_push_if(b, nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), intr->src[0].ssa),
+                              nir_fge(b, intr->src[0].ssa, nir_load_var(b, vars->tmin))));
+      {
+         nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
+         nir_store_var(b, vars->tmax, intr->src[0].ssa, 1);
+         nir_store_var(b, vars->hit_kind, intr->src[1].ssa, 1);
+      }
+      nir_pop_if(b, NULL);
+      break;
+   }
+   case nir_intrinsic_load_sbt_offset_amd: {
+      ret = nir_load_var(b, vars->sbt_offset);
+      break;
+   }
+   case nir_intrinsic_load_sbt_stride_amd: {
+      ret = nir_load_var(b, vars->sbt_stride);
+      break;
+   }
+   case nir_intrinsic_load_accel_struct_amd: {
+      ret = nir_load_var(b, vars->accel_struct);
+      break;
+   }
+   case nir_intrinsic_load_cull_mask_and_flags_amd: {
+      ret = nir_load_var(b, vars->cull_mask_and_flags);
+      break;
+   }
+   case nir_intrinsic_execute_closest_hit_amd: {
+      nir_store_var(b, vars->tmax, intr->src[1].ssa, 0x1);
+      nir_store_var(b, vars->primitive_id, intr->src[2].ssa, 0x1);
+      nir_store_var(b, vars->instance_addr, intr->src[3].ssa, 0x1);
+      nir_store_var(b, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
+      nir_store_var(b, vars->hit_kind, intr->src[5].ssa, 0x1);
+      load_sbt_entry(b, vars, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
+
+      nir_def *should_return =
+         nir_test_mask(b, nir_load_var(b, vars->cull_mask_and_flags), SpvRayFlagsSkipClosestHitShaderKHRMask);
+
+      if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_CLOSEST_HIT_SHADERS_BIT_KHR)) {
+         should_return = nir_ior(b, should_return, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
+      }
 
-   nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
+      /* should_return is set if we had a hit but we won't be calling the closest hit
+       * shader and hence need to return immediately to the calling shader. */
+      nir_push_if(b, should_return);
+      insert_rt_return(b, vars);
+      nir_pop_if(b, NULL);
+      break;
+   }
+   case nir_intrinsic_execute_miss_amd: {
+      nir_store_var(b, vars->tmax, intr->src[0].ssa, 0x1);
+      nir_def *undef = nir_undef(b, 1, 32);
+      nir_store_var(b, vars->primitive_id, undef, 0x1);
+      nir_store_var(b, vars->instance_addr, nir_undef(b, 1, 64), 0x1);
+      nir_store_var(b, vars->geometry_id_and_flags, undef, 0x1);
+      nir_store_var(b, vars->hit_kind, undef, 0x1);
+      nir_def *miss_index = nir_load_var(b, vars->miss_index);
+      load_sbt_entry(b, vars, miss_index, SBT_MISS, SBT_RECURSIVE_PTR);
+
+      if (!(vars->flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_MISS_SHADERS_BIT_KHR)) {
+         /* In case of a NULL miss shader, do nothing and just return. */
+         nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->shader_addr), 0));
+         insert_rt_return(b, vars);
+         nir_pop_if(b, NULL);
+      }
+
+      break;
+   }
+   default:
+      return false;
+   }
+
+   if (ret)
+      nir_def_rewrite_uses(&intr->def, ret);
+   nir_instr_remove(&intr->instr);
+
+   return true;
+}
+
+/* This lowers all the RT instructions that we do not want to pass on to the combined shader and
+ * that we can implement using the variables from the shader we are going to inline into. */
+static void
+lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool apply_stack_ptr)
+{
+   struct radv_lower_rt_instruction_data data = {
+      .vars = vars,
+      .apply_stack_ptr = apply_stack_ptr,
+   };
+   nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data);
 }
 
 static bool