zink: change function params and asserts to permit 64bit types in ntv
authorMike Blumenkrantz <michael.blumenkrantz@gmail.com>
Thu, 5 Nov 2020 17:57:37 +0000 (12:57 -0500)
committerMarge Bot <eric+marge@anholt.net>
Fri, 18 Dec 2020 01:07:01 +0000 (01:07 +0000)
Reviewed-by: Erik Faye-Lund <kusmabite@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7654>

src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c

index 3e7a53c..efcfab0 100644 (file)
@@ -76,15 +76,15 @@ struct ntv_context {
 
 static SpvId
 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
-                  unsigned num_components, float value);
+                  unsigned num_components, double value);
 
 static SpvId
 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
-                  unsigned num_components, uint32_t value);
+                  unsigned num_components, uint64_t value);
 
 static SpvId
 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
-                  unsigned num_components, int32_t value);
+                  unsigned num_components, int64_t value);
 
 static SpvId
 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
@@ -117,30 +117,30 @@ block_label(struct ntv_context *ctx, nir_block *block)
 }
 
 static SpvId
-emit_float_const(struct ntv_context *ctx, int bit_size, float value)
+emit_float_const(struct ntv_context *ctx, int bit_size, double value)
 {
-   assert(bit_size == 32);
+   assert(bit_size == 32 || bit_size == 64);
    return spirv_builder_const_float(&ctx->builder, bit_size, value);
 }
 
 static SpvId
-emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
+emit_uint_const(struct ntv_context *ctx, int bit_size, uint64_t value)
 {
-   assert(bit_size == 32);
+   assert(bit_size == 32 || bit_size == 64);
    return spirv_builder_const_uint(&ctx->builder, bit_size, value);
 }
 
 static SpvId
-emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
+emit_int_const(struct ntv_context *ctx, int bit_size, int64_t value)
 {
-   assert(bit_size == 32);
+   assert(bit_size == 32 || bit_size == 64);
    return spirv_builder_const_int(&ctx->builder, bit_size, value);
 }
 
 static SpvId
 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 32); // only 32-bit floats supported so far
+   assert(bit_size == 32 || bit_size == 64);
 
    SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
    if (num_components > 1)
@@ -154,7 +154,7 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 32); // only 32-bit ints supported so far
+   assert(bit_size == 32 || bit_size == 64);
 
    SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
    if (num_components > 1)
@@ -168,7 +168,7 @@ get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
 static SpvId
 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
 {
-   assert(bit_size == 32); // only 32-bit uints supported so far
+   assert(bit_size == 32 || bit_size == 64);
 
    SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
    if (num_components > 1)
@@ -196,7 +196,7 @@ get_storage_class(struct nir_variable *var)
 static SpvId
 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
 {
-   unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
+   unsigned bit_size = nir_dest_bit_size(*dest);
    return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
 }
 
@@ -656,7 +656,7 @@ get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_c
 {
    if (bit_size == 1)
       return get_bvec_type(ctx, num_components);
-   if (bit_size == 32)
+   if (bit_size == 32 || bit_size == 64)
       return get_uvec_type(ctx, bit_size, num_components);
    unreachable("unhandled register bit size");
    return 0;
@@ -728,7 +728,7 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
       return def;
 
    int bit_size = nir_src_bit_size(alu->src[src].src);
-   assert(bit_size == 1 || bit_size == 32);
+   assert(bit_size == 1 || bit_size == 32 || bit_size == 64);
 
    SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
                                     spirv_builder_type_uint(&ctx->builder, bit_size);
@@ -1059,9 +1059,9 @@ emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
 
 static SpvId
 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
-                  unsigned num_components, float value)
+                  unsigned num_components, double value)
 {
-   assert(bit_size == 32);
+   assert(bit_size == 32 || bit_size == 64);
 
    SpvId result = emit_float_const(ctx, bit_size, value);
    if (num_components == 1)
@@ -1079,9 +1079,9 @@ get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
 
 static SpvId
 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
-                  unsigned num_components, uint32_t value)
+                  unsigned num_components, uint64_t value)
 {
-   assert(bit_size == 32);
+   assert(bit_size == 32 || bit_size == 64);
 
    SpvId result = emit_uint_const(ctx, bit_size, value);
    if (num_components == 1)
@@ -1099,9 +1099,9 @@ get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
 
 static SpvId
 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
-                  unsigned num_components, int32_t value)
+                  unsigned num_components, int64_t value)
 {
-   assert(bit_size == 32);
+   assert(bit_size == 32 || bit_size == 64);
 
    SpvId result = emit_int_const(ctx, bit_size, value);
    if (num_components == 1)