Improve the performance of rot by using AVX512 and AVX2 intrinsic
authorGengxin Xie <gengxin.xie@intel.com>
Sun, 27 Sep 2020 02:38:19 +0000 (10:38 +0800)
committerGengxin Xie <gengxin.xie@intel.com>
Thu, 5 Nov 2020 07:12:36 +0000 (15:12 +0800)
driver/others/blas_l1_thread.c
driver/others/blas_server_win32.c
kernel/x86_64/KERNEL.HASWELL
kernel/x86_64/drot.c [new file with mode: 0644]
kernel/x86_64/drot_microk_haswell-2.c [new file with mode: 0644]
kernel/x86_64/drot_microk_skylakex-2.c [new file with mode: 0644]
kernel/x86_64/srot.c [new file with mode: 0644]
kernel/x86_64/srot_microk_haswell-2.c [new file with mode: 0644]
kernel/x86_64/srot_microk_skylakex-2.c [new file with mode: 0644]

index 04acbcc..06039c9 100644 (file)
@@ -80,7 +80,7 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
     break;
   }
 
-  mode |= BLAS_LEGACY;
+  if(!(mode & BLAS_PTHREAD)) mode |= BLAS_LEGACY;
 
   for (i = 0; i < nthreads; i++) blas_queue_init(&queue[i]);
 
index d2cc917..f47908c 100644 (file)
@@ -476,12 +476,15 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
 
   routine = queue -> routine;
 
-    if (!(queue -> mode & BLAS_LEGACY)) {
+  if (queue -> mode & BLAS_LEGACY) {
+    legacy_exec(routine, queue -> mode, queue -> args, queue -> sb);
+  } else
+    if (queue -> mode & BLAS_PTHREAD) {
+      void (*pthreadcompat)(void *) = queue -> routine;
+      (pthreadcompat)(queue -> args);
+    } else
       (routine)(queue -> args, queue -> range_m, queue -> range_n,
                queue -> sa, queue -> sb, 0);
-    } else {
-      legacy_exec(routine, queue -> mode, queue -> args, queue -> sb);
-    }
 
   if ((num > 1) && queue -> next) exec_blas_async_wait(num - 1, queue -> next);
 
index b979fc0..81eaf96 100644 (file)
@@ -102,3 +102,6 @@ ZGEMM3MKERNEL    =  zgemm3m_kernel_4x4_haswell.c
 
 SASUMKERNEL = sasum.c
 DASUMKERNEL = dasum.c
+
+SROTKERNEL = srot.c
+DROTKERNEL = drot.c
diff --git a/kernel/x86_64/drot.c b/kernel/x86_64/drot.c
new file mode 100644 (file)
index 0000000..a312b7f
--- /dev/null
@@ -0,0 +1,139 @@
+#include "common.h"
+
+#if defined(SKYLAKEX)
+#include "drot_microk_skylakex-2.c"
+#elif defined(HASWELL)
+#include "drot_microk_haswell-2.c"
+#endif
+
+#ifndef HAVE_DROT_KERNEL
+
+static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+    FLOAT f0, f1, f2, f3;
+    FLOAT x0, x1, x2, x3;
+    FLOAT g0, g1, g2, g3;
+    FLOAT y0, y1, y2, y3;
+
+    FLOAT* xp = x;
+    FLOAT* yp = y;
+
+    BLASLONG n1 = n & (~7);
+
+    while (i < n1) {
+        x0 = xp[0];
+        y0 = yp[0];
+        x1 = xp[1];
+        y1 = yp[1];
+        x2 = xp[2];
+        y2 = yp[2];
+        x3 = xp[3];
+        y3 = yp[3];
+
+        f0 = c*x0 + s*y0;
+        g0 = c*y0 - s*x0;
+        f1 = c*x1 + s*y1;
+        g1 = c*y1 - s*x1;
+        f2 = c*x2 + s*y2;
+        g2 = c*y2 - s*x2;
+        f3 = c*x3 + s*y3;
+        g3 = c*y3 - s*x3;
+
+        xp[0] = f0;
+        yp[0] = g0;
+        xp[1] = f1;
+        yp[1] = g1;
+        xp[2] = f2;
+        yp[2] = g2;
+        xp[3] = f3;
+        yp[3] = g3;
+
+        xp += 4;
+        yp += 4;
+        i += 4;
+    }
+
+    while (i < n) {
+        FLOAT temp = c*x[i] + s*y[i];
+        y[i] = c*y[i] - s*x[i];
+        x[i] = temp;
+
+        i++;
+    }
+}
+
+#endif
+static void rot_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+    BLASLONG ix = 0, iy = 0;
+
+    FLOAT temp;
+    
+    if (n <= 0)
+        return;
+    if ((inc_x == 1) && (inc_y == 1)) {
+            drot_kernel(n, x, y, c, s);
+    }
+    else {
+        while (i < n) {
+            temp = c * x[ix] + s * y[iy];
+            y[iy] = c * y[iy] - s * x[ix];
+            x[ix] = temp;
+
+            ix += inc_x;
+            iy += inc_y;
+            i++;
+        }
+    }
+    return;
+}
+
+
+#if defined(SMP)
+static int rot_thread_function(blas_arg_t *args)
+{
+
+    rot_compute(args->m, 
+            args->a, args->lda, 
+            args->b, args->ldb, 
+            ((FLOAT *)args->alpha)[0], 
+            ((FLOAT *)args->alpha)[1]);
+    return 0;
+}
+
+extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, int (*function)(), int nthreads);
+#endif
+int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
+{
+#if defined(SMP)
+    int nthreads;
+    FLOAT alpha[2]={c, s};
+    FLOAT dummy_c;
+#endif
+
+#if defined(SMP)
+    if (inc_x == 0 || inc_y == 0 || n <= 100000) {
+        nthreads = 1;
+    }
+    else {
+        nthreads = num_cpu_avail(1);
+    }
+
+    if (nthreads == 1) {
+        rot_compute(n, x, inc_x, y, inc_y, c, s);
+    }
+    else {
+#if defined(DOUBLE)
+           int mode = BLAS_DOUBLE | BLAS_REAL | BLAS_PTHREAD;
+#else
+           int mode = BLAS_SINGLE | BLAS_REAL | BLAS_PTHREAD;
+#endif
+           blas_level1_thread(mode, n, 0, 0, alpha, x, inc_x, y, inc_y, &dummy_c, 0, (void *)rot_thread_function, nthreads);
+    }
+#else  
+    rot_compute(n, x, inc_x, y, inc_y, c, s);
+#endif
+    return 0;
+}
diff --git a/kernel/x86_64/drot_microk_haswell-2.c b/kernel/x86_64/drot_microk_haswell-2.c
new file mode 100644 (file)
index 0000000..72a8769
--- /dev/null
@@ -0,0 +1,87 @@
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_DROT_KERNEL 1
+
+#include <immintrin.h>
+#include <stdint.h>
+
+static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+
+    BLASLONG tail_index_4 = n&(~3);
+    BLASLONG tail_index_16 = n&(~15);
+
+    __m256d c_256, s_256;
+    if (n >= 4) {
+        c_256 = _mm256_set1_pd(c);
+        s_256 = _mm256_set1_pd(s);
+    }
+
+    __m256d x0, x1, x2, x3;
+    __m256d y0, y1, y2, y3;
+    __m256d t0, t1, t2, t3;
+
+    for (i = 0; i < tail_index_16; i += 16) {
+        x0 = _mm256_loadu_pd(&x[i + 0]);
+        x1 = _mm256_loadu_pd(&x[i + 4]);
+        x2 = _mm256_loadu_pd(&x[i + 8]);
+        x3 = _mm256_loadu_pd(&x[i +12]);
+        y0 = _mm256_loadu_pd(&y[i + 0]);
+        y1 = _mm256_loadu_pd(&y[i + 4]);
+        y2 = _mm256_loadu_pd(&y[i + 8]);
+        y3 = _mm256_loadu_pd(&y[i +12]);
+
+        t0 = _mm256_mul_pd(s_256, y0);
+        t1 = _mm256_mul_pd(s_256, y1);
+        t2 = _mm256_mul_pd(s_256, y2);
+        t3 = _mm256_mul_pd(s_256, y3);
+
+        t0 = _mm256_fmadd_pd(c_256, x0, t0);
+        t1 = _mm256_fmadd_pd(c_256, x1, t1);
+        t2 = _mm256_fmadd_pd(c_256, x2, t2);
+        t3 = _mm256_fmadd_pd(c_256, x3, t3);
+
+        _mm256_storeu_pd(&x[i + 0], t0);
+        _mm256_storeu_pd(&x[i + 4], t1);
+        _mm256_storeu_pd(&x[i + 8], t2);
+        _mm256_storeu_pd(&x[i +12], t3);
+
+        t0 = _mm256_mul_pd(s_256, x0);
+        t1 = _mm256_mul_pd(s_256, x1);
+        t2 = _mm256_mul_pd(s_256, x2);
+        t3 = _mm256_mul_pd(s_256, x3);
+
+        t0 = _mm256_fmsub_pd(c_256, y0, t0);
+        t1 = _mm256_fmsub_pd(c_256, y1, t1);
+        t2 = _mm256_fmsub_pd(c_256, y2, t2);
+        t3 = _mm256_fmsub_pd(c_256, y3, t3);
+
+        _mm256_storeu_pd(&y[i + 0], t0);
+        _mm256_storeu_pd(&y[i + 4], t1);
+        _mm256_storeu_pd(&y[i + 8], t2);
+        _mm256_storeu_pd(&y[i +12], t3);
+
+    }
+
+    for (i = tail_index_16; i < tail_index_4; i += 4) {
+        x0 = _mm256_loadu_pd(&x[i]);
+        y0 = _mm256_loadu_pd(&y[i]);
+
+        t0 = _mm256_mul_pd(s_256, y0);
+        t0 = _mm256_fmadd_pd(c_256, x0, t0);
+        _mm256_storeu_pd(&x[i], t0);
+        
+        t0 = _mm256_mul_pd(s_256, x0);
+        t0 = _mm256_fmsub_pd(c_256, y0, t0);
+        _mm256_storeu_pd(&y[i], t0);
+    }
+
+    for (i = tail_index_4; i < n; ++i) {
+        FLOAT temp = c * x[i] + s * y[i];
+        y[i] = c * y[i] - s * x[i];
+        x[i] = temp;
+    }
+}
+#endif
diff --git a/kernel/x86_64/drot_microk_skylakex-2.c b/kernel/x86_64/drot_microk_skylakex-2.c
new file mode 100644 (file)
index 0000000..4e862e6
--- /dev/null
@@ -0,0 +1,94 @@
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_DROT_KERNEL 1
+
+#include <immintrin.h>
+#include <stdint.h>
+
+static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+    BLASLONG n1 = n;
+    
+    BLASLONG tail_index_8 = 0;
+    BLASLONG tail_index_32 = 0;
+
+    __m512d c_512 = _mm512_set1_pd(c);
+    __m512d s_512 = _mm512_set1_pd(s);
+
+    tail_index_8 = n1 & (~7);
+    tail_index_32 = n1 & (~31);
+
+
+    __m512d x0, x1, x2, x3;
+    __m512d y0, y1, y2, y3;
+    __m512d t0, t1, t2, t3;
+
+    for (i = 0; i < tail_index_32; i += 32) {
+        x0 = _mm512_loadu_pd(&x[i + 0]);
+        x1 = _mm512_loadu_pd(&x[i + 8]);
+        x2 = _mm512_loadu_pd(&x[i +16]);
+        x3 = _mm512_loadu_pd(&x[i +24]);
+        y0 = _mm512_loadu_pd(&y[i + 0]);
+        y1 = _mm512_loadu_pd(&y[i + 8]);
+        y2 = _mm512_loadu_pd(&y[i +16]);
+        y3 = _mm512_loadu_pd(&y[i +24]);
+
+        t0 = _mm512_mul_pd(s_512, y0);
+        t1 = _mm512_mul_pd(s_512, y1);
+        t2 = _mm512_mul_pd(s_512, y2);
+        t3 = _mm512_mul_pd(s_512, y3);
+
+        t0 = _mm512_fmadd_pd(c_512, x0, t0);
+        t1 = _mm512_fmadd_pd(c_512, x1, t1);
+        t2 = _mm512_fmadd_pd(c_512, x2, t2);
+        t3 = _mm512_fmadd_pd(c_512, x3, t3);
+
+        _mm512_storeu_pd(&x[i + 0], t0);
+        _mm512_storeu_pd(&x[i + 8], t1);
+        _mm512_storeu_pd(&x[i +16], t2);
+        _mm512_storeu_pd(&x[i +24], t3);
+
+        t0 = _mm512_mul_pd(s_512, x0);
+        t1 = _mm512_mul_pd(s_512, x1);
+        t2 = _mm512_mul_pd(s_512, x2);
+        t3 = _mm512_mul_pd(s_512, x3);
+
+        t0 = _mm512_fmsub_pd(c_512, y0, t0);
+        t1 = _mm512_fmsub_pd(c_512, y1, t1);
+        t2 = _mm512_fmsub_pd(c_512, y2, t2);
+        t3 = _mm512_fmsub_pd(c_512, y3, t3);
+
+        _mm512_storeu_pd(&y[i + 0], t0);
+        _mm512_storeu_pd(&y[i + 8], t1);
+        _mm512_storeu_pd(&y[i +16], t2);
+        _mm512_storeu_pd(&y[i +24], t3);
+    }
+
+    for (i = tail_index_32; i < tail_index_8; i += 8) {
+        x0 = _mm512_loadu_pd(&x[i]);
+        y0 = _mm512_loadu_pd(&y[i]);
+
+        t0 = _mm512_mul_pd(s_512, y0);
+        t0 = _mm512_fmadd_pd(c_512, x0, t0);
+        _mm512_storeu_pd(&x[i], t0);
+
+        t0 = _mm512_mul_pd(s_512, x0);
+        t0 = _mm512_fmsub_pd(c_512, y0, t0);
+        _mm512_storeu_pd(&y[i], t0);
+    }
+
+    if ((n1&7) > 0) {
+        unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n1&7)));
+       __m512d tail_x = _mm512_maskz_loadu_pd(*((__mmask8*) &tail_mask8), &x[tail_index_8]);
+       __m512d tail_y = _mm512_maskz_loadu_pd(*((__mmask8*) &tail_mask8), &y[tail_index_8]);
+       __m512d temp = _mm512_mul_pd(s_512, tail_y);
+       temp = _mm512_fmadd_pd(c_512, tail_x, temp);
+       _mm512_mask_storeu_pd(&x[tail_index_8],*((__mmask8*)&tail_mask8), temp);
+        temp = _mm512_mul_pd(s_512, tail_x);
+        temp = _mm512_fmsub_pd(c_512, tail_y, temp);
+        _mm512_mask_storeu_pd(&y[tail_index_8], *((__mmask8*)&tail_mask8), temp);      
+    }
+}
+#endif
diff --git a/kernel/x86_64/srot.c b/kernel/x86_64/srot.c
new file mode 100644 (file)
index 0000000..021c20d
--- /dev/null
@@ -0,0 +1,139 @@
+#include "common.h"
+
+#if defined(SKYLAKEX)
+#include "srot_microk_skylakex-2.c"
+#elif defined(HASWELL)
+#include "srot_microk_haswell-2.c"
+#endif
+
+#ifndef HAVE_SROT_KERNEL
+
+static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+    FLOAT f0, f1, f2, f3;
+    FLOAT x0, x1, x2, x3;
+    FLOAT g0, g1, g2, g3;
+    FLOAT y0, y1, y2, y3;
+
+    FLOAT* xp = x;
+    FLOAT* yp = y;
+
+    BLASLONG n1 = n & (~7);
+
+    while (i < n1) {
+        x0 = xp[0];
+        y0 = yp[0];
+        x1 = xp[1];
+        y1 = yp[1];
+        x2 = xp[2];
+        y2 = yp[2];
+        x3 = xp[3];
+        y3 = yp[3];
+
+        f0 = c*x0 + s*y0;
+        g0 = c*y0 - s*x0;
+        f1 = c*x1 + s*y1;
+        g1 = c*y1 - s*x1;
+        f2 = c*x2 + s*y2;
+        g2 = c*y2 - s*x2;
+        f3 = c*x3 + s*y3;
+        g3 = c*y3 - s*x3;
+
+        xp[0] = f0;
+        yp[0] = g0;
+        xp[1] = f1;
+        yp[1] = g1;
+        xp[2] = f2;
+        yp[2] = g2;
+        xp[3] = f3;
+        yp[3] = g3;
+
+        xp += 4;
+        yp += 4;
+        i += 4;
+    }
+
+    while (i < n) {
+        FLOAT temp = c*x[i] + s*y[i];
+        y[i] = c*y[i] - s*x[i];
+        x[i] = temp;
+
+        i++;
+    }
+}
+
+#endif
+static void rot_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+    BLASLONG ix = 0, iy = 0;
+
+    FLOAT temp;
+    
+    if (n <= 0)
+        return;
+    if ((inc_x == 1) && (inc_y == 1)) {
+            srot_kernel(n, x, y, c, s);
+    }
+    else {
+        while (i < n) {
+            temp = c * x[ix] + s * y[iy];
+            y[iy] = c * y[iy] - s * x[ix];
+            x[ix] = temp;
+
+            ix += inc_x;
+            iy += inc_y;
+            i++;
+        }
+    }
+    return;
+}
+
+
+#if defined(SMP)
+static int rot_thread_function(blas_arg_t *args)
+{
+
+    rot_compute(args->m, 
+            args->a, args->lda, 
+            args->b, args->ldb, 
+            ((float *)args->alpha)[0], 
+            ((float *)args->alpha)[1]);
+    return 0;
+}
+
+extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, int (*function)(), int nthreads);
+#endif
+int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
+{
+#if defined(SMP)
+    int nthreads;
+    FLOAT alpha[2]={c, s};
+    FLOAT dummy_c;
+#endif
+
+#if defined(SMP)
+    if (inc_x == 0 || inc_y == 0 || n <= 100000) {
+        nthreads = 1;
+    }
+    else {
+        nthreads = num_cpu_avail(1);
+    }
+
+    if (nthreads == 1) {
+        rot_compute(n, x, inc_x, y, inc_y, c, s);
+    }
+    else {
+#if defined(DOUBLE)
+           int mode = BLAS_DOUBLE | BLAS_REAL | BLAS_PTHREAD;
+#else
+           int mode = BLAS_SINGLE | BLAS_REAL | BLAS_PTHREAD;
+#endif
+           blas_level1_thread(mode, n, 0, 0, alpha, x, inc_x, y, inc_y, &dummy_c, 0, (void *)rot_thread_function, nthreads);
+    }
+#else  
+    rot_compute(n, x, inc_x, y, inc_y, c, s);
+#endif
+    return 0;
+}
diff --git a/kernel/x86_64/srot_microk_haswell-2.c b/kernel/x86_64/srot_microk_haswell-2.c
new file mode 100644 (file)
index 0000000..cba9620
--- /dev/null
@@ -0,0 +1,87 @@
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_SROT_KERNEL 1
+
+#include <immintrin.h>
+#include <stdint.h>
+
+static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+
+    BLASLONG tail_index_8 = n&(~7);
+    BLASLONG tail_index_32 = n&(~31);
+
+    __m256 c_256, s_256;
+    if (n >= 8) {
+        c_256 = _mm256_set1_ps(c);
+        s_256 = _mm256_set1_ps(s);
+    }
+
+    __m256 x0, x1, x2, x3;
+    __m256 y0, y1, y2, y3;
+    __m256 t0, t1, t2, t3;
+
+    for (i = 0; i < tail_index_32; i += 32) {
+        x0 = _mm256_loadu_ps(&x[i + 0]);
+        x1 = _mm256_loadu_ps(&x[i + 8]);
+        x2 = _mm256_loadu_ps(&x[i +16]);
+        x3 = _mm256_loadu_ps(&x[i +24]);
+        y0 = _mm256_loadu_ps(&y[i + 0]);
+        y1 = _mm256_loadu_ps(&y[i + 8]);
+        y2 = _mm256_loadu_ps(&y[i +16]);
+        y3 = _mm256_loadu_ps(&y[i +24]);
+
+        t0 = _mm256_mul_ps(s_256, y0);
+        t1 = _mm256_mul_ps(s_256, y1);
+        t2 = _mm256_mul_ps(s_256, y2);
+        t3 = _mm256_mul_ps(s_256, y3);
+
+        t0 = _mm256_fmadd_ps(c_256, x0, t0);
+        t1 = _mm256_fmadd_ps(c_256, x1, t1);
+        t2 = _mm256_fmadd_ps(c_256, x2, t2);
+        t3 = _mm256_fmadd_ps(c_256, x3, t3);
+
+        _mm256_storeu_ps(&x[i + 0], t0);
+        _mm256_storeu_ps(&x[i + 8], t1);
+        _mm256_storeu_ps(&x[i +16], t2);
+        _mm256_storeu_ps(&x[i +24], t3);
+
+        t0 = _mm256_mul_ps(s_256, x0);
+        t1 = _mm256_mul_ps(s_256, x1);
+        t2 = _mm256_mul_ps(s_256, x2);
+        t3 = _mm256_mul_ps(s_256, x3);
+
+        t0 = _mm256_fmsub_ps(c_256, y0, t0);
+        t1 = _mm256_fmsub_ps(c_256, y1, t1);
+        t2 = _mm256_fmsub_ps(c_256, y2, t2);
+        t3 = _mm256_fmsub_ps(c_256, y3, t3);
+
+        _mm256_storeu_ps(&y[i + 0], t0);
+        _mm256_storeu_ps(&y[i + 8], t1);
+        _mm256_storeu_ps(&y[i +16], t2);
+        _mm256_storeu_ps(&y[i +24], t3);
+
+    }
+
+    for (i = tail_index_32; i < tail_index_8; i += 8) {
+        x0 = _mm256_loadu_ps(&x[i]);
+        y0 = _mm256_loadu_ps(&y[i]);
+
+        t0 = _mm256_mul_ps(s_256, y0);
+        t0 = _mm256_fmadd_ps(c_256, s0, t0);
+        _mm256_storeu_ps(&x[i], t0);
+
+        t0 = _mm256_mul_ps(s_256, x0);
+        t0 = _mm256_fmsub_ps(c_256, y0, t0);
+        _mm256_storeu_ps(&y[i], t0);
+    }
+
+    for (i = tail_index_8; i < n; ++i) {
+        FLOAT temp = c * x[i] + s * y[i];
+        y[i] = c * y[i] - s * x[i];
+        x[i] = temp;
+    }
+}
+#endif
diff --git a/kernel/x86_64/srot_microk_skylakex-2.c b/kernel/x86_64/srot_microk_skylakex-2.c
new file mode 100644 (file)
index 0000000..a21d1cf
--- /dev/null
@@ -0,0 +1,91 @@
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_SROT_KERNEL 1
+
+#include <immintrin.h>
+#include <stdint.h>
+
+static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
+{
+    BLASLONG i = 0;
+    __m512 c_512, s_512;
+    c_512 = _mm512_set1_ps(c);
+    s_512 = _mm512_set1_ps(s);
+
+    BLASLONG tail_index_16 = n&(~15);
+    BLASLONG tail_index_64 = n&(~63);
+
+
+    __m512 x0, x1, x2, x3;
+    __m512 y0, y1, y2, y3;
+    __m512 t0, t1, t2, t3;
+
+    for (i = 0; i < tail_index_64; i += 64) {
+        x0 = _mm512_loadu_ps(&x[i + 0]);
+        x1 = _mm512_loadu_ps(&x[i +16]);
+        x2 = _mm512_loadu_ps(&x[i +32]);
+        x3 = _mm512_loadu_ps(&x[i +48]);
+        y0 = _mm512_loadu_ps(&y[i + 0]);
+        y1 = _mm512_loadu_ps(&y[i +16]);
+        y2 = _mm512_loadu_ps(&y[i +32]);
+        y3 = _mm512_loadu_ps(&y[i +48]);
+
+        t0 = _mm512_mul_ps(s_512, y0);
+        t1 = _mm512_mul_ps(s_512, y1);
+        t2 = _mm512_mul_ps(s_512, y2);
+        t3 = _mm512_mul_ps(s_512, y3);
+
+        t0 = _mm512_fmadd_ps(c_512, x0, t0);
+        t1 = _mm512_fmadd_ps(c_512, x1, t1);
+        t2 = _mm512_fmadd_ps(c_512, x2, t2);
+        t3 = _mm512_fmadd_ps(c_512, x3, t3);
+
+        _mm512_storeu_ps(&x[i + 0], t0);
+        _mm512_storeu_ps(&x[i +16], t1);
+        _mm512_storeu_ps(&x[i +32], t2);
+        _mm512_storeu_ps(&x[i +48], t3);
+
+        t0 = _mm512_mul_ps(s_512, x0);
+        t1 = _mm512_mul_ps(s_512, x1);
+        t2 = _mm512_mul_ps(s_512, x2);
+        t3 = _mm512_mul_ps(s_512, x3);
+
+        t0 = _mm512_fmsub_ps(c_512, y0, t0);
+        t1 = _mm512_fmsub_ps(c_512, y1, t1);
+        t2 = _mm512_fmsub_ps(c_512, y2, t2);
+        t3 = _mm512_fmsub_ps(c_512, y3, t3);
+
+        _mm512_storeu_ps(&y[i + 0], t0);
+        _mm512_storeu_ps(&y[i +16], t1);
+        _mm512_storeu_ps(&y[i +32], t2);
+        _mm512_storeu_ps(&y[i +48], t3);
+    }
+
+    for (i = tail_index_64; i < tail_index_16; i += 16) {
+        x0 = _mm512_loadu_ps(&x[i]);
+        y0 = _mm512_loadu_ps(&y[i]);
+
+        t0 = _mm512_mul_ps(s_512, y0);
+        t0 = _mm512_fmadd_ps(c_512, x0, t0);
+        _mm512_storeu_ps(&x[i], t0);
+
+        t0 = _mm512_mul_ps(s_512, x0);
+        t0 = _mm512_fmsub_ps(c_512, y0, t0);
+        _mm512_storeu_ps(&y[i], t0);
+    }
+
+
+    if ((n & 15) > 0) {
+        uint16_t tail_mask16 = (((uint16_t) 0xffff) >> (16-(n&15)));
+        __m512 tail_x = _mm512_maskz_loadu_ps(*((__mmask16*)&tail_mask16), &x[tail_index_16]);
+        __m512 tail_y = _mm512_maskz_loadu_ps(*((__mmask16*)&tail_mask16), &y[tail_index_16]);
+           __m512 temp = _mm512_mul_ps(s_512, tail_y);
+           temp = _mm512_fmadd_ps(c_512, tail_x, temp);
+           _mm512_mask_storeu_ps(&x[tail_index_16], *((__mmask16*)&tail_mask16), temp);
+           temp = _mm512_mul_ps(s_512, tail_x);
+           temp = _mm512_fmsub_ps(c_512, tail_y, temp);
+           _mm512_mask_storeu_ps(&y[tail_index_16], *((__mmask16*)&tail_mask16), temp);        
+    }
+}
+#endif