bool amd_shader_explicit_vertex_parameter;
bool amd_trinary_minmax;
bool atomic_storage;
+ bool cooperative_matrix;
bool demote_to_helper_invocation;
bool derivative_group;
bool descriptor_array_dynamic_indexing;
'vtn_alu.c',
'vtn_amd.c',
'vtn_cfg.c',
+ 'vtn_cmat.c',
'vtn_glsl450.c',
'vtn_opencl.c',
'vtn_private.h',
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = glsl_get_bare_type(type);
- if (glsl_type_is_vector_or_scalar(type)) {
+ if (glsl_type_is_cmat(type)) {
+ nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_undef");
+ vtn_set_ssa_value_var(b, val, mat->var);
+ } else if (glsl_type_is_vector_or_scalar(type)) {
unsigned num_components = glsl_get_vector_elements(val->type);
unsigned bit_size = glsl_get_bit_size(val->type);
val->def = nir_undef(&b->nb, num_components, bit_size);
struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value);
val->type = glsl_get_bare_type(type);
- if (glsl_type_is_vector_or_scalar(type)) {
+ if (glsl_type_is_cmat(type)) {
+ const struct glsl_type *element_type = glsl_get_cmat_element(type);
+
+ nir_deref_instr *mat = vtn_create_cmat_temporary(b, type, "cmat_constant");
+ nir_cmat_construct(&b->nb, &mat->def,
+ nir_build_imm(&b->nb, 1, glsl_get_bit_size(element_type),
+ constant->values));
+ vtn_set_ssa_value_var(b, val, mat->var);
+ } else if (glsl_type_is_vector_or_scalar(type)) {
val->def = nir_build_imm(&b->nb, glsl_get_vector_elements(val->type),
glsl_get_bit_size(val->type),
constant->values);
case vtn_base_type_sampler:
case vtn_base_type_sampled_image:
case vtn_base_type_event:
+ case vtn_base_type_cooperative_matrix:
return t1->type == t2->type;
case vtn_base_type_array:
case vtn_base_type_event:
case vtn_base_type_accel_struct:
case vtn_base_type_ray_query:
+ case vtn_base_type_cooperative_matrix:
/* Nothing more to do */
break;
break;
}
+ case SpvOpTypeCooperativeMatrixKHR:
+ vtn_handle_cooperative_type(b, val, opcode, w, count);
+ break;
+
case SpvOpTypeEvent:
val->type->base_type = vtn_base_type_event;
/*
case SpvOpSpecConstantComposite:
case SpvOpConstantComposite: {
unsigned elem_count = count - 3;
- vtn_fail_if(elem_count != val->type->length,
+ unsigned expected_length = val->type->base_type == vtn_base_type_cooperative_matrix ?
+ 1 : val->type->length;
+ vtn_fail_if(elem_count != expected_length,
"%s has %u constituents, expected %u",
- spirv_op_to_string(opcode), elem_count, val->type->length);
+ spirv_op_to_string(opcode), elem_count, expected_length);
nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
val->is_undef_constant = true;
val->constant->elements = elems;
break;
+ case vtn_base_type_cooperative_matrix:
+ val->constant->values[0] = elems[0]->values[0];
+ break;
+
default:
vtn_fail("Result type of %s must be a composite type",
spirv_op_to_string(opcode));
if (!glsl_type_is_vector_or_scalar(type)) {
unsigned elems = glsl_get_length(val->type);
val->elems = ralloc_array(b, struct vtn_ssa_value *, elems);
- if (glsl_type_is_array_or_matrix(type)) {
+ if (glsl_type_is_array_or_matrix(type) || glsl_type_is_cmat(type)) {
const struct glsl_type *elem_type = glsl_get_array_element(type);
for (unsigned i = 0; i < elems; i++)
val->elems[i] = vtn_create_ssa_value(b, elem_type);
struct vtn_ssa_value *insert, const uint32_t *indices,
unsigned num_indices)
{
+ if (glsl_type_is_cmat(src->type))
+ return vtn_cooperative_matrix_insert(b, src, insert, indices, num_indices);
+
struct vtn_ssa_value *dest = vtn_composite_copy(b, src);
struct vtn_ssa_value *cur = dest;
vtn_composite_extract(struct vtn_builder *b, struct vtn_ssa_value *src,
const uint32_t *indices, unsigned num_indices)
{
+ if (glsl_type_is_cmat(src->type))
+ return vtn_cooperative_matrix_extract(b, src, indices, num_indices);
+
struct vtn_ssa_value *cur = src;
for (unsigned i = 0; i < num_indices; i++) {
if (glsl_type_is_vector_or_scalar(cur->type)) {
case SpvOpCompositeConstruct: {
unsigned elems = count - 3;
assume(elems >= 1);
- if (glsl_type_is_vector_or_scalar(type->type)) {
+ if (type->base_type == vtn_base_type_cooperative_matrix) {
+ vtn_assert(elems == 1);
+ nir_deref_instr *mat = vtn_create_cmat_temporary(b, type->type, "cmat_construct");
+ nir_cmat_construct(&b->nb, &mat->def, vtn_get_nir_ssa(b, w[3]));
+ vtn_set_ssa_value_var(b, ssa, mat->var);
+ } else if (glsl_type_is_vector_or_scalar(type->type)) {
nir_def *srcs[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < elems; i++) {
srcs[i] = vtn_get_nir_ssa(b, w[3 + i]);
spv_check_supported(shader_enqueue, cap);
break;
+ case SpvCapabilityCooperativeMatrixKHR:
+ spv_check_supported(cooperative_matrix, cap);
+ break;
+
default:
vtn_fail("Unhandled capability: %s (%u)",
spirv_capability_to_string(cap), cap);
case SpvOpTypePipe:
case SpvOpTypeAccelerationStructureKHR:
case SpvOpTypeRayQueryKHR:
+ case SpvOpTypeCooperativeMatrixKHR:
vtn_handle_type(b, opcode, w, count);
break;
case SpvOpFinishWritingNodePayloadAMDX:
break;
+ case SpvOpCooperativeMatrixLoadKHR:
+ case SpvOpCooperativeMatrixStoreKHR:
+ case SpvOpCooperativeMatrixLengthKHR:
+ case SpvOpCooperativeMatrixMulAddKHR:
+ vtn_handle_cooperative_instruction(b, opcode, w, count);
+ break;
+
default:
vtn_fail_with_opcode("Unhandled opcode", opcode);
}
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
+ if (glsl_type_is_cmat(dest_type)) {
+ vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
+ return;
+ }
+
vtn_handle_no_contraction(b, dest_val);
bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
*/
struct vtn_type *type = vtn_get_type(b, w[1]);
+ if (type->base_type == vtn_base_type_cooperative_matrix) {
+ vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
+ return;
+ }
+
struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
vtn_fail_if(src->num_components * src->bit_size !=
--- /dev/null
+/*
+ * Copyright 2023 Intel Corporation
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "glsl_types.h"
+#include "nir.h"
+#include "nir_types.h"
+#include "vtn_private.h"
+
+static enum glsl_cmat_use
+vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
+{
+ switch (use) {
+ case SpvCooperativeMatrixUseMatrixAKHR:
+ return GLSL_CMAT_USE_A;
+ case SpvCooperativeMatrixUseMatrixBKHR:
+ return GLSL_CMAT_USE_B;
+ case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
+ return GLSL_CMAT_USE_ACCUMULATOR;
+ default:
+ unreachable("Unexpected cooperative matrix use");
+ }
+}
+
+void
+vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
+ SpvOp opcode, const uint32_t *w, unsigned count)
+{
+ vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
+
+ struct vtn_type *component_type = vtn_get_type(b, w[2]);
+
+ const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
+ const uint32_t rows = vtn_constant_uint(b, w[4]);
+ const uint32_t cols = vtn_constant_uint(b, w[5]);
+
+ vtn_assert(rows < 256);
+ vtn_assert(cols < 256);
+
+ enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
+
+ val->type->base_type = vtn_base_type_cooperative_matrix;
+ vtn_fail_if(!glsl_type_is_numeric(component_type->type),
+ "OpTypeCooperativeMatrixKHR "
+ "Component Type must be a scalar numerical type.");
+
+ val->type->desc.element_type = glsl_get_base_type(component_type->type);
+ val->type->desc.scope = scope;
+ val->type->desc.rows = rows;
+ val->type->desc.cols = cols;
+ val->type->desc.use = use;
+
+ val->type->type = glsl_cmat_type(&val->type->desc);
+ val->type->component_type = component_type;
+}
+
+static enum glsl_matrix_layout
+vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
+{
+ switch (layout) {
+ case SpvCooperativeMatrixLayoutRowMajorKHR:
+ return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
+ case SpvCooperativeMatrixLayoutColumnMajorKHR:
+ return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
+ default:
+ unreachable("Unexpected cooperative matrix layout");
+ }
+}
+
+nir_deref_instr *
+vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
+{
+ nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
+ return nir_build_deref_var(&b->nb, var);
+}
+
+static nir_deref_instr *
+vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
+{
+ nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
+ vtn_assert(glsl_type_is_cmat(deref->type));
+ return deref;
+}
+
+void
+vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
+ const uint32_t *w, unsigned count)
+{
+ switch (opcode) {
+ case SpvOpCooperativeMatrixLoadKHR: {
+ struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
+ struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
+ struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+
+ const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
+ nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
+
+ SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
+ if (count > 6) {
+ unsigned idx = 6, alignment;
+ SpvScope scope;
+ vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
+ vtn_emit_make_visible_barrier(b, access, scope, src->mode);
+ }
+
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
+ nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
+ .matrix_layout = vtn_matrix_layout_to_glsl(layout));
+ vtn_push_var_ssa(b, w[2], dst->var);
+ break;
+ }
+
+ case SpvOpCooperativeMatrixStoreKHR: {
+ struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
+ struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
+
+ const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
+ nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
+
+ SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
+ if (count > 5) {
+ unsigned idx = 5, alignment;
+ SpvScope scope;
+ vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
+ vtn_emit_make_available_barrier(b, access, scope, dest->mode);
+ }
+
+ nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
+ nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
+ .matrix_layout = vtn_matrix_layout_to_glsl(layout));
+ break;
+ }
+
+ case SpvOpCooperativeMatrixLengthKHR: {
+ struct vtn_type *type = vtn_get_type(b, w[3]);
+ nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
+ vtn_push_nir_ssa(b, w[2], def);
+ break;
+ }
+
+ case SpvOpCooperativeMatrixMulAddKHR: {
+ nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
+ nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
+ nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
+
+ const uint32_t operands = count > 6 ? w[6] : 0;
+ const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
+ const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
+ SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
+ SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
+ SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
+
+ STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
+ STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
+ STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
+ STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
+
+ struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
+
+ nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
+ .saturate = saturate,
+ .cmat_signed_mask = signed_mask);
+
+ vtn_push_var_ssa(b, w[2], dst->var);
+ break;
+ }
+
+ case SpvOpBitcast: {
+ struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+ vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
+ nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
+
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
+ nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
+ vtn_push_var_ssa(b, w[2], dst->var);
+ break;
+ }
+
+ default:
+ unreachable("Unexpected opcode for cooperative matrix instruction");
+ }
+}
+
+void
+vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
+ const struct glsl_type *dest_type, SpvOp opcode,
+ const uint32_t *w, unsigned count)
+{
+ vtn_assert(glsl_type_is_cmat(dest_type));
+
+ switch (opcode) {
+ case SpvOpConvertFToU:
+ case SpvOpConvertFToS:
+ case SpvOpConvertSToF:
+ case SpvOpConvertUToF:
+ case SpvOpUConvert:
+ case SpvOpSConvert:
+ case SpvOpFConvert:
+ case SpvOpFNegate:
+ case SpvOpSNegate: {
+ struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+ nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
+
+ unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
+ unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
+
+ bool ignored = false;
+ nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
+ src_bit_size, dst_bit_size);
+
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
+ nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
+ .alu_op = op);
+ vtn_push_var_ssa(b, w[2], dst->var);
+ break;
+ }
+
+ case SpvOpFAdd:
+ case SpvOpFSub:
+ case SpvOpFMul:
+ case SpvOpFDiv:
+ case SpvOpIAdd:
+ case SpvOpISub:
+ case SpvOpIMul:
+ case SpvOpSDiv:
+ case SpvOpUDiv: {
+ bool ignored = false;
+ nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
+
+ struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+ nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
+ nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
+
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
+ nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
+ .alu_op = op);
+ vtn_push_var_ssa(b, w[2], dst->var);
+ break;
+ }
+
+ case SpvOpMatrixTimesScalar: {
+ struct vtn_type *dst_type = vtn_get_type(b, w[1]);
+ nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
+
+ struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
+ vtn_assert(glsl_type_is_scalar(scalar_val->type));
+ nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
+
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
+ nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
+ .alu_op = op);
+ vtn_push_var_ssa(b, w[2], dst->var);
+ break;
+ }
+
+ default:
+ unreachable("invalid cooperative matrix alu instruction");
+ }
+}
+
+struct vtn_ssa_value *
+vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
+ const uint32_t *indices, unsigned num_indices)
+{
+ vtn_assert(glsl_type_is_cmat(mat->type));
+ nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
+
+ vtn_assert(num_indices == 1);
+ nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
+
+ const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
+ struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
+ ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
+ return ret;
+}
+
+struct vtn_ssa_value *
+vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
+ struct vtn_ssa_value *insert, const uint32_t *indices,
+ unsigned num_indices)
+{
+ vtn_assert(glsl_type_is_cmat(mat->type));
+ nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
+
+ vtn_assert(num_indices == 1);
+ nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
+
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
+ nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
+
+ struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
+ vtn_set_ssa_value_var(b, ret, dst->var);
+ return ret;
+}
vtn_base_type_ray_query,
vtn_base_type_function,
vtn_base_type_event,
+ vtn_base_type_cooperative_matrix,
};
struct vtn_type {
/* Return type for functions */
struct vtn_type *return_type;
};
+
+ /* Members for cooperative matrix types. */
+ struct {
+ struct glsl_cmat_description desc;
+ struct vtn_type *component_type;
+ };
};
};
void vtn_emit_make_available_barrier(struct vtn_builder *b, SpvMemoryAccessMask access,
SpvScope scope, enum vtn_variable_mode mode);
+
+void vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
+ SpvOp opcode, const uint32_t *w, unsigned count);
+void vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
+ const uint32_t *w, unsigned count);
+void vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
+ const struct glsl_type *dest_type, SpvOp opcode,
+ const uint32_t *w, unsigned count);
+struct vtn_ssa_value *vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
+ const uint32_t *indices, unsigned num_indices);
+struct vtn_ssa_value *vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
+ struct vtn_ssa_value *insert,
+ const uint32_t *indices, unsigned num_indices);
+nir_deref_instr *vtn_create_cmat_temporary(struct vtn_builder *b,
+ const struct glsl_type *t, const char *name);
+
#endif /* _VTN_PRIVATE_H_ */
nir_def *arr_index =
vtn_access_link_as_ssa(b, deref_chain->link[idx], 1,
tail->def.bit_size);
+ if (type->base_type == vtn_base_type_cooperative_matrix) {
+ const struct glsl_type *element_type = glsl_get_cmat_element(type->type);
+ tail = nir_build_deref_cast(&b->nb, &tail->def, tail->modes,
+ glsl_array_type(element_type, 0, 0), 0);
+ type = type->component_type;
+ } else {
+ type = type->array_element;
+ }
tail = nir_build_deref_array(&b->nb, tail, arr_index);
- type = type->array_element;
}
tail->arr.in_bounds = deref_chain->in_bounds;
struct vtn_ssa_value *inout,
enum gl_access_qualifier access)
{
- if (glsl_type_is_vector_or_scalar(deref->type)) {
+ if (glsl_type_is_cmat(deref->type)) {
+ if (load) {
+ nir_deref_instr *temp = vtn_create_cmat_temporary(b, deref->type, "cmat_ssa");
+ nir_cmat_copy(&b->nb, &temp->def, &deref->def);
+ vtn_set_ssa_value_var(b, inout, temp->var);
+ } else {
+ nir_deref_instr *src_deref = vtn_get_deref_for_ssa_value(b, inout);
+ nir_cmat_copy(&b->nb, &deref->def, &src_deref->def);
+ }
+ } else if (glsl_type_is_vector_or_scalar(deref->type)) {
if (load) {
inout->def = nir_load_deref_with_access(&b->nb, deref, access);
} else {
nir_deref_instr *parent =
nir_instr_as_deref(deref->parent.ssa->parent_instr);
- if (glsl_type_is_vector(parent->type))
+ if (parent->deref_type == nir_deref_type_cast &&
+ parent->parent.ssa->parent_instr->type == nir_instr_type_deref) {
+ nir_deref_instr *grandparent =
+ nir_instr_as_deref(parent->parent.ssa->parent_instr);
+
+ if (glsl_type_is_cmat(grandparent->type))
+ return grandparent;
+ }
+
+ if (glsl_type_is_vector(parent->type) ||
+ glsl_type_is_cmat(parent->type))
return parent;
else
return deref;
if (src_tail != src) {
val->type = src->type;
- val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
+
+ if (glsl_type_is_cmat(src_tail->type)) {
+ assert(val->is_variable);
+ nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
+
+ /* Reset is_variable because we are repurposing val. */
+ val->is_variable = false;
+ val->def = nir_cmat_extract(&b->nb,
+ glsl_get_bit_size(src->type),
+ &mat->def, src->arr.index.ssa);
+ } else {
+ val->def = nir_vector_extract(&b->nb, val->def, src->arr.index.ssa);
+ }
}
return val;
struct vtn_ssa_value *val = vtn_create_ssa_value(b, dest_tail->type);
_vtn_local_load_store(b, true, dest_tail, val, access);
- val->def = nir_vector_insert(&b->nb, val->def, src->def,
- dest->arr.index.ssa);
+ if (glsl_type_is_cmat(dest_tail->type)) {
+ nir_deref_instr *mat = vtn_get_deref_for_ssa_value(b, val);
+ nir_deref_instr *dst = vtn_create_cmat_temporary(b, dest_tail->type, "cmat_insert");
+ nir_cmat_insert(&b->nb, &dst->def, src->def, &mat->def, dest->arr.index.ssa);
+ vtn_set_ssa_value_var(b, val, dst->var);
+ } else {
+ val->def = nir_vector_insert(&b->nb, val->def, src->def,
+ dest->arr.index.ssa);
+ }
+
_vtn_local_load_store(b, false, dest_tail, val, access);
} else {
_vtn_local_load_store(b, false, dest_tail, src, access);
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_DOUBLE:
+ case GLSL_TYPE_COOPERATIVE_MATRIX:
if (glsl_type_is_vector_or_scalar(ptr->type->type)) {
/* We hit a vector or scalar; go ahead and emit the load[s] */
nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);