#include "bvh/bvh.h"
#include "meta/radv_meta.h"
+#include "ac_nir.h"
#include "radv_private.h"
#include "radv_rt_common.h"
#include "radv_shader.h"
* the correct resume index upon returning.
*/
nir_variable *idx;
+ nir_variable *shader_va;
+ nir_variable *traversal_addr;
/* scratch offset of the argument area relative to stack_ptr */
nir_variable *arg;
.create_info = create_info,
};
vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
+ vars.shader_va =
+ nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_va");
+ vars.traversal_addr =
+ nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
vars.shader_record_ptr =
src->create_info = dst->create_info;
_mesa_hash_table_insert(var_remap, src->idx, dst->idx);
+ _mesa_hash_table_insert(var_remap, src->shader_va, dst->shader_va);
+ _mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
_mesa_hash_table_insert(var_remap, src->arg, dst->arg);
_mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
_mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
return b.shader;
}
+
+void
+radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+ const struct radv_shader_args *args, const struct radv_pipeline_key *key,
+ uint32_t *stack_size)
+{
+ nir_builder b;
+ nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+ nir_builder_init(&b, impl);
+
+ struct rt_variables vars = create_rt_variables(shader, pCreateInfo);
+ lower_rt_instructions(shader, &vars, 0);
+
+ if (stack_size) {
+ vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
+ *stack_size = MAX2(*stack_size, vars.stack_size);
+ }
+ shader->scratch_size = 0;
+
+ NIR_PASS(_, shader, nir_lower_returns);
+
+ nir_cf_list list;
+ nir_cf_extract(&list, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
+
+ /* initialize variables */
+ b.cursor = nir_before_cf_list(&impl->body);
+
+ nir_ssa_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader);
+ nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
+ nir_ssa_def *shader_va = ac_nir_load_arg(&b, &args->ac, args->ac.rt.next_shader);
+ shader_va = nir_pack_64_2x32(&b, shader_va);
+ nir_store_var(&b, vars.shader_va, shader_va, 1);
+ nir_store_var(&b, vars.stack_ptr,
+ ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
+ nir_ssa_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
+ nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
+ nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
+
+ nir_ssa_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
+ nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
+ nir_store_var(&b, vars.cull_mask_and_flags,
+ ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
+ nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
+ nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
+ nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 1);
+ nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
+ nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
+ nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction),
+ 0x7);
+ nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
+
+ nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id),
+ 1);
+ nir_ssa_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
+ nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
+ nir_store_var(&b, vars.geometry_id_and_flags,
+ ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
+ nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
+
+ /* guard the shader, so that only the correct invocations execute it */
+ nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc);
+ shader_pc = nir_pack_64_2x32(&b, shader_pc);
+ nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va);
+ nir_if *shader_guard = nir_push_if(&b, cond);
+ shader_guard->control = nir_selection_control_divergent_always_taken;
+ nir_cf_reinsert(&list, b.cursor);
+ nir_pop_if(&b, shader_guard);
+
+ /* select next shader */
+ // TODO: use a priority-based selection
+ b.cursor = nir_after_cf_list(&impl->body);
+ shader_va = nir_load_var(&b, vars.shader_va);
+ nir_ssa_def *next = nir_read_first_invocation(&b, shader_va);
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next);
+
+ /* store back all variables to registers */
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base,
+ nir_load_var(&b, vars.stack_ptr));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.next_shader, nir_load_var(&b, vars.shader_va));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_record,
+ nir_load_var(&b, vars.shader_record_ptr));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags,
+ nir_load_var(&b, vars.cull_mask_and_flags));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
+
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags,
+ nir_load_var(&b, vars.geometry_id_and_flags));
+ ac_nir_store_arg(&b, &args->ac, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
+
+ /* cleanup passes */
+ NIR_PASS_V(shader, nir_lower_global_vars_to_local);
+ NIR_PASS_V(shader, nir_lower_vars_to_ssa);
+ if (shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
+ shader->info.stage == MESA_SHADER_INTERSECTION)
+ NIR_PASS_V(shader, lower_hit_attribs, NULL, key->cs.compute_subgroup_size);
+}