From 037bbabcb968b8a911e90ce61c202c76d3cc7a67 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Mon, 17 Oct 2022 10:11:08 -0400 Subject: [PATCH] zink: pass KERNEL shaders through successfully basically just merging with COMPUTE cases Part-of: --- src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c | 6 ++++-- src/gallium/drivers/zink/zink_compiler.c | 17 +++++++++-------- src/gallium/drivers/zink/zink_compiler.h | 5 +++++ src/gallium/drivers/zink/zink_descriptors.c | 5 +++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index a0a93ab..d640ed7 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -3335,7 +3335,7 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) break; case nir_intrinsic_control_barrier: - if (ctx->stage == MESA_SHADER_COMPUTE) + if (gl_shader_stage_is_compute(ctx->stage)) spirv_builder_emit_control_barrier(&ctx->builder, SpvScopeWorkgroup, SpvScopeWorkgroup, SpvMemorySemanticsWorkgroupMemoryMask | SpvMemorySemanticsAcquireReleaseMask); @@ -4428,7 +4428,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ ctx.explicit_lod = true; spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageUnknown, 0); - if (s->info.stage == MESA_SHADER_COMPUTE) { + if (gl_shader_stage_is_compute(s->info.stage)) { SpvAddressingModel model; if (s->info.cs.ptr_size == 32) model = SpvAddressingModelPhysical32; @@ -4474,6 +4474,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ exec_model = SpvExecutionModelFragment; break; case MESA_SHADER_COMPUTE: + case MESA_SHADER_KERNEL: exec_model = SpvExecutionModelGLCompute; break; default: @@ -4597,6 +4598,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_ SpvExecutionModeOutputVertices, MAX2(s->info.gs.vertices_out, 1)); break; + case MESA_SHADER_KERNEL: case MESA_SHADER_COMPUTE: if (s->info.workgroup_size[0] || s->info.workgroup_size[1] || s->info.workgroup_size[2]) spirv_builder_emit_exec_mode_literal3(&ctx.builder, entry_point, SpvExecutionModeLocalSize, diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index ddbe583..ddd3683 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -2245,7 +2245,7 @@ zink_shader_spirv_compile(struct zink_screen *screen, struct zink_shader *zs, st } nir_shader *nir = spirv_to_nir(spirv->words, spirv->num_words, spec_entries, num_spec_entries, - zs->nir->info.stage, "main", &spirv_options, &screen->nir_options); + clamp_stage(zs->nir), "main", &spirv_options, &screen->nir_options); assert(nir); ralloc_free(nir); free(spec_entries); @@ -2791,7 +2791,7 @@ zink_binding(gl_shader_stage stage, VkDescriptorType type, int index, bool compa } else { unsigned base = stage; /* clamp compute bindings for better driver efficiency */ - if (stage == MESA_SHADER_COMPUTE) + if (gl_shader_stage_is_compute(stage)) base = 0; switch (type) { case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: @@ -3263,7 +3263,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, subgroup_options.ballot_bit_size = 32; subgroup_options.ballot_components = 4; subgroup_options.lower_subgroup_masks = true; - if (!(screen->info.subgroup.supportedStages & mesa_to_vk_shader_stage(nir->info.stage))) { + if (!(screen->info.subgroup.supportedStages & mesa_to_vk_shader_stage(clamp_stage(nir)))) { subgroup_options.subgroup_size = 1; subgroup_options.lower_vote_trivial = true; } @@ -3325,8 +3325,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, ztype = ZINK_DESCRIPTOR_TYPE_UBO; /* buffer 0 is a push descriptor */ var->data.descriptor_set = !!var->data.driver_location; - var->data.binding = !var->data.driver_location ? nir->info.stage : - zink_binding(nir->info.stage, + var->data.binding = !var->data.driver_location ? clamp_stage(nir) : + zink_binding(clamp_stage(nir), VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, var->data.driver_location, screen->compact_descriptors); @@ -3347,7 +3347,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, } else if (var->data.mode == nir_var_mem_ssbo) { ztype = ZINK_DESCRIPTOR_TYPE_SSBO; var->data.descriptor_set = screen->desc_set_id[ztype]; - var->data.binding = zink_binding(nir->info.stage, + var->data.binding = zink_binding(clamp_stage(nir), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, var->data.driver_location, screen->compact_descriptors); @@ -3370,7 +3370,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, ret->num_texel_buffers++; var->data.driver_location = var->data.binding; var->data.descriptor_set = screen->desc_set_id[ztype]; - var->data.binding = zink_binding(nir->info.stage, vktype, var->data.driver_location, screen->compact_descriptors); + var->data.binding = zink_binding(clamp_stage(nir), vktype, var->data.driver_location, screen->compact_descriptors); ret->bindings[ztype][ret->num_bindings[ztype]].index = var->data.driver_location; ret->bindings[ztype][ret->num_bindings[ztype]].binding = var->data.binding; ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype; @@ -3389,7 +3389,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir, if (!screen->info.feats.features.shaderInt64 || !screen->info.feats.features.shaderFloat64) NIR_PASS_V(nir, lower_64bit_vars, screen->info.feats.features.shaderInt64); - NIR_PASS_V(nir, match_tex_dests); + if (nir->info.stage != MESA_SHADER_KERNEL) + NIR_PASS_V(nir, match_tex_dests); ret->nir = nir; nir_foreach_shader_out_variable(var, nir) diff --git a/src/gallium/drivers/zink/zink_compiler.h b/src/gallium/drivers/zink/zink_compiler.h index 21f6bab..1572aa3 100644 --- a/src/gallium/drivers/zink/zink_compiler.h +++ b/src/gallium/drivers/zink/zink_compiler.h @@ -40,6 +40,11 @@ struct spirv_shader; struct tgsi_token; +static inline gl_shader_stage +clamp_stage(nir_shader *nir) +{ + return nir->info.stage == MESA_SHADER_KERNEL ? MESA_SHADER_COMPUTE : nir->info.stage; +} const void * zink_get_compiler_options(struct pipe_screen *screen, diff --git a/src/gallium/drivers/zink/zink_descriptors.c b/src/gallium/drivers/zink/zink_descriptors.c index e9caa26..00b0746 100644 --- a/src/gallium/drivers/zink/zink_descriptors.c +++ b/src/gallium/drivers/zink/zink_descriptors.c @@ -26,6 +26,7 @@ */ #include "zink_context.h" +#include "zink_compiler.h" #include "zink_descriptors.h" #include "zink_program.h" #include "zink_render_pass.h" @@ -308,7 +309,7 @@ init_template_entry(struct zink_shader *shader, enum zink_descriptor_type type, unsigned idx, VkDescriptorUpdateTemplateEntry *entry, unsigned *entry_idx) { int index = shader->bindings[type][idx].index; - gl_shader_stage stage = shader->nir->info.stage; + gl_shader_stage stage = clamp_stage(shader->nir); entry->dstArrayElement = 0; entry->dstBinding = shader->bindings[type][idx].binding; entry->descriptorCount = shader->bindings[type][idx].size; @@ -423,7 +424,7 @@ zink_descriptor_program_init(struct zink_context *ctx, struct zink_program *pg) if (!shader) continue; - gl_shader_stage stage = shader->nir->info.stage; + gl_shader_stage stage = clamp_stage(shader->nir); VkShaderStageFlagBits stage_flags = mesa_to_vk_shader_stage(stage); /* uniform ubos handled in push */ if (shader->has_uniforms) { -- 2.7.4