Optimize M < 16 using AVX512 mask
authorWangyang Guo <wangyang.guo@intel.com>
Sat, 8 May 2021 15:59:14 +0000 (15:59 +0000)
committerWangyang Guo <wangyang.guo@intel.com>
Mon, 2 Aug 2021 07:06:54 +0000 (07:06 +0000)
kernel/x86_64/sgemm_small_kernel_nn_skylakex.c

index f2c7987..f0b6d63 100644 (file)
@@ -31,17 +31,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps()
 #define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)])
+#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)])
 #define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)]))
 #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N)
 #if defined(B0)
 #define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
                        _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
+#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
+                       _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
 #else
 #define STORE_512(M, N) \
        BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \
        result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
        asm("vfmadd231ps (%1, %2, 4), %3, %0": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512)); \
        _mm512_storeu_ps(&C[offset##M##N], result##M##N)
+#define MASK_STORE_512(M, N) \
+       BLASLONG offset##M##N = (j+N)*ldc + i + (M*16); \
+       result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
+       asm("vfmadd231ps (%1, %2, 4), %3, %0 %{%4%}": "+v"(result##M##N):"r"(&C), "r"(offset##M##N), "v"(beta_512), "k"(mask)); \
+       _mm512_mask_storeu_ps(&C[offset##M##N], mask, result##M##N)
 #endif
 
 #define DECLARE_RESULT_256(M, N) __m256 result##M##N = _mm256_setzero_ps()
@@ -241,6 +249,51 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
                        STORE_512(0, 0);
                }
        }
+       if (M - i > 0) {
+               register __mmask16 mask asm("k1") = (1UL << (M - i)) - 1;
+               for (j = 0; j < n4; j += 4) {
+                       DECLARE_RESULT_512(0, 0);
+                       DECLARE_RESULT_512(0, 1);
+                       DECLARE_RESULT_512(0, 2);
+                       DECLARE_RESULT_512(0, 3);
+                       for (k = 0; k < K; k++) {
+                               MASK_LOAD_A_512(0, x);
+                               BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
+                               BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3);
+
+                               MATMUL_512(0, 0);
+                               MATMUL_512(0, 1);
+                               MATMUL_512(0, 2);
+                               MATMUL_512(0, 3);
+                       }
+                       MASK_STORE_512(0, 0);
+                       MASK_STORE_512(0, 1);
+                       MASK_STORE_512(0, 2);
+                       MASK_STORE_512(0, 3);
+               }
+               for (; j < n2; j += 2) {
+                       DECLARE_RESULT_512(0, 0);
+                       DECLARE_RESULT_512(0, 1);
+                       for (k = 0; k < K; k++) {
+                               MASK_LOAD_A_512(0, x);
+                               BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1);
+                               MATMUL_512(0, 0);
+                               MATMUL_512(0, 1);
+                       }
+                       MASK_STORE_512(0, 0);
+                       MASK_STORE_512(0, 1);
+               }
+               for (; j < N; j++) {
+                       DECLARE_RESULT_512(0, 0);
+                       for (k = 0; k < K; k++) {
+                               MASK_LOAD_A_512(0, x);
+                               BROADCAST_LOAD_B_512(x, 0);
+                               MATMUL_512(0, 0);
+                       }
+                       MASK_STORE_512(0, 0);
+               }
+               return;
+       }
        __m256 alpha_256 = _mm256_broadcastss_ps(_mm_load_ss(&alpha));
 #if !defined(B0)
        __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta));