anv: handle mesh shaders with max primitives == 0
authorMarcin Ślusarz <marcin.slusarz@intel.com>
Mon, 12 Dec 2022 13:28:05 +0000 (14:28 +0100)
committerMarge Bot <emma+marge@anholt.net>
Wed, 14 Dec 2022 09:55:10 +0000 (09:55 +0000)
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20279>

src/intel/vulkan/anv_pipeline.c
src/intel/vulkan/genX_pipeline.c

index 70ba8fe..c8b5dc8 100644 (file)
 /* Needed for SWIZZLE macros */
 #include "program/prog_instruction.h"
 
-struct lower_mesh_ext_state {
+struct lower_set_vtx_and_prim_count_state {
    nir_variable *primitive_count;
-   nir_variable *primitive_indices;
 };
 
-static bool
-anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data)
+static nir_variable *
+anv_nir_prim_count_store(nir_builder *b, nir_ssa_def *val)
 {
-   if (instr->type != nir_instr_type_intrinsic)
-      return false;
-
-   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-   if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count)
-      return false;
-
-   struct lower_mesh_ext_state *state = data;
-   /* this intrinsic should show up only once */
-   assert(state->primitive_count == NULL);
-
-   state->primitive_count =
+   nir_variable *primitive_count =
          nir_variable_create(b->shader,
                              nir_var_shader_out,
                              glsl_uint_type(),
                              "gl_PrimitiveCountNV");
-   state->primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT;
-   state->primitive_count->data.interpolation = INTERP_MODE_NONE;
-
-   b->cursor = nir_before_instr(&intrin->instr);
+   primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT;
+   primitive_count->data.interpolation = INTERP_MODE_NONE;
 
    nir_ssa_def *local_invocation_index = nir_build_load_local_invocation_index(b);
 
@@ -80,24 +66,61 @@ anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data)
                                   nir_imm_int(b, 0));
    nir_if *if_stmt = nir_push_if(b, cmp);
    {
-      nir_deref_instr *prim_count_deref = nir_build_deref_var(b, state->primitive_count);
-      nir_store_deref(b, prim_count_deref, intrin->src[1].ssa, 1);
+      nir_deref_instr *prim_count_deref = nir_build_deref_var(b, primitive_count);
+      nir_store_deref(b, prim_count_deref, val, 1);
    }
    nir_pop_if(b, if_stmt);
 
+   return primitive_count;
+}
+
+static bool
+anv_nir_lower_set_vtx_and_prim_count_instr(nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   if (intrin->intrinsic != nir_intrinsic_set_vertex_and_primitive_count)
+      return false;
+
+   struct lower_set_vtx_and_prim_count_state *state = data;
+   /* this intrinsic should show up only once */
+   assert(state->primitive_count == NULL);
+
+   b->cursor = nir_before_instr(&intrin->instr);
+
+   state->primitive_count = anv_nir_prim_count_store(b, intrin->src[1].ssa);
+
    nir_instr_remove(instr);
 
    return true;
 }
 
 static bool
-anv_nir_lower_mesh_ext(nir_shader *nir)
+anv_nir_lower_set_vtx_and_prim_count(nir_shader *nir)
 {
-   struct lower_mesh_ext_state state = { NULL, };
+   struct lower_set_vtx_and_prim_count_state state = { NULL, };
 
-   return nir_shader_instructions_pass(nir, anv_nir_lower_mesh_ext_instr,
-                                       nir_metadata_none,
-                                       &state);
+   nir_shader_instructions_pass(nir,
+                                anv_nir_lower_set_vtx_and_prim_count_instr,
+                                nir_metadata_none,
+                                &state);
+
+   /* If we didn't find set_vertex_and_primitive_count, then we have to
+    * insert store of value 0 to primitive_count.
+    */
+   if (state.primitive_count == NULL) {
+      nir_builder b;
+      nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir);
+      nir_builder_init(&b, entrypoint);
+      b.cursor = nir_before_block(nir_start_block(entrypoint));
+      nir_ssa_def *zero = nir_imm_int(&b, 0);
+      state.primitive_count = anv_nir_prim_count_store(&b, zero);
+   }
+
+   assert(state.primitive_count != NULL);
+   return true;
 }
 
 /* Eventually, this will become part of anv_CreateShader.  Unfortunately,
@@ -231,12 +254,9 @@ anv_shader_stage_to_nir(struct anv_device *device,
    brw_preprocess_nir(compiler, nir, &opts);
 
    if (nir->info.stage == MESA_SHADER_MESH && !nir->info.mesh.nv) {
-      bool progress = false;
-      NIR_PASS(progress, nir, anv_nir_lower_mesh_ext);
-      if (progress) {
-         NIR_PASS(_, nir, nir_opt_dce);
-         NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
-      }
+      NIR_PASS(_, nir, anv_nir_lower_set_vtx_and_prim_count);
+      NIR_PASS(_, nir, nir_opt_dce);
+      NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
    }
 
    return nir;
index 5e530a5..2ed9193 100644 (file)
@@ -1806,7 +1806,7 @@ emit_mesh_state(struct anv_graphics_pipeline *pipeline)
       mesh.LocalXMaximum                     = mesh_dispatch.group_size - 1;
       mesh.EmitLocalIDX                      = true;
 
-      mesh.MaximumPrimitiveCount             = mesh_prog_data->map.max_primitives - 1;
+      mesh.MaximumPrimitiveCount             = MAX2(mesh_prog_data->map.max_primitives, 1) - 1;
       mesh.OutputTopology                    = output_topology;
       mesh.PerVertexDataPitch                = mesh_prog_data->map.per_vertex_pitch_dw / 8;
       mesh.PerPrimitiveDataPresent           = mesh_prog_data->map.per_primitive_pitch_dw > 0;