zink: pass KERNEL shaders through successfully
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Mon, 17 Oct 2022 14:11:08 +0000 (10:11 -0400)
committerMarge Bot <emma+marge@anholt.net>
Thu, 27 Oct 2022 22:01:34 +0000 (22:01 +0000)
basically just merging with COMPUTE cases

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19327>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
src/gallium/drivers/zink/zink_compiler.c
src/gallium/drivers/zink/zink_compiler.h
src/gallium/drivers/zink/zink_descriptors.c

index a0a93ab..d640ed7 100644 (file)
@@ -3335,7 +3335,7 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
       break;
 
    case nir_intrinsic_control_barrier:
-      if (ctx->stage == MESA_SHADER_COMPUTE)
+      if (gl_shader_stage_is_compute(ctx->stage))
          spirv_builder_emit_control_barrier(&ctx->builder, SpvScopeWorkgroup,
                                             SpvScopeWorkgroup,
                                             SpvMemorySemanticsWorkgroupMemoryMask | SpvMemorySemanticsAcquireReleaseMask);
@@ -4428,7 +4428,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
    ctx.explicit_lod = true;
    spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageUnknown, 0);
 
-   if (s->info.stage == MESA_SHADER_COMPUTE) {
+   if (gl_shader_stage_is_compute(s->info.stage)) {
       SpvAddressingModel model;
       if (s->info.cs.ptr_size == 32)
          model = SpvAddressingModelPhysical32;
@@ -4474,6 +4474,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
       exec_model = SpvExecutionModelFragment;
       break;
    case MESA_SHADER_COMPUTE:
+   case MESA_SHADER_KERNEL:
       exec_model = SpvExecutionModelGLCompute;
       break;
    default:
@@ -4597,6 +4598,7 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
                                            SpvExecutionModeOutputVertices,
                                            MAX2(s->info.gs.vertices_out, 1));
       break;
+   case MESA_SHADER_KERNEL:
    case MESA_SHADER_COMPUTE:
       if (s->info.workgroup_size[0] || s->info.workgroup_size[1] || s->info.workgroup_size[2])
          spirv_builder_emit_exec_mode_literal3(&ctx.builder, entry_point, SpvExecutionModeLocalSize,
index ddbe583..ddd3683 100644 (file)
@@ -2245,7 +2245,7 @@ zink_shader_spirv_compile(struct zink_screen *screen, struct zink_shader *zs, st
       }
       nir_shader *nir = spirv_to_nir(spirv->words, spirv->num_words,
                          spec_entries, num_spec_entries,
-                         zs->nir->info.stage, "main", &spirv_options, &screen->nir_options);
+                         clamp_stage(zs->nir), "main", &spirv_options, &screen->nir_options);
       assert(nir);
       ralloc_free(nir);
       free(spec_entries);
@@ -2791,7 +2791,7 @@ zink_binding(gl_shader_stage stage, VkDescriptorType type, int index, bool compa
    } else {
       unsigned base = stage;
       /* clamp compute bindings for better driver efficiency */
-      if (stage == MESA_SHADER_COMPUTE)
+      if (gl_shader_stage_is_compute(stage))
          base = 0;
       switch (type) {
       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
@@ -3263,7 +3263,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
       subgroup_options.ballot_bit_size = 32;
       subgroup_options.ballot_components = 4;
       subgroup_options.lower_subgroup_masks = true;
-      if (!(screen->info.subgroup.supportedStages & mesa_to_vk_shader_stage(nir->info.stage))) {
+      if (!(screen->info.subgroup.supportedStages & mesa_to_vk_shader_stage(clamp_stage(nir)))) {
          subgroup_options.subgroup_size = 1;
          subgroup_options.lower_vote_trivial = true;
       }
@@ -3325,8 +3325,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
             ztype = ZINK_DESCRIPTOR_TYPE_UBO;
             /* buffer 0 is a push descriptor */
             var->data.descriptor_set = !!var->data.driver_location;
-            var->data.binding = !var->data.driver_location ? nir->info.stage :
-                                zink_binding(nir->info.stage,
+            var->data.binding = !var->data.driver_location ? clamp_stage(nir) :
+                                zink_binding(clamp_stage(nir),
                                              VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
                                              var->data.driver_location,
                                              screen->compact_descriptors);
@@ -3347,7 +3347,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
          } else if (var->data.mode == nir_var_mem_ssbo) {
             ztype = ZINK_DESCRIPTOR_TYPE_SSBO;
             var->data.descriptor_set = screen->desc_set_id[ztype];
-            var->data.binding = zink_binding(nir->info.stage,
+            var->data.binding = zink_binding(clamp_stage(nir),
                                              VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
                                              var->data.driver_location,
                                              screen->compact_descriptors);
@@ -3370,7 +3370,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
                   ret->num_texel_buffers++;
                var->data.driver_location = var->data.binding;
                var->data.descriptor_set = screen->desc_set_id[ztype];
-               var->data.binding = zink_binding(nir->info.stage, vktype, var->data.driver_location, screen->compact_descriptors);
+               var->data.binding = zink_binding(clamp_stage(nir), vktype, var->data.driver_location, screen->compact_descriptors);
                ret->bindings[ztype][ret->num_bindings[ztype]].index = var->data.driver_location;
                ret->bindings[ztype][ret->num_bindings[ztype]].binding = var->data.binding;
                ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype;
@@ -3389,7 +3389,8 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
 
    if (!screen->info.feats.features.shaderInt64 || !screen->info.feats.features.shaderFloat64)
       NIR_PASS_V(nir, lower_64bit_vars, screen->info.feats.features.shaderInt64);
-   NIR_PASS_V(nir, match_tex_dests);
+   if (nir->info.stage != MESA_SHADER_KERNEL)
+      NIR_PASS_V(nir, match_tex_dests);
 
    ret->nir = nir;
    nir_foreach_shader_out_variable(var, nir)
index 21f6bab..1572aa3 100644 (file)
@@ -40,6 +40,11 @@ struct spirv_shader;
 
 struct tgsi_token;
 
+static inline gl_shader_stage
+clamp_stage(nir_shader *nir)
+{
+   return nir->info.stage == MESA_SHADER_KERNEL ? MESA_SHADER_COMPUTE : nir->info.stage;
+}
 
 const void *
 zink_get_compiler_options(struct pipe_screen *screen,
index e9caa26..00b0746 100644 (file)
@@ -26,6 +26,7 @@
  */
 
 #include "zink_context.h"
+#include "zink_compiler.h"
 #include "zink_descriptors.h"
 #include "zink_program.h"
 #include "zink_render_pass.h"
@@ -308,7 +309,7 @@ init_template_entry(struct zink_shader *shader, enum zink_descriptor_type type,
                     unsigned idx, VkDescriptorUpdateTemplateEntry *entry, unsigned *entry_idx)
 {
     int index = shader->bindings[type][idx].index;
-    gl_shader_stage stage = shader->nir->info.stage;
+    gl_shader_stage stage = clamp_stage(shader->nir);
     entry->dstArrayElement = 0;
     entry->dstBinding = shader->bindings[type][idx].binding;
     entry->descriptorCount = shader->bindings[type][idx].size;
@@ -423,7 +424,7 @@ zink_descriptor_program_init(struct zink_context *ctx, struct zink_program *pg)
       if (!shader)
          continue;
 
-      gl_shader_stage stage = shader->nir->info.stage;
+      gl_shader_stage stage = clamp_stage(shader->nir);
       VkShaderStageFlagBits stage_flags = mesa_to_vk_shader_stage(stage);
       /* uniform ubos handled in push */
       if (shader->has_uniforms) {