Merge pull request #2796 from Guobing-Chen/BF16_dot_coversion_apis
authorMartin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Mon, 14 Sep 2020 13:00:19 +0000 (15:00 +0200)
committerGitHub <noreply@github.com>
Mon, 14 Sep 2020 13:00:19 +0000 (15:00 +0200)
Add bfloat16 based dot and conversion with single/double

31 files changed:
Makefile.tail
cblas.h
cmake/kernel.cmake
common.h
common_interface.h
common_level1.h
common_macro.h
common_param.h
common_sh.h
common_thread.h
common_x86_64.h
driver/others/blas_l1_thread.c
driver/others/blas_server.c
driver/others/blas_server_omp.c
driver/others/blas_server_win32.c
driver/others/dynamic.c
exports/gensymbol
interface/Makefile
interface/bf16dot.c [new file with mode: 0644]
interface/bf16to.c [new file with mode: 0644]
interface/tobf16.c [new file with mode: 0644]
kernel/Makefile.L1
kernel/setparam-ref.c
kernel/x86_64/KERNEL
kernel/x86_64/bf16to.c [new file with mode: 0644]
kernel/x86_64/dtobf16_microk_cooperlake.c [new file with mode: 0644]
kernel/x86_64/shdot.c [new file with mode: 0644]
kernel/x86_64/shdot_microk_cooperlake.c [new file with mode: 0644]
kernel/x86_64/stobf16_microk_cooperlake.c [new file with mode: 0644]
kernel/x86_64/tobf16.c [new file with mode: 0644]
openblas_config_template.h

index 3990298..cfc4a36 100644 (file)
@@ -5,13 +5,14 @@ QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
 CBLASOBJS_P = $(CBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
 ZBLASOBJS_P = $(ZBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
 XBLASOBJS_P = $(XBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
+SHEXTOBJS_P = $(SHEXTOBJS:.$(SUFFIX)=.$(PSUFFIX))
 
 COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
 
 HPLOBJS_P   = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
 
-BLASOBJS    = $(SHBLASOBJS)  $(SBLASOBJS)   $(DBLASOBJS)   $(CBLASOBJS)   $(ZBLASOBJS)
-BLASOBJS_P  = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
+BLASOBJS    = $(SHEXTOBJS) $(SHBLASOBJS)  $(SBLASOBJS)   $(DBLASOBJS)   $(CBLASOBJS)   $(ZBLASOBJS)
+BLASOBJS_P  = $(SHEXTOBJS_P) $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
 
 ifdef EXPRECISION
 BLASOBJS   += $(QBLASOBJS)   $(XBLASOBJS)
@@ -30,6 +31,7 @@ $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
 $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE  -DCOMPLEX
 $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE  -DCOMPLEX
 $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
+$(SHEXTOBJS) $(SHEXTOBJS_P) : override CFLAGS += -DHALF -UDOUBLE  -UCOMPLEX
 
 $(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
@@ -38,6 +40,7 @@ $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(CBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(ZBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(XBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
+$(SHEXTOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 
 libs    :: $(BLASOBJS) $(COMMONOBJS)
        $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
diff --git a/cblas.h b/cblas.h
index 4bc5588..21f3958 100644 (file)
--- a/cblas.h
+++ b/cblas.h
@@ -382,6 +382,17 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
 void cblas_zgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint crows, OPENBLAS_CONST blasint ccols, OPENBLAS_CONST double *calpha, double *a, OPENBLAS_CONST blasint clda, OPENBLAS_CONST double *cbeta, 
                  double *c, OPENBLAS_CONST blasint cldc); 
 
+/*** BFLOAT16 and INT8 extensions ***/
+/* convert float array to BFLOAT16 array by rounding */
+void   cblas_shstobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST float  *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
+/* convert double array to BFLOAT16 array by rounding */
+void   cblas_shdtobf16(OPENBLAS_CONST blasint n, OPENBLAS_CONST double *in, OPENBLAS_CONST blasint incin, bfloat16 *out, OPENBLAS_CONST blasint incout);
+/* convert BFLOAT16 array to float array */
+void   cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, float  *out, OPENBLAS_CONST blasint incout);
+/* convert BFLOAT16 array to double array */
+void   cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
+/* dot production of BFLOAT16 input arrays, and output as float */
+float  cblas_shdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
 
 #ifdef __cplusplus
 }
index 4b505a1..79eeaae 100644 (file)
@@ -126,12 +126,14 @@ if (BUILD_HALF)
   set(SHAXPYKERNEL ../arm/axpy.c)
   set(SHAXPBYKERNEL ../arm/axpby.c)
   set(SHCOPYKERNEL ../arm/copy.c)
-  set(SHDOTKERNEL ../arm/dot.c)
+  set(SHDOTKERNEL ../x86_64/shdot.c)
   set(SHROTKERNEL ../arm/rot.c)
   set(SHSCALKERNEL ../arm/scal.c)
   set(SHNRM2KERNEL ../arm/nrm2.c)
   set(SHSUMKERNEL ../arm/sum.c)
   set(SHSWAPKERNEL ../arm/swap.c)
+  set(TOBF16KERNEL ../x86_64/tobf16.c)
+  set(BF16TOKERNEL ../x86_64/bf16to.c)
 endif ()
 endmacro ()
 
index d6637ab..adc1625 100644 (file)
--- a/common.h
+++ b/common.h
@@ -258,7 +258,8 @@ typedef unsigned long BLASULONG;
 #endif
 
 #ifndef BFLOAT16
-typedef unsigned short bfloat16;
+#include <stdint.h>
+typedef uint16_t bfloat16;
 #define HALFCONVERSION 1
 #endif
 
index 78f5be6..35a957a 100644 (file)
@@ -54,6 +54,11 @@ double BLASFUNC(dsdot) (blasint *, float  *, blasint *, float  *, blasint *);
 double BLASFUNC(ddot)  (blasint *, double *, blasint *, double *, blasint *);
 xdouble BLASFUNC(qdot)  (blasint *, xdouble *, blasint *, xdouble *, blasint *);
 
+float  BLASFUNC(shdot)     (blasint *, bfloat16 *, blasint *, bfloat16 *, blasint *);
+void   BLASFUNC(shstobf16) (blasint *, float *,    blasint *, bfloat16 *, blasint *);
+void   BLASFUNC(shdtobf16) (blasint *, double *,   blasint *, bfloat16 *, blasint *);
+void   BLASFUNC(sbf16tos)  (blasint *, bfloat16 *, blasint *, float *,    blasint *);
+void   BLASFUNC(dbf16tod)  (blasint *, bfloat16 *, blasint *, double *,   blasint *);
 
 #ifdef RETURN_BY_STRUCT
 typedef struct {
index 74cafb6..88aa275 100644 (file)
@@ -46,6 +46,12 @@ float   sdot_k(BLASLONG, float   *, BLASLONG, float   *, BLASLONG);
 double dsdot_k(BLASLONG, float   *, BLASLONG, float *, BLASLONG);
 double  ddot_k(BLASLONG, double  *, BLASLONG, double  *, BLASLONG);
 xdouble qdot_k(BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
+float  shdot_k(BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
+
+void   shstobf16_k(BLASLONG, float    *, BLASLONG, bfloat16 *, BLASLONG);
+void   shdtobf16_k(BLASLONG, double   *, BLASLONG, bfloat16 *, BLASLONG);
+void   sbf16tos_k (BLASLONG, bfloat16 *, BLASLONG, float    *, BLASLONG);
+void   dbf16tod_k (BLASLONG, bfloat16 *, BLASLONG, double   *, BLASLONG);
 
 openblas_complex_float cdotc_k (BLASLONG, float  *, BLASLONG, float  *, BLASLONG);
 openblas_complex_float cdotu_k (BLASLONG, float  *, BLASLONG, float  *, BLASLONG);
index 8fe1f15..3d6bcd9 100644 (file)
 
 #elif defined(HALF)
 
+#define  D_TO_BF16_K    SHDTOBF16_K
+#define  D_BF16_TO_K    DBF16TOD_K
+#define  S_TO_BF16_K    SHSTOBF16_K
+#define  S_BF16_TO_K    SBF16TOS_K
+
 #define        AMAX_K                  SAMAX_K
 #define        AMIN_K                  SAMIN_K
 #define        MAX_K                   SMAX_K
 #define        ASUM_K                  SASUM_K
 #define        DOTU_K                  SDOTU_K
 #define        DOTC_K                  SDOTC_K
+#define BF16_DOT_K      SHDOT_K
 #define        AXPYU_K                 SAXPYU_K
 #define        AXPYC_K                 SAXPYC_K
 #define AXPBY_K     SAXPBY_K
index 0437482..a52de98 100644 (file)
@@ -51,6 +51,11 @@ typedef struct {
   int shgemm_p, shgemm_q, shgemm_r;
   int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
 
+  void   (*shstobf16_k) (BLASLONG, float    *, BLASLONG, bfloat16 *, BLASLONG);
+  void   (*shdtobf16_k) (BLASLONG, double   *, BLASLONG, bfloat16 *, BLASLONG);
+  void   (*sbf16tos_k)  (BLASLONG, bfloat16 *, BLASLONG, float    *, BLASLONG);
+  void   (*dbf16tod_k)  (BLASLONG, bfloat16 *, BLASLONG, double   *, BLASLONG);
+
   float  (*shamax_k) (BLASLONG, float *, BLASLONG);
   float  (*shamin_k) (BLASLONG, float *, BLASLONG);
   float  (*shmax_k)  (BLASLONG, float *, BLASLONG);
@@ -64,7 +69,7 @@ BLASLONG (*ishmin_k) (BLASLONG, float *, BLASLONG);
   float  (*shasum_k) (BLASLONG, float *, BLASLONG);
   float  (*shsum_k)  (BLASLONG, float *, BLASLONG);
   int    (*shcopy_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
-  float  (*shdot_k)  (BLASLONG, float *, BLASLONG, float *, BLASLONG);
+  float  (*shdot_k)  (BLASLONG, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG);
   double (*dshdot_k) (BLASLONG, float *, BLASLONG, float *, BLASLONG);
 
   int    (*shrot_k)  (BLASLONG, float *, BLASLONG, float *, BLASLONG, float, float);
index 7a00457..5dc99b3 100644 (file)
@@ -3,6 +3,12 @@
 
 #ifndef DYNAMIC_ARCH
 
+#define SHDOT_K             shdot_k
+#define SHSTOBF16_K         shstobf16_k
+#define SHDTOBF16_K         shdtobf16_k
+#define SBF16TOS_K          sbf16tos_k
+#define DBF16TOD_K          dbf16tod_k
+
 #define        SHGEMM_ONCOPY           shgemm_oncopy
 #define        SHGEMM_OTCOPY           shgemm_otcopy
 
 
 #else
 
+#define SHDOT_K             gotoblas -> shdot_k
+#define SHSTOBF16_K         gotoblas -> shstobf16_k
+#define SHDTOBF16_K         gotoblas -> shdtobf16_k
+#define SBF16TOS_K          gotoblas -> sbf16tos_k
+#define DBF16TOD_K          gotoblas -> dbf16tod_k
+
 #define        SHGEMM_ONCOPY           gotoblas -> shgemm_oncopy
 #define        SHGEMM_OTCOPY           gotoblas -> shgemm_otcopy
 #define        SHGEMM_INCOPY           gotoblas -> shgemm_incopy
index ec0c65b..a18df0d 100644 (file)
@@ -59,12 +59,19 @@ extern int blas_omp_linked;
 #define BLAS_PTHREAD   0x4000U
 #define BLAS_NODE      0x2000U
 
-#define BLAS_PREC      0x0003U
-#define BLAS_SINGLE    0x0000U
-#define BLAS_DOUBLE    0x0001U
-#define BLAS_XDOUBLE   0x0002U
-#define BLAS_REAL      0x0000U
-#define BLAS_COMPLEX   0x0004U
+#define BLAS_PREC       0x000FU
+#define BLAS_INT8       0x0000U
+#define BLAS_BFLOAT16   0x0001U
+#define BLAS_SINGLE     0x0002U
+#define BLAS_DOUBLE     0x0003U
+#define BLAS_XDOUBLE    0x0004U
+#define BLAS_STOBF16    0x0008U
+#define BLAS_DTOBF16    0x0009U
+#define BLAS_BF16TOS    0x000AU
+#define BLAS_BF16TOD    0x000BU
+
+#define BLAS_REAL       0x0000U
+#define BLAS_COMPLEX    0x1000U
 
 #define BLAS_TRANSA    0x0030U /* 2bit */
 #define BLAS_TRANSA_N  0x0000U
index bee7e8c..b813336 100644 (file)
@@ -142,6 +142,29 @@ static __inline void cpuid(int op, int *eax, int *ebx, int *ecx, int *edx){
 #endif
 }
 
+static __inline void cpuid_count(int op, int count, int *eax, int *ebx, int *ecx, int *edx)
+{
+#ifdef C_MSVC
+  int cpuInfo[4] = {-1};
+  __cpuidex(cpuInfo, op, count);
+  *eax = cpuInfo[0];
+  *ebx = cpuInfo[1];
+  *ecx = cpuInfo[2];
+  *edx = cpuInfo[3];
+#else
+#if defined(__i386__) && defined(__PIC__)
+  __asm__ __volatile__
+    ("mov %%ebx, %%edi;"
+      "cpuid;"
+      "xchgl %%ebx, %%edi;"
+      : "=a" (*eax), "=D" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
+#else
+  __asm__ __volatile__
+    ("cpuid": "=a" (*eax), "=b" (*ebx), "=c" (*ecx), "=d" (*edx) : "0" (op), "2" (count) : "cc");
+#endif
+#endif
+}
+
 /*
 #define WHEREAMI
 */
index e405c74..04acbcc 100644 (file)
@@ -49,9 +49,36 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
   blas_arg_t   args [MAX_CPU_NUMBER];
 
   BLASLONG i, width, astride, bstride;
-  int num_cpu, calc_type;
-
-  calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
+  int num_cpu, calc_type_a, calc_type_b;
+
+  switch (mode & BLAS_PREC) {
+  case BLAS_INT8    :
+  case BLAS_BFLOAT16:
+  case BLAS_SINGLE  :
+  case BLAS_DOUBLE  :
+  case BLAS_XDOUBLE :
+    calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_STOBF16 :
+    calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_DTOBF16 :
+    calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_BF16TOS :
+    calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_BF16TOD :
+    calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  default:
+    calc_type_a = calc_type_b = 0;
+    break;
+  }
 
   mode |= BLAS_LEGACY;
 
@@ -77,8 +104,8 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
       bstride = width;
     }
 
-    astride <<= calc_type;
-    bstride <<= calc_type;
+    astride <<= calc_type_a;
+    bstride <<= calc_type_b;
 
     args[num_cpu].m = width;
     args[num_cpu].n = n;
@@ -120,9 +147,36 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
   blas_arg_t   args [MAX_CPU_NUMBER];
 
   BLASLONG i, width, astride, bstride;
-  int num_cpu, calc_type;
-
-  calc_type = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0) + 2;
+  int num_cpu, calc_type_a, calc_type_b;
+
+  switch (mode & BLAS_PREC) {
+  case BLAS_INT8    :
+  case BLAS_BFLOAT16:
+  case BLAS_SINGLE  :
+  case BLAS_DOUBLE  :
+  case BLAS_XDOUBLE :
+    calc_type_a = calc_type_b = (mode & BLAS_PREC) + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_STOBF16 :
+    calc_type_a = 2 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_DTOBF16 :
+    calc_type_a = 3 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 1 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_BF16TOS :
+    calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 2 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  case BLAS_BF16TOD :
+    calc_type_a = 1 + ((mode & BLAS_COMPLEX) != 0);
+    calc_type_b = 3 + ((mode & BLAS_COMPLEX) != 0);
+    break;
+  default:
+    calc_type_a = calc_type_b = 0;
+    break;
+  }
 
   mode |= BLAS_LEGACY;
 
@@ -148,8 +202,8 @@ int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASL
       bstride = width;
     }
 
-    astride <<= calc_type;
-    bstride <<= calc_type;
+    astride <<= calc_type_a;
+    bstride <<= calc_type_b;
 
     args[num_cpu].m = width;
     args[num_cpu].n = n;
index 756e51b..8d3dda3 100644 (file)
@@ -192,7 +192,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
 
       if (!(mode & BLAS_COMPLEX)){
 #ifdef EXPRECISION
-       if (mode & BLAS_XDOUBLE){
+       if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
          /* REAL / Extended Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
                        xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -205,7 +205,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> c, args -> ldc, sb);
        } else
 #endif
-         if (mode & BLAS_DOUBLE){
+         if ((mode & BLAS_PREC) == BLAS_DOUBLE){
            /* REAL / Double */
            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
                          double *, BLASLONG, double *, BLASLONG,
@@ -216,21 +216,58 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                  args -> a, args -> lda,
                  args -> b, args -> ldb,
                  args -> c, args -> ldc, sb);
-         } else {
-           /* REAL / Single */
-           void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
-                         float *, BLASLONG, float *, BLASLONG,
-                         float *, BLASLONG, void *) = func;
-
-           afunc(args -> m, args -> n, args -> k,
-                 ((float *)args -> alpha)[0],
-                 args -> a, args -> lda,
-                 args -> b, args -> ldb,
-                 args -> c, args -> ldc, sb);
+         } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
+            /* REAL / Single */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+                          float *, BLASLONG, float *, BLASLONG,
+                          float *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((float *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+#ifdef BUILD_HALF
+         } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
+            /* REAL / BFLOAT16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
+                          bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
+                          bfloat16 *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((bfloat16 *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+          } else if ((mode & BLAS_PREC) == BLAS_STOBF16){
+            /* REAL / BLAS_STOBF16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+                          float *, BLASLONG, bfloat16 *, BLASLONG,
+                          float *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((float *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+          } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
+            /* REAL / BLAS_DTOBF16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
+                          double *, BLASLONG, bfloat16 *, BLASLONG,
+                          double *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((double *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+#endif
+      } else {
+           /* REAL / Other types in future */
          }
       } else {
 #ifdef EXPRECISION
-       if (mode & BLAS_XDOUBLE){
+       if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
          /* COMPLEX / Extended Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
                        xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -244,7 +281,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> c, args -> ldc, sb);
        } else
 #endif
-         if (mode & BLAS_DOUBLE){
+         if ((mode & BLAS_PREC) == BLAS_DOUBLE) {
            /* COMPLEX / Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
                        double *, BLASLONG, double *, BLASLONG,
@@ -256,7 +293,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> a, args -> lda,
                args -> b, args -> ldb,
                args -> c, args -> ldc, sb);
-         } else {
+         } else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
            /* COMPLEX / Single */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
                        float *, BLASLONG, float *, BLASLONG,
@@ -268,7 +305,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> a, args -> lda,
                args -> b, args -> ldb,
                args -> c, args -> ldc, sb);
-         }
+      } else {
+          /* COMPLEX / Other types in future */
+      }
       }
 }
 
@@ -414,33 +453,37 @@ blas_queue_t *tscq;
       if (sb == NULL) {
        if (!(queue -> mode & BLAS_COMPLEX)){
 #ifdef EXPRECISION
-         if (queue -> mode & BLAS_XDOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
            sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
          } else
 #endif
-         if (queue -> mode & BLAS_DOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
            sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
 
-         } else {
+         } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
            sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
-         }
+      } else {
+          /* Other types in future */
+      }
        } else {
 #ifdef EXPRECISION
-         if (queue -> mode & BLAS_XDOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
            sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
          } else
 #endif
-         if (queue -> mode & BLAS_DOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
            sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
-         } else {
+         } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
            sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
-         }
+      } else {
+          /* Other types in future */
+      }
        }
        queue->sb=sb;
       }
index d9969b5..d126955 100644 (file)
@@ -142,7 +142,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
 
       if (!(mode & BLAS_COMPLEX)){
 #ifdef EXPRECISION
-       if (mode & BLAS_XDOUBLE){
+       if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
          /* REAL / Extended Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
                        xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -155,7 +155,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> c, args -> ldc, sb);
        } else
 #endif
-         if (mode & BLAS_DOUBLE){
+         if ((mode & BLAS_PREC) == BLAS_DOUBLE){
            /* REAL / Double */
            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
                          double *, BLASLONG, double *, BLASLONG,
@@ -166,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                  args -> a, args -> lda,
                  args -> b, args -> ldb,
                  args -> c, args -> ldc, sb);
-         } else {
+         } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
            /* REAL / Single */
            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
                          float *, BLASLONG, float *, BLASLONG,
@@ -177,10 +177,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                  args -> a, args -> lda,
                  args -> b, args -> ldb,
                  args -> c, args -> ldc, sb);
+#ifdef BUILD_HALF
+          } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
+            /* REAL / BFLOAT16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
+                          bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
+                          bfloat16 *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((bfloat16 *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+          } else if ((mode & BLAS_PREC) == BLAS_STOBF16){
+            /* REAL / BLAS_STOBF16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+                          float *, BLASLONG, bfloat16 *, BLASLONG,
+                          float *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((float *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+          } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
+            /* REAL / BLAS_DTOBF16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
+                          double *, BLASLONG, bfloat16 *, BLASLONG,
+                          double *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((double *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+#endif
+          } else {
+            /* REAL / Other types in future */
          }
       } else {
 #ifdef EXPRECISION
-       if (mode & BLAS_XDOUBLE){
+       if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
          /* COMPLEX / Extended Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
                        xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -194,7 +231,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> c, args -> ldc, sb);
        } else
 #endif
-         if (mode & BLAS_DOUBLE){
+         if ((mode & BLAS_PREC) == BLAS_DOUBLE){
            /* COMPLEX / Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
                        double *, BLASLONG, double *, BLASLONG,
@@ -206,7 +243,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> a, args -> lda,
                args -> b, args -> ldb,
                args -> c, args -> ldc, sb);
-         } else {
+         } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
            /* COMPLEX / Single */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
                        float *, BLASLONG, float *, BLASLONG,
@@ -218,8 +255,10 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> a, args -> lda,
                args -> b, args -> ldb,
                args -> c, args -> ldc, sb);
-         }
-      }
+      } else {
+            /* COMPLEX / Other types in future */
+        }
+   }
 }
 
 static void exec_threads(blas_queue_t *queue, int buf_index){
@@ -255,32 +294,36 @@ static void exec_threads(blas_queue_t *queue, int buf_index){
     if (sb == NULL) {
       if (!(queue -> mode & BLAS_COMPLEX)){
 #ifdef EXPRECISION
-       if (queue -> mode & BLAS_XDOUBLE){
+       if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
          sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
                                          + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
        } else
 #endif
-         if (queue -> mode & BLAS_DOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
            sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
                                            + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
 
-         } else {
+         } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE){
            sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
                                            + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+         } else {
+          /* Other types in future */
          }
       } else {
 #ifdef EXPRECISION
-       if (queue -> mode & BLAS_XDOUBLE){
+       if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
          sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
                                          + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
        } else
 #endif
-         if (queue -> mode & BLAS_DOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
            sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
                                            + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
-         } else {
+         } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
            sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
                                            + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+         } else {
+          /* Other types in future */
          }
       }
       queue->sb=sb;
index 5ecc442..d2cc917 100644 (file)
@@ -77,7 +77,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
 
       if (!(mode & BLAS_COMPLEX)){
 #ifdef EXPRECISION
-       if (mode & BLAS_XDOUBLE){
+       if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
          /* REAL / Extended Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble,
                        xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -90,7 +90,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> c, args -> ldc, sb);
        } else
 #endif
-         if (mode & BLAS_DOUBLE){
+         if ((mode & BLAS_PREC) == BLAS_DOUBLE){
            /* REAL / Double */
            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
                          double *, BLASLONG, double *, BLASLONG,
@@ -101,7 +101,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                  args -> a, args -> lda,
                  args -> b, args -> ldb,
                  args -> c, args -> ldc, sb);
-         } else {
+         } else if ((mode & BLAS_PREC) == BLAS_SINGLE){
            /* REAL / Single */
            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
                          float *, BLASLONG, float *, BLASLONG,
@@ -112,10 +112,47 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                  args -> a, args -> lda,
                  args -> b, args -> ldb,
                  args -> c, args -> ldc, sb);
+#ifdef BUILD_HALF
+          } else if ((mode & BLAS_PREC) == BLAS_BFLOAT16){
+            /* REAL / BFLOAT16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, bfloat16,
+                          bfloat16 *, BLASLONG, bfloat16 *, BLASLONG,
+                          bfloat16 *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((bfloat16 *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+          } else if ((mode & BLAS_PREC) == BLAS_STOBF16){
+            /* REAL / BLAS_STOBF16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float,
+                          float *, BLASLONG, bfloat16 *, BLASLONG,
+                          float *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((float *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+          } else if ((mode & BLAS_PREC) == BLAS_DTOBF16){
+            /* REAL / BLAS_DTOBF16 */
+            void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double,
+                          double *, BLASLONG, bfloat16 *, BLASLONG,
+                          double *, BLASLONG, void *) = func;
+
+            afunc(args -> m, args -> n, args -> k,
+                  ((double *)args -> alpha)[0],
+                  args -> a, args -> lda,
+                  args -> b, args -> ldb,
+                  args -> c, args -> ldc, sb);
+#endif
+          } else {
+            /* REAL / Other types in future */
          }
       } else {
 #ifdef EXPRECISION
-       if (mode & BLAS_XDOUBLE){
+       if ((mode & BLAS_PREC) == BLAS_XDOUBLE){
          /* COMPLEX / Extended Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble,
                        xdouble *, BLASLONG, xdouble *, BLASLONG,
@@ -129,7 +166,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> c, args -> ldc, sb);
        } else
 #endif
-         if (mode & BLAS_DOUBLE){
+         if ((mode & BLAS_PREC) == BLAS_DOUBLE){
            /* COMPLEX / Double */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, double, double,
                        double *, BLASLONG, double *, BLASLONG,
@@ -141,7 +178,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> a, args -> lda,
                args -> b, args -> ldb,
                args -> c, args -> ldc, sb);
-         } else {
+         } else if ((mode & BLAS_PREC) == BLAS_SINGLE) {
            /* COMPLEX / Single */
          void (*afunc)(BLASLONG, BLASLONG, BLASLONG, float, float,
                        float *, BLASLONG, float *, BLASLONG,
@@ -153,7 +190,9 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
                args -> a, args -> lda,
                args -> b, args -> ldb,
                args -> c, args -> ldc, sb);
-         }
+        } else {
+          /* COMPLEX / Other types in future */
+        }
       }
 }
 
@@ -233,32 +272,36 @@ static DWORD WINAPI blas_thread_server(void *arg){
       if (sb == NULL) {
        if (!(queue -> mode & BLAS_COMPLEX)){
 #ifdef EXPRECISION
-         if (queue -> mode & BLAS_XDOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
            sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * sizeof(xdouble)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
          } else
 #endif
-           if (queue -> mode & BLAS_DOUBLE){
+           if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
              sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
                                          + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
 
-           } else {
+           } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
              sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
                                          + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+           } else {
+            /* Other types in future */
            }
        } else {
 #ifdef EXPRECISION
-         if (queue -> mode & BLAS_XDOUBLE){
+         if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
            sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
                                        + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
          } else
 #endif
-           if (queue -> mode & BLAS_DOUBLE){
+           if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
              sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
                                          + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
-           } else {
+           } else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
              sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
                                          + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
+           } else {
+            /* Other types in future */
            }
        }
        queue->sb=sb;
index 5d71b1b..21d2c79 100644 (file)
@@ -207,6 +207,19 @@ extern gotoblas_t gotoblas_SKYLAKEX;
 #else
 #define gotoblas_SKYLAKEX gotoblas_PRESCOTT
 #endif
+#ifdef DYN_COOPERLAKE
+extern gotoblas_t gotoblas_COOPERLAKE;
+#elif defined(DYN_SKYLAKEX)
+#define gotoblas_COOPERLAKE gotoblas_SKYLAKEX
+#elif defined(DYN_HASWELL)
+#define gotoblas_COOPERLAKE gotoblas_HASWELL
+#elif defined(DYN_SANDYBRIDGE)
+#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
+#elif defined(DYN_NEHALEM)
+#define gotoblas_COOPERLAKE gotoblas_NEHALEM
+#else
+#define gotoblas_COOPERLAKE gotoblas_PRESCOTT
+#endif
 
 
 #else // not DYNAMIC_LIST
@@ -247,14 +260,17 @@ extern gotoblas_t  gotoblas_EXCAVATOR;
 #ifdef NO_AVX2
 #define gotoblas_HASWELL gotoblas_SANDYBRIDGE
 #define gotoblas_SKYLAKEX gotoblas_SANDYBRIDGE
+#define gotoblas_COOPERLAKE gotoblas_SANDYBRIDGE
 #define gotoblas_ZEN gotoblas_SANDYBRIDGE
 #else
 extern gotoblas_t  gotoblas_HASWELL;
 extern gotoblas_t  gotoblas_ZEN;
 #ifndef NO_AVX512
 extern gotoblas_t  gotoblas_SKYLAKEX;
+extern gotoblas_t  gotoblas_COOPERLAKE;
 #else
 #define gotoblas_SKYLAKEX gotoblas_HASWELL
+#define gotoblas_COOPERLAKE gotoblas_HASWELL
 #endif
 #endif
 #else
@@ -262,6 +278,7 @@ extern gotoblas_t  gotoblas_SKYLAKEX;
 #define gotoblas_SANDYBRIDGE gotoblas_NEHALEM
 #define gotoblas_HASWELL gotoblas_NEHALEM
 #define gotoblas_SKYLAKEX gotoblas_NEHALEM
+#define gotoblas_COOPERLAKE gotoblas_NEHALEM
 #define gotoblas_BULLDOZER gotoblas_BARCELONA
 #define gotoblas_PILEDRIVER gotoblas_BARCELONA
 #define gotoblas_STEAMROLLER gotoblas_BARCELONA
@@ -343,6 +360,23 @@ int support_avx512(){
 #endif
 }
 
+int support_avx512_bf16(){
+#if !defined(NO_AVX) && !defined(NO_AVX512)
+  int eax, ebx, ecx, edx;
+  int ret=0;
+
+  if (!support_avx512())
+    return 0;
+  cpuid_count(7, 1, &eax, &ebx, &ecx, &edx);
+  if((eax & 32) == 32){
+      ret=1;  // CPUID.7.1:EAX[bit 5] indicates whether avx512_bf16 supported or not
+  }
+  return ret;
+#else
+  return 0;
+#endif
+}
+
 extern void openblas_warning(int verbose, const char * msg);
 #define FALLBACK_VERBOSE 1
 #define NEHALEM_FALLBACK "OpenBLAS : Your OS does not support AVX instructions. OpenBLAS is using Nehalem kernels as a fallback, which may give poorer performance.\n"
@@ -524,7 +558,10 @@ static gotoblas_t *get_coretype(void){
            return &gotoblas_NEHALEM; //OS doesn't support AVX. Use old kernels.
          }
        }
-       if (model == 5) {       
+       if (model == 5) {
+       // Intel Cooperlake
+          if(support_avx512_bf16())
+             return &gotoblas_COOPERLAKE;
        // Intel Skylake X
           if (support_avx512()) 
            return &gotoblas_SKYLAKEX;
@@ -774,7 +811,8 @@ static char *corename[] = {
     "Steamroller",
     "Excavator",
     "Zen",
-    "SkylakeX" 
+    "SkylakeX",
+    "Cooperlake"
 };
 
 char *gotoblas_corename(void) {
@@ -838,6 +876,7 @@ char *gotoblas_corename(void) {
   if (gotoblas == &gotoblas_EXCAVATOR)    return corename[22];
   if (gotoblas == &gotoblas_ZEN)          return corename[23];
   if (gotoblas == &gotoblas_SKYLAKEX)     return corename[24];
+  if (gotoblas == &gotoblas_COOPERLAKE)   return corename[25];
   return corename[0];
 }
 
@@ -868,6 +907,7 @@ static gotoblas_t *force_coretype(char *coretype){
 
        switch (found)
        {
+               case 25: return (&gotoblas_COOPERLAKE);
                case 24: return (&gotoblas_SKYLAKEX);   
                case 23: return (&gotoblas_ZEN);
                case 22: return (&gotoblas_EXCAVATOR);
index 73b4be2..ce4d9bb 100644 (file)
@@ -46,7 +46,7 @@
     ssum, dsum, scsum, dzsum
 );
 
-@halfblasobjs = (shgemm);
+@halfblasobjs = (shgemm, shdot, shstobf16, shdtobf16, sbf16tos, dbf16tod);
 @cblasobjs = (
     cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
     cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
@@ -84,7 +84,7 @@
     cblas_xerbla
 );
 
-@halfcblasobjs = (cblas_shgemm);
+@halfcblasobjs = (cblas_shgemm, cblas_shdot, cblas_shstobf16, cblas_shdtobf16, cblas_sbf16tos, cblas_dbf16tod);
 
 @exblasobjs = (
     qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,
index 2dbd600..fde6227 100644 (file)
@@ -47,7 +47,9 @@ SBLAS3OBJS    = \
                sgeadd.$(SUFFIX)
 
 ifeq ($(BUILD_HALF),1)
+SHBLAS1OBJS    = shdot.$(SUFFIX)
 SHBLAS3OBJS    = shgemm.$(SUFFIX)
+SHEXTOBJS      = shstobf16.$(SUFFIX) shdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
 endif
 
 DBLAS1OBJS    = \
@@ -281,7 +283,9 @@ CSBLAS3OBJS   = \
        cblas_sgeadd.$(SUFFIX)
 
 ifeq ($(BUILD_HALF),1)
+CSHBLAS1OBJS = cblas_shdot.$(SUFFIX)
 CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
+CSHEXTOBJS   = cblas_shstobf16.$(SUFFIX) cblas_shdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
 endif
 
 CDBLAS1OBJS   = \
@@ -374,6 +378,7 @@ override CFLAGS += -I.
 SBLAS1OBJS   += $(CSBLAS1OBJS)
 SBLAS2OBJS   += $(CSBLAS2OBJS)
 SBLAS3OBJS   += $(CSBLAS3OBJS)
+SHBLAS1OBJS  += $(CSHBLAS1OBJS)
 SHBLAS3OBJS  += $(CSHBLAS3OBJS)
 DBLAS1OBJS   += $(CDBLAS1OBJS)
 DBLAS2OBJS   += $(CDBLAS2OBJS)
@@ -385,10 +390,11 @@ ZBLAS1OBJS   += $(CZBLAS1OBJS)
 ZBLAS2OBJS   += $(CZBLAS2OBJS)
 ZBLAS3OBJS   += $(CZBLAS3OBJS)
 
+SHEXTOBJS     += $(CSHEXTOBJS)
 endif
 
 SBLASOBJS    = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
-SHBLASOBJS   = $(SHBLAS3OBJS)
+SHBLASOBJS   = $(SHBLAS1OBJS) $(SHBLAS3OBJS)
 DBLASOBJS    = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
 QBLASOBJS    = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
 CBLASOBJS    = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@@ -463,7 +469,7 @@ ZBLASOBJS += $(ZLAPACKOBJS)
 
 endif
 
-FUNCOBJS    = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
+FUNCOBJS    = $(SHEXTOBJS) $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
 
 ifdef EXPRECISION
 FUNCOBJS   += $(QBLASOBJS) $(XBLASOBJS)
@@ -491,7 +497,7 @@ endif
 clean ::
        @rm -f functable.h
 
-level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
+level1 : $(BEXTOBJS) $(SHBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
        $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
 
 level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
@@ -725,6 +731,19 @@ sdsdot.$(SUFFIX) sdsdot.$(PSUFFIX) : sdsdot.c
 dsdot.$(SUFFIX) dsdot.$(PSUFFIX) : dsdot.c
        $(CC) $(CFLAGS) -c $< -o $(@F)
 
+ifeq ($(BUILD_HALF),1)
+shdot.$(SUFFIX) shdot.$(PSUFFIX) : bf16dot.c
+       $(CC) $(CFLAGS) -c $< -o $(@F)
+shstobf16.$(SUFFIX) shstobf16.$(PSUFFIX) : tobf16.c
+       $(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+shdtobf16.$(SUFFIX) shdtobf16.$(PSUFFIX) : tobf16.c
+       $(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+sbf16tos.$(SUFFIX)  sbf16tos.$(PSUFFIX) : bf16to.c
+       $(CC) $(CFLAGS) -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+dbf16tod.$(SUFFIX)  dbf16tod.$(PSUFFIX) : bf16to.c
+       $(CC) $(CFLAGS) -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+endif
+
 sdot.$(SUFFIX) sdot.$(PSUFFIX) : dot.c
        $(CC) $(CFLAGS) -c $< -o $(@F)
 
@@ -1463,6 +1482,19 @@ cblas_sdsdot.$(SUFFIX) cblas_sdsdot.$(PSUFFIX) : sdsdot.c
 cblas_dsdot.$(SUFFIX) cblas_dsdot.$(PSUFFIX) : dsdot.c
        $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
 
+ifeq ($(BUILD_HALF),1)
+cblas_shdot.$(SUFFIX) cblas_shdot.$(PSUFFIX) : bf16dot.c
+       $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
+cblas_shstobf16.$(SUFFIX) cblas_shstobf16.$(PSUFFIX) : tobf16.c
+       $(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+cblas_shdtobf16.$(SUFFIX) cblas_shdtobf16.$(PSUFFIX) : tobf16.c
+       $(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+cblas_sbf16tos.$(SUFFIX)  cblas_sbf16tos.$(PSUFFIX) : bf16to.c
+       $(CC) $(CFLAGS) -DCBLAS -DSINGLE_PREC -UDOUBLE_PREC -c $< -o $(@F)
+cblas_dbf16tod.$(SUFFIX)  cblas_dbf16tod.$(PSUFFIX) : bf16to.c
+       $(CC) $(CFLAGS) -DCBLAS -USINGLE_PREC -DDOUBLE_PREC -c $< -o $(@F)
+endif
+
 cblas_sdot.$(SUFFIX) cblas_sdot.$(PSUFFIX) : dot.c
                $(CC) $(CFLAGS) -DCBLAS -c $< -o $(@F)
 
diff --git a/interface/bf16dot.c b/interface/bf16dot.c
new file mode 100644 (file)
index 0000000..33717e3
--- /dev/null
@@ -0,0 +1,52 @@
+#include <stdio.h>
+#include "common.h"
+#ifdef FUNCTION_PROFILE
+#include "functable.h"
+#endif
+
+#ifndef CBLAS
+float NAME(blasint *N, bfloat16 *x, blasint *INCX, bfloat16 *y, blasint *INCY){
+   BLASLONG n    = *N;
+   BLASLONG incx = *INCX;
+   BLASLONG incy = *INCY;
+   float ret;
+   PRINT_DEBUG_NAME;
+
+   if (n <= 0) return 0.;
+
+   IDEBUG_START;
+   FUNCTION_PROFILE_START();
+
+   if (incx < 0) x -= (n - 1) * incx;
+   if (incy < 0) y -= (n - 1) * incy;
+   ret = BF16_DOT_K(n, x, incx, y, incy);
+
+   FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+   IDEBUG_END;
+
+   return ret;
+ }
+
+#else
+
+float CNAME(blasint n, bfloat16 *x, blasint incx, bfloat16 *y, blasint incy){
+
+  float ret;
+  PRINT_DEBUG_CNAME;
+
+  if (n <= 0) return 0.;
+
+  IDEBUG_START;
+  FUNCTION_PROFILE_START();
+
+  if (incx < 0) x -= (n - 1) * incx;
+  if (incy < 0) y -= (n - 1) * incy;
+  ret = BF16_DOT_K(n, x, incx, y, incy);
+
+  FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+  IDEBUG_END;
+
+  return ret;
+}
+
+#endif
diff --git a/interface/bf16to.c b/interface/bf16to.c
new file mode 100644 (file)
index 0000000..036c0b1
--- /dev/null
@@ -0,0 +1,62 @@
+#include <stdio.h>
+#include "common.h"
+#ifdef FUNCTION_PROFILE
+#include "functable.h"
+#endif
+
+#if defined(DOUBLE_PREC)
+#define FLOAT_TYPE double
+#elif defined(SINGLE_PREC)
+#define FLOAT_TYPE float
+#else
+#endif
+
+#ifndef CBLAS
+void NAME(blasint *N, bfloat16 *in, blasint *INC_IN, FLOAT_TYPE *out, blasint *INC_OUT){
+  BLASLONG n    = *N;
+  BLASLONG inc_in = *INC_IN;
+  BLASLONG inc_out = *INC_OUT;
+
+  PRINT_DEBUG_NAME;
+
+  if (n <= 0) return;
+
+  IDEBUG_START;
+  FUNCTION_PROFILE_START();
+
+  if (inc_in < 0)   in -= (n - 1) * inc_in;
+  if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+  D_BF16_TO_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+  S_BF16_TO_K(n, in, inc_in, out, inc_out);
+#else
+#endif
+
+  FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+  IDEBUG_END;
+}
+#else
+void CNAME(blasint n, bfloat16 * in, blasint inc_in, FLOAT_TYPE * out, blasint inc_out){
+  PRINT_DEBUG_CNAME;
+
+  if (n <= 0) return;
+
+  IDEBUG_START;
+  FUNCTION_PROFILE_START();
+
+  if (inc_in < 0)   in -= (n - 1) * inc_in;
+  if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+  D_BF16_TO_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+  S_BF16_TO_K(n, in, inc_in, out, inc_out);
+#else
+#endif
+
+  FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+  IDEBUG_END;
+}
+#endif
diff --git a/interface/tobf16.c b/interface/tobf16.c
new file mode 100644 (file)
index 0000000..787d9d6
--- /dev/null
@@ -0,0 +1,61 @@
+#include <stdio.h>
+#include "common.h"
+#ifdef FUNCTION_PROFILE
+#include "functable.h"
+#endif
+
+#if defined(DOUBLE_PREC)
+#define FLOAT_TYPE double
+#elif defined(SINGLE_PREC)
+#define FLOAT_TYPE float
+#else
+#endif
+
+#ifndef CBLAS
+void NAME(blasint *N, FLOAT_TYPE *in, blasint *INC_IN, bfloat16 *out, blasint *INC_OUT){
+   BLASLONG n    = *N;
+   BLASLONG inc_in = *INC_IN;
+   BLASLONG inc_out = *INC_OUT;
+
+   PRINT_DEBUG_NAME;
+
+   if (n <= 0) return;
+
+   IDEBUG_START;
+   FUNCTION_PROFILE_START();
+
+   if (inc_in < 0)   in -= (n - 1) * inc_in;
+   if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+   D_TO_BF16_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+   S_TO_BF16_K(n, in, inc_in, out, inc_out);
+#else
+#endif
+
+   FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+   IDEBUG_END;
+}
+#else
+void CNAME(blasint n, FLOAT_TYPE *in, blasint inc_in, bfloat16 *out, blasint inc_out){
+  PRINT_DEBUG_CNAME;
+
+  if (n <= 0) return;
+
+  IDEBUG_START;
+  FUNCTION_PROFILE_START();
+
+  if (inc_in < 0)   in -= (n - 1) * inc_in;
+  if (inc_out < 0) out -= (n - 1) * inc_out;
+
+#if defined(DOUBLE_PREC)
+  D_TO_BF16_K(n, in, inc_in, out, inc_out);
+#elif defined(SINGLE_PREC)
+  S_TO_BF16_K(n, in, inc_in, out, inc_out);
+#endif
+
+  FUNCTION_PROFILE_END(1, 2 * n, 2 * n);
+  IDEBUG_END;
+}
+#endif
index 9707032..c6576ee 100644 (file)
@@ -262,6 +262,20 @@ ifndef XDOTKERNEL
 XDOTKERNEL = zdot.S
 endif
 
+ifeq ($(BUILD_HALF),1)
+ifndef SHDOTKERNEL
+SHDOTKERNEL = ../x86_64/shdot.c
+endif
+
+ifndef TOBF16KERNEL
+TOBF16KERNEL = ../x86_64/tobf16.c
+endif
+
+ifndef BF16TOKERNEL
+BF16TOKERNEL = ../x86_64/bf16to.c
+endif
+endif
+
 ### NRM2 ###
 
 ifndef SNRM2KERNEL
@@ -516,6 +530,15 @@ XBLASOBJS  += \
        xdotc_k$(TSUFFIX).$(SUFFIX) xdotu_k$(TSUFFIX).$(SUFFIX) xnrm2_k$(TSUFFIX).$(SUFFIX) xqrot_k$(TSUFFIX).$(SUFFIX) \
        xscal_k$(TSUFFIX).$(SUFFIX) xswap_k$(TSUFFIX).$(SUFFIX) xsum_k$(TSUFFIX).$(SUFFIX)
 
+ifeq ($(BUILD_HALF),1)
+SHBLASOBJS     += \
+        shdot_k$(TSUFFIX).$(SUFFIX)
+SHEXTOBJS        += \
+        shstobf16_k$(TSUFFIX).$(SUFFIX) shdtobf16_k$(TSUFFIX).$(SUFFIX)
+SHEXTOBJS        += \
+        sbf16tos_k$(TSUFFIX).$(SUFFIX) dbf16tod_k$(TSUFFIX).$(SUFFIX)
+endif
+
 ### AMAX ###
 
 
@@ -734,6 +757,19 @@ $(KDIR)ddot_k$(TSUFFIX).$(SUFFIX) $(KDIR)ddot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNEL
 $(KDIR)qdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)qdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QDOTKERNEL)
        $(CC) -c $(CFLAGS) -UCOMPLEX -DXDOUBLE $< -o $@
 
+ifeq ($(BUILD_HALF),1)
+$(KDIR)shdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)shdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHDOTKERNEL)
+       $(CC) -c $(CFLAGS) -UCOMPLEX $< -o $@
+$(KDIR)shstobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
+       $(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
+$(KDIR)shdtobf16_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(TOBF16KERNEL)
+       $(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
+$(KDIR)sbf16tos_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
+       $(CC) -c $(CFLAGS) -UDOUBLE -DSINGLE $< -o $@
+$(KDIR)dbf16tod_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(BF16TOKERNEL)
+       $(CC) -c $(CFLAGS) -DDOUBLE -USINGLE $< -o $@
+endif
+
 $(KDIR)sdot_k$(TSUFFIX).$(SUFFIX) $(KDIR)sdot_k$(TPSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SDOTKERNEL)
        $(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE $< -o $@
 
index 582a1dc..c435203 100644 (file)
@@ -62,9 +62,11 @@ gotoblas_t TABLE_NAME = {
  MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
 #endif
 
+  shstobf16_kTS, shdtobf16_kTS, sbf16tos_kTS, dbf16tod_kTS,
+
   samax_kTS,  samin_kTS,  smax_kTS,  smin_kTS,
   isamax_kTS, isamin_kTS, ismax_kTS, ismin_kTS,
-  snrm2_kTS,  sasum_kTS, ssum_kTS, scopy_kTS, sdot_kTS,
+  snrm2_kTS,  sasum_kTS, ssum_kTS, scopy_kTS, shdot_kTS,
   dsdot_kTS,
   srot_kTS,   saxpy_kTS,  sscal_kTS, sswap_kTS,
   sgemv_nTS,  sgemv_tTS, sger_kTS,
index 4874711..4a2e13b 100644 (file)
@@ -146,6 +146,18 @@ ifndef XDOTKERNEL
 XDOTKERNEL = zdot.S
 endif
 
+ifndef SHDOTKERNEL
+SHDOTKERNEL = shdot.c
+endif
+
+ifndef TOBF16KERNEL
+TOBF16KERNEL = tobf16.c
+endif
+
+ifndef BF16TOKERNEL
+BF16TOKERNEL = bf16to.c
+endif
+
 ifndef ISAMAXKERNEL
 ISAMAXKERNEL = iamax_sse.S
 endif
diff --git a/kernel/x86_64/bf16to.c b/kernel/x86_64/bf16to.c
new file mode 100644 (file)
index 0000000..fc6b5a5
--- /dev/null
@@ -0,0 +1,114 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+#include <stddef.h>
+#include "common.h"
+
+#if defined(DOUBLE)
+#define FLOAT_TYPE double
+#elif defined(SINGLE)
+#define FLOAT_TYPE float
+#else
+#endif
+
+/* Notes for algorithm:
+ * - Input denormal treated as zero
+ * - Force to be QNAN
+ */
+static void bf16to_kernel_1(BLASLONG n, const bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
+{
+    BLASLONG register index_in  = 0;
+    BLASLONG register index_out = 0;
+    BLASLONG register index     = 0;
+    uint16_t * tmp = NULL;
+#if defined(DOUBLE)
+    float             float_out = 0.0;
+#endif
+
+    while(index<n) {
+#if defined(DOUBLE)
+        float_out = 0.0;
+        tmp = (uint16_t*)(&float_out);
+#else
+        *(out+index_out) = 0;
+        tmp = (uint16_t*)(out+index_out);
+#endif
+
+        switch((*(in+index_in)) & 0xff80u) {
+            case (0x0000u):   /* Type 1: Positive denormal */
+                tmp[1] = 0x0000u;
+                tmp[0] = 0x0000u;
+                break;
+            case (0x8000u):   /* Type 2: Negative denormal */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+                tmp[1] = 0x8000u;
+                tmp[0] = 0x0000u;
+#else
+                tmp[1] = 0x0000u;
+                tmp[0] = 0x8000u;
+#endif
+                break;
+            case (0x7f80u):   /* Type 3: Positive infinity or NAN */
+            case (0xff80u):   /* Type 4: Negative infinity or NAN */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+                tmp[1] = *(in+index_in);
+#else
+                tmp[0] = *(in+index_in);
+#endif
+                /* Specific for NAN */
+                if (((*(in+index_in)) & 0x007fu) != 0) {
+                    /* Force to be QNAN */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+                    tmp[1] |= 0x0040u;
+#else
+                    tmp[0] |= 0x0040u;
+#endif
+                }
+                break;
+            default:              /* Type 5: Normal case */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+                tmp[1] = *(in+index_in);
+#else
+                tmp[0] = *(in+index_in);
+#endif
+                break;
+        }
+#if defined(DOUBLE)
+       *(out+index_out) = (double)float_out;
+#endif
+        index_in  += inc_in;
+        index_out += inc_out;
+        index++;
+    }
+}
+
+void CNAME(BLASLONG n, bfloat16 * in, BLASLONG inc_in, FLOAT_TYPE * out, BLASLONG inc_out)
+{
+    if (n <= 0)  return;
+
+    bf16to_kernel_1(n, in, inc_in, out, inc_out);
+}
diff --git a/kernel/x86_64/dtobf16_microk_cooperlake.c b/kernel/x86_64/dtobf16_microk_cooperlake.c
new file mode 100644 (file)
index 0000000..9b8ac47
--- /dev/null
@@ -0,0 +1,104 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_TOBF16_ACCL_KERNEL 1
+#include "common.h"
+#include <immintrin.h>
+
+static void tobf16_accl_kernel(BLASLONG n, const double * in, bfloat16 * out)
+{
+    /* Get the 64-bytes unaligned header number targeting for avx512
+     * processing (Assume input float array is natural aligned) */
+    int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 3) & 0x7;
+
+    if (n < align_header) {align_header = n;}
+
+    if (align_header != 0) {
+        unsigned char align_mask8 = (((unsigned char)0xff) >> (8-align_header));
+        __m512d a = _mm512_maskz_loadu_pd(*((__mmask8*) &align_mask8), &in[0]);
+        _mm_mask_storeu_epi16(&out[0], *((__mmask8*) &align_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(a)));
+    }
+
+    if (n == align_header) {
+        return;
+    } else {
+        n -= align_header;
+        in += align_header;
+        out += align_header;
+    }
+
+    int tail_index_8   = n&(~7);
+    int tail_index_32  = n&(~31);
+    int tail_index_128 = n&(~127);
+    unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n&7)));
+
+    /* Processing the main chunk with 128-elements per round */
+    for (int i = 0; i < tail_index_128; i += 128) {
+       // Fold 1
+        __m512 data1_512_low  = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+ 8])), 1);
+        __m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+24])), 1);
+        _mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
+
+       // Fold 2
+        __m512 data2_512_low  = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+32]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+40])), 1);
+        __m512 data2_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+48]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+56])), 1);
+        _mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(data2_512_high, data2_512_low));
+
+       // Fold 3
+        __m512 data3_512_low  = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+64]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+72])), 1);
+        __m512 data3_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+80]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+88])), 1);
+        _mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(data3_512_high, data3_512_low));
+
+       // Fold 4
+        __m512 data4_512_low  = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+96]))),  _mm512_cvtpd_ps(_mm512_load_pd(&in[i+104])), 1);
+        __m512 data4_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[i+112]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[i+120])), 1);
+        _mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(data4_512_high, data4_512_low));
+    }
+
+    /* Processing the remaining <128 chunk with 32-elements per round */
+    for (int j = tail_index_128; j < tail_index_32; j += 32) {
+        __m512 data1_512_low  = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 0]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+ 8])), 1);
+        __m512 data1_512_high = _mm512_insertf32x8(_mm512_castps256_ps512(_mm512_cvtpd_ps(_mm512_load_pd(&in[j+16]))), _mm512_cvtpd_ps(_mm512_load_pd(&in[j+24])), 1);
+        _mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(data1_512_high, data1_512_low));
+    }
+
+    /* Processing the remaining <32 chunk with 8-elements per round */
+    for (int j = tail_index_32; j < tail_index_8; j += 8) {
+        _mm_storeu_si128((__m128i *)&out[j], (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(_mm512_load_pd(&in[j]))));
+    }
+
+    /* Processing the remaining <8 chunk with masked processing */
+    if ((n&7) > 0) {
+        __m512d data_512 = _mm512_maskz_load_pd(*((__mmask8*) &tail_mask8), &in[tail_index_8]);
+        _mm_mask_storeu_epi16(&out[tail_index_8], *((__mmask8*) &tail_mask8), (__m128i) _mm256_cvtneps_pbh(_mm512_cvtpd_ps(data_512)));
+    }
+}
+
+#endif
diff --git a/kernel/x86_64/shdot.c b/kernel/x86_64/shdot.c
new file mode 100644 (file)
index 0000000..5073fda
--- /dev/null
@@ -0,0 +1,115 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+#include "common.h"
+
+#if defined(COOPERLAKE)
+#include "shdot_microk_cooperlake.c"
+#endif
+
+static float shdot_compute(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
+{
+    float d = 0.0;
+
+#ifdef HAVE_SHDOT_ACCL_KERNEL
+    if ((inc_x == 1) && (inc_y == 1)) {
+        return shdot_accl_kernel(n, x, y);
+    }
+#endif
+
+    float * x_fp32 = malloc(sizeof(float)*n);
+    float * y_fp32 = malloc(sizeof(float)*n);
+
+    SBF16TOS_K(n, x, inc_x, x_fp32, 1);
+    SBF16TOS_K(n, y, inc_y, y_fp32, 1);
+
+    d = SDOTU_K(n, x_fp32, 1, y_fp32, 1);
+
+    free(x_fp32);
+    free(y_fp32);
+
+    return d;
+}
+
+#if defined(SMP)
+static int shdot_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, bfloat16 dummy2,
+                           bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
+                           float *result, BLASLONG dummy3)
+{
+    *(float *)result = shdot_compute(n, x, inc_x, y, inc_y);
+    return 0;
+}
+
+extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
+                            void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
+                           int (*function)(), int nthreads);
+#endif
+
+float CNAME(BLASLONG n, bfloat16 *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y)
+{
+    float dot_result = 0.0;
+
+    if (n <= 0)  return 0.0;
+
+#if defined(SMP)
+    int nthreads;
+    int thread_thres = 40960;
+    bfloat16 dummy_alpha;
+#endif
+
+#if defined(SMP)
+    if (inc_x == 0 || inc_y == 0 || n <= thread_thres)
+        nthreads = 1;
+    else
+        nthreads = num_cpu_avail(1);
+
+    int best_threads = (int) (n/(float)thread_thres + 0.5);
+
+    if (best_threads < nthreads) {
+        nthreads = best_threads;
+    }
+
+    if (nthreads <= 1) {
+        dot_result = shdot_compute(n, x, inc_x, y, inc_y);
+    } else {
+        char thread_result[MAX_CPU_NUMBER * sizeof(double) * 2];
+        int mode = BLAS_BFLOAT16 | BLAS_REAL;
+        blas_level1_thread_with_return_value(mode, n, 0, 0, &dummy_alpha,
+                                             x, inc_x, y, inc_y, thread_result, 0,
+                                             (void *)shdot_thread_func, nthreads);
+        float * ptr = (float *)thread_result;
+        for (int i = 0; i < nthreads; i++) {
+            dot_result += (*ptr);
+            ptr = (float *)(((char *)ptr) + sizeof(double) * 2);
+        }
+    }
+#else
+    dot_result = shdot_compute(n, x, inc_x, y, inc_y);
+#endif
+
+    return dot_result;
+}
diff --git a/kernel/x86_64/shdot_microk_cooperlake.c b/kernel/x86_64/shdot_microk_cooperlake.c
new file mode 100644 (file)
index 0000000..e645296
--- /dev/null
@@ -0,0 +1,159 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_SHDOT_ACCL_KERNEL 1
+#include "common.h"
+#include <immintrin.h>
+
+static float shdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
+{
+    __m128 accum128   = _mm_setzero_ps();
+    if (n> 127) { /* n range from 128 to inf. */
+        long tail_index_32  = n&(~31);
+        long tail_index_128 = n&(~127);
+        unsigned int tail_mask_uint = (((unsigned int)0xffffffff) >> (32-(n&31)));
+        __mmask32 tail_mask = *((__mmask32*) &tail_mask_uint);
+
+        __m512 accum512_0 = _mm512_setzero_ps();
+        __m512 accum512_1 = _mm512_setzero_ps();
+        __m512 accum512_2 = _mm512_setzero_ps();
+        __m512 accum512_3 = _mm512_setzero_ps();
+
+        /* Processing the main chunk with 128-elements per round */
+        for (long i = 0; i < tail_index_128; i += 128) {
+            accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[i+ 0]), (__m512bh) _mm512_loadu_si512(&y[i+ 0]));
+            accum512_1 = _mm512_dpbf16_ps(accum512_1, (__m512bh) _mm512_loadu_si512(&x[i+32]), (__m512bh) _mm512_loadu_si512(&y[i+32]));
+            accum512_2 = _mm512_dpbf16_ps(accum512_2, (__m512bh) _mm512_loadu_si512(&x[i+64]), (__m512bh) _mm512_loadu_si512(&y[i+64]));
+            accum512_3 = _mm512_dpbf16_ps(accum512_3, (__m512bh) _mm512_loadu_si512(&x[i+96]), (__m512bh) _mm512_loadu_si512(&y[i+96]));
+        }
+
+        /* Processing the remaining <128 chunk with 32-elements per round */
+        for (long j = tail_index_128; j < tail_index_32; j += 32) {
+            accum512_0 = _mm512_dpbf16_ps(accum512_0, (__m512bh) _mm512_loadu_si512(&x[j]), (__m512bh) _mm512_loadu_si512(&y[j]));
+        }
+
+        /* Processing the remaining <32 chunk with masked 32-elements processing */
+        if ((n&31) != 0) {
+            accum512_2 = _mm512_dpbf16_ps(accum512_2,
+                                          (__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &x[tail_index_32]),
+                                          (__m512bh) _mm512_maskz_loadu_epi16(tail_mask, &y[tail_index_32]));
+        }
+
+        /* Accumulate the 4 registers into 1 register */
+        accum512_0 = _mm512_add_ps(accum512_0, accum512_1);
+        accum512_2 = _mm512_add_ps(accum512_2, accum512_3);
+        accum512_0 = _mm512_add_ps(accum512_0, accum512_2);
+
+        __m256 accum256 = _mm256_add_ps(_mm512_castps512_ps256(accum512_0), _mm512_extractf32x8_ps(accum512_0, 1));
+        accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
+    } else if (n > 31) { /* n range from 32 to 127 */
+        /* Processing <128 chunk with 32-elements per round */
+        __m256 accum256   = _mm256_setzero_ps();
+        __m256 accum256_1 = _mm256_setzero_ps();
+        int tail_index_32  = n&(~31);
+        for (int j = 0; j < tail_index_32; j += 32) {
+            accum256   = _mm256_dpbf16_ps(accum256,   (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0]));
+            accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16]));
+        }
+        accum256 = _mm256_add_ps(accum256, accum256_1);
+
+        /* Processing the remaining <32 chunk with 16-elements processing */
+        if ((n&16) != 0) {
+            accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32]));
+        }
+        accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
+
+        /* Processing the remaining <16 chunk with 8-elements processing */
+        if ((n&8) != 0) {
+            int tail_index_16  = n&(~15);
+            accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
+        }
+
+        /* Processing the remaining <8 chunk with masked 8-elements processing */
+        if ((n&7) != 0) {
+            unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+            __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+            int tail_index_8   = n&(~7);
+            accum128 = _mm_dpbf16_ps(accum128,
+                                     (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
+                                     (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
+        }
+    } else if (n > 15) { /* n range from 16 to 31 */
+        /* Processing <32 chunk with 16-elements processing */
+        __m256 accum256   = _mm256_setzero_ps();
+        accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0]));
+        accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1));
+
+        /* Processing the remaining <16 chunk with 8-elements processing */
+        if ((n&8) != 0) {
+            int tail_index_16  = n&(~15);
+            accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16]));
+        }
+
+        /* Processing the remaining <8 chunk with masked 8-elements processing */
+        if ((n&7) != 0) {
+            unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+            __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+            int tail_index_8   = n&(~7);
+            accum128 = _mm_dpbf16_ps(accum128,
+                                     (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
+                                     (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
+        }
+    } else if (n > 7) { /* n range from 8 to 15 */
+        /* Processing <16 chunk with 8-elements processing */
+        accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0]));
+
+        /* Processing the remaining <8 chunk with masked 8-elements processing */
+        if ((n&7) != 0) {
+            unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+            __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+            int tail_index_8   = n&(~7);
+            accum128 = _mm_dpbf16_ps(accum128,
+                                     (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[tail_index_8]),
+                                     (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[tail_index_8]));
+        }
+    } else { /* n range from 1 to 7 */
+        unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(n&7)));
+        __mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
+        accum128 = _mm_dpbf16_ps(accum128,
+                                 (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &x[0]),
+                                 (__m128bh) _mm_maskz_loadu_epi16(tail_mask, &y[0]));
+    }
+
+    /* Add up the 4 elements into lowest entry */
+    __m128 accum128_1 = _mm_shuffle_ps(accum128, accum128, 14);
+    accum128 = _mm_add_ps(accum128, accum128_1);
+    accum128_1 = _mm_shuffle_ps(accum128, accum128, 1);
+    accum128 = _mm_add_ps(accum128, accum128_1);
+
+    return accum128[0];
+}
+
+#endif
diff --git a/kernel/x86_64/stobf16_microk_cooperlake.c b/kernel/x86_64/stobf16_microk_cooperlake.c
new file mode 100644 (file)
index 0000000..2756a69
--- /dev/null
@@ -0,0 +1,86 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+/* need a new enough GCC for avx512 support */
+#if (( defined(__GNUC__)  && __GNUC__   >= 10 && defined(__AVX512BF16__)) || (defined(__clang__) && __clang_major__ >= 9))
+
+#define HAVE_TOBF16_ACCL_KERNEL 1
+#include "common.h"
+#include <immintrin.h>
+
+static void tobf16_accl_kernel(BLASLONG n, const float * in, bfloat16 * out)
+{
+    /* Get the 64-bytes unaligned header number targeting for avx512
+     * processing (Assume input float array is natural aligned) */
+    int align_header = ((64 - ((uintptr_t)in & (uintptr_t)0x3f)) >> 2) & 0xf;
+
+    if (n < align_header) {align_header = n;}
+
+    if (align_header != 0) {
+        uint16_t align_mask16 = (((uint16_t)0xffff) >> (16-align_header));
+        __m512 a = _mm512_maskz_loadu_ps(*((__mmask16*) &align_mask16), &in[0]);
+        _mm256_mask_storeu_epi16(&out[0], *((__mmask16*) &align_mask16), (__m256i) _mm512_cvtneps_pbh(a));
+    }
+
+    if (n == align_header) {
+        return;
+    } else {
+        n -= align_header;
+        in += align_header;
+        out += align_header;
+    }
+
+    int tail_index_32  = n&(~31);
+    int tail_index_128 = n&(~127);
+    uint32_t tail_mask32 = (((uint32_t) 0xffffffff) >> (32-(n&31)));
+    uint16_t tail_mask16 = (((uint16_t) 0xffff)     >> (16-(n&15)));
+
+    /* Processing the main chunk with 128-elements per round */
+    for (int i = 0; i < tail_index_128; i += 128) {
+        _mm512_storeu_si512(&out[i+ 0], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 16]), _mm512_load_ps(&in[i+ 0])));
+        _mm512_storeu_si512(&out[i+32], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 48]), _mm512_load_ps(&in[i+32])));
+        _mm512_storeu_si512(&out[i+64], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+ 80]), _mm512_load_ps(&in[i+64])));
+        _mm512_storeu_si512(&out[i+96], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[i+112]), _mm512_load_ps(&in[i+96])));
+    }
+
+    /* Processing the remaining <128 chunk with 32-elements per round */
+    for (int j = tail_index_128; j < tail_index_32; j += 32) {
+        _mm512_storeu_si512(&out[j], (__m512i) _mm512_cvtne2ps_pbh(_mm512_load_ps(&in[j+ 16]), _mm512_load_ps(&in[j])));
+    }
+
+    /* Processing the remaining <32 chunk with masked processing */
+    if ((n&31) > 15) {
+        __m512 b = _mm512_load_ps(&in[tail_index_32]);
+        __m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32+16]);
+        _mm512_mask_storeu_epi16(&out[tail_index_32], *((__mmask32*) &tail_mask32), (__m512i) _mm512_cvtne2ps_pbh(a, b));
+    } else if ((n&31) > 0) {
+        __m512 a = _mm512_maskz_load_ps(*((__mmask16*) &tail_mask16), &in[tail_index_32]);
+        _mm256_mask_storeu_epi16(&out[tail_index_32], *((__mmask16*) &tail_mask16), (__m256i) _mm512_cvtneps_pbh(a));
+    }
+}
+
+#endif
diff --git a/kernel/x86_64/tobf16.c b/kernel/x86_64/tobf16.c
new file mode 100644 (file)
index 0000000..3d17966
--- /dev/null
@@ -0,0 +1,170 @@
+/***************************************************************************
+Copyright (c) 2014, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+
+#include <stddef.h>
+#include "common.h"
+
+#if defined(DOUBLE)
+#define FLOAT_TYPE double
+#elif defined(SINGLE)
+#define FLOAT_TYPE float
+#else
+#endif
+
+#if defined(COOPERLAKE)
+#if defined(DOUBLE)
+#include "dtobf16_microk_cooperlake.c"
+#elif defined(SINGLE)
+#include "stobf16_microk_cooperlake.c"
+#endif
+#endif
+
+/* Notes for algorithm:
+ * - Round to Nearest Even used generally
+ * - QNAN for NAN case
+ * - Input denormals are treated as zero
+ */
+static void tobf16_generic_kernel(BLASLONG n, const FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
+{
+    BLASLONG register index_in  = 0;
+    BLASLONG register index_out = 0;
+    BLASLONG register index     = 0;
+    float             float_in  = 0.0;
+    uint32_t *        uint32_in = (uint32_t *)(&float_in);
+    uint16_t *        uint16_in = (uint16_t *)(&float_in);
+
+    while(index<n) {
+#if defined(DOUBLE)
+        float_in = (float)(*(in+index_in));
+#else
+        float_in = *(in+index_in);
+#endif
+
+        switch((*uint32_in) & 0xff800000u) {
+            case (0x00000000u):   /* Type 1: Positive denormal */
+                *(out+index_out) = 0x0000u;
+                break;
+            case (0x80000000u):   /* Type 2: Negative denormal */
+                *(out+index_out) = 0x8000u;
+                break;
+            case (0x7f800000u):   /* Type 3: Positive infinity or NAN */
+            case (0xff800000u):   /* Type 4: Negative infinity or NAN */
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+                *(out+index_out) = uint16_in[1];
+#else
+                *(out+index_out) = uint16_in[0];
+#endif
+                /* Specific for NAN */
+                if (((*uint32_in) & 0x007fffffu) != 0) {
+                    /* Force to be QNAN */
+                    *(out+index_out) |= 0x0040u;
+                }
+                break;
+            default:              /* Type 5: Normal case */
+                (*uint32_in) += ((((*uint32_in) >> 16) & 0x1u) + 0x7fffu);
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+                *(out+index_out) = uint16_in[1];
+#else
+                *(out+index_out) = uint16_in[0];
+#endif
+                break;
+        }
+
+        index_in  += inc_in;
+        index_out += inc_out;
+        index++;
+    }
+}
+
+#ifndef HAVE_TOBF16_ACCL_KERNEL
+static void tobf16_accl_kernel(BLASLONG n, const FLOAT_TYPE * in, bfloat16 * out)
+{
+    tobf16_generic_kernel(n, in, 1, out, 1);
+}
+#endif
+
+static void tobf16_compute(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
+{
+    if ((inc_in == 1) && (inc_out == 1)) {
+        tobf16_accl_kernel(n, in, out);
+    } else {
+        tobf16_generic_kernel(n, in, inc_in, out, inc_out);
+    }
+}
+
+#if defined(SMP)
+static int tobf16_thread_func(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT_TYPE dummy2,
+                            FLOAT_TYPE *x, BLASLONG inc_x, bfloat16 *y, BLASLONG inc_y,
+                            FLOAT_TYPE *dummy3, BLASLONG dummy4)
+{
+        tobf16_compute(n, x, inc_x, y, inc_y);
+        return 0;
+}
+
+extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha,
+                              void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc,
+                              int (*function)(), int nthreads);
+#endif
+
+void CNAME(BLASLONG n, FLOAT_TYPE * in, BLASLONG inc_in, bfloat16 * out, BLASLONG inc_out)
+{
+    if (n <= 0)  return;
+
+#if defined(SMP)
+    int nthreads;
+    FLOAT_TYPE dummy_alpha;
+    FLOAT_TYPE dummy_c;
+#endif
+
+#if defined(SMP)
+    if (inc_in == 0 || inc_out == 0 || n <= 100000) {
+        nthreads = 1;
+    } else {
+        if (n/100000 < 100) {
+            nthreads = 4;
+        } else {
+            nthreads = 16;
+        }
+    }
+
+    if (nthreads == 1) {
+        tobf16_compute(n, in, inc_in, out, inc_out);
+    } else {
+#if defined(DOUBLE)
+        int mode = BLAS_REAL | BLAS_DTOBF16;
+#elif defined(SINGLE)
+        int mode = BLAS_REAL | BLAS_STOBF16;
+#endif
+        blas_level1_thread(mode, n, 0, 0, &dummy_alpha,
+                           in, inc_in, out, inc_out, &dummy_c, 0,
+                           (void *)tobf16_thread_func, nthreads);
+    }
+#else
+    tobf16_compute(n, in, inc_in, out, inc_out);
+#endif
+
+}
index 9955e5c..858b8c5 100644 (file)
@@ -35,7 +35,8 @@ typedef unsigned long BLASULONG;
 #endif
 
 #ifndef BFLOAT16
-typedef unsigned short bfloat16;
+#include <stdint.h>
+typedef uint16_t bfloat16;
 #endif
 
 #ifdef OPENBLAS_USE64BITINT