zink: variable shared mem support
authorKarol Herbst <git@karolherbst.de>
Tue, 19 Sep 2023 12:44:26 +0000 (14:44 +0200)
committerMarge Bot <emma+marge@anholt.net>
Sat, 14 Oct 2023 01:01:16 +0000 (01:01 +0000)
Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24839>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
src/gallium/drivers/zink/zink_compiler.h
src/gallium/drivers/zink/zink_pipeline.c
src/gallium/drivers/zink/zink_program.c
src/gallium/drivers/zink/zink_types.h

index 37059f4..d1f26cb 100644 (file)
@@ -101,6 +101,8 @@ struct ntv_context {
          local_group_size_var,
          base_vertex_var, base_instance_var, draw_id_var;
 
+   SpvId shared_mem_size;
+
    SpvId subgroup_eq_mask_var,
          subgroup_ge_mask_var,
          subgroup_gt_mask_var,
@@ -663,13 +665,25 @@ get_scratch_block(struct ntv_context *ctx, unsigned bit_size)
 }
 
 static void
-create_shared_block(struct ntv_context *ctx, unsigned shared_size, unsigned bit_size)
+create_shared_block(struct ntv_context *ctx, unsigned bit_size)
 {
    unsigned idx = bit_size >> 4;
    SpvId type = spirv_builder_type_uint(&ctx->builder, bit_size);
-   unsigned block_size = shared_size / (bit_size / 8);
-   assert(block_size);
-   SpvId array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size));
+   SpvId array;
+
+   assert(gl_shader_stage_is_compute(ctx->nir->info.stage));
+   if (ctx->nir->info.cs.has_variable_shared_mem) {
+      assert(ctx->shared_mem_size);
+      SpvId const_shared_size = emit_uint_const(ctx, 32, ctx->nir->info.shared_size);
+      SpvId shared_mem_size = spirv_builder_emit_triop(&ctx->builder, SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpIAdd, const_shared_size, ctx->shared_mem_size);
+      shared_mem_size = spirv_builder_emit_triop(&ctx->builder, SpvOpSpecConstantOp, spirv_builder_type_uint(&ctx->builder, 32), SpvOpUDiv, shared_mem_size, emit_uint_const(ctx, 32, bit_size / 8));
+      array = spirv_builder_type_array(&ctx->builder, type, shared_mem_size);
+   } else {
+      unsigned block_size = ctx->nir->info.shared_size / (bit_size / 8);
+      assert(block_size);
+      array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, block_size));
+   }
+
    spirv_builder_emit_array_stride(&ctx->builder, array, bit_size / 8);
    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
                                                SpvStorageClassWorkgroup,
@@ -686,7 +700,7 @@ get_shared_block(struct ntv_context *ctx, unsigned bit_size)
 {
    unsigned idx = bit_size >> 4;
    if (!ctx->shared_block_var[idx])
-      create_shared_block(ctx, ctx->nir->info.shared_size, bit_size);
+      create_shared_block(ctx, bit_size);
    if (ctx->sinfo->have_workgroup_memory_explicit_layout) {
       spirv_builder_emit_extension(&ctx->builder, "SPV_KHR_workgroup_memory_explicit_layout");
       spirv_builder_emit_cap(&ctx->builder, SpvCapabilityWorkgroupMemoryExplicitLayoutKHR);
@@ -4591,6 +4605,11 @@ nir_to_spirv(struct nir_shader *s, const struct zink_shader_info *sinfo, uint32_
             spirv_builder_emit_builtin(&ctx.builder, ctx.local_group_size_var, SpvBuiltInWorkgroupSize);
          }
       }
+      if (s->info.cs.has_variable_shared_mem) {
+         ctx.shared_mem_size = spirv_builder_spec_const_uint(&ctx.builder, 32);
+         spirv_builder_emit_specid(&ctx.builder, ctx.shared_mem_size, ZINK_VARIABLE_SHARED_MEM);
+         spirv_builder_emit_name(&ctx.builder, ctx.shared_mem_size, "variable_shared_mem");
+      }
       if (s->info.cs.derivative_group) {
          SpvCapability caps[] = { 0, SpvCapabilityComputeDerivativeGroupQuadsNV, SpvCapabilityComputeDerivativeGroupLinearNV };
          SpvExecutionMode modes[] = { 0, SpvExecutionModeDerivativeGroupQuadsNV, SpvExecutionModeDerivativeGroupLinearNV };
index 834084d..1319193 100644 (file)
@@ -29,6 +29,7 @@
 #define ZINK_WORKGROUP_SIZE_X 1
 #define ZINK_WORKGROUP_SIZE_Y 2
 #define ZINK_WORKGROUP_SIZE_Z 3
+#define ZINK_VARIABLE_SHARED_MEM 4
 #define ZINK_INLINE_VAL_FLAT_MASK 0
 #define ZINK_INLINE_VAL_PV_LAST_VERT 1
 
index cceb54c..063fdd9 100644 (file)
@@ -457,8 +457,8 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro
    stage.pName = "main";
 
    VkSpecializationInfo sinfo = {0};
-   VkSpecializationMapEntry me[3];
-   uint32_t data[3];
+   VkSpecializationMapEntry me[4];
+   uint32_t data[4];
    if (state)  {
       int i = 0;
 
@@ -475,6 +475,16 @@ zink_create_compute_pipeline(struct zink_screen *screen, struct zink_compute_pro
          }
       }
 
+      if (comp->has_variable_shared_mem) {
+         sinfo.mapEntryCount += 1;
+         sinfo.dataSize += sizeof(uint32_t);
+         data[i] = state->variable_shared_mem;
+         me[i].size = sizeof(uint32_t);
+         me[i].constantID = ZINK_VARIABLE_SHARED_MEM;
+         me[i].offset = i * sizeof(uint32_t);
+         i++;
+      }
+
       if (sinfo.dataSize) {
          stage.pSpecializationInfo = &sinfo;
          sinfo.pData = data;
index 9683586..796cc25 100644 (file)
@@ -1304,6 +1304,10 @@ zink_program_update_compute_pipeline_state(struct zink_context *ctx, struct zink
          ctx->compute_pipeline_state.local_size[i] = info->block[i];
       }
    }
+   if (ctx->compute_pipeline_state.variable_shared_mem != info->variable_shared_mem) {
+      ctx->compute_pipeline_state.dirty = true;
+      ctx->compute_pipeline_state.variable_shared_mem = info->variable_shared_mem;
+   }
 }
 
 static bool
index 0dbb4d3..7a2c79c 100644 (file)
@@ -937,6 +937,7 @@ struct zink_compute_pipeline_state {
    uint32_t final_hash;
    bool dirty;
    uint32_t local_size[3];
+   uint32_t variable_shared_mem;
 
    uint32_t module_hash;
    VkShaderModule module;