microsoft/compiler: Enable packed dot product intrinsics for SM6.4+
authorJesse Natalie <jenatali@microsoft.com>
Thu, 11 May 2023 02:11:59 +0000 (19:11 -0700)
committerMarge Bot <emma+marge@anholt.net>
Thu, 11 May 2023 21:56:31 +0000 (21:56 +0000)
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22952>

src/microsoft/compiler/dxil_function.c
src/microsoft/compiler/dxil_nir.h
src/microsoft/compiler/dxil_nir_algebraic.py
src/microsoft/compiler/nir_to_dxil.c

index f4704ad..f72ef1f 100644 (file)
@@ -110,6 +110,7 @@ static struct  predefined_func_descr predefined_funcs[] = {
 {"dx.op.wavePrefixOp", "O", "iOcc", DXIL_ATTR_KIND_NO_UNWIND},
 {"dx.op.quadReadLaneAt", "O", "iOi", DXIL_ATTR_KIND_NO_UNWIND},
 {"dx.op.quadOp", "O", "iOc", DXIL_ATTR_KIND_NO_UNWIND},
+{"dx.op.dot4AddPacked", "i", "iiii", DXIL_ATTR_KIND_READ_NONE},
 };
 
 struct func_descr {
index ff404af..9c9d340 100644 (file)
@@ -34,7 +34,7 @@ extern "C" {
 
 bool dxil_nir_lower_8bit_conv(nir_shader *shader);
 bool dxil_nir_lower_16bit_conv(nir_shader *shader);
-bool dxil_nir_lower_x2b(nir_shader *shader);
+bool dxil_nir_algebraic(nir_shader *shader);
 bool dxil_nir_lower_fquantize2f16(nir_shader *shader);
 bool dxil_nir_lower_ubo_to_temp(nir_shader *shader);
 struct dxil_nir_lower_loads_stores_options {
index 9fd9ca5..868f799 100644 (file)
@@ -29,6 +29,8 @@ import sys
 import math
 
 a = 'a'
+b = 'b'
+c = 'c'
 
 # The nir_lower_bit_size() pass gets rid of all 8bit ALUs but insert new u2u8
 # and i2i8 operations to convert the result back to the original type after the
@@ -91,9 +93,13 @@ def remove_unsupported_casts(arr, bit_size, mask, max_unsigned_float, min_signed
 remove_unsupported_casts(no_8bit_conv, 8, 0xff, 255.0, -128.0, 127.0)
 remove_unsupported_casts(no_16bit_conv, 16, 0xffff, 65535.0, -32768.0, 32767.0)
 
-lower_x2b = [
+algebraic_ops = [
   (('b2b32', 'a'), ('b2i32', 'a')),
   (('b2b1', 'a'), ('ine', ('b2i32', a), 0)),
+
+  # We don't support the saturating versions of these
+  (('sdot_4x8_iadd_sat', a, b, c), ('iadd_sat', ('sdot_4x8_iadd', a, b, 0), c)),
+  (('udot_4x8_uadd_sat', a, b, c), ('uadd_sat', ('udot_4x8_uadd', a, b, 0), c)),
 ]
 
 no_16bit_conv += [
@@ -118,8 +124,8 @@ def run():
                                       no_8bit_conv).render())
     print(nir_algebraic.AlgebraicPass("dxil_nir_lower_16bit_conv",
                                       no_16bit_conv).render())
-    print(nir_algebraic.AlgebraicPass("dxil_nir_lower_x2b",
-                                      lower_x2b).render())
+    print(nir_algebraic.AlgebraicPass("dxil_nir_algebraic",
+                                      algebraic_ops).render())
 
 if __name__ == '__main__':
     main()
index 8b8e718..4f902d2 100644 (file)
@@ -179,6 +179,10 @@ dxil_get_nir_compiler_options(nir_shader_compiler_options *options,
       options->lower_doubles_options = ~0;
    if ((supported_int_sizes & 16) && (supported_float_sizes & 16))
       options->support_16bit_alu = true;
+   if (shader_model_max >= SHADER_MODEL_6_4) {
+      options->has_sdot_4x8 = true;
+      options->has_udot_4x8 = true;
+   }
 }
 
 static bool
@@ -373,6 +377,9 @@ enum dxil_intr {
    DXIL_INTR_RAW_BUFFER_LOAD = 139,
    DXIL_INTR_RAW_BUFFER_STORE = 140,
 
+   DXIL_INTR_DOT4_ADD_I8_PACKED = 163,
+   DXIL_INTR_DOT4_ADD_U8_PACKED = 164,
+
    DXIL_INTR_ANNOTATE_HANDLE = 216,
    DXIL_INTR_CREATE_HANDLE_FROM_BINDING = 217,
    DXIL_INTR_CREATE_HANDLE_FROM_HEAP = 218,
@@ -2499,6 +2506,25 @@ emit_bitfield_insert(struct ntd_context *ctx, nir_alu_instr *alu,
    return true;
 }
 
+static bool
+emit_dot4add_packed(struct ntd_context *ctx, nir_alu_instr *alu,
+                    enum dxil_intr intr,
+                    const struct dxil_value *src0,
+                    const struct dxil_value *src1,
+                    const struct dxil_value *accum)
+{
+   const struct dxil_func *f = dxil_get_function(&ctx->mod, "dx.op.dot4AddPacked", DXIL_I32);
+   if (!f)
+      return false;
+   const struct dxil_value *srcs[] = { dxil_module_get_int32_const(&ctx->mod, intr), accum, src0, src1 };
+   const struct dxil_value *v = dxil_emit_call(&ctx->mod, f, srcs, ARRAY_SIZE(srcs));
+   if (!v)
+      return false;
+
+   store_alu_dest(ctx, alu, 0, v);
+   return true;
+}
+
 static bool emit_select(struct ntd_context *ctx, nir_alu_instr *alu,
                         const struct dxil_value *sel,
                         const struct dxil_value *val_true,
@@ -2868,6 +2894,9 @@ emit_alu(struct ntd_context *ctx, nir_alu_instr *alu)
    case nir_op_unpack_half_2x16_split_y: return emit_f16tof32(ctx, alu, src[0], true);
    case nir_op_pack_half_2x16_split: return emit_f32tof16(ctx, alu, src[0], src[1]);
 
+   case nir_op_sdot_4x8_iadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_I8_PACKED, src[0], src[1], src[2]);
+   case nir_op_udot_4x8_uadd: return emit_dot4add_packed(ctx, alu, DXIL_INTR_DOT4_ADD_U8_PACKED, src[0], src[1], src[2]);
+
    case nir_op_i2i1:
    case nir_op_u2u1:
    case nir_op_b2i16:
@@ -6454,7 +6483,7 @@ optimize_nir(struct nir_shader *s, const struct nir_to_dxil_options *opts)
       NIR_PASS(progress, s, nir_opt_cse);
       NIR_PASS(progress, s, nir_opt_peephole_select, 8, true, true);
       NIR_PASS(progress, s, nir_opt_algebraic);
-      NIR_PASS(progress, s, dxil_nir_lower_x2b);
+      NIR_PASS(progress, s, dxil_nir_algebraic);
       if (s->options->lower_int64_options)
          NIR_PASS(progress, s, nir_lower_int64);
       NIR_PASS(progress, s, nir_lower_alu);