* performance. */
#define MAX_STACK_ENTRY_COUNT 16
-/* The hit attributes are stored on the stack. This is the offset compared to the current stack
- * pointer of where the hit attrib is stored. */
-const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
-
static bool
lower_rt_derefs(nir_shader *shader)
{
continue;
nir_deref_instr *deref = nir_instr_as_deref(instr);
- b.cursor = nir_before_instr(&deref->instr);
-
- nir_deref_instr *replacement = NULL;
- if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
- deref->modes = nir_var_function_temp;
- progress = true;
-
- if (deref->deref_type == nir_deref_type_var)
- replacement =
- nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
- } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
- deref->modes = nir_var_function_temp;
- progress = true;
-
- if (deref->deref_type == nir_deref_type_var)
- replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
- nir_var_function_temp, deref->type, 0);
- }
+ if (!nir_deref_mode_is(deref, nir_var_shader_call_data))
+ continue;
+
+ deref->modes = nir_var_function_temp;
+ progress = true;
- if (replacement != NULL) {
+ if (deref->deref_type == nir_deref_type_var) {
+ b.cursor = nir_before_instr(&deref->instr);
+ nir_deref_instr *replacement =
+ nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa);
nir_instr_remove(&deref->instr);
}
switch (intr->intrinsic) {
case nir_intrinsic_rt_execute_callable: {
- uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+ uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
nir_store_var(
break;
}
case nir_intrinsic_rt_trace_ray: {
- uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+ uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
nir_store_var(
break;
}
case nir_intrinsic_rt_resume: {
- uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
+ uint32_t size = align(nir_intrinsic_stack_size(intr), 16);
nir_store_var(
&b_shader, vars->stack_ptr,
nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
}
+static bool
+lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
+{
+ if (instr->type != nir_instr_type_intrinsic)
+ return false;
+
+ nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+ if (intrin->intrinsic != nir_intrinsic_load_deref &&
+ intrin->intrinsic != nir_intrinsic_store_deref)
+ return false;
+
+ nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
+ if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
+ return false;
+
+ assert(deref->deref_type == nir_deref_type_var);
+
+ b->cursor = nir_after_instr(instr);
+
+ if (intrin->intrinsic == nir_intrinsic_load_deref) {
+ uint32_t num_components = intrin->dest.ssa.num_components;
+ uint32_t bit_size = intrin->dest.ssa.bit_size;
+
+ nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
+
+ for (uint32_t comp = 0; comp < num_components; comp++) {
+ uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
+ uint32_t base = offset / 4;
+ uint32_t comp_offset = offset % 4;
+
+ if (bit_size == 64) {
+ components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
+ nir_load_hit_attrib_amd(b, .base = base + 1));
+ } else if (bit_size == 32) {
+ components[comp] = nir_load_hit_attrib_amd(b, .base = base);
+ } else if (bit_size == 16) {
+ components[comp] = nir_channel(
+ b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
+ } else if (bit_size == 8) {
+ components[comp] = nir_channel(
+ b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
+ } else {
+ unreachable("Invalid bit_size");
+ }
+ }
+
+ nir_ssa_def_rewrite_uses(&intrin->dest.ssa, nir_vec(b, components, num_components));
+ } else {
+ nir_ssa_def *value = intrin->src[1].ssa;
+ uint32_t num_components = value->num_components;
+ uint32_t bit_size = value->bit_size;
+
+ for (uint32_t comp = 0; comp < num_components; comp++) {
+ uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8;
+ uint32_t base = offset / 4;
+ uint32_t comp_offset = offset % 4;
+
+ nir_ssa_def *component = nir_channel(b, value, comp);
+
+ if (bit_size == 64) {
+ nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
+ nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
+ } else if (bit_size == 32) {
+ nir_store_hit_attrib_amd(b, component, .base = base);
+ } else if (bit_size == 16) {
+ nir_ssa_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
+ nir_ssa_def *components[2];
+ for (uint32_t word = 0; word < 2; word++)
+ components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp)
+ : nir_channel(b, prev, word);
+ nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)),
+ .base = base);
+ } else if (bit_size == 8) {
+ nir_ssa_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
+ nir_ssa_def *components[4];
+ for (uint32_t byte = 0; byte < 4; byte++)
+ components[byte] =
+ (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
+ nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)),
+ .base = base);
+ } else {
+ unreachable("Invalid bit_size");
+ }
+ }
+ }
+
+ nir_instr_remove(instr);
+ return true;
+}
+
+static bool
+lower_hit_attrib_derefs(nir_shader *shader)
+{
+ bool progress = nir_shader_instructions_pass(
+ shader, lower_hit_attrib_deref, nir_metadata_block_index | nir_metadata_dominance, NULL);
+ if (progress) {
+ nir_remove_dead_derefs(shader);
+ nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
+ }
+
+ return progress;
+}
+
static void
insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
uint32_t call_idx_base, uint32_t call_idx)
nir_rt_return_amd(&b_inner);
}
+ NIR_PASS(_, shader, nir_split_struct_vars, nir_var_ray_hit_attrib);
+ NIR_PASS(_, shader, nir_lower_indirect_derefs, nir_var_ray_hit_attrib, UINT32_MAX);
+ NIR_PASS(_, shader, nir_split_array_vars, nir_var_ray_hit_attrib);
+
NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
glsl_get_natural_size_align_bytes);
NIR_PASS(_, shader, lower_rt_derefs);
+ NIR_PASS(_, shader, lower_hit_attrib_derefs);
NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp,
nir_address_format_32bit_offset);
const VkRayTracingPipelineCreateInfoKHR *createInfo;
struct rt_variables *vars;
struct rt_traversal_vars *trav_vars;
+ nir_variable *barycentrics;
};
static void
nir_ssa_def *hit_kind =
nir_bcsel(b, intersection->frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
- nir_ssa_def *barycentrics_addr =
- nir_iadd_imm_nuw(b, nir_load_var(b, data->vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET);
- nir_ssa_def *prev_barycentrics = nir_load_scratch(b, 2, 32, barycentrics_addr, .align_mul = 16);
- nir_store_scratch(b, intersection->barycentrics, barycentrics_addr, .align_mul = 16);
+ nir_ssa_def *prev_barycentrics = nir_load_var(b, data->barycentrics);
+ nir_store_var(b, data->barycentrics, intersection->barycentrics, 0x3);
nir_store_var(b, data->vars->ahit_accept, nir_imm_true(b), 0x1);
nir_store_var(b, data->vars->ahit_terminate, nir_imm_false(b), 0x1);
nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
{
- nir_store_scratch(b, prev_barycentrics, barycentrics_addr, .align_mul = 16);
+ nir_store_var(b, data->barycentrics, prev_barycentrics, 0x3);
nir_jump(b, nir_jump_continue);
}
nir_pop_if(b, NULL);
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes)
{
- nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
+ /* Create the traversal shader as an intersection shader to prevent validation failures due to
+ * invalid variable modes.*/
+ nir_builder b = radv_meta_init_shader(device, MESA_SHADER_INTERSECTION, "rt_traversal");
b.shader->info.internal = false;
b.shader->info.workgroup_size[0] = 8;
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
+ nir_variable *barycentrics = nir_variable_create(
+ b.shader, nir_var_ray_hit_attrib, glsl_vector_type(GLSL_TYPE_FLOAT, 2), "barycentrics");
+ barycentrics->data.driver_location = 0;
+
/* initialize trace_ray arguments */
nir_ssa_def *accel_struct = nir_load_accel_struct_amd(&b);
nir_store_var(&b, vars.flags, nir_load_ray_flags(&b), 0x1);
.createInfo = pCreateInfo,
.vars = &vars,
.trav_vars = &trav_vars,
+ .barycentrics = barycentrics,
};
struct radv_ray_traversal_args args = {
NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
+ lower_hit_attrib_derefs(b.shader);
+
return b.shader;
}
nir_metadata_all & (~nir_metadata_instr_index));
}
+static void
+lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs)
+{
+ nir_function_impl *impl = nir_shader_get_entrypoint(shader);
+
+ nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib)
+ attrib->data.mode = nir_var_shader_temp;
+
+ nir_builder b;
+ nir_builder_init(&b, impl);
+
+ nir_foreach_block (block, impl) {
+ nir_foreach_instr_safe (instr, block) {
+ if (instr->type != nir_instr_type_intrinsic)
+ continue;
+
+ b.cursor = nir_after_instr(instr);
+ nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+ if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) {
+ nir_ssa_def *ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]);
+ nir_ssa_def_rewrite_uses(nir_instr_ssa_def(instr), ret);
+ nir_instr_remove(instr);
+ } else if (intrin->intrinsic == nir_intrinsic_store_hit_attrib_amd) {
+ nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1);
+ nir_instr_remove(instr);
+ }
+ }
+ }
+
+ nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
+
+ nir_lower_global_vars_to_local(shader);
+ nir_lower_vars_to_ssa(shader);
+}
+
nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes)
else
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
+ nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
+ for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
+ hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader),
+ glsl_uint_type(), "attribute");
+
nir_loop *loop = nir_push_loop(&b);
nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
+ lower_hit_attribs(b.shader, hit_attribs);
+
return b.shader;
}