zink: handle COMPUTE setup in ntv
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Wed, 12 Aug 2020 19:54:48 +0000 (15:54 -0400)
committerMarge Bot <eric+marge@anholt.net>
Wed, 10 Feb 2021 00:19:38 +0000 (00:19 +0000)
addressing mode, shared block, and execution modes all need to be handled
here

Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/8781>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index ca65944..4cbb09a 100644 (file)
@@ -82,7 +82,8 @@ struct ntv_context {
          primitive_id_var, invocation_id_var, // geometry
          sample_mask_type, sample_id_var, sample_pos_var, sample_mask_in_var,
          tess_patch_vertices_in, tess_coord_var, // tess
-         push_const_var;
+         push_const_var,
+         shared_block_var;
 };
 
 static SpvId
@@ -361,6 +362,20 @@ get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
    unreachable("we shouldn't get here, I think...");
 }
 
+static void
+create_shared_block(struct ntv_context *ctx, unsigned shared_size)
+{
+   SpvId type = spirv_builder_type_uint(&ctx->builder, 32);
+   SpvId array = spirv_builder_type_array(&ctx->builder, type, emit_uint_const(ctx, 32, shared_size / 4));
+   spirv_builder_emit_array_stride(&ctx->builder, array, 4);
+   SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
+                                               SpvStorageClassWorkgroup,
+                                               array);
+   ctx->shared_block_var = spirv_builder_emit_var(&ctx->builder, ptr_type, SpvStorageClassWorkgroup);
+   assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
+   ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->shared_block_var;
+}
+
 static inline unsigned char
 reserve_slot(struct ntv_context *ctx)
 {
@@ -3399,8 +3414,11 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info,
    ctx.stage = s->info.stage;
    ctx.so_info = so_info;
    ctx.num_ssbos = s->info.num_ssbos;
-   ctx.shader_slot_map = shader_slot_map;
-   ctx.shader_slots_reserved = *shader_slots_reserved;
+   if (shader_slot_map) {
+      /* COMPUTE doesn't have this */
+      ctx.shader_slot_map = shader_slot_map;
+      ctx.shader_slots_reserved = *shader_slots_reserved;
+   }
    ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
    spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageUnknown, 0);
 
@@ -3411,6 +3429,16 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info,
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityVulkanMemoryModelDeviceScope);
       spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
                                    SpvMemoryModelVulkan);
+   } else if (s->info.stage == MESA_SHADER_COMPUTE) {
+      SpvAddressingModel model;
+      if (s->info.cs.ptr_size == 32)
+         model = SpvAddressingModelPhysical32;
+      else if (s->info.cs.ptr_size == 64)
+         model = SpvAddressingModelPhysical64;
+      else
+         model = SpvAddressingModelLogical;
+      spirv_builder_emit_mem_model(&ctx.builder, model,
+                                   SpvMemoryModelGLSL450);
    } else
       spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
                                    SpvMemoryModelGLSL450);
@@ -3512,6 +3540,14 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info,
                                            SpvExecutionModeOutputVertices,
                                            s->info.gs.vertices_out);
       break;
+   case MESA_SHADER_COMPUTE:
+      if (s->info.cs.local_size[0] || s->info.cs.local_size[1] || s->info.cs.local_size[2])
+         spirv_builder_emit_exec_mode_literal3(&ctx.builder, entry_point, SpvExecutionModeLocalSize,
+                                               (uint32_t[3]){(uint32_t)s->info.cs.local_size[0], (uint32_t)s->info.cs.local_size[1],
+                                               (uint32_t)s->info.cs.local_size[2]});
+      if (s->info.cs.shared_size)
+         create_shared_block(&ctx, s->info.cs.shared_size);
+      break;
    default:
       break;
    }
@@ -3591,7 +3627,8 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info,
    assert(ret->num_words == num_words);
 
    ralloc_free(ctx.mem_ctx);
-   *shader_slots_reserved = ctx.shader_slots_reserved;
+   if (shader_slots_reserved)
+      *shader_slots_reserved = ctx.shader_slots_reserved;
 
    return ret;