anv: implement EXT_mesh_shader
authorMarcin Ślusarz <marcin.slusarz@intel.com>
Sat, 30 Apr 2022 11:07:57 +0000 (13:07 +0200)
committerMarge Bot <emma+marge@anholt.net>
Fri, 2 Sep 2022 17:40:47 +0000 (17:40 +0000)
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18371>

src/intel/vulkan/anv_pipeline.c

index 469236f..dac2a44 100644 (file)
 /* Needed for SWIZZLE macros */
 #include "program/prog_instruction.h"
 
+struct lower_mesh_ext_state {
+   nir_variable *primitive_count;
+   nir_variable *primitive_indices;
+};
+
+static bool
+anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   struct lower_mesh_ext_state *state = data;
+
+   switch (intrin->intrinsic) {
+   case nir_intrinsic_set_vertex_and_primitive_count: {
+      /* this intrinsic should show up only once */
+      assert(state->primitive_count == NULL);
+
+      state->primitive_count =
+            nir_variable_create(b->shader,
+                                nir_var_shader_out,
+                                glsl_uint_type(),
+                                "gl_PrimitiveCountNV");
+      state->primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT;
+      state->primitive_count->data.interpolation = INTERP_MODE_NONE;
+
+      b->cursor = nir_before_instr(&intrin->instr);
+
+      nir_ssa_def *local_invocation_index = nir_build_load_local_invocation_index(b);
+
+      nir_ssa_def *cmp = nir_ieq(b, local_invocation_index,
+                                     nir_imm_int(b, 0));
+      nir_if *if_stmt = nir_push_if(b, cmp);
+      {
+         nir_deref_instr *prim_count_deref = nir_build_deref_var(b, state->primitive_count);
+         nir_store_deref(b, prim_count_deref, intrin->src[1].ssa, 1);
+      }
+      nir_pop_if(b, if_stmt);
+
+      nir_instr_remove(instr);
+
+      return true;
+   }
+
+   case nir_intrinsic_store_deref: {
+      /* Replace:
+       * gl_PrimitiveTriangleIndicesEXT[N] := vec3(X,Y,Z)
+       * by:
+       * gl_PrimitiveIndicesNV[N*3+0] := X
+       * gl_PrimitiveIndicesNV[N*3+1] := Y
+       * gl_PrimitiveIndicesNV[N*3+2] := Z
+       */
+
+      nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
+      if (deref->deref_type != nir_deref_type_array)
+         break;
+
+      nir_deref_instr *deref2 = nir_src_as_deref(deref->parent);
+      if (deref2->deref_type != nir_deref_type_var)
+         break;
+
+      const nir_variable *var = deref2->var;
+      if (var->data.location != VARYING_SLOT_PRIMITIVE_INDICES)
+         break;
+
+      if (state->primitive_count == NULL)
+         assert(!"primitive count must be set before indices");
+
+      b->cursor = nir_before_instr(instr);
+
+      if (!state->primitive_indices) {
+         const struct glsl_type *type =
+               glsl_array_type(glsl_uint_type(),
+                               glsl_get_length(var->type),
+                               0);
+
+         state->primitive_indices =
+               nir_variable_create(b->shader,
+                                   nir_var_shader_out,
+                                   type,
+                                   "gl_PrimitiveIndicesNV");
+         state->primitive_indices->data.location = var->data.location;
+         state->primitive_indices->data.interpolation = var->data.interpolation;
+      }
+
+      nir_deref_instr *primitive_indices_deref =
+            nir_build_deref_var(b, state->primitive_indices);
+
+      assert(intrin->src[1].is_ssa);
+      uint8_t components = intrin->src[1].ssa->num_components;
+
+      unsigned vertices_per_primitive =
+            num_mesh_vertices_per_primitive(b->shader->info.mesh.primitive_type);
+      assert(vertices_per_primitive == components);
+      assert(nir_intrinsic_write_mask(intrin) == (1u << components) - 1);
+
+      nir_src ind = deref->arr.index;
+      assert(ind.is_ssa);
+      nir_ssa_def *new_base = nir_imul_imm(b, ind.ssa, components);
+
+      for (unsigned i = 0; i < components; ++i) {
+         nir_ssa_def *new_idx = nir_iadd_imm(b, new_base, i);
+
+         nir_deref_instr *reindexed_deref =
+               nir_build_deref_array(b, primitive_indices_deref, new_idx);
+
+         nir_store_deref(b, reindexed_deref, nir_channel(b, intrin->src[1].ssa, i), 1);
+      }
+
+      nir_instr_remove(instr);
+
+      return true;
+   }
+
+   default:
+      break;
+   }
+   return false;
+}
+
+static bool
+anv_nir_lower_mesh_ext(nir_shader *nir)
+{
+   struct lower_mesh_ext_state state = { NULL, };
+
+   return nir_shader_instructions_pass(nir, anv_nir_lower_mesh_ext_instr,
+                                       nir_metadata_none,
+                                       &state);
+}
+
 /* Eventually, this will become part of anv_CreateShader.  Unfortunately,
  * we can't do that yet because we don't have the ability to copy nir.
  */
@@ -88,6 +219,7 @@ anv_shader_stage_to_nir(struct anv_device *device,
          .int64 = true,
          .int64_atomics = true,
          .integer_functions2 = true,
+         .mesh_shading = pdevice->vk.supported_extensions.EXT_mesh_shader,
          .mesh_shading_nv = pdevice->vk.supported_extensions.NV_mesh_shader,
          .min_lod = true,
          .multiview = true,
@@ -167,6 +299,15 @@ anv_shader_stage_to_nir(struct anv_device *device,
 
    brw_preprocess_nir(compiler, nir, NULL);
 
+   if (nir->info.stage == MESA_SHADER_MESH && !nir->info.mesh.nv) {
+      bool progress = false;
+      NIR_PASS(progress, nir, anv_nir_lower_mesh_ext);
+      if (progress) {
+         NIR_PASS(_, nir, nir_opt_dce);
+         NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL);
+      }
+   }
+
    return nir;
 }
 
@@ -703,6 +844,22 @@ anv_pipeline_lower_nir(struct anv_pipeline *pipeline,
                });
    }
 
+   if ((nir->info.stage == MESA_SHADER_MESH ||
+         nir->info.stage == MESA_SHADER_TASK) && !nir->info.mesh.nv) {
+      /* We can't/shouldn't lower id to index for NV_mesh_shader, because:
+       * 3DMESH_1D doesn't expose registers needed for
+       * nir_intrinsic_load_num_workgroups (generated by this pass)
+       * and we can't unify NV with EXT, because 3DMESH_3D doesn't support
+       * vkCmdDrawMeshTasksNV.firstTask.
+       */
+      nir_lower_compute_system_values_options options = {
+            .lower_cs_local_id_to_index = true,
+            .lower_workgroup_id_to_index = true,
+      };
+
+      NIR_PASS(_, nir, nir_lower_compute_system_values, &options);
+   }
+
    NIR_PASS(_, nir, anv_nir_lower_ycbcr_textures, layout);
 
    if (pipeline->type == ANV_PIPELINE_GRAPHICS) {