zink: support emitting 16-bit int types
authorErik Faye-Lund <erik.faye-lund@collabora.com>
Tue, 6 Apr 2021 16:51:06 +0000 (18:51 +0200)
committerMarge Bot <eric+marge@anholt.net>
Fri, 30 Apr 2021 12:02:04 +0000 (12:02 +0000)
This prepares us for being able to support using 16-bit int types in
shaders, which might help performance in some cases.

Reviewed-By: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/10101>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c

index 772076d..75d75ba 100644 (file)
@@ -227,14 +227,14 @@ emit_float_const(struct ntv_context *ctx, int bit_size, double value)
 static SpvId
 emit_uint_const(struct ntv_context *ctx, int bit_size, uint64_t value)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
    return spirv_builder_const_uint(&ctx->builder, bit_size, value);
 }
 
 static SpvId
 emit_int_const(struct ntv_context *ctx, int bit_size, int64_t value)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
    return spirv_builder_const_int(&ctx->builder, bit_size, value);
 }
 
@@ -255,7 +255,7 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
 
    SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
    if (num_components > 1)
@@ -269,7 +269,7 @@ get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
 
    SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
    if (num_components > 1)
@@ -949,7 +949,7 @@ get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_c
 {
    if (bit_size == 1)
       return get_bvec_type(ctx, num_components);
-   if (bit_size == 32 || bit_size == 64)
+   if (bit_size == 16 || bit_size == 32 || bit_size == 64)
       return get_uvec_type(ctx, bit_size, num_components);
    unreachable("unhandled register bit size");
    return 0;
@@ -1021,7 +1021,7 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
       return def;
 
    int bit_size = nir_src_bit_size(alu->src[src].src);
-   assert(bit_size == 1 || bit_size == 32 || bit_size == 64);
+   assert(bit_size == 1 || bit_size == 16 || bit_size == 32 || bit_size == 64);
 
    SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
                                     spirv_builder_type_uint(&ctx->builder, bit_size);
@@ -1429,7 +1429,7 @@ static SpvId
 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
                   unsigned num_components, int64_t value)
 {
-   assert(bit_size == 32 || bit_size == 64);
+   assert(bit_size == 16 || bit_size == 32 || bit_size == 64);
 
    SpvId result = emit_int_const(ctx, bit_size, value);
    if (num_components == 1)
@@ -1574,11 +1574,15 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
    UNOP(nir_op_fddy, SpvOpDPdy)
    UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
    UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
+   UNOP(nir_op_f2i16, SpvOpConvertFToS)
+   UNOP(nir_op_f2u16, SpvOpConvertFToU)
    UNOP(nir_op_f2i32, SpvOpConvertFToS)
    UNOP(nir_op_f2u32, SpvOpConvertFToU)
    UNOP(nir_op_i2f32, SpvOpConvertSToF)
    UNOP(nir_op_u2f32, SpvOpConvertUToF)
+   UNOP(nir_op_i2i16, SpvOpSConvert)
    UNOP(nir_op_i2i32, SpvOpSConvert)
+   UNOP(nir_op_u2u16, SpvOpUConvert)
    UNOP(nir_op_u2u32, SpvOpUConvert)
    UNOP(nir_op_f2f32, SpvOpFConvert)
    UNOP(nir_op_f2i64, SpvOpConvertFToS)
@@ -1599,6 +1603,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
          result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
       break;
 
+   case nir_op_b2i16:
    case nir_op_b2i32:
    case nir_op_b2i64:
       assert(nir_op_infos[alu->op].num_inputs == 1);
@@ -1816,6 +1821,8 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
       if (bit_size == 1)
          constant = spirv_builder_const_bool(&ctx->builder,
                                              load_const->value[0].b);
+      else if (bit_size == 16)
+         constant = emit_uint_const(ctx, bit_size, load_const->value[0].u16);
       else if (bit_size == 32)
          constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
       else if (bit_size == 64)
@@ -3576,6 +3583,8 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info, bool spir
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
    }
 
+   if (s->info.bit_sizes_int & 16)
+      spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt16);
    if (s->info.bit_sizes_int & 64)
       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt64);
    if (s->info.bit_sizes_float & 64)
index dfb3323..a6fe423 100644 (file)
@@ -1419,7 +1419,7 @@ spirv_builder_const_int(struct spirv_builder *b, int width, int64_t val)
 SpvId
 spirv_builder_const_uint(struct spirv_builder *b, int width, uint64_t val)
 {
-   assert(width >= 32);
+   assert(width >= 16);
    SpvId type = spirv_builder_type_uint(b, width);
    if (width <= 32)
       return emit_constant_32(b, type, val);