From d7a1916798bb8e0331223c3de9c7398d46c16bc8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Marcin=20=C5=9Alusarz?= Date: Mon, 12 Dec 2022 14:28:05 +0100 Subject: [PATCH] anv: handle mesh shaders with max primitives == 0 Reviewed-by: Caio Oliveira Part-of: --- src/intel/vulkan/anv_pipeline.c | 86 +++++++++++++++++++++++++--------------- src/intel/vulkan/genX_pipeline.c | 2 +- 2 files changed, 54 insertions(+), 34 deletions(-) diff --git a/src/intel/vulkan/anv_pipeline.c b/src/intel/vulkan/anv_pipeline.c index 70ba8fe..c8b5dc8 100644 --- a/src/intel/vulkan/anv_pipeline.c +++ b/src/intel/vulkan/anv_pipeline.c @@ -45,34 +45,20 @@ /* Needed for SWIZZLE macros */ #include "program/prog_instruction.h" -struct lower_mesh_ext_state { +struct lower_set_vtx_and_prim_count_state { nir_variable *primitive_count; - nir_variable *primitive_indices; }; -static bool -anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data) +static nir_variable * +anv_nir_prim_count_store(nir_builder *b, nir_ssa_def *val) { - if (instr->type != nir_instr_type_intrinsic) - return false; - - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count) - return false; - - struct lower_mesh_ext_state *state = data; - /* this intrinsic should show up only once */ - assert(state->primitive_count == NULL); - - state->primitive_count = + nir_variable *primitive_count = nir_variable_create(b->shader, nir_var_shader_out, glsl_uint_type(), "gl_PrimitiveCountNV"); - state->primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT; - state->primitive_count->data.interpolation = INTERP_MODE_NONE; - - b->cursor = nir_before_instr(&intrin->instr); + primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT; + primitive_count->data.interpolation = INTERP_MODE_NONE; nir_ssa_def *local_invocation_index = nir_build_load_local_invocation_index(b); @@ -80,24 +66,61 @@ anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data) nir_imm_int(b, 0)); nir_if *if_stmt = nir_push_if(b, cmp); { - nir_deref_instr *prim_count_deref = nir_build_deref_var(b, state->primitive_count); - nir_store_deref(b, prim_count_deref, intrin->src[1].ssa, 1); + nir_deref_instr *prim_count_deref = nir_build_deref_var(b, primitive_count); + nir_store_deref(b, prim_count_deref, val, 1); } nir_pop_if(b, if_stmt); + return primitive_count; +} + +static bool +anv_nir_lower_set_vtx_and_prim_count_instr(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count) + return false; + + struct lower_set_vtx_and_prim_count_state *state = data; + /* this intrinsic should show up only once */ + assert(state->primitive_count == NULL); + + b->cursor = nir_before_instr(&intrin->instr); + + state->primitive_count = anv_nir_prim_count_store(b, intrin->src[1].ssa); + nir_instr_remove(instr); return true; } static bool -anv_nir_lower_mesh_ext(nir_shader *nir) +anv_nir_lower_set_vtx_and_prim_count(nir_shader *nir) { - struct lower_mesh_ext_state state = { NULL, }; + struct lower_set_vtx_and_prim_count_state state = { NULL, }; - return nir_shader_instructions_pass(nir, anv_nir_lower_mesh_ext_instr, - nir_metadata_none, - &state); + nir_shader_instructions_pass(nir, + anv_nir_lower_set_vtx_and_prim_count_instr, + nir_metadata_none, + &state); + + /* If we didn't find set_vertex_and_primitive_count, then we have to + * insert store of value 0 to primitive_count. + */ + if (state.primitive_count == NULL) { + nir_builder b; + nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir); + nir_builder_init(&b, entrypoint); + b.cursor = nir_before_block(nir_start_block(entrypoint)); + nir_ssa_def *zero = nir_imm_int(&b, 0); + state.primitive_count = anv_nir_prim_count_store(&b, zero); + } + + assert(state.primitive_count != NULL); + return true; } /* Eventually, this will become part of anv_CreateShader. Unfortunately, @@ -231,12 +254,9 @@ anv_shader_stage_to_nir(struct anv_device *device, brw_preprocess_nir(compiler, nir, &opts); if (nir->info.stage == MESA_SHADER_MESH && !nir->info.mesh.nv) { - bool progress = false; - NIR_PASS(progress, nir, anv_nir_lower_mesh_ext); - if (progress) { - NIR_PASS(_, nir, nir_opt_dce); - NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL); - } + NIR_PASS(_, nir, anv_nir_lower_set_vtx_and_prim_count); + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL); } return nir; diff --git a/src/intel/vulkan/genX_pipeline.c b/src/intel/vulkan/genX_pipeline.c index 5e530a5..2ed9193 100644 --- a/src/intel/vulkan/genX_pipeline.c +++ b/src/intel/vulkan/genX_pipeline.c @@ -1806,7 +1806,7 @@ emit_mesh_state(struct anv_graphics_pipeline *pipeline) mesh.LocalXMaximum = mesh_dispatch.group_size - 1; mesh.EmitLocalIDX = true; - mesh.MaximumPrimitiveCount = mesh_prog_data->map.max_primitives - 1; + mesh.MaximumPrimitiveCount = MAX2(mesh_prog_data->map.max_primitives, 1) - 1; mesh.OutputTopology = output_topology; mesh.PerVertexDataPitch = mesh_prog_data->map.per_vertex_pitch_dw / 8; mesh.PerPrimitiveDataPresent = mesh_prog_data->map.per_primitive_pitch_dw > 0; -- 2.7.4