spirv: use sdot_2x16 and udot_2x16 opcodes
authorRhys Perry <pendingchaos02@gmail.com>
Mon, 30 Aug 2021 12:56:17 +0000 (13:56 +0100)
committerMarge Bot <eric+marge@anholt.net>
Fri, 3 Sep 2021 13:21:27 +0000 (13:21 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12617>

src/compiler/spirv/vtn_alu.c

index ed73118..443878c 100644 (file)
@@ -892,6 +892,7 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
                   spirv_op_to_string(opcode));
    }
 
+   unsigned packed_bit_size = 8;
    if (glsl_type_is_vector(vtn_src[0]->type)) {
       /* FINISHME: Is this actually as good or better for platforms that don't
        * have the special instructions (i.e., one or both of has_dot_4x8 or
@@ -902,6 +903,14 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
           glsl_get_bit_size(dest_type) <= 32) {
          src[0] = nir_pack_32_4x8(&b->nb, src[0]);
          src[1] = nir_pack_32_4x8(&b->nb, src[1]);
+      } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
+                 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
+                 glsl_get_bit_size(dest_type) <= 32 &&
+                 opcode != SpvOpSUDotKHR &&
+                 opcode != SpvOpSUDotAccSatKHR) {
+         src[0] = nir_pack_32_2x16(&b->nb, src[0]);
+         src[1] = nir_pack_32_2x16(&b->nb, src[1]);
+         packed_bit_size = 16;
       }
    } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
               glsl_type_is_32bit(vtn_src[0]->type)) {
@@ -1012,53 +1021,64 @@ vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
       assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
 
       nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
-      bool is_signed;
-
-      switch (opcode) {
-      case SpvOpSDotKHR:
-         dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
-         is_signed = true;
-         break;
-
-      case SpvOpUDotKHR:
-         dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
-         is_signed = false;
-         break;
-
-      case SpvOpSUDotKHR:
-         dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
-         is_signed = true;
-         break;
-
-      case SpvOpSDotAccSatKHR:
-         if (dest_size == 32)
-            dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
-         else
+      bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
+                       opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
+
+      if (packed_bit_size == 16) {
+         switch (opcode) {
+         case SpvOpSDotKHR:
+            dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
+            break;
+         case SpvOpUDotKHR:
+            dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
+            break;
+         case SpvOpSDotAccSatKHR:
+            if (dest_size == 32)
+               dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
+            else
+               dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
+            break;
+         case SpvOpUDotAccSatKHR:
+            if (dest_size == 32)
+               dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
+            else
+               dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
+            break;
+         default:
+            unreachable("Invalid opcode.");
+         }
+      } else {
+         switch (opcode) {
+         case SpvOpSDotKHR:
             dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
-
-         is_signed = true;
-         break;
-
-      case SpvOpUDotAccSatKHR:
-         if (dest_size == 32)
-            dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
-         else
+            break;
+         case SpvOpUDotKHR:
             dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
-
-         is_signed = false;
-         break;
-
-      case SpvOpSUDotAccSatKHR:
-         if (dest_size == 32)
-            dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
-         else
+            break;
+         case SpvOpSUDotKHR:
             dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
-
-         is_signed = true;
-         break;
-
-      default:
-         unreachable("Invalid opcode.");
+            break;
+         case SpvOpSDotAccSatKHR:
+            if (dest_size == 32)
+               dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
+            else
+               dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
+            break;
+         case SpvOpUDotAccSatKHR:
+            if (dest_size == 32)
+               dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
+            else
+               dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
+            break;
+         case SpvOpSUDotAccSatKHR:
+            if (dest_size == 32)
+               dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
+            else
+               dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
+            break;
+         default:
+            unreachable("Invalid opcode.");
+         }
       }
 
       if (dest_size != 32) {