updated optimized dsymv_U kernel for bulldozer
authorwernsaar <wernsaar@googlemail.com>
Wed, 20 Aug 2014 07:00:56 +0000 (09:00 +0200)
committerwernsaar <wernsaar@googlemail.com>
Wed, 20 Aug 2014 07:00:56 +0000 (09:00 +0200)
kernel/x86_64/dsymv_U.c
kernel/x86_64/dsymv_U_microk_bulldozer-2.c

index 1f22abe..267755c 100644 (file)
@@ -28,43 +28,97 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 #include "common.h"
 
+
 #if defined(BULLDOZER)
 #include "dsymv_U_microk_bulldozer-2.c"
+#elif defined(NEHALEM)
+#include "dsymv_U_microk_nehalem-2.c"
 #endif
 
+#ifndef HAVE_KERNEL_4x4
+
+static void dsymv_kernel_4x4(BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *a2, FLOAT *a3, FLOAT *xp, FLOAT *yp, FLOAT *temp1, FLOAT *temp2)
+{
+       FLOAT at0,at1,at2,at3;
+       FLOAT x;
+       FLOAT tmp2[4] = { 0.0, 0.0, 0.0, 0.0 };
+       FLOAT tp0;
+       FLOAT tp1;
+       FLOAT tp2;
+       FLOAT tp3;
+       BLASLONG i;
 
-#ifndef HAVE_KERNEL_8x2
+       tp0 = temp1[0];
+       tp1 = temp1[1];
+       tp2 = temp1[2];
+       tp3 = temp1[3];
+       
+       for (i=0; i<n; i++)
+       {
+               at0     = a0[i];
+               at1     = a1[i];
+               at2     = a2[i];
+               at3     = a3[i];
+               x       = xp[i];
+               yp[i]   += tp0 * at0 + tp1 *at1 + tp2 * at2 + tp3 * at3;
+               tmp2[0] += at0 * x;
+               tmp2[1] += at1 * x;
+               tmp2[2] += at2 * x;
+               tmp2[3] += at3 * x;
+
+       }
 
-static void dsymv_kernel_8x2(BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *xp, FLOAT *yp, FLOAT *temp1, FLOAT *temp2)
+       temp2[0] += tmp2[0];
+       temp2[1] += tmp2[1];
+       temp2[2] += tmp2[2];
+       temp2[3] += tmp2[3];
+}
+
+#endif
+
+
+#ifndef HAVE_KERNEL_1x4
+
+static void dsymv_kernel_1x4(BLASLONG from, BLASLONG to, FLOAT *a0, FLOAT *a1, FLOAT *a2, FLOAT *a3, FLOAT *xp, FLOAT *yp, FLOAT *temp1, FLOAT *temp2)
 {
        FLOAT at0,at1,at2,at3;
-       FLOAT tmp2[2] = { 0.0, 0.0 };
+       FLOAT x;
+       FLOAT tmp2[4] = { 0.0, 0.0, 0.0, 0.0 };
        FLOAT tp0;
        FLOAT tp1;
+       FLOAT tp2;
+       FLOAT tp3;
        BLASLONG i;
 
        tp0 = temp1[0];
        tp1 = temp1[1];
+       tp2 = temp1[2];
+       tp3 = temp1[3];
        
-       for (i=0; i<n; i+=2)
+       for (i=from; i<to; i++)
        {
                at0     = a0[i];
                at1     = a1[i];
-               at2     = a0[i+1];
-               at3     = a1[i+1];
-               yp[i]   += tp0 * at0 + tp1 *at1;
-               yp[i+1] += tp0 * at2 + tp1 *at3;
-               tmp2[0] += at0 * xp[i] + at2 * xp[i+1];
-               tmp2[1] += at1 * xp[i] + at3 * xp[i+1];
+               at2     = a2[i];
+               at3     = a3[i];
+               x       = xp[i];
+               yp[i]   += tp0 * at0 + tp1 *at1 + tp2 * at2 + tp3 * at3;
+               tmp2[0] += at0 * x;
+               tmp2[1] += at1 * x;
+               tmp2[2] += at2 * x;
+               tmp2[3] += at3 * x;
 
        }
-       temp2[0] = tmp2[0];
-       temp2[1] = tmp2[1];
 
+       temp2[0] += tmp2[0];
+       temp2[1] += tmp2[1];
+       temp2[2] += tmp2[2];
+       temp2[3] += tmp2[3];
 }
 
 #endif
 
+
 static void dsymv_kernel_8x1(BLASLONG n, FLOAT *a0, FLOAT *xp, FLOAT *yp, FLOAT *temp1, FLOAT *temp2)
 {
        FLOAT at0,at1,at2,at3;
@@ -99,13 +153,16 @@ int CNAME(BLASLONG m, BLASLONG offset, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOA
        BLASLONG ix,iy;
        BLASLONG jx,jy;
        BLASLONG j;
+       BLASLONG j1;
+       BLASLONG j2;
        BLASLONG m2;
        FLOAT temp1;
        FLOAT temp2;
        FLOAT *xp, *yp;
-       FLOAT *a0,*a1;
-       FLOAT tmp1[2];
-       FLOAT tmp2[2];
+       FLOAT *a0,*a1,*a2,*a3;
+       FLOAT at0,at1,at2,at3;
+       FLOAT tmp1[4];
+       FLOAT tmp2[4];
 
 #if 0
        if( m != offset )
@@ -145,37 +202,45 @@ int CNAME(BLASLONG m, BLASLONG offset, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOA
        xp = x;
        yp = y;
 
-       m2 = m - ( mrange % 2 );
+       m2 = m - ( mrange % 4 );
 
-       for (j=m1; j<m2; j+=2)
+       for (j=m1; j<m2; j+=4)
        {
                tmp1[0] = alpha * xp[j];
                tmp1[1] = alpha * xp[j+1];
+               tmp1[2] = alpha * xp[j+2];
+               tmp1[3] = alpha * xp[j+3];
                tmp2[0] = 0.0;
                tmp2[1] = 0.0;
+               tmp2[2] = 0.0;
+               tmp2[3] = 0.0;
                a0    = &a[j*lda];
                a1    = a0+lda;
-               FLOAT at0,at1;
-               BLASLONG j1 = (j/8)*8;          
+               a2    = a1+lda;
+               a3    = a2+lda;
+               j1 = (j/8)*8;           
                if ( j1 )
-                       dsymv_kernel_8x2(j1, a0, a1, xp, yp, tmp1, tmp2);
+                       dsymv_kernel_4x4(j1, a0, a1, a2, a3, xp, yp, tmp1, tmp2);
+               if ( j1 < j )
+                       dsymv_kernel_1x4(j1, j,  a0, a1, a2, a3, xp, yp, tmp1, tmp2);
 
-               for (i=j1; i<j; i++)
+               j2 = 0;
+               for ( j1 = j ; j1 < j+4 ; j1++ )
                {
-                       at0     = a0[i];
-                       at1     = a1[i];
-                       yp[i]   += tmp1[0] * at0 + tmp1[1] *at1;
-                       tmp2[0] += at0 * xp[i];
-                       tmp2[1] += at1 * xp[i];
-                       
-               }
+                       temp1 = tmp1[j2];
+                       temp2 = tmp2[j2];
+                       a0    = &a[j1*lda];
+                       for ( i=j ; i<j1; i++ )
+                       {
+                               yp[i] += temp1 * a0[i]; 
+                               temp2 += a0[i] * xp[i];
+                               
+                       }
+                       y[j1] += temp1 * a0[j1] + alpha * temp2;
+                       j2++;
 
-               at1     = a1[j];
-               yp[j]   += tmp1[1] * at1;
-               tmp2[1] += at1 * xp[j];
+               }
 
-               yp[j]   += tmp1[0] * a0[j]   + alpha * tmp2[0];
-               yp[j+1] += tmp1[1] * a1[j+1] + alpha * tmp2[1];
        }
 
        for ( ; j<m; j++)
@@ -184,7 +249,7 @@ int CNAME(BLASLONG m, BLASLONG offset, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOA
                temp2 = 0.0;
                a0    = &a[j*lda];
                FLOAT at0;
-               BLASLONG j1 = (j/8)*8;          
+               j1 = (j/8)*8;           
 
                if ( j1 )
                        dsymv_kernel_8x1(j1, a0, xp, yp, &temp1, &temp2);
index 3b03522..4929202 100644 (file)
@@ -25,10 +25,10 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *****************************************************************************/
 
-#define HAVE_KERNEL_8x2 1
-static void dsymv_kernel_8x2( BLASLONG n, FLOAT *a1, FLOAT *a2, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2) __attribute__ ((noinline));
+#define HAVE_KERNEL_4x4 1
+static void dsymv_kernel_4x4( BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *a2, FLOAT *a3, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2) __attribute__ ((noinline));
 
-static void dsymv_kernel_8x2(BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2)
+static void dsymv_kernel_4x4(BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *a2, FLOAT *a3, FLOAT *x, FLOAT *y, FLOAT *temp1, FLOAT *temp2)
 {
 
        BLASLONG register i = 0;
@@ -37,62 +37,73 @@ static void dsymv_kernel_8x2(BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *x, FLOAT *
        (
        "vxorpd         %%xmm0 , %%xmm0 , %%xmm0     \n\t"      // temp2[0]
        "vxorpd         %%xmm1 , %%xmm1 , %%xmm1     \n\t"      // temp2[1]
-       "vmovddup       (%6),    %%xmm2              \n\t"      // temp1[0]
-       "vmovddup      8(%6),    %%xmm3              \n\t"      // temp1[1]
+       "vxorpd         %%xmm2 , %%xmm2 , %%xmm2     \n\t"      // temp2[2]
+       "vxorpd         %%xmm3 , %%xmm3 , %%xmm3     \n\t"      // temp2[3]
+       "vmovddup       (%8),    %%xmm4              \n\t"      // temp1[0]
+       "vmovddup  8(%8),    %%xmm5                  \n\t"      // temp1[1]
+       "vmovddup 16(%8),    %%xmm6                  \n\t"      // temp1[1]
+       "vmovddup 24(%8),    %%xmm7                  \n\t"      // temp1[1]
 
        "xorq           %0,%0                        \n\t"
 
-       ".align 16                               \n\t"
-       ".L01LOOP%=:                             \n\t"
-
-       "prefetcht0      192(%4,%0,8)                  \n\t"
-       "vmovups        (%4,%0,8), %%xmm4        \n\t"  // 2 * a0
-       "vmovups      16(%4,%0,8), %%xmm5        \n\t"  // 2 * a0
-       "prefetcht0      192(%2,%0,8)                  \n\t"
-       "vmovups        (%2,%0,8), %%xmm8        \n\t"  // 2 * x
-       "vmovups      16(%2,%0,8), %%xmm9        \n\t"  // 2 * x
-       "prefetcht0      192(%3,%0,8)                  \n\t"
-       "vmovups      32(%4,%0,8), %%xmm6        \n\t"  // 2 * a0
-       "vmovups      48(%4,%0,8), %%xmm7        \n\t"  // 2 * a0
-       "vmovups      32(%2,%0,8), %%xmm10       \n\t"  // 2 * x
-       "vmovups      48(%2,%0,8), %%xmm11       \n\t"  // 2 * x
-
-       "prefetcht0      192(%5,%0,8)                       \n\t"
-       "vfmaddpd    (%3,%0,8), %%xmm2 , %%xmm4 , %%xmm12   \n\t" // y += temp1 * a0
-       "vfmaddpd      %%xmm0 , %%xmm8 , %%xmm4 , %%xmm0    \n\t" // temp2 += a0 * x
-       "vfmaddpd  16(%3,%0,8), %%xmm2 , %%xmm5 , %%xmm13   \n\t" // y += temp1 * a0
-       "vmovups        (%5,%0,8), %%xmm4                   \n\t"       // 2 * a1
-       "vfmaddpd      %%xmm0 , %%xmm9 , %%xmm5 , %%xmm0    \n\t" // temp2 += a0 * x
-       "vfmaddpd  32(%3,%0,8), %%xmm2 , %%xmm6 , %%xmm14   \n\t" // y += temp1 * a0
-       "vmovups      16(%5,%0,8), %%xmm5                   \n\t"       // 2 * a1
-       "vfmaddpd      %%xmm0 , %%xmm10, %%xmm6 , %%xmm0    \n\t" // temp2 += a0 * x
-       "vfmaddpd  48(%3,%0,8), %%xmm2 , %%xmm7 , %%xmm15   \n\t" // y += temp1 * a0
-       "vmovups      32(%5,%0,8), %%xmm6                   \n\t"       // 2 * a1
-       "vfmaddpd      %%xmm0 , %%xmm11, %%xmm7 , %%xmm0    \n\t" // temp2 += a0 * x
-       "vmovups      48(%5,%0,8), %%xmm7                   \n\t"       // 2 * a1
-
-       "vfmaddpd %%xmm12, %%xmm3 , %%xmm4 , %%xmm12   \n\t" // y += temp1 * a1
-       "vfmaddpd %%xmm13, %%xmm3 , %%xmm5 , %%xmm13   \n\t" // y += temp1 * a1
-       "vmovups  %%xmm12,   (%3,%0,8)                 \n\t"    // 2 * y
-       "vfmaddpd %%xmm14, %%xmm3 , %%xmm6 , %%xmm14   \n\t" // y += temp1 * a1
-       "vmovups  %%xmm13, 16(%3,%0,8)                 \n\t"    // 2 * y
-       "vfmaddpd %%xmm15, %%xmm3 , %%xmm7 , %%xmm15   \n\t" // y += temp1 * a1
-       "vmovups  %%xmm14, 32(%3,%0,8)                 \n\t"    // 2 * y
-
-       "vfmaddpd %%xmm1 , %%xmm8 , %%xmm4 , %%xmm1   \n\t" // temp2 += a1 * x
-       "vfmaddpd %%xmm1 , %%xmm9 , %%xmm5 , %%xmm1   \n\t" // temp2 += a1 * x
-       "vmovups  %%xmm15, 48(%3,%0,8)                \n\t"     // 2 * y
-       "vfmaddpd %%xmm1 , %%xmm10, %%xmm6 , %%xmm1   \n\t" // temp2 += a1 * x
-       "vfmaddpd %%xmm1 , %%xmm11, %%xmm7 , %%xmm1   \n\t" // temp2 += a1 * x
-
-        "addq          $8, %0                        \n\t"
-       "subq           $8, %1                        \n\t"             
+       ".align 16                                   \n\t"
+       ".L01LOOP%=:                                 \n\t"
+
+       "vmovups        (%4,%0,8), %%xmm12                 \n\t"  // 2 * a
+       "vmovups        (%2,%0,8), %%xmm8                  \n\t"  // 2 * x
+       "vmovups        (%3,%0,8), %%xmm9                  \n\t"  // 2 * y
+
+       "vmovups        (%5,%0,8), %%xmm13                 \n\t"  // 2 * a
+
+       "vfmaddpd       %%xmm0 , %%xmm8, %%xmm12 , %%xmm0  \n\t"  // temp2 += x * a
+       "vfmaddpd       %%xmm9 , %%xmm4, %%xmm12 , %%xmm9  \n\t"  // y     += temp1 * a
+       "vmovups        (%6,%0,8), %%xmm14                 \n\t"  // 2 * a
+
+       "vfmaddpd       %%xmm1 , %%xmm8, %%xmm13 , %%xmm1  \n\t"  // temp2 += x * a
+       "vfmaddpd       %%xmm9 , %%xmm5, %%xmm13 , %%xmm9  \n\t"  // y     += temp1 * a
+       "vmovups        (%7,%0,8), %%xmm15                 \n\t"  // 2 * a
+
+       "vmovups        16(%3,%0,8), %%xmm11               \n\t"  // 2 * y
+       "vfmaddpd       %%xmm2 , %%xmm8, %%xmm14 , %%xmm2  \n\t"  // temp2 += x * a
+       "vmovups        16(%4,%0,8), %%xmm12               \n\t"  // 2 * a
+       "vfmaddpd       %%xmm9 , %%xmm6, %%xmm14 , %%xmm9  \n\t"  // y     += temp1 * a
+       "vmovups        16(%2,%0,8), %%xmm10               \n\t"  // 2 * x
+
+       "vfmaddpd       %%xmm3 , %%xmm8, %%xmm15 , %%xmm3  \n\t"  // temp2 += x * a
+       "vfmaddpd       %%xmm9 , %%xmm7, %%xmm15 , %%xmm9  \n\t"  // y     += temp1 * a
+
+       "vmovups        16(%5,%0,8), %%xmm13               \n\t"  // 2 * a
+       "vmovups        16(%6,%0,8), %%xmm14               \n\t"  // 2 * a
+
+       "vfmaddpd       %%xmm0 , %%xmm10, %%xmm12 , %%xmm0  \n\t"  // temp2 += x * a
+       "vfmaddpd       %%xmm11 , %%xmm4, %%xmm12 , %%xmm11  \n\t"  // y     += temp1 * a
+
+       "vmovups        16(%7,%0,8), %%xmm15               \n\t"  // 2 * a
+       "vfmaddpd       %%xmm1 , %%xmm10, %%xmm13 , %%xmm1  \n\t"  // temp2 += x * a
+       "vfmaddpd       %%xmm11 , %%xmm5, %%xmm13 , %%xmm11  \n\t"  // y     += temp1 * a
+
+       "vfmaddpd       %%xmm2 , %%xmm10, %%xmm14 , %%xmm2  \n\t"  // temp2 += x * a
+       "addq           $4 , %0                       \n\t"
+       "vfmaddpd       %%xmm11 , %%xmm6, %%xmm14 , %%xmm11  \n\t"  // y     += temp1 * a
+
+       "vfmaddpd       %%xmm3 , %%xmm10, %%xmm15 , %%xmm3  \n\t"  // temp2 += x * a
+       "vfmaddpd       %%xmm11 , %%xmm7, %%xmm15 , %%xmm11  \n\t"  // y     += temp1 * a
+       "subq           $4 , %1                       \n\t"             
+
+       "vmovups        %%xmm9 ,  -32(%3,%0,8)             \n\t"
+       "vmovups        %%xmm11 , -16(%3,%0,8)             \n\t"
+
        "jnz            .L01LOOP%=                    \n\t"
 
        "vhaddpd        %%xmm0, %%xmm0, %%xmm0  \n\t"
        "vhaddpd        %%xmm1, %%xmm1, %%xmm1  \n\t"
-       "vmovsd         %%xmm0 , (%7)           \n\t"   // save temp2
-       "vmovsd         %%xmm1 ,8(%7)           \n\t"   // save temp2
+       "vhaddpd        %%xmm2, %%xmm2, %%xmm2  \n\t"
+       "vhaddpd        %%xmm3, %%xmm3, %%xmm3  \n\t"
+
+       "vmovsd         %%xmm0 ,  (%9)          \n\t"   // save temp2
+       "vmovsd         %%xmm1 , 8(%9)          \n\t"   // save temp2
+       "vmovsd         %%xmm2 ,16(%9)          \n\t"   // save temp2
+       "vmovsd         %%xmm3 ,24(%9)          \n\t"   // save temp2
 
        :
         : 
@@ -100,10 +111,12 @@ static void dsymv_kernel_8x2(BLASLONG n, FLOAT *a0, FLOAT *a1, FLOAT *x, FLOAT *
          "r" (n),      // 1
           "r" (x),      // 2
           "r" (y),      // 3
-          "r" (a0),  // 4
-          "r" (a1),  // 5
-          "r" (temp1),  // 6
-          "r" (temp2)   // 7
+          "r" (a0),     // 4
+          "r" (a1),     // 5
+          "r" (a2),     // 6
+          "r" (a3),     // 7
+          "r" (temp1),  // 8
+          "r" (temp2)   // 9
        : "cc", 
          "%xmm0", "%xmm1", "%xmm2", "%xmm3", 
          "%xmm4", "%xmm5", "%xmm6", "%xmm7",