&ctx->ac, ctx->gs_ngg_scratch, LLVMConstInt(ctx->ac.i32, 12 + 8 * stream, false));
primemit_scan[stream].waveidx = get_wave_id_in_tg(ctx);
primemit_scan[stream].numwaves = get_tgsize(ctx);
- primemit_scan[stream].maxwaves = 8;
+ if (ctx->stage == MESA_SHADER_GEOMETRY) {
+ /* ngg_subgroup_size is only the input size. GS can always generate up to 256 vertices. */
+ primemit_scan[stream].maxwaves = DIV_ROUND_UP(256, ctx->ac.wave_size);
+ } else {
+ primemit_scan[stream].maxwaves = DIV_ROUND_UP(ctx->screen->ngg_subgroup_size,
+ ctx->ac.wave_size);
+ }
ac_build_wg_scan_top(&ctx->ac, &primemit_scan[stream]);
}
}
struct si_shader_selector *sel = shader->selector;
struct si_shader_info *info = &sel->info;
LLVMBuilderRef builder = ctx->ac.builder;
- unsigned subgroup_size = ctx->screen->ngg_subgroup_size;
- unsigned max_waves = ctx->ac.wave_size == 64 ? DIV_ROUND_UP(subgroup_size, 64) :
- DIV_ROUND_UP(subgroup_size, 32);
+ unsigned max_waves = DIV_ROUND_UP(ctx->screen->ngg_subgroup_size, ctx->ac.wave_size);
assert(shader->key.opt.ngg_culling);
assert(shader->key.as_ngg);
vertlive_scan.scratch = ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, ctx->ac.i32_0);
vertlive_scan.waveidx = get_wave_id_in_tg(ctx);
vertlive_scan.numwaves = get_tgsize(ctx);
- vertlive_scan.maxwaves = 8;
+ vertlive_scan.maxwaves = DIV_ROUND_UP(256, ctx->ac.wave_size);
ac_build_wg_scan(&ctx->ac, &vertlive_scan);