spirv: Implement SPV_KHR_cooperative_matrix
authorCaio Oliveira <caio.oliveira@intel.com>
Sat, 17 Jun 2023 00:02:39 +0000 (17:02 -0700)
committerMarge Bot <emma+marge@anholt.net>
Thu, 28 Sep 2023 07:35:02 +0000 (07:35 +0000)
Includes a modified version of using extract/insert for OpLoad/OpStore
from Ian.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> (earlier version)
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl> (earlier version)
Acked-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>

src/compiler/shader_info.h
src/compiler/spirv/meson.build
src/compiler/spirv/spirv_to_nir.c
src/compiler/spirv/vtn_alu.c
src/compiler/spirv/vtn_cmat.c [new file with mode: 0644]
src/compiler/spirv/vtn_private.h
src/compiler/spirv/vtn_variables.c

index 785473a..51000b2 100644 (file)
@@ -47,6 +47,7 @@ struct spirv_supported_capabilities {
    bool amd_shader_explicit_vertex_parameter;
    bool amd_trinary_minmax;
    bool atomic_storage;
+   bool cooperative_matrix;
    bool demote_to_helper_invocation;
    bool derivative_group;
    bool descriptor_array_dynamic_indexing;
index 06dc9f7..dfb53d6 100644 (file)
@@ -51,6 +51,7 @@ files_libvtn = files(
   'vtn_alu.c',
   'vtn_amd.c',
   'vtn_cfg.c',
+  'vtn_cmat.c',
   'vtn_glsl450.c',
   'vtn_opencl.c',
   'vtn_private.h',
index c876f85..5b20c5c 100644 (file)
@@ -266,7 +266,10 @@ vtn_undef_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
    struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
    val->type = glsl_get_bare_type(type);
 
-   if (glsl_type_is_vector_or_scalar(type)) {
+   if (glsl_type_is_cmat(type)) {
+      nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_undef");
+      vtn_set_ssa_value_var(b, val, mat->var);
+   } else if (glsl_type_is_vector_or_scalar(type)) {
       unsigned num_components = glsl_get_vector_elements(val->type);
       unsigned bit_size = glsl_get_bit_size(val->type);
       val->def = nir_undef(&b->nb, num_components, bit_size);
@@ -296,7 +299,15 @@ vtn_const_ssa_value(struct vtn_builder *b, nir_constant *constant,
    struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
    val->type = glsl_get_bare_type(type);
 
-   if (glsl_type_is_vector_or_scalar(type)) {
+   if (glsl_type_is_cmat(type)) {
+      const struct glsl_type *element_type = glsl_get_cmat_element(type);
+
+      nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_constant");
+      nir_cmat_construct(&b->nb, &mat->def,
+                         nir_build_imm(&b->nb, 1, glsl_get_bit_size(element_type),
+                                       constant->values));
+      vtn_set_ssa_value_var(b, val, mat->var);
+   } else if (glsl_type_is_vector_or_scalar(type)) {
       val->def = nir_build_imm(&b->nb, glsl_get_vector_elements(val->type),
                                glsl_get_bit_size(val->type),
                                constant->values);
@@ -859,6 +870,7 @@ vtn_types_compatible(struct vtn_builder *b,
    case vtn_base_type_sampler:
    case vtn_base_type_sampled_image:
    case vtn_base_type_event:
+   case vtn_base_type_cooperative_matrix:
       return t1->type == t2->type;
 
    case vtn_base_type_array:
@@ -921,6 +933,7 @@ vtn_type_copy(struct vtn_builder *b, struct vtn_type *src)
    case vtn_base_type_event:
    case vtn_base_type_accel_struct:
    case vtn_base_type_ray_query:
+   case vtn_base_type_cooperative_matrix:
       /* Nothing more to do */
       break;
 
@@ -1951,6 +1964,10 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   case SpvOpTypeCooperativeMatrixKHR:
+      vtn_handle_cooperative_type(b, val, opcode, w, count);
+      break;
+
    case SpvOpTypeEvent:
       val->type->base_type = vtn_base_type_event;
       /*
@@ -2135,9 +2152,11 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
    case SpvOpSpecConstantComposite:
    case SpvOpConstantComposite: {
       unsigned elem_count = count - 3;
-      vtn_fail_if(elem_count != val->type->length,
+      unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ?
+         1 : val->type->length;
+      vtn_fail_if(elem_count != expected_length,
                   "%s has %u constituents, expected %u",
-                  spirv_op_to_string(opcode), elem_count, val->type->length);
+                  spirv_op_to_string(opcode), elem_count, expected_length);
 
       nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
       val->is_undef_constant = true;
@@ -2173,6 +2192,10 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
          val->constant->elements = elems;
          break;
 
+      case vtn_base_type_cooperative_matrix:
+         val->constant->values[0] = elems[0]->values[0];
+         break;
+
       default:
          vtn_fail("Result type of %s must be a composite type",
                   spirv_op_to_string(opcode));
@@ -2685,7 +2708,7 @@ vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type)
    if (!glsl_type_is_vector_or_scalar(type)) {
       unsigned elems = glsl_get_length(val->type);
       val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
-      if (glsl_type_is_array_or_matrix(type)) {
+      if (glsl_type_is_array_or_matrix(type) || glsl_type_is_cmat(type)) {
          const struct glsl_type *elem_type = glsl_get_array_element(type);
          for (unsigned i = 0; i < elems; i++)
             val->elems[i] = vtn_create_ssa_value(b, elem_type);
@@ -4216,6 +4239,9 @@ vtn_composite_insert(struct vtn_builder *b, struct vtn_ssa_value *src,
                      struct vtn_ssa_value *insert, const uint32_t *indices,
                      unsigned num_indices)
 {
+   if (glsl_type_is_cmat(src->type))
+      return vtn_cooperative_matrix_insert(b, src, insert, indices, num_indices);
+
    struct vtn_ssa_value *dest = vtn_composite_copy(b, src);
 
    struct vtn_ssa_value *cur = dest;
@@ -4254,6 +4280,9 @@ static struct vtn_ssa_value *
 vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
                       const uint32_t *indices, unsigned num_indices)
 {
+   if (glsl_type_is_cmat(src->type))
+      return vtn_cooperative_matrix_extract(b, src, indices, num_indices);
+
    struct vtn_ssa_value *cur = src;
    for (unsigned i = 0; i < num_indices; i++) {
       if (glsl_type_is_vector_or_scalar(cur->type)) {
@@ -4310,7 +4339,12 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
    case SpvOpCompositeConstruct: {
       unsigned elems = count - 3;
       assume(elems >= 1);
-      if (glsl_type_is_vector_or_scalar(type->type)) {
+      if (type->base_type == vtn_base_type_cooperative_matrix) {
+         vtn_assert(elems == 1);
+         nir_deref_instr *mat = vtn_create_cmat_temporary(b, type->type, "cmat_construct");
+         nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3]));
+         vtn_set_ssa_value_var(b, ssa, mat->var);
+      } else if (glsl_type_is_vector_or_scalar(type->type)) {
          nir_def *srcs[NIR_MAX_VEC_COMPONENTS];
          for (unsigned i = 0; i < elems; i++) {
             srcs[i] = vtn_get_nir_ssa(b, w[3 + i]);
@@ -5022,6 +5056,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          spv_check_supported(shader_enqueue, cap);
          break;
 
+      case SpvCapabilityCooperativeMatrixKHR:
+         spv_check_supported(cooperative_matrix, cap);
+         break;
+
       default:
          vtn_fail("Unhandled capability: %s (%u)",
                   spirv_capability_to_string(cap), cap);
@@ -5656,6 +5694,7 @@ vtn_handle_variable_or_type_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpTypePipe:
    case SpvOpTypeAccelerationStructureKHR:
    case SpvOpTypeRayQueryKHR:
+   case SpvOpTypeCooperativeMatrixKHR:
       vtn_handle_type(b, opcode, w, count);
       break;
 
@@ -6621,6 +6660,13 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpFinishWritingNodePayloadAMDX:
       break;
 
+   case SpvOpCooperativeMatrixLoadKHR:
+   case SpvOpCooperativeMatrixStoreKHR:
+   case SpvOpCooperativeMatrixLengthKHR:
+   case SpvOpCooperativeMatrixMulAddKHR:
+      vtn_handle_cooperative_instruction(b, opcode, w, count);
+      break;
+
    default:
       vtn_fail_with_opcode("Unhandled opcode", opcode);
    }
index 4cba604..04d71ae 100644 (file)
@@ -597,6 +597,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
 
+   if (glsl_type_is_cmat(dest_type)) {
+      vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
+      return;
+   }
+
    vtn_handle_no_contraction(b, dest_val);
    bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
 
@@ -1297,6 +1302,11 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
     */
 
    struct vtn_type *type = vtn_get_type(b, w[1]);
+   if (type->base_type == vtn_base_type_cooperative_matrix) {
+      vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
+      return;
+   }
+
    struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
 
    vtn_fail_if(src->num_components * src->bit_size !=
diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c
new file mode 100644 (file)
index 0000000..2e28984
--- /dev/null
@@ -0,0 +1,296 @@
+/*
+ * Copyright 2023 Intel Corporation
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "glsl_types.h"
+#include "nir.h"
+#include "nir_types.h"
+#include "vtn_private.h"
+
+static enum glsl_cmat_use
+vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
+{
+   switch (use) {
+   case SpvCooperativeMatrixUseMatrixAKHR:
+      return GLSL_CMAT_USE_A;
+   case SpvCooperativeMatrixUseMatrixBKHR:
+      return GLSL_CMAT_USE_B;
+   case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
+      return GLSL_CMAT_USE_ACCUMULATOR;
+   default:
+      unreachable("Unexpected cooperative matrix use");
+   }
+}
+
+void
+vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
+                            SpvOp opcode, const uint32_t *w, unsigned count)
+{
+   vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
+
+   struct vtn_type *component_type = vtn_get_type(b, w[2]);
+
+   const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
+   const uint32_t rows = vtn_constant_uint(b, w[4]);
+   const uint32_t cols = vtn_constant_uint(b, w[5]);
+
+   vtn_assert(rows < 256);
+   vtn_assert(cols < 256);
+
+   enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
+
+   val->type->base_type = vtn_base_type_cooperative_matrix;
+   vtn_fail_if(!glsl_type_is_numeric(component_type->type),
+               "OpTypeCooperativeMatrixKHR "
+               "Component Type must be a scalar numerical type.");
+
+   val->type->desc.element_type = glsl_get_base_type(component_type->type);
+   val->type->desc.scope = scope;
+   val->type->desc.rows = rows;
+   val->type->desc.cols = cols;
+   val->type->desc.use = use;
+
+   val->type->type = glsl_cmat_type(&val->type->desc);
+   val->type->component_type = component_type;
+}
+
+static enum glsl_matrix_layout
+vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
+{
+   switch (layout) {
+   case SpvCooperativeMatrixLayoutRowMajorKHR:
+      return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
+   case SpvCooperativeMatrixLayoutColumnMajorKHR:
+      return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
+   default:
+      unreachable("Unexpected cooperative matrix layout");
+   }
+}
+
+nir_deref_instr *
+vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
+{
+   nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
+   return nir_build_deref_var(&b->nb, var);
+}
+
+static nir_deref_instr *
+vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
+{
+   nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
+   vtn_assert(glsl_type_is_cmat(deref->type));
+   return deref;
+}
+
+void
+vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
+                                   const uint32_t *w, unsigned count)
+{
+   switch (opcode) {
+   case SpvOpCooperativeMatrixLoadKHR: {
+      struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
+      struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
+      struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+
+      const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
+      nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
+
+      SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
+      if (count > 6) {
+         unsigned idx = 6, alignment;
+         SpvScope scope;
+         vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
+         vtn_emit_make_visible_barrier(b, access, scope, src->mode);
+      }
+
+      nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
+      nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
+                    .matrix_layout = vtn_matrix_layout_to_glsl(layout));
+      vtn_push_var_ssa(b, w[2], dst->var);
+      break;
+   }
+
+   case SpvOpCooperativeMatrixStoreKHR: {
+      struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
+      struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
+
+      const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
+      nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
+
+      SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
+      if (count > 5) {
+         unsigned idx = 5, alignment;
+         SpvScope scope;
+         vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
+         vtn_emit_make_available_barrier(b, access, scope, dest->mode);
+      }
+
+      nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
+      nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
+                     .matrix_layout = vtn_matrix_layout_to_glsl(layout));
+      break;
+   }
+
+   case SpvOpCooperativeMatrixLengthKHR: {
+      struct vtn_type *type = vtn_get_type(b, w[3]);
+      nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
+      vtn_push_nir_ssa(b, w[2], def);
+      break;
+   }
+
+   case SpvOpCooperativeMatrixMulAddKHR: {
+      nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
+      nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
+      nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
+
+      const uint32_t operands = count > 6 ? w[6] : 0;
+      const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
+      const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
+                                               SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
+                                               SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
+                                               SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
+
+      STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
+      STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
+      STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
+      STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
+
+      struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+      nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
+
+      nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
+                      .saturate = saturate,
+                      .cmat_signed_mask = signed_mask);
+
+      vtn_push_var_ssa(b, w[2], dst->var);
+      break;
+   }
+
+   case SpvOpBitcast: {
+      struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+      vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
+      nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
+
+      nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
+      nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
+      vtn_push_var_ssa(b, w[2], dst->var);
+      break;
+   }
+
+   default:
+      unreachable("Unexpected opcode for cooperative matrix instruction");
+   }
+}
+
+void
+vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
+                           const struct glsl_type *dest_type, SpvOp opcode,
+                           const uint32_t *w, unsigned count)
+{
+      vtn_assert(glsl_type_is_cmat(dest_type));
+
+      switch (opcode) {
+      case SpvOpConvertFToU:
+      case SpvOpConvertFToS:
+      case SpvOpConvertSToF:
+      case SpvOpConvertUToF:
+      case SpvOpUConvert:
+      case SpvOpSConvert:
+      case SpvOpFConvert:
+      case SpvOpFNegate:
+      case SpvOpSNegate: {
+         struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+         nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
+
+         unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
+         unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
+
+         bool ignored = false;
+         nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
+                                                     src_bit_size, dst_bit_size);
+
+         nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
+         nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
+                           .alu_op = op);
+         vtn_push_var_ssa(b, w[2], dst->var);
+         break;
+      }
+
+      case SpvOpFAdd:
+      case SpvOpFSub:
+      case SpvOpFMul:
+      case SpvOpFDiv:
+      case SpvOpIAdd:
+      case SpvOpISub:
+      case SpvOpIMul:
+      case SpvOpSDiv:
+      case SpvOpUDiv: {
+         bool ignored = false;
+         nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
+
+         struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+         nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
+         nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
+
+         nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
+         nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
+                            .alu_op = op);
+         vtn_push_var_ssa(b, w[2], dst->var);
+         break;
+      }
+
+      case SpvOpMatrixTimesScalar: {
+         struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+         nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
+
+         struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
+         vtn_assert(glsl_type_is_scalar(scalar_val->type));
+         nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
+
+         nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
+         nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
+                            .alu_op = op);
+         vtn_push_var_ssa(b, w[2], dst->var);
+         break;
+      }
+
+      default:
+         unreachable("invalid cooperative matrix alu instruction");
+      }
+}
+
+struct vtn_ssa_value *
+vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
+                               const uint32_t *indices, unsigned num_indices)
+{
+   vtn_assert(glsl_type_is_cmat(mat->type));
+   nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
+
+   vtn_assert(num_indices == 1);
+   nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
+
+   const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
+   struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
+   ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
+   return ret;
+}
+
+struct vtn_ssa_value *
+vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
+                              struct vtn_ssa_value *insert, const uint32_t *indices,
+                              unsigned num_indices)
+{
+   vtn_assert(glsl_type_is_cmat(mat->type));
+   nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
+
+   vtn_assert(num_indices == 1);
+   nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
+
+   nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
+   nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
+
+   struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
+   vtn_set_ssa_value_var(b, ret, dst->var);
+   return ret;
+}
index 66c5cdb..02fe2f2 100644 (file)
@@ -283,6 +283,7 @@ enum vtn_base_type {
    vtn_base_type_ray_query,
    vtn_base_type_function,
    vtn_base_type_event,
+   vtn_base_type_cooperative_matrix,
 };
 
 struct vtn_type {
@@ -391,6 +392,12 @@ struct vtn_type {
          /* Return type for functions */
          struct vtn_type *return_type;
       };
+
+      /* Members for cooperative matrix types. */
+      struct {
+         struct glsl_cmat_description desc;
+         struct vtn_type *component_type;
+      };
    };
 };
 
@@ -1048,4 +1055,20 @@ void vtn_emit_make_visible_barrier(struct vtn_builder *b, SpvMemoryAccessMask ac
 void vtn_emit_make_available_barrier(struct vtn_builder *b, SpvMemoryAccessMask access,
                                      SpvScope scope, enum vtn_variable_mode mode);
 
+
+void vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
+                                 SpvOp opcode, const uint32_t *w, unsigned count);
+void vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
+                                        const uint32_t *w, unsigned count);
+void vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
+                                const struct glsl_type *dest_type, SpvOp opcode,
+                                const uint32_t *w, unsigned count);
+struct vtn_ssa_value *vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
+                                                     const uint32_t *indices, unsigned num_indices);
+struct vtn_ssa_value *vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
+                                                    struct vtn_ssa_value *insert,
+                                                    const uint32_t *indices, unsigned num_indices);
+nir_deref_instr *vtn_create_cmat_temporary(struct vtn_builder *b,
+                                           const struct glsl_type *t, const char *name);
+
 #endif /* _VTN_PRIVATE_H_ */
index 9e14d69..d00e3d7 100644 (file)
@@ -474,8 +474,15 @@ vtn_pointer_dereference(struct vtn_builder *b,
          nir_def *arr_index =
             vtn_access_link_as_ssa(b, deref_chain->link[idx], 1,
                                    tail->def.bit_size);
+         if (type->base_type == vtn_base_type_cooperative_matrix) {
+            const struct glsl_type *element_type = glsl_get_cmat_element(type->type);
+            tail = nir_build_deref_cast(&b->nb, &tail->def, tail->modes,
+                                        glsl_array_type(element_type, 0, 0), 0);
+            type = type->component_type;
+         } else {
+            type = type->array_element;
+         }
          tail = nir_build_deref_array(&b->nb, tail, arr_index);
-         type = type->array_element;
       }
       tail->arr.in_bounds = deref_chain->in_bounds;
 
@@ -510,7 +517,16 @@ _vtn_local_load_store(struct vtn_builder *b, bool load, nir_deref_instr *deref,
                       struct vtn_ssa_value *inout,
                       enum gl_access_qualifier access)
 {
-   if (glsl_type_is_vector_or_scalar(deref->type)) {
+   if (glsl_type_is_cmat(deref->type)) {
+      if (load) {
+         nir_deref_instr *temp = vtn_create_cmat_temporary(b, deref->type, "cmat_ssa");
+         nir_cmat_copy(&b->nb, &temp->def, &deref->def);
+         vtn_set_ssa_value_var(b, inout, temp->var);
+      } else {
+         nir_deref_instr *src_deref = vtn_get_deref_for_ssa_value(b, inout);
+         nir_cmat_copy(&b->nb, &deref->def, &src_deref->def);
+      }
+   } else if (glsl_type_is_vector_or_scalar(deref->type)) {
       if (load) {
          inout->def = nir_load_deref_with_access(&b->nb, deref, access);
       } else {
@@ -555,7 +571,17 @@ get_deref_tail(nir_deref_instr *deref)
    nir_deref_instr *parent =
       nir_instr_as_deref(deref->parent.ssa->parent_instr);
 
-   if (glsl_type_is_vector(parent->type))
+   if (parent->deref_type == nir_deref_type_cast &&
+       parent->parent.ssa->parent_instr->type == nir_instr_type_deref) {
+      nir_deref_instr *grandparent =
+         nir_instr_as_deref(parent->parent.ssa->parent_instr);
+
+      if (glsl_type_is_cmat(grandparent->type))
+         return grandparent;
+   }
+
+   if (glsl_type_is_vector(parent->type) ||
+       glsl_type_is_cmat(parent->type))
       return parent;
    else
       return deref;
@@ -571,7 +597,19 @@ vtn_local_load(struct vtn_builder *b, nir_deref_instr *src,
 
    if (src_tail != src) {
       val->type = src->type;
-      val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
+
+      if (glsl_type_is_cmat(src_tail->type)) {
+         assert(val->is_variable);
+         nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
+
+         /* Reset is_variable because we are repurposing val. */
+         val->is_variable = false;
+         val->def = nir_cmat_extract(&b->nb,
+                                     glsl_get_bit_size(src->type),
+                                     &mat->def, src->arr.index.ssa);
+      } else {
+         val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
+      }
    }
 
    return val;
@@ -587,8 +625,16 @@ vtn_local_store(struct vtn_builder *b, struct vtn_ssa_value *src,
       struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type);
       _vtn_local_load_store(b, true, dest_tail, val, access);
 
-      val->def = nir_vector_insert(&b->nb, val->def, src->def,
-                                   dest->arr.index.ssa);
+      if (glsl_type_is_cmat(dest_tail->type)) {
+         nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
+         nir_deref_instr *dst = vtn_create_cmat_temporary(b, dest_tail->type, "cmat_insert");
+         nir_cmat_insert(&b->nb, &dst->def, src->def, &mat->def, dest->arr.index.ssa);
+         vtn_set_ssa_value_var(b, val, dst->var);
+      } else {
+         val->def = nir_vector_insert(&b->nb, val->def, src->def,
+                                      dest->arr.index.ssa);
+      }
+
       _vtn_local_load_store(b, false, dest_tail, val, access);
    } else {
       _vtn_local_load_store(b, false, dest_tail, src, access);
@@ -654,6 +700,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
    case GLSL_TYPE_FLOAT16:
    case GLSL_TYPE_BOOL:
    case GLSL_TYPE_DOUBLE:
+   case GLSL_TYPE_COOPERATIVE_MATRIX:
       if (glsl_type_is_vector_or_scalar(ptr->type->type)) {
          /* We hit a vector or scalar; go ahead and emit the load[s] */
          nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);