zink: handle 8/16bit ssbo storage
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Tue, 20 Jul 2021 20:40:59 +0000 (16:40 -0400)
committerMarge Bot <eric+marge@anholt.net>
Tue, 7 Sep 2021 13:29:57 +0000 (13:29 +0000)
this is a bit gross, but basically just add an array of extra spvids
so that each bitsize can have its own variables to keep the types in sync

glsl can't do this, but (future) internal mesa shaders can

Reviewed-by: Dave Airlie <airlied@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12634>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 2602556..a17c97d 100644 (file)
@@ -51,7 +51,7 @@ struct ntv_context {
 
    SpvId ubos[128];
 
-   SpvId ssbos[PIPE_MAX_SHADER_BUFFERS];
+   SpvId ssbos[PIPE_MAX_SHADER_BUFFERS][3]; //8, 16, 32
    nir_variable *ssbo_vars[PIPE_MAX_SHADER_BUFFERS];
    SpvId image_types[PIPE_MAX_SAMPLERS];
    SpvId images[PIPE_MAX_SAMPLERS];
@@ -877,34 +877,35 @@ emit_image(struct ntv_context *ctx, struct nir_variable *var)
 }
 
 static SpvId
-get_sized_uint_array_type(struct ntv_context *ctx, unsigned array_size)
+get_sized_uint_array_type(struct ntv_context *ctx, unsigned array_size, unsigned bitsize)
 {
    SpvId array_length = emit_uint_const(ctx, 32, array_size);
-   SpvId array_type = spirv_builder_type_array(&ctx->builder, get_uvec_type(ctx, 32, 1),
+   SpvId array_type = spirv_builder_type_array(&ctx->builder, get_uvec_type(ctx, bitsize, 1),
                                             array_length);
-   spirv_builder_emit_array_stride(&ctx->builder, array_type, 4);
+   spirv_builder_emit_array_stride(&ctx->builder, array_type, bitsize / 8);
    return array_type;
 }
 
 static SpvId
-get_bo_array_type(struct ntv_context *ctx, struct nir_variable *var)
+get_bo_array_type(struct ntv_context *ctx, struct nir_variable *var, unsigned bitsize)
 {
+   assert(bitsize);
    SpvId array_type;
-   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
+   SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bitsize);
    if (glsl_type_is_unsized_array(var->type)) {
       array_type = spirv_builder_type_runtime_array(&ctx->builder, uint_type);
-      spirv_builder_emit_array_stride(&ctx->builder, array_type, 4);
+      spirv_builder_emit_array_stride(&ctx->builder, array_type, bitsize / 8);
    } else {
-      uint32_t array_size = glsl_get_length(glsl_get_struct_field(var->interface_type, 0));
-      array_type = get_sized_uint_array_type(ctx, array_size);
+      uint32_t array_size = glsl_get_length(glsl_get_struct_field(var->interface_type, 0)) * (bitsize / 4);
+      array_type = get_sized_uint_array_type(ctx, array_size, bitsize);
    }
    return array_type;
 }
 
 static SpvId
-get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
+get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var, unsigned bitsize)
 {
-   SpvId array_type = get_bo_array_type(ctx, var);
+   SpvId array_type = get_bo_array_type(ctx, var, bitsize);
    bool ssbo = var->data.mode == nir_var_mem_ssbo;
 
    // wrap UBO-array in a struct
@@ -913,7 +914,7 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
        const struct glsl_type *last_member = glsl_get_struct_field(var->interface_type, glsl_get_length(var->interface_type) - 1);
        if (glsl_type_is_unsized_array(last_member)) {
           bool is_64bit = glsl_type_is_64bit(glsl_without_array(last_member));
-          runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, is_64bit ? 64 : 32, 1));
+          runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, is_64bit ? 64 : bitsize, 1));
           spirv_builder_emit_array_stride(&ctx->builder, runtime_array, glsl_get_explicit_stride(last_member));
        }
    }
@@ -940,11 +941,14 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
 }
 
 static void
-emit_bo(struct ntv_context *ctx, struct nir_variable *var)
+emit_bo(struct ntv_context *ctx, struct nir_variable *var, unsigned force_bitsize)
 {
    bool ssbo = var->data.mode == nir_var_mem_ssbo;
+   unsigned bitsize = force_bitsize ? force_bitsize : 32;
+   unsigned idx = bitsize >> 4;
+   assert(idx < ARRAY_SIZE(ctx->ssbos[0]));
 
-   SpvId pointer_type = get_bo_struct_type(ctx, var);
+   SpvId pointer_type = get_bo_struct_type(ctx, var, bitsize);
 
    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
                                          ssbo ? SpvStorageClassStorageBuffer : SpvStorageClassUniform);
@@ -952,8 +956,8 @@ emit_bo(struct ntv_context *ctx, struct nir_variable *var)
       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
 
    if (ssbo) {
-      assert(!ctx->ssbos[var->data.driver_location]);
-      ctx->ssbos[var->data.driver_location] = var_id;
+      assert(!ctx->ssbos[var->data.driver_location][idx]);
+      ctx->ssbos[var->data.driver_location][idx] = var_id;
       ctx->ssbo_vars[var->data.driver_location] = var;
    } else {
       assert(!ctx->ubos[var->data.driver_location]);
@@ -972,7 +976,7 @@ static void
 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
 {
    if (var->data.mode == nir_var_mem_ubo || var->data.mode == nir_var_mem_ssbo)
-      emit_bo(ctx, var);
+      emit_bo(ctx, var, 0);
    else {
       assert(var->data.mode == nir_var_uniform);
       const struct glsl_type *type = glsl_without_array(var->type);
@@ -1883,8 +1887,16 @@ emit_load_bo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    bool ssbo = intr->intrinsic == nir_intrinsic_load_ssbo;
    assert(const_block_index); // no dynamic indexing for now
 
-   SpvId bo = ssbo ? ctx->ssbos[const_block_index->u32] : ctx->ubos[const_block_index->u32];
+   unsigned ssbo_idx = 0;
    unsigned bit_size = nir_dest_bit_size(intr->dest);
+   if (ssbo) {
+      ssbo_idx = MIN2(bit_size, 32) >> 4;
+      ssbo_idx = 2;
+      assert(ssbo_idx < ARRAY_SIZE(ctx->ssbos[0]));
+      if (!ctx->ssbos[const_block_index->u32][ssbo_idx])
+         emit_bo(ctx, ctx->ssbo_vars[const_block_index->u32], nir_dest_bit_size(intr->dest));
+   }
+   SpvId bo = ssbo ? ctx->ssbos[const_block_index->u32][ssbo_idx] : ctx->ubos[const_block_index->u32];
    SpvId uint_type = get_uvec_type(ctx, 32, 1);
    SpvId one = emit_uint_const(ctx, 32, 1);
 
@@ -1899,7 +1911,7 @@ emit_load_bo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    /* destination type for the load */
    SpvId type = get_dest_uvec_type(ctx, &intr->dest);
    /* an id of an array member in bytes */
-   SpvId uint_size = emit_uint_const(ctx, 32, sizeof(uint32_t));
+   SpvId uint_size = emit_uint_const(ctx, 32, ssbo ? MIN2(bit_size, 32) / 8 : sizeof(uint32_t));
 
    /* we grab a single array member at a time, so it's a pointer to a uint */
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
@@ -1984,7 +1996,11 @@ emit_store_ssbo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    nir_const_value *const_block_index = nir_src_as_const_value(intr->src[1]);
    assert(const_block_index);
 
-   SpvId bo = ctx->ssbos[const_block_index->u32];
+   unsigned idx = MIN2(nir_src_bit_size(intr->src[0]), 32) >> 4;
+   assert(idx < ARRAY_SIZE(ctx->ssbos[0]));
+   if (!ctx->ssbos[const_block_index->u32][idx])
+      emit_bo(ctx, ctx->ssbo_vars[const_block_index->u32], nir_src_bit_size(intr->src[0]));
+   SpvId bo = ctx->ssbos[const_block_index->u32][idx];
 
    unsigned bit_size = nir_src_bit_size(intr->src[0]);
    SpvId uint_type = get_uvec_type(ctx, 32, 1);
@@ -1998,11 +2014,11 @@ emit_store_ssbo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    bool is_64bit = bit_size == 64;
 
    /* an id of an array member in bytes */
-   SpvId uint_size = emit_uint_const(ctx, 32, sizeof(uint32_t));
+   SpvId uint_size = emit_uint_const(ctx, 32, MIN2(bit_size, 32) / 8);
    /* we grab a single array member at a time, so it's a pointer to a uint */
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassStorageBuffer,
-                                                   uint_type);
+                                                   get_uvec_type(ctx, MIN2(bit_size, 32), 1));
 
    /* our generated uniform has a memory layout like
     *
@@ -2034,7 +2050,7 @@ emit_store_ssbo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
     * (composite|vector)_extract both take literals
     */
    unsigned write_count = 0;
-   SpvId src_base_type = get_uvec_type(ctx, nir_src_bit_size(intr->src[0]), 1);
+   SpvId src_base_type = get_uvec_type(ctx, bit_size, 1);
    for (unsigned i = 0; write_count < num_components; i++) {
       if (wrmask & (1 << i)) {
          SpvId component = nir_src_num_components(intr->src[0]) > 1 ?
@@ -2418,7 +2434,12 @@ emit_ssbo_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 
    nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
    assert(const_block_index); // no dynamic indexing for now
-   ssbo = ctx->ssbos[const_block_index->u32];
+   unsigned bit_size = MIN2(nir_src_bit_size(intr->src[0]), 32);
+   unsigned idx = bit_size >> 4;
+   assert(idx < ARRAY_SIZE(ctx->ssbos[0]));
+   if (!ctx->ssbos[const_block_index->u32][idx])
+      emit_bo(ctx, ctx->ssbo_vars[const_block_index->u32], nir_dest_bit_size(intr->dest));
+   ssbo = ctx->ssbos[const_block_index->u32][idx];
    param = get_src(ctx, &intr->src[2]);
 
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
@@ -2426,7 +2447,7 @@ emit_ssbo_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
                                                    dest_type);
    SpvId uint_type = get_uvec_type(ctx, 32, 1);
    /* an id of the array stride in bytes */
-   SpvId uint_size = emit_uint_const(ctx, 32, sizeof(uint32_t));
+   SpvId uint_size = emit_uint_const(ctx, 32, bit_size / 8);
    SpvId member = emit_uint_const(ctx, 32, 0);
    SpvId offset = get_src(ctx, &intr->src[1]);
    SpvId vec_offset = emit_binop(ctx, SpvOpUDiv, uint_type, offset, uint_size);
@@ -2472,7 +2493,7 @@ emit_get_ssbo_size(struct ntv_context *ctx, nir_intrinsic_instr *intr)
    assert(const_block_index); // no dynamic indexing for now
    nir_variable *var = ctx->ssbo_vars[const_block_index->u32];
    SpvId result = spirv_builder_emit_binop(&ctx->builder, SpvOpArrayLength, uint_type,
-                                             ctx->ssbos[const_block_index->u32], 1);
+                                             ctx->ssbos[const_block_index->u32][2], 1);
    /* this is going to be converted by nir to:
 
       length = (buffer_size - offset) / stride