compiler/types: Add support for Cooperative Matrix types
authorCaio Oliveira <caio.oliveira@intel.com>
Wed, 31 May 2023 06:26:14 +0000 (23:26 -0700)
committerMarge Bot <emma+marge@anholt.net>
Thu, 28 Sep 2023 07:35:02 +0000 (07:35 +0000)
Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>

src/compiler/glsl/ast_to_hir.cpp
src/compiler/glsl/gl_nir_link_uniform_initializers.c
src/compiler/glsl/ir_clone.cpp
src/compiler/glsl_types.cpp
src/compiler/glsl_types.h
src/compiler/nir/nir.c
src/compiler/nir_types.cpp
src/compiler/nir_types.h
src/intel/compiler/brw_shader.cpp
src/intel/compiler/brw_vec4_visitor.cpp
src/mesa/main/uniform_query.cpp

index f91d1ae..b972160 100644 (file)
@@ -1191,6 +1191,9 @@ do_comparison(void *mem_ctx, int operation, ir_rvalue *op0, ir_rvalue *op1)
        * ignores the sampler present in the type.
        */
       break;
+
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
+      unreachable("unsupported base type cooperative matrix");
    }
 
    if (cmp == NULL)
index 74e52d8..80cd6a1 100644 (file)
@@ -169,6 +169,8 @@ copy_constant_to_storage(union gl_constant_value *storage,
              */
             assert(!"Should not get here.");
             break;
+         case GLSL_TYPE_COOPERATIVE_MATRIX:
+            unreachable("unsupported base type cooperative matrix");
          }
          i += dmul;
       }
index 5a38447..059ae57 100644 (file)
@@ -370,6 +370,9 @@ ir_constant::clone(void *mem_ctx, struct hash_table *ht) const
    case GLSL_TYPE_INTERFACE:
       assert(!"Should not get here.");
       break;
+
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
+      unreachable("unsupported base type cooperative matrix");
    }
 
    return NULL;
index e9871d2..d563b20 100644 (file)
@@ -51,6 +51,7 @@ static struct {
 
    hash_table *explicit_matrix_types;
    hash_table *array_types;
+   hash_table *cmat_types;
    hash_table *struct_types;
    hash_table *interface_types;
    hash_table *subroutine_types;
@@ -391,6 +392,7 @@ const glsl_type *glsl_type::get_bare_type() const
       return get_array_instance(this->fields.array->get_bare_type(),
                                 this->length);
 
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_SAMPLER:
    case GLSL_TYPE_TEXTURE:
    case GLSL_TYPE_IMAGE:
@@ -527,6 +529,19 @@ make_array_type(linear_ctx *lin_ctx, const glsl_type *element_type, unsigned len
    return t;
 }
 
+static const char *
+glsl_cmat_use_to_string(enum glsl_cmat_use use)
+{
+   switch (use) {
+   case GLSL_CMAT_USE_NONE:        return "NONE";
+   case GLSL_CMAT_USE_A:           return "A";
+   case GLSL_CMAT_USE_B:           return "B";
+   case GLSL_CMAT_USE_ACCUMULATOR: return "ACCUMULATOR";
+   default:
+      unreachable("invalid cooperative matrix use");
+   }
+};
+
 const glsl_type *
 glsl_type::vec(unsigned components, const glsl_type *const ts[])
 {
@@ -1250,6 +1265,68 @@ glsl_type::get_array_instance(const glsl_type *element,
    return t;
 }
 
+static const struct glsl_type *
+make_cmat_type(linear_ctx *lin_ctx, const struct glsl_cmat_description desc)
+{
+   assert(lin_ctx != NULL);
+
+   struct glsl_type *t = linear_zalloc(lin_ctx, struct glsl_type);
+   t->base_type = GLSL_TYPE_COOPERATIVE_MATRIX;
+   t->sampled_type = GLSL_TYPE_VOID;
+   t->vector_elements = 1;
+   t->cmat_desc = desc;
+
+   const struct glsl_type *element_type = glsl_type::get_instance(desc.element_type, 1, 1);
+   t->name_id = (uintptr_t ) linear_asprintf(lin_ctx, "coopmat<%s, %s, %u, %u, %s>",
+                                             glsl_get_type_name(element_type),
+                                             mesa_scope_name((mesa_scope)desc.scope),
+                                             desc.rows, desc.cols,
+                                             glsl_cmat_use_to_string((enum glsl_cmat_use)desc.use));
+
+   return t;
+}
+
+const glsl_type *
+glsl_type::get_cmat_instance(const struct glsl_cmat_description desc)
+{
+   STATIC_ASSERT(sizeof(struct glsl_cmat_description) == 4);
+
+   const uint32_t key = desc.element_type | desc.scope << 5 |
+                        desc.rows << 8 | desc.cols << 16 |
+                        desc.use << 24;
+   const uint32_t key_hash = _mesa_hash_uint(&key);
+
+   simple_mtx_lock(&glsl_type_cache_mutex);
+   assert(glsl_type_cache.users > 0);
+   void *mem_ctx = glsl_type_cache.mem_ctx;
+
+   if (glsl_type_cache.cmat_types == NULL) {
+      glsl_type_cache.cmat_types =
+         _mesa_hash_table_create_u32_keys(mem_ctx);
+   }
+   hash_table *cmat_types = glsl_type_cache.cmat_types;
+
+   const struct hash_entry *entry = _mesa_hash_table_search_pre_hashed(
+      cmat_types, key_hash, (void *) (uintptr_t) key);
+   if (entry == NULL) {
+      const struct glsl_type *t = make_cmat_type(glsl_type_cache.lin_ctx, desc);
+      entry = _mesa_hash_table_insert_pre_hashed(cmat_types, key_hash,
+                                                 (void *) (uintptr_t) key, (void *) t);
+   }
+
+   const struct glsl_type *t = (const struct glsl_type *)entry->data;
+   simple_mtx_unlock(&glsl_type_cache_mutex);
+
+   assert(t->base_type == GLSL_TYPE_COOPERATIVE_MATRIX);
+   assert(t->cmat_desc.element_type == desc.element_type);
+   assert(t->cmat_desc.scope == desc.scope);
+   assert(t->cmat_desc.rows == desc.rows);
+   assert(t->cmat_desc.cols == desc.cols);
+   assert(t->cmat_desc.use == desc.use);
+
+   return t;
+}
+
 bool
 glsl_type::compare_no_precision(const glsl_type *b) const
 {
@@ -1679,6 +1756,7 @@ glsl_type::component_slots() const
    case GLSL_TYPE_SUBROUTINE:
       return 1;
 
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_ATOMIC_UINT:
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
@@ -1745,6 +1823,7 @@ glsl_type::component_slots_aligned(unsigned offset) const
    case GLSL_TYPE_SUBROUTINE:
       return 1;
 
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_ATOMIC_UINT:
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
@@ -2599,6 +2678,10 @@ glsl_type::get_explicit_type_for_size_align(glsl_type_size_align_func type_info,
       type_info(this, size, alignment);
       assert(*alignment > 0);
       return this;
+   } else if (this->is_cmat()) {
+      *size = 0;
+      *alignment = 0;
+      return this;
    } else if (this->is_scalar()) {
       type_info(this, size, alignment);
       assert(*size == explicit_type_scalar_byte_size(this));
@@ -2822,6 +2905,7 @@ glsl_type::count_vec4_slots(bool is_gl_vertex_input, bool is_bindless) const
    case GLSL_TYPE_SUBROUTINE:
       return 1;
 
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_ATOMIC_UINT:
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
@@ -2925,6 +3009,7 @@ union packed_type {
       unsigned length:13;
       unsigned explicit_stride:14;
    } array;
+   glsl_cmat_description cmat_desc;
    struct {
       unsigned base_type:5;
       unsigned interface_packing_or_packed:2;
@@ -3039,6 +3124,10 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
          blob_write_uint32(blob, type->explicit_stride);
       encode_type_to_blob(blob, type->fields.array);
       return;
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
+      encoded.cmat_desc = type->cmat_desc;
+      blob_write_uint32(blob, encoded.u32);
+      return;
    case GLSL_TYPE_STRUCT:
    case GLSL_TYPE_INTERFACE:
       encoded.strct.length = MIN2(type->length, 0xfffff);
@@ -3145,6 +3234,9 @@ decode_type_from_blob(struct blob_reader *blob)
       return glsl_type::get_array_instance(decode_type_from_blob(blob),
                                            length, explicit_stride);
    }
+   case GLSL_TYPE_COOPERATIVE_MATRIX: {
+      return glsl_type::get_cmat_instance(encoded.cmat_desc);
+   }
    case GLSL_TYPE_STRUCT:
    case GLSL_TYPE_INTERFACE: {
       char *name = blob_read_string(blob);
index 06e1096..9d2e704 100644 (file)
@@ -76,6 +76,7 @@ enum glsl_base_type {
    GLSL_TYPE_UINT64,
    GLSL_TYPE_INT64,
    GLSL_TYPE_BOOL,
+   GLSL_TYPE_COOPERATIVE_MATRIX,
    GLSL_TYPE_SAMPLER,
    GLSL_TYPE_TEXTURE,
    GLSL_TYPE_IMAGE,
@@ -167,6 +168,7 @@ glsl_base_type_get_bit_size(const enum glsl_base_type base_type)
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_FLOAT: /* TODO handle mediump */
    case GLSL_TYPE_SUBROUTINE:
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
       return 32;
 
    case GLSL_TYPE_FLOAT16:
@@ -279,6 +281,24 @@ enum {
    GLSL_PRECISION_LOW
 };
 
+enum glsl_cmat_use {
+   GLSL_CMAT_USE_NONE = 0,
+   GLSL_CMAT_USE_A,
+   GLSL_CMAT_USE_B,
+   GLSL_CMAT_USE_ACCUMULATOR,
+};
+
+struct glsl_cmat_description {
+   /* MSVC can't merge bitfields of different types and also sign extend enums,
+    * so use uint8_t for those cases.
+    */
+   uint8_t element_type:5; /* enum glsl_base_type */
+   uint8_t scope:3; /* mesa_scope */
+   uint8_t rows;
+   uint8_t cols;
+   uint8_t use; /* enum glsl_cmat_use */
+};
+
 const char *glsl_get_type_name(const struct glsl_type *type);
 
 struct glsl_type {
@@ -297,6 +317,8 @@ struct glsl_type {
    unsigned interface_packing:2;
    unsigned interface_row_major:1;
 
+   struct glsl_cmat_description cmat_desc;
+
    /**
     * For \c GLSL_TYPE_STRUCT this specifies if the struct is packed or not.
     *
@@ -457,6 +479,11 @@ struct glsl_type {
                                               unsigned explicit_stride = 0);
 
    /**
+    * Get the instance of a cooperative matrix type
+    */
+   static const glsl_type *get_cmat_instance(const struct glsl_cmat_description desc);
+
+   /**
     * Get the instance of a record type
     */
    static const glsl_type *get_struct_instance(const glsl_struct_field *fields,
@@ -931,6 +958,11 @@ struct glsl_type {
       return is_array() && fields.array->is_array();
    }
 
+   bool is_cmat() const
+   {
+      return base_type == GLSL_TYPE_COOPERATIVE_MATRIX;
+   }
+
    /**
     * Query whether or not a type is a record
     */
index 26dedff..6282dea 100644 (file)
@@ -2755,6 +2755,7 @@ nir_get_nir_type_for_glsl_base_type(enum glsl_base_type base_type)
    case GLSL_TYPE_DOUBLE:  return nir_type_float64;
       /* clang-format on */
 
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_SAMPLER:
    case GLSL_TYPE_TEXTURE:
    case GLSL_TYPE_IMAGE:
index e167105..b84c608 100644 (file)
@@ -336,6 +336,12 @@ glsl_type_is_array_or_matrix(const struct glsl_type *type)
 }
 
 bool
+glsl_type_is_cmat(const struct glsl_type *type)
+{
+   return type->is_cmat();
+}
+
+bool
 glsl_type_is_struct(const struct glsl_type *type)
 {
    return type->is_struct();
@@ -643,6 +649,12 @@ glsl_array_type(const glsl_type *element, unsigned array_size,
 }
 
 const glsl_type *
+glsl_cmat_type(const glsl_cmat_description *desc)
+{
+   return glsl_type::get_cmat_instance(*desc);
+}
+
+const glsl_type *
 glsl_replace_vector_type(const glsl_type *t, unsigned components)
 {
    if (glsl_type_is_array(t)) {
@@ -857,6 +869,7 @@ glsl_get_natural_size_align_bytes(const struct glsl_type *type,
 
    case GLSL_TYPE_ATOMIC_UINT:
    case GLSL_TYPE_SUBROUTINE:
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
       unreachable("type does not have a natural size");
@@ -910,6 +923,7 @@ glsl_get_vec4_size_align_bytes(const struct glsl_type *type,
    case GLSL_TYPE_IMAGE:
    case GLSL_TYPE_ATOMIC_UINT:
    case GLSL_TYPE_SUBROUTINE:
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
       unreachable("type does not make sense for glsl_get_vec4_size_align_bytes()");
@@ -1102,3 +1116,17 @@ glsl_type_replace_vec3_with_vec4(const struct glsl_type *type)
 {
    return type->replace_vec3_with_vec4();
 }
+
+const struct glsl_type *
+glsl_get_cmat_element(const struct glsl_type *type)
+{
+   assert(type->base_type == GLSL_TYPE_COOPERATIVE_MATRIX);
+   return glsl_type::get_instance(type->cmat_desc.element_type, 1, 1);
+}
+
+const struct glsl_cmat_description *
+glsl_get_cmat_description(const struct glsl_type *type)
+{
+   assert(type->base_type == GLSL_TYPE_COOPERATIVE_MATRIX);
+   return &type->cmat_desc;
+}
index 22a9ec2..ff6172a 100644 (file)
@@ -140,6 +140,7 @@ bool glsl_type_is_array(const struct glsl_type *type);
 bool glsl_type_is_unsized_array(const struct glsl_type *type);
 bool glsl_type_is_array_of_arrays(const struct glsl_type *type);
 bool glsl_type_is_array_or_matrix(const struct glsl_type *type);
+bool glsl_type_is_cmat(const struct glsl_type *type);
 bool glsl_type_is_struct(const struct glsl_type *type);
 bool glsl_type_is_interface(const struct glsl_type *type);
 bool glsl_type_is_struct_or_ifc(const struct glsl_type *type);
@@ -201,6 +202,8 @@ const struct glsl_type *glsl_array_type(const struct glsl_type *element,
                                         unsigned array_size,
                                         unsigned explicit_stride);
 
+const struct glsl_type *glsl_cmat_type(const struct glsl_cmat_description *desc);
+
 const struct glsl_type *glsl_struct_type(const struct glsl_struct_field *fields,
                                          unsigned num_fields, const char *name,
                                          bool packed);
@@ -254,6 +257,9 @@ int glsl_get_field_index(const struct glsl_type *type, const char *name);
 
 bool glsl_type_is_leaf(const struct glsl_type *type);
 
+const struct glsl_type *glsl_get_cmat_element(const struct glsl_type *type);
+const struct glsl_cmat_description *glsl_get_cmat_description(const struct glsl_type *type);
+
 #ifdef __cplusplus
 }
 #endif
index 88beeb5..423c197 100644 (file)
@@ -74,6 +74,7 @@ brw_type_for_base_type(const struct glsl_type *type)
       return BRW_REGISTER_TYPE_Q;
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
       unreachable("not reached");
    }
 
index 9a56f1e..54866dc 100644 (file)
@@ -622,6 +622,7 @@ type_size_xvec4(const struct glsl_type *type, bool as_vec4, bool bindless)
       return bindless ? 1 : DIV_ROUND_UP(BRW_IMAGE_PARAM_SIZE, 4);
    case GLSL_TYPE_VOID:
    case GLSL_TYPE_ERROR:
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
       unreachable("not reached");
    }
 
index 245c9fa..d186369 100644 (file)
@@ -1011,6 +1011,7 @@ associate_uniform_storage(struct gl_context *ctx,
          case GLSL_TYPE_STRUCT:
          case GLSL_TYPE_ERROR:
          case GLSL_TYPE_INTERFACE:
+         case GLSL_TYPE_COOPERATIVE_MATRIX:
             assert(!"Should not get here.");
             break;
          }