Small Matrix: skylakex: sgemm nn: add n6 to improve performance
authorWangyang Guo <wangyang.guo@intel.com>
Thu, 13 May 2021 10:16:54 +0000 (10:16 +0000)
committerWangyang Guo <wangyang.guo@intel.com>
Mon, 2 Aug 2021 07:06:54 +0000 (07:06 +0000)
kernel/x86_64/sgemm_small_kernel_nn_skylakex.c

index c9f43f9..a675411 100644 (file)
@@ -110,6 +110,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
        BLASLONG m4 = M & ~3;
        BLASLONG m2 = M & ~1;
 
+       BLASLONG n6 = N - (N % 6);
        BLASLONG n4 = N & ~3;
        BLASLONG n2 = N & ~1;
 
@@ -165,7 +166,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
                }
        }
        for (; i < m32; i += 32) {
-               for (j = 0; j < n4; j += 4) {
+               for (j = 0; j < n6; j += 6) {
+                       DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
+                       DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
+                       DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
+                       DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3);
+                       DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4);
+                       DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5);
+                       for (k = 0; k < K; k++) {
+                               LOAD_A_512(0, x); LOAD_A_512(1, x);
+                               BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
+                               BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
+                               BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5);
+
+                               MATMUL_512(0, 0); MATMUL_512(1, 0);
+                               MATMUL_512(0, 1); MATMUL_512(1, 1);
+                               MATMUL_512(0, 2); MATMUL_512(1, 2);
+                               MATMUL_512(0, 3); MATMUL_512(1, 3);
+                               MATMUL_512(0, 4); MATMUL_512(1, 4);
+                               MATMUL_512(0, 5); MATMUL_512(1, 5);
+                       }
+                       STORE_512(0, 0); STORE_512(1, 0);
+                       STORE_512(0, 1); STORE_512(1, 1);
+                       STORE_512(0, 2); STORE_512(1, 2);
+                       STORE_512(0, 3); STORE_512(1, 3);
+                       STORE_512(0, 4); STORE_512(1, 4);
+                       STORE_512(0, 5); STORE_512(1, 5);
+               }
+               for (;j < n4; j += 4) {
                        DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
                        DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
                        DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
@@ -208,7 +236,34 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
                }
        }
        for (; i < m16; i += 16) {
-               for (j = 0; j < n4; j += 4) {
+               for (j = 0; j < n6; j += 6) {
+                       DECLARE_RESULT_512(0, 0);
+                       DECLARE_RESULT_512(0, 1);
+                       DECLARE_RESULT_512(0, 2);
+                       DECLARE_RESULT_512(0, 3);
+                       DECLARE_RESULT_512(0, 4);
+                       DECLARE_RESULT_512(0, 5);
+                       for (k = 0; k < K; k++) {
+                               LOAD_A_512(0, x);
+                               BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
+                               BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
+                               BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5);
+
+                               MATMUL_512(0, 0);
+                               MATMUL_512(0, 1);
+                               MATMUL_512(0, 2);
+                               MATMUL_512(0, 3);
+                               MATMUL_512(0, 4);
+                               MATMUL_512(0, 5);
+                       }
+                       STORE_512(0, 0);
+                       STORE_512(0, 1);
+                       STORE_512(0, 2);
+                       STORE_512(0, 3);
+                       STORE_512(0, 4);
+                       STORE_512(0, 5);
+               }
+               for (; j < n4; j += 4) {
                        DECLARE_RESULT_512(0, 0);
                        DECLARE_RESULT_512(0, 1);
                        DECLARE_RESULT_512(0, 2);
@@ -228,6 +283,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
                        STORE_512(0, 2);
                        STORE_512(0, 3);
                }
+
                for (; j < n2; j += 2) {
                        DECLARE_RESULT_512(0, 0);
                        DECLARE_RESULT_512(0, 1);
@@ -254,26 +310,54 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
        if (!mm) return 0;
        if (mm > 8 || K < 32) {
                register __mmask16 mask asm("k1") = (1UL << mm) - 1;
-               for (j = 0; j < n4; j += 4) {
+               for (j = 0; j < n6; j += 6) {
                        DECLARE_RESULT_512(0, 0);
                        DECLARE_RESULT_512(0, 1);
                        DECLARE_RESULT_512(0, 2);
                        DECLARE_RESULT_512(0, 3);
+                       DECLARE_RESULT_512(0, 4);
+                       DECLARE_RESULT_512(0, 5);
                        for (k = 0; k < K; k++) {
                                MASK_LOAD_A_512(0, x);
                                BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
                                BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
+                               BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5);
 
                                MATMUL_512(0, 0);
                                MATMUL_512(0, 1);
                                MATMUL_512(0, 2);
                                MATMUL_512(0, 3);
+                               MATMUL_512(0, 4);
+                               MATMUL_512(0, 5);
                        }
                        MASK_STORE_512(0, 0);
                        MASK_STORE_512(0, 1);
                        MASK_STORE_512(0, 2);
                        MASK_STORE_512(0, 3);
+                       MASK_STORE_512(0, 4);
+                       MASK_STORE_512(0, 5);
                }
+               for (; j < n4; j += 4) {
+                       DECLARE_RESULT_512(0, 0);
+                       DECLARE_RESULT_512(0, 1);
+                       DECLARE_RESULT_512(0, 2);
+                       DECLARE_RESULT_512(0, 3);
+                       for (k = 0; k < K; k++) {
+                               MASK_LOAD_A_512(0, x);
+                               BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
+                               BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
+
+                               MATMUL_512(0, 0);
+                               MATMUL_512(0, 1);
+                               MATMUL_512(0, 2);
+                               MATMUL_512(0, 3);
+                       }
+                       MASK_STORE_512(0, 0);
+                       MASK_STORE_512(0, 1);
+                       MASK_STORE_512(0, 2);
+                       MASK_STORE_512(0, 3);
+               }
+
                for (; j < n2; j += 2) {
                        DECLARE_RESULT_512(0, 0);
                        DECLARE_RESULT_512(0, 1);