skylake dgemm: Add a 16x8 kernel
authorArjan van de Ven <arjan@linux.intel.com>
Fri, 5 Oct 2018 11:49:43 +0000 (11:49 +0000)
committerArjan van de Ven <arjan@linux.intel.com>
Fri, 5 Oct 2018 13:11:35 +0000 (13:11 +0000)
The next step for the avx512 dgemm code is adding a 16x8 kernel.
In the 8x8 kernel, each FMA has a matching load (the broadcast);
in the 16x8 kernel we can reuse this load for 2 FMAs, which
in turn reduces pressure on the load ports of the CPU and gives
a nice performance boost (in the 25% range).

kernel/x86_64/dgemm_kernel_4x8_skylakex.c

index 8d0205c..09d48f9 100644 (file)
@@ -849,11 +849,13 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double * __restrict__ A,
 
                i = m;
 
-               while (i >= 8) {
+               while (i >= 16) {
                        double *BO;
+                       double *A1;
                        int kloop = K;
 
                        BO = B + 12;
+                       A1 = AO + 8 * K;
                        /*
                         *  This is the inner loop for the hot hot path 
                         *  Written in inline asm because compilers like GCC 8 and earlier
@@ -861,6 +863,157 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double * __restrict__ A,
                         *  the AVX512 built in broadcast ability (1to8)
                         */
                        asm(
+                       "vxorpd  %%zmm1, %%zmm1, %%zmm1\n"
+                       "vmovapd %%zmm1, %%zmm2\n"
+                       "vmovapd %%zmm1, %%zmm3\n"
+                       "vmovapd %%zmm1, %%zmm4\n"
+                       "vmovapd %%zmm1, %%zmm5\n"
+                       "vmovapd %%zmm1, %%zmm6\n"
+                       "vmovapd %%zmm1, %%zmm7\n"
+                       "vmovapd %%zmm1, %%zmm8\n"
+                       "vmovapd %%zmm1, %%zmm11\n"
+                       "vmovapd %%zmm1, %%zmm12\n"
+                       "vmovapd %%zmm1, %%zmm13\n"
+                       "vmovapd %%zmm1, %%zmm14\n"
+                       "vmovapd %%zmm1, %%zmm15\n"
+                       "vmovapd %%zmm1, %%zmm16\n"
+                       "vmovapd %%zmm1, %%zmm17\n"
+                       "vmovapd %%zmm1, %%zmm18\n"
+                       "jmp .label16\n"
+                       ".align 32\n"
+                       /* Inner math loop */
+                       ".label16:\n"
+                       "vmovupd     -128(%[AO]),%%zmm0\n"
+                       "vmovupd     -128(%[A1]),%%zmm10\n"
+
+                       "vbroadcastsd       -96(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm1\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm11\n"
+
+                       "vbroadcastsd       -88(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm2\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm12\n"
+
+                       "vbroadcastsd       -80(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm3\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm13\n"
+
+                       "vbroadcastsd       -72(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm4\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm14\n"
+
+                       "vbroadcastsd       -64(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm5\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm15\n"
+
+                       "vbroadcastsd       -56(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm6\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm16\n"
+
+                       "vbroadcastsd       -48(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm7\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm17\n"
+
+                       "vbroadcastsd       -40(%[BO]),  %%zmm9\n"
+                       "vfmadd231pd    %%zmm9, %%zmm0,  %%zmm8\n"
+                       "vfmadd231pd    %%zmm9, %%zmm10, %%zmm18\n"
+                       "add $64, %[AO]\n"
+                       "add $64, %[A1]\n"
+                       "add $64, %[BO]\n"
+                       "prefetch 512(%[AO])\n"
+                       "prefetch 512(%[A1])\n"
+                       "prefetch 512(%[BO])\n"
+                       "subl $1, %[kloop]\n"
+                       "jg .label16\n"
+                       /* multiply the result by alpha */
+                       "vbroadcastsd (%[alpha]), %%zmm9\n"
+                       "vmulpd %%zmm9, %%zmm1,  %%zmm1\n"
+                       "vmulpd %%zmm9, %%zmm2,  %%zmm2\n"
+                       "vmulpd %%zmm9, %%zmm3,  %%zmm3\n"
+                       "vmulpd %%zmm9, %%zmm4,  %%zmm4\n"
+                       "vmulpd %%zmm9, %%zmm5,  %%zmm5\n"
+                       "vmulpd %%zmm9, %%zmm6,  %%zmm6\n"
+                       "vmulpd %%zmm9, %%zmm7,  %%zmm7\n"
+                       "vmulpd %%zmm9, %%zmm8,  %%zmm8\n"
+                       "vmulpd %%zmm9, %%zmm11, %%zmm11\n"
+                       "vmulpd %%zmm9, %%zmm12, %%zmm12\n"
+                       "vmulpd %%zmm9, %%zmm13, %%zmm13\n"
+                       "vmulpd %%zmm9, %%zmm14, %%zmm14\n"
+                       "vmulpd %%zmm9, %%zmm15, %%zmm15\n"
+                       "vmulpd %%zmm9, %%zmm16, %%zmm16\n"
+                       "vmulpd %%zmm9, %%zmm17, %%zmm17\n"
+                       "vmulpd %%zmm9, %%zmm18, %%zmm18\n"
+                       /* And store additively in C */
+                       "vaddpd (%[C0]), %%zmm1, %%zmm1\n"
+                       "vaddpd (%[C1]), %%zmm2, %%zmm2\n"
+                       "vaddpd (%[C2]), %%zmm3, %%zmm3\n"
+                       "vaddpd (%[C3]), %%zmm4, %%zmm4\n"
+                       "vaddpd (%[C4]), %%zmm5, %%zmm5\n"
+                       "vaddpd (%[C5]), %%zmm6, %%zmm6\n"
+                       "vaddpd (%[C6]), %%zmm7, %%zmm7\n"
+                       "vaddpd (%[C7]), %%zmm8, %%zmm8\n"
+                       "vmovupd %%zmm1, (%[C0])\n"
+                       "vmovupd %%zmm2, (%[C1])\n"
+                       "vmovupd %%zmm3, (%[C2])\n"
+                       "vmovupd %%zmm4, (%[C3])\n"
+                       "vmovupd %%zmm5, (%[C4])\n"
+                       "vmovupd %%zmm6, (%[C5])\n"
+                       "vmovupd %%zmm7, (%[C6])\n"
+                       "vmovupd %%zmm8, (%[C7])\n"
+
+                       "vaddpd 64(%[C0]), %%zmm11, %%zmm11\n"
+                       "vaddpd 64(%[C1]), %%zmm12, %%zmm12\n"
+                       "vaddpd 64(%[C2]), %%zmm13, %%zmm13\n"
+                       "vaddpd 64(%[C3]), %%zmm14, %%zmm14\n"
+                       "vaddpd 64(%[C4]), %%zmm15, %%zmm15\n"
+                       "vaddpd 64(%[C5]), %%zmm16, %%zmm16\n"
+                       "vaddpd 64(%[C6]), %%zmm17, %%zmm17\n"
+                       "vaddpd 64(%[C7]), %%zmm18, %%zmm18\n"
+                       "vmovupd %%zmm11, 64(%[C0])\n"
+                       "vmovupd %%zmm12, 64(%[C1])\n"
+                       "vmovupd %%zmm13, 64(%[C2])\n"
+                       "vmovupd %%zmm14, 64(%[C3])\n"
+                       "vmovupd %%zmm15, 64(%[C4])\n"
+                       "vmovupd %%zmm16, 64(%[C5])\n"
+                       "vmovupd %%zmm17, 64(%[C6])\n"
+                       "vmovupd %%zmm18, 64(%[C7])\n"
+
+                          :
+                               [AO]    "+r" (AO),
+                               [A1]    "+r" (A1),
+                               [BO]    "+r" (BO),
+                               [C0]    "+r" (CO1),
+                               [kloop] "+r" (kloop)
+                          :
+                               [alpha]         "r" (&alpha),
+                               [C1]    "r" (CO1 + 1 * ldc),
+                               [C2]    "r" (CO1 + 2 * ldc),
+                               [C3]    "r" (CO1 + 3 * ldc),
+                               [C4]    "r" (CO1 + 4 * ldc),
+                               [C5]    "r" (CO1 + 5 * ldc),
+                               [C6]    "r" (CO1 + 6 * ldc),
+                               [C7]    "r" (CO1 + 7 * ldc)
+
+                            :  "memory", "zmm0",  "zmm1",  "zmm2",  "zmm3",  "zmm4",  "zmm5",  "zmm6",  "zmm7",  "zmm8", "zmm9",
+                                         "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15", "zmm16", "zmm17", "zmm18"
+                       );
+                       CO1 += 16;
+                       AO += 8 * K;
+                       i-= 16;
+               }
+
+               while (i >= 8) {
+                       double *BO;
+                       int kloop = K;
+
+                       BO = B + 12;
+                       /*
+                        *  This is the inner loop for the hot hot path
+                        *  Written in inline asm because compilers like GCC 8 and earlier
+                        *  struggle with register allocation and are not good at using
+                        *  the AVX512 built in broadcast ability (1to8)
+                        */
+                       asm(
                        "vxorpd  %%zmm1, %%zmm1, %%zmm1\n" 
                        "vmovapd %%zmm1, %%zmm2\n"
                        "vmovapd %%zmm1, %%zmm3\n"