case nir_intrinsic_load_rt_dynamic_callable_stack_base_amd:
bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)),
get_arg(ctx, ctx->args->ac.rt_dynamic_callable_stack_base));
+ ctx->program->rt_stack = true;
break;
case nir_intrinsic_overwrite_vs_arguments_amd: {
ctx->arg_temps[ctx->args->ac.vertex_id.arg_index] = get_ssa_temp(ctx, instr->src[0].ssa);
uint16_t min_waves = 0;
unsigned workgroup_size; /* if known; otherwise UINT_MAX */
bool wgp_mode;
+ bool rt_stack = false;
bool needs_vcc = false;
get_extra_sgprs(Program* program)
{
/* We don't use this register on GFX6-8 and it's removed on GFX10+. */
- bool needs_flat_scr = program->config->scratch_bytes_per_wave && program->gfx_level == GFX9;
+ bool needs_flat_scr =
+ (program->config->scratch_bytes_per_wave || program->rt_stack) && program->gfx_level == GFX9;
if (program->gfx_level >= GFX10) {
assert(!program->dev.xnack_enabled);
}
case aco_opcode::p_init_scratch: {
assert(program->gfx_level >= GFX8 && program->gfx_level <= GFX10_3);
- if (!program->config->scratch_bytes_per_wave)
+ if (!program->config->scratch_bytes_per_wave && !program->rt_stack)
break;
Operand scratch_addr = instr->operands[0];
{
unsigned scratch_bytes_per_wave = 0;
unsigned max_waves = 0;
+ bool is_rt = pipeline->type == RADV_PIPELINE_RAY_TRACING;
for (int i = 0; i < MESA_VULKAN_SHADER_STAGES; ++i) {
- if (pipeline->shaders[i] && pipeline->shaders[i]->config.scratch_bytes_per_wave) {
+ if (pipeline->shaders[i] && (pipeline->shaders[i]->config.scratch_bytes_per_wave || is_rt)) {
unsigned max_stage_waves = device->scratch_waves;
scratch_bytes_per_wave =
nir_pop_loop(&b, loop);
if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
- b.shader->scratch_size = 4; /* To enable scratch. */
+ b.shader->scratch_size = 0; /* Stack size is set by the application. */
else
b.shader->scratch_size += compute_rt_stack_size(pCreateInfo, stack_sizes);
struct ac_shader_config *config_out)
{
const struct radv_physical_device *pdevice = device->physical_device;
- bool scratch_enabled = config_in->scratch_bytes_per_wave > 0;
+ bool scratch_enabled = config_in->scratch_bytes_per_wave > 0 || info->cs.is_rt_shader;
bool trap_enabled = !!device->trap_handler_shader;
unsigned vgpr_comp_cnt = 0;
unsigned num_input_vgprs = args->ac.num_vgprs_used;