performance optimizations for sgemv_n
authorwernsaar <wernsaar@googlemail.com>
Fri, 18 Jul 2014 09:25:21 +0000 (11:25 +0200)
committerwernsaar <wernsaar@googlemail.com>
Fri, 18 Jul 2014 09:25:21 +0000 (11:25 +0200)
kernel/x86_64/sgemv_n_avx.c
kernel/x86_64/sgemv_n_microk_bulldozer.c

index 8c263543c7c6f44d3e1767a0b180074a1d3be51e..dc8d015d8cddd9e8e716ff6e091a4b390ffd630d 100644 (file)
@@ -70,12 +70,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
        n1 = n / 512 ;
        n2 = n % 512 ;
 
-       m1 = m / 32;
-       m2 = m % 32;
+       m1 = m / 64;
+       m2 = m % 64;
 
-       x_ptr = x;
-       a_ptr = a;
        y_ptr = y;
+       x_ptr = x;
 
        for (j=0; j<n1; j++)
        {
@@ -85,12 +84,19 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
                else
                        copy_x(512,x_ptr,xbuffer,inc_x);
 
-               x_ptr += 512 * inc_x;
-               a_ptr += j * 512;
+               a_ptr = a + j * 512 * lda;
                y_ptr = y;
 
-
                for(i = 0; i<m1; i++ )
+               {
+                       sgemv_kernel_64(512,alpha,a_ptr,lda,xbuffer,ybuffer);
+                       add_y(64,ybuffer,y_ptr,inc_y);
+                       y_ptr += 64 * inc_y;
+                       a_ptr += 64;                    
+
+               }
+
+               if ( m2 & 32 )
                {
                        sgemv_kernel_32(512,alpha,a_ptr,lda,xbuffer,ybuffer);
                        add_y(32,ybuffer,y_ptr,inc_y);
@@ -98,6 +104,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
                        a_ptr += 32;                    
 
                }
+
                if ( m2 & 16 )
                {
                        sgemv_kernel_16(512,alpha,a_ptr,lda,xbuffer,ybuffer);
@@ -131,6 +138,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
                        sgemv_kernel_1(512,alpha,a_ptr,lda,xbuffer,ybuffer);
                        add_y(1,ybuffer,y_ptr,inc_y);
                }
+               x_ptr += 512 * inc_x;
 
        }
 
@@ -142,9 +150,19 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
                else
                        copy_x(n2,x_ptr,xbuffer,inc_x);
 
+               a_ptr = a + n1 * 512 * lda;
                y_ptr = y;
 
                for(i = 0; i<m1; i++ )
+               {
+                       sgemv_kernel_64(n2,alpha,a_ptr,lda,xbuffer,ybuffer);
+                       add_y(64,ybuffer,y_ptr,inc_y);
+                       y_ptr += 64 * inc_y;
+                       a_ptr += 64;                    
+
+               }
+
+               if ( m2 & 32 )
                {
                        sgemv_kernel_32(n2,alpha,a_ptr,lda,xbuffer,ybuffer);
                        add_y(32,ybuffer,y_ptr,inc_y);
index 3dad4364395dfa679d5a17007754236f4d1c19e5..1cecd96c5ea9b8de96a260ed6f0ea8d11e8e15ce 100644 (file)
@@ -25,12 +25,11 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *****************************************************************************/
 
-
-static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x, float *y)
+static void sgemv_kernel_64( long n, float alpha, float *a, long lda, float *x, float *y)
 {
 
 
-       float *pre = a + lda*4*2;
+       float *pre = a + lda*3;
 
        __asm __volatile
        (
@@ -44,38 +43,143 @@ static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x,
        "prefetcht0     (%%r8)\n\t"                     // Prefetch
        "prefetcht0   64(%%r8)\n\t"                     // Prefetch
 
+       "vxorps         %%ymm8 , %%ymm8 , %%ymm8 \n\t"  // set to zero
+       "vxorps         %%ymm9 , %%ymm9 , %%ymm9 \n\t"  // set to zero
+       "vxorps         %%ymm10, %%ymm10, %%ymm10\n\t"  // set to zero
+       "vxorps         %%ymm11, %%ymm11, %%ymm11\n\t"  // set to zero
        "vxorps         %%ymm12, %%ymm12, %%ymm12\n\t"  // set to zero
        "vxorps         %%ymm13, %%ymm13, %%ymm13\n\t"  // set to zero
        "vxorps         %%ymm14, %%ymm14, %%ymm14\n\t"  // set to zero
        "vxorps         %%ymm15, %%ymm15, %%ymm15\n\t"  // set to zero
-
+       ".align 16                               \n\t"
        ".L01LOOP%=:                             \n\t"
        "vbroadcastss   (%%rdi),   %%ymm0        \n\t"  // load values of c
-        "addq          $4     ,   %%rdi         \n\t"  // increment pointer of c 
-
+       "nop                                     \n\t"
        "leaq     (%%r8 , %%rcx, 4), %%r8        \n\t"  // add lda to pointer for prefetch
+
        "prefetcht0     (%%r8)\n\t"                     // Prefetch
+       "vfmaddps %%ymm8 ,   0*4(%%rsi), %%ymm0, %%ymm8 \n\t" // multiply a and c and add to temp
        "prefetcht0   64(%%r8)\n\t"                     // Prefetch
+       "vfmaddps %%ymm9 ,   8*4(%%rsi), %%ymm0, %%ymm9 \n\t" // multiply a and c and add to temp
+       "prefetcht0   128(%%r8)\n\t"                    // Prefetch
+       "vfmaddps %%ymm10,  16*4(%%rsi), %%ymm0, %%ymm10\n\t" // multiply a and c and add to temp
+       "vfmaddps %%ymm11,  24*4(%%rsi), %%ymm0, %%ymm11\n\t" // multiply a and c and add to temp
+       "prefetcht0   192(%%r8)\n\t"                    // Prefetch
+       "vfmaddps %%ymm12,  32*4(%%rsi), %%ymm0, %%ymm12\n\t" // multiply a and c and add to temp
+       "vfmaddps %%ymm13,  40*4(%%rsi), %%ymm0, %%ymm13\n\t" // multiply a and c and add to temp
+       "vfmaddps %%ymm14,  48*4(%%rsi), %%ymm0, %%ymm14\n\t" // multiply a and c and add to temp
+       "vfmaddps %%ymm15,  56*4(%%rsi), %%ymm0, %%ymm15\n\t" // multiply a and c and add to temp
 
-       "vfmaddps %%ymm12,   0*4(%%rsi), %%ymm0, %%ymm12\n\t" // multiply a and c and add to temp
-       "vfmaddps %%ymm13,   8*4(%%rsi), %%ymm0, %%ymm13\n\t" // multiply a and c and add to temp
-       "vfmaddps %%ymm14,  16*4(%%rsi), %%ymm0, %%ymm14\n\t" // multiply a and c and add to temp
-       "vfmaddps %%ymm15,  24*4(%%rsi), %%ymm0, %%ymm15\n\t" // multiply a and c and add to temp
-
+        "addq          $4     ,   %%rdi         \n\t"  // increment pointer of c 
        "leaq     (%%rsi, %%rcx, 4), %%rsi       \n\t"  // add lda to pointer of a
 
        "dec            %%rax                    \n\t"  // n = n -1
        "jnz            .L01LOOP%=               \n\t"
 
+       "vmulps         %%ymm8 , %%ymm1,  %%ymm8 \n\t"  // scale by alpha
+       "vmulps         %%ymm9 , %%ymm1,  %%ymm9 \n\t"  // scale by alpha
+       "vmulps         %%ymm10, %%ymm1,  %%ymm10\n\t"  // scale by alpha
+       "vmulps         %%ymm11, %%ymm1,  %%ymm11\n\t"  // scale by alpha
        "vmulps         %%ymm12, %%ymm1,  %%ymm12\n\t"  // scale by alpha
        "vmulps         %%ymm13, %%ymm1,  %%ymm13\n\t"  // scale by alpha
        "vmulps         %%ymm14, %%ymm1,  %%ymm14\n\t"  // scale by alpha
        "vmulps         %%ymm15, %%ymm1,  %%ymm15\n\t"  // scale by alpha
 
-       "vmovups        %%ymm12,     (%%rdx)     \n\t"  // store temp -> y
-       "vmovups        %%ymm13,  8*4(%%rdx)     \n\t"  // store temp -> y
-       "vmovups        %%ymm14, 16*4(%%rdx)     \n\t"  // store temp -> y
-       "vmovups        %%ymm15, 24*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm8 ,     (%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm9 ,  8*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm10, 16*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm11, 24*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm12, 32*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm13, 40*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm14, 48*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%ymm15, 56*4(%%rdx)     \n\t"  // store temp -> y
+
+       :
+        :
+          "m" (n),     // 0    
+         "m" (alpha),  // 1
+         "m" (a),      // 2
+          "m" (lda),    // 3
+          "m" (x),      // 4
+          "m" (y),      // 5
+         "m" (pre)     // 6
+       : "rax", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11",
+         "xmm0" , "xmm1", 
+         "xmm8", "xmm9", "xmm10", "xmm11",
+         "xmm12", "xmm13", "xmm14", "xmm15",
+         "memory"
+       );
+
+} 
+
+
+
+static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x, float *y)
+{
+
+
+       float *pre = a + lda*3;
+
+       __asm __volatile
+       (
+       "movq           %0,      %%rax\n\t"             // n -> rax
+       "vbroadcastss   %1,      %%xmm1\n\t"            // alpha -> xmm1
+       "movq           %2,      %%rsi\n\t"             // adress of a -> rsi
+       "movq           %3,      %%rcx\n\t"             // value of lda > rcx
+       "movq           %4,      %%rdi\n\t"             // adress of x -> rdi
+       "movq           %5,      %%rdx\n\t"             // adress of y -> rdx
+       "movq           %6,      %%r8\n\t"              // address for prefetch
+       "prefetcht0     (%%r8)\n\t"                     // Prefetch
+       "prefetcht0   64(%%r8)\n\t"                     // Prefetch
+
+       "vxorps         %%xmm8 , %%xmm8 , %%xmm8 \n\t"  // set to zero
+       "vxorps         %%xmm9 , %%xmm9 , %%xmm9 \n\t"  // set to zero
+       "vxorps         %%xmm10, %%xmm10, %%xmm10\n\t"  // set to zero
+       "vxorps         %%xmm11, %%xmm11, %%xmm11\n\t"  // set to zero
+       "vxorps         %%xmm12, %%xmm12, %%xmm12\n\t"  // set to zero
+       "vxorps         %%xmm13, %%xmm13, %%xmm13\n\t"  // set to zero
+       "vxorps         %%xmm14, %%xmm14, %%xmm14\n\t"  // set to zero
+       "vxorps         %%xmm15, %%xmm15, %%xmm15\n\t"  // set to zero
+       ".align 16                               \n\t"
+       ".L01LOOP%=:                             \n\t"
+       "vbroadcastss   (%%rdi),   %%xmm0        \n\t"  // load values of c
+       "nop                                     \n\t"
+       "leaq     (%%r8 , %%rcx, 4), %%r8        \n\t"  // add lda to pointer for prefetch
+
+       "prefetcht0     (%%r8)\n\t"                     // Prefetch
+       "vfmaddps %%xmm8 ,   0*4(%%rsi), %%xmm0, %%xmm8 \n\t" // multiply a and c and add to temp
+       "prefetcht0   64(%%r8)\n\t"                     // Prefetch
+       "vfmaddps %%xmm9 ,   4*4(%%rsi), %%xmm0, %%xmm9 \n\t" // multiply a and c and add to temp
+       "vfmaddps %%xmm10,   8*4(%%rsi), %%xmm0, %%xmm10\n\t" // multiply a and c and add to temp
+       "vfmaddps %%xmm11,  12*4(%%rsi), %%xmm0, %%xmm11\n\t" // multiply a and c and add to temp
+       "vfmaddps %%xmm12,  16*4(%%rsi), %%xmm0, %%xmm12\n\t" // multiply a and c and add to temp
+       "vfmaddps %%xmm13,  20*4(%%rsi), %%xmm0, %%xmm13\n\t" // multiply a and c and add to temp
+       "vfmaddps %%xmm14,  24*4(%%rsi), %%xmm0, %%xmm14\n\t" // multiply a and c and add to temp
+       "vfmaddps %%xmm15,  28*4(%%rsi), %%xmm0, %%xmm15\n\t" // multiply a and c and add to temp
+
+        "addq          $4     ,   %%rdi         \n\t"  // increment pointer of c 
+       "leaq     (%%rsi, %%rcx, 4), %%rsi       \n\t"  // add lda to pointer of a
+
+       "dec            %%rax                    \n\t"  // n = n -1
+       "jnz            .L01LOOP%=               \n\t"
+
+       "vmulps         %%xmm8 , %%xmm1,  %%xmm8 \n\t"  // scale by alpha
+       "vmulps         %%xmm9 , %%xmm1,  %%xmm9 \n\t"  // scale by alpha
+       "vmulps         %%xmm10, %%xmm1,  %%xmm10\n\t"  // scale by alpha
+       "vmulps         %%xmm11, %%xmm1,  %%xmm11\n\t"  // scale by alpha
+       "vmulps         %%xmm12, %%xmm1,  %%xmm12\n\t"  // scale by alpha
+       "vmulps         %%xmm13, %%xmm1,  %%xmm13\n\t"  // scale by alpha
+       "vmulps         %%xmm14, %%xmm1,  %%xmm14\n\t"  // scale by alpha
+       "vmulps         %%xmm15, %%xmm1,  %%xmm15\n\t"  // scale by alpha
+
+       "vmovups        %%xmm8 ,     (%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm9 ,  4*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm10,  8*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm11, 12*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm12, 16*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm13, 20*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm14, 24*4(%%rdx)     \n\t"  // store temp -> y
+       "vmovups        %%xmm15, 28*4(%%rdx)     \n\t"  // store temp -> y
 
        :
         :
@@ -88,6 +192,7 @@ static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x,
          "m" (pre)     // 6
        : "rax", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11",
          "xmm0" , "xmm1", 
+         "xmm8", "xmm9", "xmm10", "xmm11",
          "xmm12", "xmm13", "xmm14", "xmm15",
          "memory"
        );
@@ -97,7 +202,7 @@ static void sgemv_kernel_32( long n, float alpha, float *a, long lda, float *x,
 static void sgemv_kernel_16( long n, float alpha, float *a, long lda, float *x, float *y)
 {
 
-       float *pre = a + lda*4*3;
+       float *pre = a + lda*1;
 
        __asm __volatile
        (