From 637063ffc63e0ff2f95c14b015c343cbe34a73e7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Marcin=20=C5=9Alusarz?= Date: Sat, 30 Apr 2022 13:07:57 +0200 Subject: [PATCH] anv: implement EXT_mesh_shader Reviewed-by: Caio Oliveira Part-of: --- src/intel/vulkan/anv_pipeline.c | 157 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/src/intel/vulkan/anv_pipeline.c b/src/intel/vulkan/anv_pipeline.c index 469236f..dac2a44 100644 --- a/src/intel/vulkan/anv_pipeline.c +++ b/src/intel/vulkan/anv_pipeline.c @@ -45,6 +45,137 @@ /* Needed for SWIZZLE macros */ #include "program/prog_instruction.h" +struct lower_mesh_ext_state { + nir_variable *primitive_count; + nir_variable *primitive_indices; +}; + +static bool +anv_nir_lower_mesh_ext_instr(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + struct lower_mesh_ext_state *state = data; + + switch (intrin->intrinsic) { + case nir_intrinsic_set_vertex_and_primitive_count: { + /* this intrinsic should show up only once */ + assert(state->primitive_count == NULL); + + state->primitive_count = + nir_variable_create(b->shader, + nir_var_shader_out, + glsl_uint_type(), + "gl_PrimitiveCountNV"); + state->primitive_count->data.location = VARYING_SLOT_PRIMITIVE_COUNT; + state->primitive_count->data.interpolation = INTERP_MODE_NONE; + + b->cursor = nir_before_instr(&intrin->instr); + + nir_ssa_def *local_invocation_index = nir_build_load_local_invocation_index(b); + + nir_ssa_def *cmp = nir_ieq(b, local_invocation_index, + nir_imm_int(b, 0)); + nir_if *if_stmt = nir_push_if(b, cmp); + { + nir_deref_instr *prim_count_deref = nir_build_deref_var(b, state->primitive_count); + nir_store_deref(b, prim_count_deref, intrin->src[1].ssa, 1); + } + nir_pop_if(b, if_stmt); + + nir_instr_remove(instr); + + return true; + } + + case nir_intrinsic_store_deref: { + /* Replace: + * gl_PrimitiveTriangleIndicesEXT[N] := vec3(X,Y,Z) + * by: + * gl_PrimitiveIndicesNV[N*3+0] := X + * gl_PrimitiveIndicesNV[N*3+1] := Y + * gl_PrimitiveIndicesNV[N*3+2] := Z + */ + + nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); + if (deref->deref_type != nir_deref_type_array) + break; + + nir_deref_instr *deref2 = nir_src_as_deref(deref->parent); + if (deref2->deref_type != nir_deref_type_var) + break; + + const nir_variable *var = deref2->var; + if (var->data.location != VARYING_SLOT_PRIMITIVE_INDICES) + break; + + if (state->primitive_count == NULL) + assert(!"primitive count must be set before indices"); + + b->cursor = nir_before_instr(instr); + + if (!state->primitive_indices) { + const struct glsl_type *type = + glsl_array_type(glsl_uint_type(), + glsl_get_length(var->type), + 0); + + state->primitive_indices = + nir_variable_create(b->shader, + nir_var_shader_out, + type, + "gl_PrimitiveIndicesNV"); + state->primitive_indices->data.location = var->data.location; + state->primitive_indices->data.interpolation = var->data.interpolation; + } + + nir_deref_instr *primitive_indices_deref = + nir_build_deref_var(b, state->primitive_indices); + + assert(intrin->src[1].is_ssa); + uint8_t components = intrin->src[1].ssa->num_components; + + unsigned vertices_per_primitive = + num_mesh_vertices_per_primitive(b->shader->info.mesh.primitive_type); + assert(vertices_per_primitive == components); + assert(nir_intrinsic_write_mask(intrin) == (1u << components) - 1); + + nir_src ind = deref->arr.index; + assert(ind.is_ssa); + nir_ssa_def *new_base = nir_imul_imm(b, ind.ssa, components); + + for (unsigned i = 0; i < components; ++i) { + nir_ssa_def *new_idx = nir_iadd_imm(b, new_base, i); + + nir_deref_instr *reindexed_deref = + nir_build_deref_array(b, primitive_indices_deref, new_idx); + + nir_store_deref(b, reindexed_deref, nir_channel(b, intrin->src[1].ssa, i), 1); + } + + nir_instr_remove(instr); + + return true; + } + + default: + break; + } + return false; +} + +static bool +anv_nir_lower_mesh_ext(nir_shader *nir) +{ + struct lower_mesh_ext_state state = { NULL, }; + + return nir_shader_instructions_pass(nir, anv_nir_lower_mesh_ext_instr, + nir_metadata_none, + &state); +} + /* Eventually, this will become part of anv_CreateShader. Unfortunately, * we can't do that yet because we don't have the ability to copy nir. */ @@ -88,6 +219,7 @@ anv_shader_stage_to_nir(struct anv_device *device, .int64 = true, .int64_atomics = true, .integer_functions2 = true, + .mesh_shading = pdevice->vk.supported_extensions.EXT_mesh_shader, .mesh_shading_nv = pdevice->vk.supported_extensions.NV_mesh_shader, .min_lod = true, .multiview = true, @@ -167,6 +299,15 @@ anv_shader_stage_to_nir(struct anv_device *device, brw_preprocess_nir(compiler, nir, NULL); + if (nir->info.stage == MESA_SHADER_MESH && !nir->info.mesh.nv) { + bool progress = false; + NIR_PASS(progress, nir, anv_nir_lower_mesh_ext); + if (progress) { + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_shader_out, NULL); + } + } + return nir; } @@ -703,6 +844,22 @@ anv_pipeline_lower_nir(struct anv_pipeline *pipeline, }); } + if ((nir->info.stage == MESA_SHADER_MESH || + nir->info.stage == MESA_SHADER_TASK) && !nir->info.mesh.nv) { + /* We can't/shouldn't lower id to index for NV_mesh_shader, because: + * 3DMESH_1D doesn't expose registers needed for + * nir_intrinsic_load_num_workgroups (generated by this pass) + * and we can't unify NV with EXT, because 3DMESH_3D doesn't support + * vkCmdDrawMeshTasksNV.firstTask. + */ + nir_lower_compute_system_values_options options = { + .lower_cs_local_id_to_index = true, + .lower_workgroup_id_to_index = true, + }; + + NIR_PASS(_, nir, nir_lower_compute_system_values, &options); + } + NIR_PASS(_, nir, anv_nir_lower_ycbcr_textures, layout); if (pipeline->type == ANV_PIPELINE_GRAPHICS) { -- 2.7.4