nir: Add new intrinsics for Cooperative Matrix
authorCaio Oliveira <caio.oliveira@intel.com>
Tue, 8 Aug 2023 18:02:14 +0000 (11:02 -0700)
committerMarge Bot <emma+marge@anholt.net>
Thu, 28 Sep 2023 07:35:02 +0000 (07:35 +0000)
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23825>

src/compiler/nir/nir.h
src/compiler/nir/nir_intrinsics.py
src/compiler/nir/nir_print.c

index 71ece84..a21fe8c 100644 (file)
@@ -269,6 +269,17 @@ typedef enum {
    nir_resource_intel_non_uniform = 1u << 3,
 } nir_resource_data_intel;
 
+/**
+ * Which components to interpret as signed in cmat_muladd.
+ * See 'Cooperative Matrix Operands' in SPV_KHR_cooperative_matrix.
+ */
+typedef enum {
+   NIR_CMAT_A_SIGNED = 1u << 0,
+   NIR_CMAT_B_SIGNED = 1u << 1,
+   NIR_CMAT_C_SIGNED = 1u << 2,
+   NIR_CMAT_RESULT_SIGNED = 1u << 3,
+} nir_cmat_signed;
+
 typedef union {
    bool b;
    float f32;
index 65ad47a..50b25ed 100644 (file)
@@ -312,6 +312,12 @@ index("bool", "legacy_fneg")
 # On a register store, floating-point saturate the stored value.
 index("bool", "legacy_fsat")
 
+# For Cooperative Matrix intrinsics.
+index("struct glsl_cmat_description", "cmat_desc")
+index("enum glsl_matrix_layout", "matrix_layout")
+index("nir_cmat_signed", "cmat_signed_mask")
+index("nir_op", "alu_op")
+
 intrinsic("nop", flags=[CAN_ELIMINATE])
 
 intrinsic("convert_alu_types", dest_comp=0, src_comp=[0],
@@ -1196,6 +1202,29 @@ system_value("flat_mask", 1)
 # Whether provoking vertex mode is last
 system_value("provoking_last", 1)
 
+# SPV_KHR_cooperative_matrix.
+#
+# Cooperative matrices are referred through derefs to variables,
+# the destination of the operations appears as the first source,
+# ordering follows SPIR-V operation.
+#
+# Load/Store include an extra source for stride, since that
+# can be a _dynamically_ uniform value.
+#
+# Length takes a type not a value, that's encoded as a MATRIX_DESC.
+intrinsic("cmat_construct", src_comp=[-1, 1])
+intrinsic("cmat_load", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
+intrinsic("cmat_store", src_comp=[-1, -1, 1], indices=[MATRIX_LAYOUT])
+intrinsic("cmat_length", src_comp=[], dest_comp=1, indices=[CMAT_DESC], bit_sizes=[32])
+intrinsic("cmat_muladd", src_comp=[-1, -1, -1, -1], indices=[SATURATE, CMAT_SIGNED_MASK])
+intrinsic("cmat_unary_op", src_comp=[-1, -1], indices=[ALU_OP])
+intrinsic("cmat_binary_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
+intrinsic("cmat_scalar_op", src_comp=[-1, -1, -1], indices=[ALU_OP])
+intrinsic("cmat_bitcast", src_comp=[-1, -1])
+intrinsic("cmat_extract", src_comp=[-1, 1], dest_comp=1)
+intrinsic("cmat_insert", src_comp=[-1, 1, -1, 1])
+intrinsic("cmat_copy", src_comp=[-1, -1])
+
 # IR3-specific version of most SSBO intrinsics. The only different
 # compare to the originals is that they add an extra source to hold
 # the dword-offset, which is needed by the backend code apart from
index 48fa157..e87e8b5 100644 (file)
@@ -1523,6 +1523,64 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
          break;
       }
 
+      case NIR_INTRINSIC_MATRIX_LAYOUT: {
+         fprintf(fp, "matrix_layout=");
+         switch (nir_intrinsic_matrix_layout(instr)) {
+         case GLSL_MATRIX_LAYOUT_ROW_MAJOR:
+            fprintf(fp, "row_major");
+            break;
+         case GLSL_MATRIX_LAYOUT_COLUMN_MAJOR:
+            fprintf(fp, "col_major");
+            break;
+         default:
+            fprintf(fp, "unknown");
+            break;
+         }
+         break;
+      }
+
+      case NIR_INTRINSIC_CMAT_DESC: {
+         struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(instr);
+         const struct glsl_type *t = glsl_cmat_type(&desc);
+         fprintf(fp, "%s", glsl_get_type_name(t));
+         break;
+      }
+
+      case NIR_INTRINSIC_CMAT_SIGNED_MASK: {
+         fprintf(fp, "cmat_signed=");
+         unsigned int mask = nir_intrinsic_cmat_signed_mask(instr);
+         if (mask == 0)
+            fputc('0', fp);
+         while (mask) {
+            nir_cmat_signed i = 1u << u_bit_scan(&mask);
+            switch (i) {
+            case NIR_CMAT_A_SIGNED:
+               fputc('A', fp);
+               break;
+            case NIR_CMAT_B_SIGNED:
+               fputc('B', fp);
+               break;
+            case NIR_CMAT_C_SIGNED:
+               fputc('C', fp);
+               break;
+            case NIR_CMAT_RESULT_SIGNED:
+               fprintf(fp, "Result");
+               break;
+            default:
+               fprintf(fp, "unknown");
+               break;
+            }
+            fprintf(fp, "%s", mask ? "|" : "");
+         }
+         break;
+      }
+
+      case NIR_INTRINSIC_ALU_OP: {
+         nir_op alu_op = nir_intrinsic_alu_op(instr);
+         fprintf(fp, "alu_op=%s", nir_op_infos[alu_op].name);
+         break;
+      }
+
       default: {
          unsigned off = info->index_map[idx] - 1;
          fprintf(fp, "%s=%d", nir_intrinsic_index_names[idx], instr->const_index[off]);