zink: add handling for CL-style discrete shader samplers
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Tue, 18 Oct 2022 15:45:16 +0000 (11:45 -0400)
committerMarge Bot <emma+marge@anholt.net>
Thu, 27 Oct 2022 22:01:34 +0000 (22:01 +0000)
this splits the bindings for sampler desc sets in CL like
* 32 samplers
* 128 samplerviews
* (compacted only) shader images

and then handles recombination during texop emission

it does NOT change the descriptor limits, which are still clamped to 32

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19327>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
src/gallium/drivers/zink/zink_compiler.c

index ad0d55b..ad3f48a 100644 (file)
@@ -992,7 +992,7 @@ static SpvId
 get_image_type(struct ntv_context *ctx, struct nir_variable *var, bool is_sampler)
 {
    SpvId image_type = get_bare_image_type(ctx, var, is_sampler);
-   return is_sampler ? spirv_builder_type_sampled_image(&ctx->builder, image_type) : image_type;
+   return is_sampler && ctx->stage != MESA_SHADER_KERNEL ? spirv_builder_type_sampled_image(&ctx->builder, image_type) : image_type;
 }
 
 static SpvId
@@ -1003,7 +1003,7 @@ emit_image(struct ntv_context *ctx, struct nir_variable *var, SpvId image_type,
    const struct glsl_type *type = glsl_without_array(var->type);
 
    bool is_sampler = glsl_type_is_sampler(type);
-   SpvId var_type = is_sampler ? spirv_builder_type_sampled_image(&ctx->builder, image_type) : image_type;
+   SpvId var_type = is_sampler && ctx->stage != MESA_SHADER_KERNEL ? spirv_builder_type_sampled_image(&ctx->builder, image_type) : image_type;
    bool mediump = (var->data.precision == GLSL_PRECISION_MEDIUM || var->data.precision == GLSL_PRECISION_LOW);
 
    int index = var->data.driver_location;
@@ -3733,7 +3733,14 @@ emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
       SpvId ptr = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassUniformConstant, sampled_type);
       sampler_id = spirv_builder_emit_access_chain(&ctx->builder, ptr, sampler_id, &tex_offset, 1);
    }
-   SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type, sampler_id);
+   SpvId load;
+   if (ctx->stage == MESA_SHADER_KERNEL) {
+      SpvId image_load = spirv_builder_emit_load(&ctx->builder, image_type, sampler_id);
+      SpvId sampler_load = spirv_builder_emit_load(&ctx->builder, spirv_builder_type_sampler(&ctx->builder), ctx->cl_samplers[tex->sampler_index]);
+      load = spirv_builder_emit_sampled_image(&ctx->builder, sampled_type, image_load, sampler_load);
+   } else {
+      load = spirv_builder_emit_load(&ctx->builder, sampled_type, sampler_id);
+   }
 
    if (tex->is_sparse)
       tex->dest.ssa.num_components--;
index 8e73b49..fd68970 100644 (file)
@@ -2798,17 +2798,33 @@ zink_binding(gl_shader_stage stage, VkDescriptorType type, int index, bool compa
       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
          return base * 2 + !!index;
 
-      case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
+      case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE:
+         assert(stage == MESA_SHADER_KERNEL);
+         FALLTHROUGH;
       case VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER:
+         if (stage == MESA_SHADER_KERNEL) {
+            assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
+            return index + PIPE_MAX_SAMPLERS;
+         }
+         FALLTHROUGH;
+      case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
          assert(index < PIPE_MAX_SAMPLERS);
+         assert(stage != MESA_SHADER_KERNEL);
          return (base * PIPE_MAX_SAMPLERS) + index;
 
+      case VK_DESCRIPTOR_TYPE_SAMPLER:
+         assert(index < PIPE_MAX_SAMPLERS);
+         assert(stage == MESA_SHADER_KERNEL);
+         return index;
+
       case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
          return base + (compact_descriptors * (ZINK_GFX_SHADER_COUNT * 2));
 
       case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
       case VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER:
          assert(index < ZINK_MAX_SHADER_IMAGES);
+         if (stage == MESA_SHADER_KERNEL)
+            return index + (compact_descriptors ? (PIPE_MAX_SAMPLERS + PIPE_MAX_SHADER_SAMPLER_VIEWS) : 0);
          return (base * ZINK_MAX_SHADER_IMAGES) + index + (compact_descriptors * (ZINK_GFX_SHADER_COUNT * PIPE_MAX_SAMPLERS));
 
       default:
@@ -3521,6 +3537,15 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
    unsigned sampler_mask = 0;
    if (nir->info.stage == MESA_SHADER_KERNEL) {
       NIR_PASS_V(nir, type_images, &sampler_mask);
+      enum zink_descriptor_type ztype = ZINK_DESCRIPTOR_TYPE_SAMPLER_VIEW;
+      VkDescriptorType vktype = VK_DESCRIPTOR_TYPE_SAMPLER;
+      u_foreach_bit(s, sampler_mask) {
+         ret->bindings[ztype][ret->num_bindings[ztype]].index = s;
+         ret->bindings[ztype][ret->num_bindings[ztype]].binding = zink_binding(MESA_SHADER_KERNEL, vktype, s, screen->compact_descriptors);
+         ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype;
+         ret->bindings[ztype][ret->num_bindings[ztype]].size = 1;
+         ret->num_bindings[ztype]++;
+      }
       ret->sinfo.sampler_mask = sampler_mask;
    }
 
@@ -3536,7 +3561,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
             /* buffer 0 is a push descriptor */
             var->data.descriptor_set = !!var->data.driver_location;
             var->data.binding = !var->data.driver_location ? clamp_stage(nir) :
-                                zink_binding(clamp_stage(nir),
+                                zink_binding(nir->info.stage,
                                              VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
                                              var->data.driver_location,
                                              screen->compact_descriptors);
@@ -3557,7 +3582,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
          } else if (var->data.mode == nir_var_mem_ssbo) {
             ztype = ZINK_DESCRIPTOR_TYPE_SSBO;
             var->data.descriptor_set = screen->desc_set_id[ztype];
-            var->data.binding = zink_binding(clamp_stage(nir),
+            var->data.binding = zink_binding(nir->info.stage,
                                              VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
                                              var->data.driver_location,
                                              screen->compact_descriptors);
@@ -3575,12 +3600,14 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
                handle_bindless_var(nir, var, type, &bindless);
             } else if (glsl_type_is_sampler(type) || glsl_type_is_image(type)) {
                VkDescriptorType vktype = glsl_type_is_image(type) ? zink_image_type(type) : zink_sampler_type(type);
+               if (nir->info.stage == MESA_SHADER_KERNEL && vktype == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER)
+                  vktype = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
                ztype = zink_desc_type_from_vktype(vktype);
                if (vktype == VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER)
                   ret->num_texel_buffers++;
                var->data.driver_location = var->data.binding;
                var->data.descriptor_set = screen->desc_set_id[ztype];
-               var->data.binding = zink_binding(clamp_stage(nir), vktype, var->data.driver_location, screen->compact_descriptors);
+               var->data.binding = zink_binding(nir->info.stage, vktype, var->data.driver_location, screen->compact_descriptors);
                ret->bindings[ztype][ret->num_bindings[ztype]].index = var->data.driver_location;
                ret->bindings[ztype][ret->num_bindings[ztype]].binding = var->data.binding;
                ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype;