RFC : Add half precision gemm for bfloat16 in OpenBLAS
authorRajalakshmi Srinivasaraghavan <raji@linux.ibm.com>
Tue, 14 Apr 2020 19:55:08 +0000 (14:55 -0500)
committerRajalakshmi Srinivasaraghavan <raji@linux.ibm.com>
Tue, 14 Apr 2020 19:55:08 +0000 (14:55 -0500)
This patch adds support for bfloat16 data type matrix multiplication kernel.
For architectures that don't support bfloat16, it is defined as unsigned short
(2 bytes).  Default unroll sizes can be changed as per architecture as done for
SGEMM and for now 8 and 4 are used for M and N.  Size of ncopy/tcopy can be
changed as per architecture requirement and for now, size 2 is used.

Added shgemm in kernel/power/KERNEL.POWER9 and tested in powerpc64le and
powerpc64.  For reference, added a small test compare_sgemm_shgemm.c to compare
sgemm and shgemm output.

This patch does not cover OpenBLAS test, benchmark and lapack tests for shgemm.
Complex type implementation can be discussed and added once this is approved.

27 files changed:
Makefile.system
Makefile.tail
cmake/prebuild.cmake
cmake/system.cmake
common.h
common_interface.h
common_level3.h
common_macro.h
common_param.h
common_sh.h [new file with mode: 0644]
driver/level3/Makefile
driver/level3/level3.c
driver/level3/level3_thread.c
driver/others/parameter.c
getarch_2nd.c
interface/Makefile
interface/gemm.c
kernel/Makefile.L3
kernel/generic/gemm_beta.c
kernel/generic/gemm_ncopy_2.c
kernel/generic/gemm_tcopy_2.c
kernel/generic/gemmkernel_2x2.c
kernel/power/KERNEL.POWER9
kernel/setparam-ref.c
lapack/getrf/potrf_parallel.c
param.h
test/compare_sgemm_shgemm.c [new file with mode: 0644]

index 2998c0e..0e17698 100644 (file)
@@ -1390,6 +1390,8 @@ export FUNCTION_PROFILE
 export TARGET_CORE
 export NO_AVX512
 
+export SHGEMM_UNROLL_M
+export SHGEMM_UNROLL_N
 export SGEMM_UNROLL_M
 export SGEMM_UNROLL_N
 export DGEMM_UNROLL_M
index 2adede1..3990298 100644 (file)
@@ -1,3 +1,4 @@
+SHBLASOBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
 SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
 DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
 QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
@@ -9,8 +10,8 @@ COMMONOBJS_P = $(COMMONOBJS:.$(SUFFIX)=.$(PSUFFIX))
 
 HPLOBJS_P   = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
 
-BLASOBJS    = $(SBLASOBJS)   $(DBLASOBJS)   $(CBLASOBJS)   $(ZBLASOBJS)
-BLASOBJS_P  = $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
+BLASOBJS    = $(SHBLASOBJS)  $(SBLASOBJS)   $(DBLASOBJS)   $(CBLASOBJS)   $(ZBLASOBJS)
+BLASOBJS_P  = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
 
 ifdef EXPRECISION
 BLASOBJS   += $(QBLASOBJS)   $(XBLASOBJS)
@@ -22,6 +23,7 @@ BLASOBJS   += $(QBLASOBJS)   $(XBLASOBJS)
 BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
 endif
 
+$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHALF -UDOUBLE  -UCOMPLEX
 $(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE  -UCOMPLEX
 $(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE  -UCOMPLEX
 $(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
@@ -29,6 +31,7 @@ $(CBLASOBJS) $(CBLASOBJS_P) : override CFLAGS += -UDOUBLE  -DCOMPLEX
 $(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE  -DCOMPLEX
 $(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
 
+$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
 $(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
index 44e1473..e069609 100644 (file)
@@ -16,6 +16,8 @@
 # HAVE_SSE2
 # HAVE_SSE3
 # MAKE
+# SHGEMM_UNROLL_M
+# SHGEMM_UNROLL_N
 # SGEMM_UNROLL_M
 # SGEMM_UNROLL_N
 # DGEMM_UNROLL_M
@@ -437,6 +439,8 @@ if (DEFINED CORE AND CMAKE_CROSSCOMPILING AND NOT (${HOST_OS} STREQUAL "WINDOWSS
     set(ZGEMM_UNROLL_N 2)
     set(SYMV_P 8)
   endif()
+  set(SHGEMM_UNROLL_M 8)
+  set(SHGEMM_UNROLL_N 4)
 
   # Or should this actually be NUM_CORES?
   if (${NUM_THREADS} GREATER 0)
index ce980a7..65e5aa5 100644 (file)
@@ -530,6 +530,8 @@ endif ()
 #export FUNCTION_PROFILE
 #export TARGET_CORE
 #
+#export SHGEMM_UNROLL_M
+#export SHGEMM_UNROLL_N
 #export SGEMM_UNROLL_M
 #export SGEMM_UNROLL_N
 #export DGEMM_UNROLL_M
index 762968e..1d8bf07 100644 (file)
--- a/common.h
+++ b/common.h
@@ -297,6 +297,17 @@ typedef int blasint;
 #define SIZE   8
 #define  BASE_SHIFT 3
 #define ZBASE_SHIFT 4
+#elif defined(HALF)
+#ifndef BFLOAT16
+typedef unsigned short bfloat16;
+#define HALFCONVERSION 1
+#endif
+#define IFLOAT bfloat16
+#define XFLOAT IFLOAT
+#define FLOAT  float
+#define SIZE   2
+#define BASE_SHIFT 1
+#define ZBASE_SHIFT 2
 #else
 #define FLOAT  float
 #define SIZE    4
@@ -308,6 +319,10 @@ typedef int blasint;
 #define XFLOAT FLOAT
 #endif
 
+#ifndef IFLOAT
+#define IFLOAT FLOAT
+#endif
+
 #ifndef COMPLEX
 #define COMPSIZE  1
 #else
index c350ac8..081043a 100644 (file)
@@ -37,6 +37,9 @@
 /*********************************************************************/
 
 #ifndef ASSEMBLER
+#ifndef BFLOAT16
+typedef unsigned short bfloat16;
+#endif
 
 #ifdef __cplusplus
 extern "C" {
@@ -469,6 +472,8 @@ void BLASFUNC(xhbmv)(char *, blasint *, blasint *, xdouble *, xdouble *, blasint
 
 /* Level 3 routines */
 
+void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
+          bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
 void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
           float  *, blasint *, float  *, blasint *, float  *, float  *, blasint *);
 void BLASFUNC(dgemm)(char *, char *, blasint *, blasint *, blasint *, double *,
index 6fa902b..8194ba6 100644 (file)
@@ -37,6 +37,9 @@
 /*********************************************************************/
 
 #ifndef ASSEMBLER
+#ifndef BFLOAT16
+typedef unsigned short bfloat16;
+#endif
 
 #ifdef __CUDACC__
 __global__ void cuda_sgemm_kernel(int, int, int, float *, float *, float *);
@@ -55,6 +58,8 @@ extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K,
 extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
 
 
+int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
+              bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
 int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
               float  *, BLASLONG, float   *, BLASLONG, float  *, BLASLONG);
 int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double,
@@ -76,6 +81,10 @@ int xgemm_beta(BLASLONG, BLASLONG, BLASLONG, xdouble *,
               xdouble *, BLASLONG, xdouble  *, BLASLONG, xdouble *, BLASLONG);
 #endif
 
+int shgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
+int shgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
+int shgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
+int shgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
 int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
 int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
 int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
@@ -499,6 +508,7 @@ int xher2k_kernel_UC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdoubl
 int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
 int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
 
+int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float,  bfloat16 *, bfloat16 *, float *, BLASLONG);
 int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float,  float  *, float  *, float  *, BLASLONG);
 int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
 
@@ -527,6 +537,11 @@ int cgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, float,  float,  float  *, float
 int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
 int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);
 
+int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+
 int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
 int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
 int sgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
@@ -619,6 +634,11 @@ int xgemm_cr(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLON
 int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
 #endif
 
+int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+
 int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
 int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
 int sgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
index 13bb857..b438c83 100644 (file)
@@ -39,6 +39,7 @@
 #ifndef COMMON_MACRO
 #define COMMON_MACRO
 
+#include "common_sh.h"
 #include "common_s.h"
 #include "common_d.h"
 #include "common_q.h"
 #define IMATCOPY_K_RT          DIMATCOPY_K_RT
 
 #define GEADD_K                 DGEADD_K
+
+#elif defined(HALF)
+
+#define GEMM_BETA               SHGEMM_BETA
+#define        GEMM_KERNEL_N           SHGEMM_KERNEL
+#define        GEMM_KERNEL_L           SHGEMM_KERNEL
+#define        GEMM_KERNEL_R           SHGEMM_KERNEL
+#define        GEMM_KERNEL_B           SHGEMM_KERNEL
+
+#define        GEMM_NN                 SHGEMM_NN
+#define        GEMM_CN                 SHGEMM_TN
+#define        GEMM_TN                 SHGEMM_TN
+#define        GEMM_NC                 SHGEMM_NT
+#define        GEMM_NT                 SHGEMM_NT
+#define        GEMM_CC                 SHGEMM_TT
+#define        GEMM_CT                 SHGEMM_TT
+#define        GEMM_TC                 SHGEMM_TT
+#define        GEMM_TT                 SHGEMM_TT
+#define        GEMM_NR                 SHGEMM_NN
+#define        GEMM_TR                 SHGEMM_TN
+#define        GEMM_CR                 SHGEMM_TN
+#define        GEMM_RN                 SHGEMM_NN
+#define        GEMM_RT                 SHGEMM_NT
+#define        GEMM_RC                 SHGEMM_NT
+#define        GEMM_RR                 SHGEMM_NN
+#define        GEMM_ONCOPY             SHGEMM_ONCOPY
+#define        GEMM_OTCOPY             SHGEMM_OTCOPY
+#define        GEMM_INCOPY             SHGEMM_INCOPY
+#define        GEMM_ITCOPY             SHGEMM_ITCOPY
+
+#define        GEMM_THREAD_NN          SHGEMM_THREAD_NN
+#define        GEMM_THREAD_CN          SHGEMM_THREAD_TN
+#define        GEMM_THREAD_TN          SHGEMM_THREAD_TN
+#define        GEMM_THREAD_NC          SHGEMM_THREAD_NT
+#define        GEMM_THREAD_NT          SHGEMM_THREAD_NT
+#define        GEMM_THREAD_CC          SHGEMM_THREAD_TT
+#define        GEMM_THREAD_CT          SHGEMM_THREAD_TT
+#define        GEMM_THREAD_TC          SHGEMM_THREAD_TT
+#define        GEMM_THREAD_TT          SHGEMM_THREAD_TT
+#define        GEMM_THREAD_NR          SHGEMM_THREAD_NN
+#define        GEMM_THREAD_TR          SHGEMM_THREAD_TN
+#define        GEMM_THREAD_CR          SHGEMM_THREAD_TN
+#define        GEMM_THREAD_RN          SHGEMM_THREAD_NN
+#define        GEMM_THREAD_RT          SHGEMM_THREAD_NT
+#define        GEMM_THREAD_RC          SHGEMM_THREAD_NT
+#define        GEMM_THREAD_RR          SHGEMM_THREAD_NN
+
 #else
 
 #define        AMAX_K                  SAMAX_K
 #if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64)
 extern BLASLONG gemm_offset_a;
 extern BLASLONG gemm_offset_b;
+extern BLASLONG shgemm_p;
+extern BLASLONG shgemm_q;
+extern BLASLONG shgemm_r;
 extern BLASLONG sgemm_p;
 extern BLASLONG sgemm_q;
 extern BLASLONG sgemm_r;
index 574d5e1..f1cac38 100644 (file)
@@ -84,6 +84,16 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
   int    (*sgemm_kernel   )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
   int    (*sgemm_beta     )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float  *, BLASLONG);
 
+  int shgemm_p, shgemm_q, shgemm_r;
+  int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
+  int    (*shgemm_kernel   )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
+  int    (*shgemm_beta     )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
+
+  int    (*shgemm_incopy   )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+  int    (*shgemm_itcopy   )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+  int    (*shgemm_oncopy   )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+  int    (*shgemm_otcopy   )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+
   int    (*sgemm_incopy   )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
   int    (*sgemm_itcopy   )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
   int    (*sgemm_oncopy   )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
@@ -907,6 +917,13 @@ extern gotoblas_t *gotoblas;
 
 #define HAVE_EX_L2     gotoblas -> exclusive_cache
 
+#define        SHGEMM_P                gotoblas -> shgemm_p
+#define        SHGEMM_Q                gotoblas -> shgemm_q
+#define        SHGEMM_R                gotoblas -> shgemm_r
+#define        SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
+#define        SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
+#define        SHGEMM_UNROLL_MN        gotoblas -> shgemm_unroll_mn
+
 #define        SGEMM_P         gotoblas -> sgemm_p
 #define        SGEMM_Q         gotoblas -> sgemm_q
 #define        SGEMM_R         gotoblas -> sgemm_r
@@ -984,6 +1001,17 @@ extern gotoblas_t *gotoblas;
 #define HAVE_EX_L2     0
 #endif
 
+#define        SHGEMM_P                SHGEMM_DEFAULT_P
+#define        SHGEMM_Q                SHGEMM_DEFAULT_Q
+#define        SHGEMM_R                SHGEMM_DEFAULT_R
+#define SHGEMM_UNROLL_M        SHGEMM_DEFAULT_UNROLL_M
+#define SHGEMM_UNROLL_N        SHGEMM_DEFAULT_UNROLL_N
+#ifdef  SHGEMM_DEFAULT_UNROLL_MN
+#define SHGEMM_UNROLL_MN       SHGEMM_DEFAULT_UNROLL_MN
+#else
+#define SHGEMM_UNROLL_MN       MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
+#endif
+
 #define        SGEMM_P         SGEMM_DEFAULT_P
 #define        SGEMM_Q         SGEMM_DEFAULT_Q
 #define        SGEMM_R         SGEMM_DEFAULT_R
@@ -1119,6 +1147,18 @@ extern gotoblas_t *gotoblas;
 #define GEMM_DEFAULT_R         DGEMM_DEFAULT_R
 #define GEMM_DEFAULT_UNROLL_M  DGEMM_DEFAULT_UNROLL_M
 #define GEMM_DEFAULT_UNROLL_N  DGEMM_DEFAULT_UNROLL_N
+#elif defined(HALF)
+#define GEMM_P                 SHGEMM_P
+#define GEMM_Q                 SHGEMM_Q
+#define GEMM_R                 SHGEMM_R
+#define GEMM_UNROLL_M          SHGEMM_UNROLL_M
+#define GEMM_UNROLL_N          SHGEMM_UNROLL_N
+#define GEMM_UNROLL_MN         SHGEMM_UNROLL_MN
+#define GEMM_DEFAULT_P         SHGEMM_DEFAULT_P
+#define GEMM_DEFAULT_Q         SHGEMM_DEFAULT_Q
+#define GEMM_DEFAULT_R         SHGEMM_DEFAULT_R
+#define GEMM_DEFAULT_UNROLL_M  SHGEMM_DEFAULT_UNROLL_M
+#define GEMM_DEFAULT_UNROLL_N  SHGEMM_DEFAULT_UNROLL_N
 #else
 #define GEMM_P                 SGEMM_P
 #define GEMM_Q                 SGEMM_Q
@@ -1204,6 +1244,10 @@ extern gotoblas_t *gotoblas;
 #define GEMM_THREAD gemm_thread_n
 #endif
 
+#ifndef SHGEMM_DEFAULT_R
+#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q *  4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q *  4) - 15) & ~15)
+#endif
+
 #ifndef SGEMM_DEFAULT_R
 #define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q *  4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q *  4) - 15) & ~15)
 #endif
diff --git a/common_sh.h b/common_sh.h
new file mode 100644 (file)
index 0000000..8859694
--- /dev/null
@@ -0,0 +1,65 @@
+#ifndef COMMON_H_H
+#define COMMON_H_H
+
+#ifndef DYNAMIC_ARCH
+
+#define        SHGEMM_ONCOPY           shgemm_oncopy
+#define        SHGEMM_OTCOPY           shgemm_otcopy
+
+#if SHGEMM_DEFAULT_UNROLL_M == SHGEMM_DEFAULT_UNROLL_N
+#define        SHGEMM_INCOPY           shgemm_oncopy
+#define        SHGEMM_ITCOPY           shgemm_otcopy
+#else
+#define        SHGEMM_INCOPY           shgemm_incopy
+#define        SHGEMM_ITCOPY           shgemm_itcopy
+#endif
+#define        SHGEMM_BETA             shgemm_beta
+#define SHGEMM_KERNEL            shgemm_kernel
+
+#else
+
+#define        SHGEMM_ONCOPY           gotoblas -> shgemm_oncopy
+#define        SHGEMM_OTCOPY           gotoblas -> shgemm_otcopy
+#define        SHGEMM_INCOPY           gotoblas -> shgemm_incopy
+#define        SHGEMM_ITCOPY           gotoblas -> shgemm_itcopy
+#define        SHGEMM_BETA             gotoblas -> shgemm_beta
+#define        SHGEMM_KERNEL           gotoblas -> shgemm_kernel
+
+#endif
+
+#define        SHGEMM_NN               shgemm_nn
+#define        SHGEMM_CN               shgemm_tn
+#define        SHGEMM_TN               shgemm_tn
+#define        SHGEMM_NC               shgemm_nt
+#define        SHGEMM_NT               shgemm_nt
+#define        SHGEMM_CC               shgemm_tt
+#define        SHGEMM_CT               shgemm_tt
+#define        SHGEMM_TC               shgemm_tt
+#define        SHGEMM_TT               shgemm_tt
+#define        SHGEMM_NR               shgemm_nn
+#define        SHGEMM_TR               shgemm_tn
+#define        SHGEMM_CR               shgemm_tn
+#define        SHGEMM_RN               shgemm_nn
+#define        SHGEMM_RT               shgemm_nt
+#define        SHGEMM_RC               shgemm_nt
+#define        SHGEMM_RR               shgemm_nn
+
+#define        SHGEMM_THREAD_NN                shgemm_thread_nn
+#define        SHGEMM_THREAD_CN                shgemm_thread_tn
+#define        SHGEMM_THREAD_TN                shgemm_thread_tn
+#define        SHGEMM_THREAD_NC                shgemm_thread_nt
+#define        SHGEMM_THREAD_NT                shgemm_thread_nt
+#define        SHGEMM_THREAD_CC                shgemm_thread_tt
+#define        SHGEMM_THREAD_CT                shgemm_thread_tt
+#define        SHGEMM_THREAD_TC                shgemm_thread_tt
+#define        SHGEMM_THREAD_TT                shgemm_thread_tt
+#define        SHGEMM_THREAD_NR                shgemm_thread_nn
+#define        SHGEMM_THREAD_TR                shgemm_thread_tn
+#define        SHGEMM_THREAD_CR                shgemm_thread_tn
+#define        SHGEMM_THREAD_RN                shgemm_thread_nn
+#define        SHGEMM_THREAD_RT                shgemm_thread_nt
+#define        SHGEMM_THREAD_RC                shgemm_thread_nt
+#define        SHGEMM_THREAD_RR                shgemm_thread_nn
+
+#endif
+
index e320092..881b4ee 100644 (file)
@@ -19,6 +19,7 @@ ifeq ($(ARCH), MIPS)
 USE_GEMM3M = 1
 endif
 
+SHBLASOBJS       += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
 SBLASOBJS      += \
        sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
        strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
@@ -204,6 +205,7 @@ COMMONOBJS  += syrk_thread.$(SUFFIX)
 
 ifndef USE_SIMPLE_THREADED_LEVEL3
 
+SHBLASOBJS    += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
 SBLASOBJS    += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
 DBLASOBJS    += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
 QBLASOBJS    += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
@@ -283,6 +285,18 @@ endif
 
 all ::
 
+shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
+
 sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
        $(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
 
@@ -478,6 +492,17 @@ gemm_thread_variable.$(SUFFIX) : gemm_thread_variable.c ../../common.h
 beta_thread.$(SUFFIX) : beta_thread.c ../../common.h
        $(CC) -c $(CFLAGS) $< -o $(@F)
 
+shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
 
 sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
        $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
@@ -2652,6 +2677,18 @@ xtrsm_RCLU.$(SUFFIX) : trsm_R.c
 xtrsm_RCLN.$(SUFFIX) : trsm_R.c
        $(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F)
 
+shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
+
 sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
        $(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
 
@@ -2848,6 +2885,18 @@ beta_thread.$(PSUFFIX) : beta_thread.c ../../common.h
        $(CC) -c $(PFLAGS) $< -o $(@F)
 
 
+shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+       $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
+
 sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
        $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
 
index 9aa6728..c6bbb9c 100644 (file)
 #ifndef ICOPY_OPERATION
 #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
     defined(RN) || defined(RT) || defined(RC) || defined(RR)
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #else
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #endif
 #endif
 
 #ifndef OCOPY_OPERATION
 #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
     defined(NR) || defined(TR) || defined(CR) || defined(RR)
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #else
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #endif
 #endif
 
@@ -173,7 +173,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
                  XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
   BLASLONG k, lda, ldb, ldc;
   FLOAT *alpha, *beta;
-  FLOAT *a, *b, *c;
+  IFLOAT *a, *b;
+  FLOAT *c;
   BLASLONG m_from, m_to, n_from, n_to;
 
   BLASLONG ls, is, js;
@@ -198,8 +199,8 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
   k = K;
 
-  a = (FLOAT *)A;
-  b = (FLOAT *)B;
+  a = (IFLOAT *)A;
+  b = (IFLOAT *)B;
   c = (FLOAT *)C;
 
   lda = LDA;
index ca0085e..5a8d497 100644 (file)
@@ -117,18 +117,18 @@ typedef struct {
 #ifndef ICOPY_OPERATION
 #if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
   defined(RN) || defined(RT) || defined(RC) || defined(RR)
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #else
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #endif
 #endif
 
 #ifndef OCOPY_OPERATION
 #if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
   defined(NR) || defined(TR) || defined(CR) || defined(RR)
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #else
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #endif
 #endif
 
@@ -219,15 +219,16 @@ typedef struct {
 #define STOP_RPCC(COUNTER)
 #endif
 
-static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
+static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
 
-  FLOAT *buffer[DIVIDE_RATE];
+  IFLOAT *buffer[DIVIDE_RATE];
 
   BLASLONG k, lda, ldb, ldc;
   BLASLONG m_from, m_to, n_from, n_to;
 
   FLOAT *alpha, *beta;
-  FLOAT *a, *b, *c;
+  IFLOAT *a, *b;
+  FLOAT *c;
   job_t *job = (job_t *)args -> common;
 
   BLASLONG nthreads_m;
@@ -255,8 +256,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
   k = K;
 
-  a = (FLOAT *)A;
-  b = (FLOAT *)B;
+  a = (IFLOAT *)A;
+  b = (IFLOAT *)B;
   c = (FLOAT *)C;
 
   lda = LDA;
@@ -425,7 +426,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
           /* Apply kernel with local region of A and part of other region of B */
          START_RPCC();
          KERNEL_OPERATION(min_i, MIN(range_n[current + 1]  - js,  div_n), min_l, alpha,
-                          sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
+                          sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
                           c, ldc, m_from, js);
           STOP_RPCC(kernel);
 
@@ -469,7 +470,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
           /* Apply kernel with local region of A and part of region of B */
          START_RPCC();
          KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
-                          sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
+                          sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
                           c, ldc, is, js);
           STOP_RPCC(kernel);
           
@@ -532,7 +533,7 @@ static int round_up(int remainder, int width, int multiple)
 
 
 static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
-                      *range_n, FLOAT *sa, FLOAT *sb,
+                      *range_n, IFLOAT *sa, IFLOAT *sb,
                        BLASLONG nthreads_m, BLASLONG nthreads_n) {
 
 #ifndef USE_OPENMP
@@ -728,7 +729,7 @@ EnterCriticalSection((PCRITICAL_SECTION)&level3_lock);
   return 0;
 }
 
-int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
+int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
 
   BLASLONG m = args -> m;
   BLASLONG n = args -> n;
index 8bf7da7..b1f3bef 100644 (file)
@@ -62,6 +62,11 @@ BLASLONG gemm_offset_b = DEFAULT_GEMM_OFFSET_B;
 BLASLONG gemm_offset_b = GEMM_OFFSET_B;
 #endif
 
+#if SHGEMM_P == shgemm_p
+BLASLONG shgemm_p = DEFAULT_GEMM_P;
+#else
+BLASLONG shgemm_p = SHGEMM_P;
+#endif
 #if SGEMM_P == sgemm_p
 BLASLONG sgemm_p = DEFAULT_GEMM_P;
 #else
@@ -83,6 +88,11 @@ BLASLONG zgemm_p = DEFAULT_GEMM_P;
 BLASLONG zgemm_p = ZGEMM_P;
 #endif
 
+#if SHGEMM_Q == shgemm_q
+BLASLONG shgemm_q = DEFAULT_GEMM_Q;
+#else
+BLASLONG shgemm_q = SHGEMM_Q;
+#endif
 #if SGEMM_Q == sgemm_q
 BLASLONG sgemm_q = DEFAULT_GEMM_Q;
 #else
@@ -104,6 +114,11 @@ BLASLONG zgemm_q = DEFAULT_GEMM_Q;
 BLASLONG zgemm_q = ZGEMM_Q;
 #endif
 
+#if SHGEMM_R == shgemm_r
+BLASLONG shgemm_r = DEFAULT_GEMM_R;
+#else
+BLASLONG shgemm_r = SHGEMM_R;
+#endif
 #if SGEMM_R == sgemm_r
 BLASLONG sgemm_r = DEFAULT_GEMM_R;
 #else
@@ -597,6 +612,7 @@ void blas_set_parameter(void){
 
   size = BITMASK(cpuid3, 16, 0xff);
 
+  shgemm_p = 192 * (size + 1);
   sgemm_p = 192 * (size + 1);
   dgemm_p =  96 * (size + 1);
   cgemm_p =  96 * (size + 1);
@@ -610,6 +626,7 @@ void blas_set_parameter(void){
   xgemm_p =  16 * (size + 1);
 #endif
 
+  shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q *  4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q *  4)) - 15) & ~15;
   sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q *  4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q *  4)) - 15) & ~15;
   dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q *  8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q *  8)) - 15) & ~15;
   cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q *  8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q *  8)) - 15) & ~15;
index cf9c578..a1d0cca 100644 (file)
@@ -9,6 +9,8 @@
 int main(int argc, char **argv) {
 
   if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) {
+    printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M);
+    printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N);
     printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M);
     printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N);
     printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M);
index 3f0dcca..741f6ba 100644 (file)
@@ -46,6 +46,7 @@ SBLAS3OBJS    = \
                somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\
                sgeadd.$(SUFFIX)
 
+SHBLAS3OBJS    = shgemm.$(SUFFIX)
 
 DBLAS1OBJS    = \
                daxpy.$(SUFFIX) dswap.$(SUFFIX) \
@@ -277,6 +278,8 @@ CSBLAS3OBJS   = \
        cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX)  cblas_simatcopy.$(SUFFIX)\
        cblas_sgeadd.$(SUFFIX)
 
+CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
+
 CDBLAS1OBJS   = \
        cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \
        cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \
@@ -367,6 +370,7 @@ override CFLAGS += -I.
 SBLAS1OBJS   += $(CSBLAS1OBJS)
 SBLAS2OBJS   += $(CSBLAS2OBJS)
 SBLAS3OBJS   += $(CSBLAS3OBJS)
+SHBLAS3OBJS  += $(CSHBLAS3OBJS)
 DBLAS1OBJS   += $(CDBLAS1OBJS)
 DBLAS2OBJS   += $(CDBLAS2OBJS)
 DBLAS3OBJS   += $(CDBLAS3OBJS)
@@ -380,6 +384,7 @@ ZBLAS3OBJS   += $(CZBLAS3OBJS)
 endif
 
 SBLASOBJS    = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
+SHBLASOBJS   = $(SHBLAS3OBJS)
 DBLASOBJS    = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
 QBLASOBJS    = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
 CBLASOBJS    = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
@@ -454,7 +459,7 @@ ZBLASOBJS += $(ZLAPACKOBJS)
 
 endif
 
-FUNCOBJS    = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
+FUNCOBJS    = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
 
 ifdef EXPRECISION
 FUNCOBJS   += $(QBLASOBJS) $(XBLASOBJS)
@@ -488,10 +493,10 @@ level1 : $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $
 level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
        $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
 
-level3 : $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
+level3 : $(SHBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
        $(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
 
-$(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \
+$(CSHBLASOBJS) $(CSHBLASOBJS_P) $(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \
 $(CCBLASOBJS) $(CCBLASOBJS_P) $(CZBLASOBJS) $(CZBLASOBJS_P) $(CXBLASOBJS) $(CXBLASOBJS_P) : override CFLAGS += -DCBLAS
 
 srot.$(SUFFIX) srot.$(PSUFFIX) : rot.c
@@ -1209,6 +1214,9 @@ zhpr2.$(SUFFIX) zhpr2.$(PSUFFIX) : zhpr2.c
 xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c
        $(CC) -c $(CFLAGS) $< -o $(@F)
 
+shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h
+       $(CC) -c $(CFLAGS) $< -o $(@F)
+
 sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h
        $(CC) -c $(CFLAGS) $< -o $(@F)
 
@@ -1770,6 +1778,9 @@ cblas_zhemv.$(SUFFIX) cblas_zhemv.$(PSUFFIX) : zhemv.c
 cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h
        $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
 
+cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h
+       $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
+
 cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h
        $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
 
index 0b18d9a..99388e7 100644 (file)
@@ -77,7 +77,7 @@
 #define GEMM_MULTITHREAD_THRESHOLD 4
 #endif
 
-static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
+static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
 #ifndef GEMM3M
   GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
   GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
@@ -108,8 +108,8 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLA
 void NAME(char *TRANSA, char *TRANSB,
          blasint *M, blasint *N, blasint *K,
          FLOAT *alpha,
-         FLOAT *a, blasint *ldA,
-         FLOAT *b, blasint *ldB,
+         IFLOAT *a, blasint *ldA,
+         IFLOAT *b, blasint *ldB,
          FLOAT *beta,
          FLOAT *c, blasint *ldC){
 
@@ -119,8 +119,8 @@ void NAME(char *TRANSA, char *TRANSB,
   blasint info;
 
   char transA, transB;
-  FLOAT *buffer;
-  FLOAT *sa, *sb;
+  IFLOAT *buffer;
+  IFLOAT *sa, *sb;
 
 #ifdef SMP
   double MNK;
index 6d96abb..aee610e 100644 (file)
@@ -59,6 +59,10 @@ ifeq ($(CORE), Z14)
 USE_TRMM = 1
 endif
 
+SHKERNELOBJS   += \
+       shgemm_kernel$(TSUFFIX).$(SUFFIX) \
+       $(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \
+       $(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ)
 
 SKERNELOBJS    += \
        sgemm_kernel$(TSUFFIX).$(SUFFIX) \
@@ -93,6 +97,7 @@ XKERNELOBJS   += \
        $(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \
        $(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ)
 
+SHBLASOBJS      += $(SHKERNELOBJS)
 SBLASOBJS      += $(SKERNELOBJS)
 DBLASOBJS      += $(DKERNELOBJS)
 QBLASOBJS      += $(QKERNELOBJS)
@@ -100,6 +105,7 @@ CBLASOBJS   += $(CKERNELOBJS)
 ZBLASOBJS      += $(ZKERNELOBJS)
 XBLASOBJS      += $(XKERNELOBJS)
 
+SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX)
 SBLASOBJS      += \
        sgemm_beta$(TSUFFIX).$(SUFFIX) \
        strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \
@@ -390,6 +396,10 @@ ZBLASOBJS += \
        zgeadd_k$(TSUFFIX).$(SUFFIX) 
 
 
+SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
 SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
 SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
 SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
@@ -415,6 +425,9 @@ XGEMMITCOPYOBJ_P = $(XGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
 XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
 XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
 
+$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
 $(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
        $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 
@@ -433,6 +446,36 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA)
 $(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
        $(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
 
+$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY)
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY)
+ifeq ($(OS), AIX)
+       $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmotcopy.s
+       m4 shgemmotcopy.s > shgemmotcopy_nomacros.s
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmotcopy_nomacros.s -o $@
+       rm shgemmotcopy.s shgemmotcopy_nomacros.s
+else
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
+ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
+
+$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY)
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY)
+ifeq ($(OS), AIX)
+       $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmitcopy.s
+       m4 shgemmitcopy.s > shgemmitcopy_nomacros.s
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmitcopy_nomacros.s -o $@
+       rm shgemmitcopy.s shgemmitcopy_nomacros.s
+else
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
+endif
+
 $(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY)
        $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 
@@ -590,6 +633,16 @@ else
        $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 endif
 
+$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
+ifeq ($(OS), AIX)
+       $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX  $< -o shgemm_kernel$(TSUFFIX).s
+       m4 shgemm_kernel$(TSUFFIX).s > shgemm_kernel$(TSUFFIX)_nomacros.s
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemm_kernel$(TSUFFIX)_nomacros.s -o $@
+       rm shgemm_kernel$(TSUFFIX).s shgemm_kernel$(TSUFFIX)_nomacros.s
+else
+       $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
 $(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND)
 ifeq ($(OS), AIX)
        $(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s
@@ -2206,6 +2259,9 @@ $(KDIR)xtrsm_oltncopy$(TSUFFIX).$(SUFFIX) : generic/ztrsm_ltcopy_$(XGEMM_UNROLL_
 $(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
        $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 
+$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
+       $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
 $(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA)
        $(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@
 
@@ -2221,6 +2277,20 @@ $(KDIR)zgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(ZGEMM_BETA)
 $(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
        $(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
 
+$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY)
+       $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY)
+       $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
+$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY)
+       $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY)
+       $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+endif
 $(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY)
        $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 
@@ -2325,6 +2395,9 @@ endif
 
 endif
 
+$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
+       $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
 $(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND)
        $(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 
index fa9d768..ccb772c 100644 (file)
@@ -39,7 +39,7 @@
 #include "common.h"
 
 int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
-         FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5,
+         IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5,
          FLOAT *c, BLASLONG ldc){
 
 
index b728c71..415860f 100644 (file)
 #include <stdio.h>
 #include "common.h"
 
-int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
+int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
   BLASLONG i, j;
-  FLOAT *a_offset, *a_offset1, *a_offset2;
-  FLOAT *b_offset;
+  IFLOAT *a_offset, *a_offset1, *a_offset2;
+  IFLOAT *b_offset;
 
   a_offset = a;
   b_offset = b;
index 5695b13..b4aa4de 100644 (file)
 #include <stdio.h>
 #include "common.h"
 
-int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
+int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
   BLASLONG i, j;
 
-  FLOAT *a_offset, *a_offset1, *a_offset2;
-  FLOAT *b_offset, *b_offset1, *b_offset2;
+  IFLOAT *a_offset, *a_offset1, *a_offset2;
+  IFLOAT *b_offset, *b_offset1, *b_offset2;
 
   a_offset = a;
   b_offset = b;
index 01f1c67..26a88db 100644 (file)
@@ -1,13 +1,32 @@
 #include "common.h"
-int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc
+#if defined(HALF) && defined(HALFCONVERSION)
+float
+bfloat16tof32 (bfloat16 f16)
+{
+  float result = 0;
+  unsigned short* q = (unsigned short*)(&result);
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+  q[0] = f16;
+#else
+  q[1] = f16;
+#endif
+  return result;
+}
+#define BF16TOF32(x) (bfloat16tof32(x))
+#else
+#define BF16TOF32(x) x
+#endif
+int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
 #ifdef TRMMKERNEL
                ,BLASLONG offset
 #endif
                )
 {
    BLASLONG i,j,k;
-   FLOAT *C0,*C1,*ptrba,*ptrbb;
-   FLOAT res0,res1,res2,res3,load0,load1,load2,load3,load4,load5,load6,load7;
+   FLOAT *C0,*C1;
+   IFLOAT *ptrba,*ptrbb;
+   FLOAT res0,res1,res2,res3;
+   IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
    for (j=0; j<bn/2; j+=1)
      {
         C0 = C;
@@ -24,36 +43,36 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
                {
                   load0 = ptrba[2*0+0];
                   load1 = ptrbb[2*0+0];
-                  res0 = res0+load0*load1;
+                  res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
                   load2 = ptrba[2*0+1];
-                  res1 = res1+load2*load1;
+                  res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
                   load3 = ptrbb[2*0+1];
-                  res2 = res2+load0*load3;
-                  res3 = res3+load2*load3;
+                  res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
+                  res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
                   load4 = ptrba[2*1+0];
                   load5 = ptrbb[2*1+0];
-                  res0 = res0+load4*load5;
+                  res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
                   load6 = ptrba[2*1+1];
-                  res1 = res1+load6*load5;
+                  res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
                   load7 = ptrbb[2*1+1];
-                  res2 = res2+load4*load7;
-                  res3 = res3+load6*load7;
+                  res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
+                  res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
                   load0 = ptrba[2*2+0];
                   load1 = ptrbb[2*2+0];
-                  res0 = res0+load0*load1;
+                  res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
                   load2 = ptrba[2*2+1];
-                  res1 = res1+load2*load1;
+                  res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
                   load3 = ptrbb[2*2+1];
-                  res2 = res2+load0*load3;
-                  res3 = res3+load2*load3;
+                  res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
+                  res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
                   load4 = ptrba[2*3+0];
                   load5 = ptrbb[2*3+0];
-                  res0 = res0+load4*load5;
+                  res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
                   load6 = ptrba[2*3+1];
-                  res1 = res1+load6*load5;
+                  res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
                   load7 = ptrbb[2*3+1];
-                  res2 = res2+load4*load7;
-                  res3 = res3+load6*load7;
+                  res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
+                  res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
                   ptrba = ptrba+8;
                   ptrbb = ptrbb+8;
                }
@@ -61,12 +80,12 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
                {
                   load0 = ptrba[2*0+0];
                   load1 = ptrbb[2*0+0];
-                  res0 = res0+load0*load1;
+                  res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
                   load2 = ptrba[2*0+1];
-                  res1 = res1+load2*load1;
+                  res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
                   load3 = ptrbb[2*0+1];
-                  res2 = res2+load0*load3;
-                  res3 = res3+load2*load3;
+                  res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
+                  res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
                   ptrba = ptrba+2;
                   ptrbb = ptrbb+2;
                }
@@ -90,9 +109,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
                {
                   load0 = ptrba[0+0];
                   load1 = ptrbb[2*0+0];
-                  res0 = res0+load0*load1;
+                  res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
                   load2 = ptrbb[2*0+1];
-                  res1 = res1+load0*load2;
+                  res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
                   ptrba = ptrba+1;
                   ptrbb = ptrbb+2;
                }
@@ -121,9 +140,9 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
                {
                   load0 = ptrba[2*0+0];
                   load1 = ptrbb[0+0];
-                  res0 = res0+load0*load1;
+                  res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
                   load2 = ptrba[2*0+1];
-                  res1 = res1+load2*load1;
+                  res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
                   ptrba = ptrba+2;
                   ptrbb = ptrbb+1;
                }
@@ -141,7 +160,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
                {
                   load0 = ptrba[0+0];
                   load1 = ptrbb[0+0];
-                  res0 = res0+load0*load1;
+                  res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
                   ptrba = ptrba+1;
                   ptrbb = ptrbb+1;
                }
index aabb5d9..dedb015 100644 (file)
@@ -12,6 +12,17 @@ DTRMMKERNEL  = dgemm_kernel_power9.S
 CTRMMKERNEL    = cgemm_kernel_power9.S
 ZTRMMKERNEL    = zgemm_kernel_power9.S
 
+SHGEMM_BETA = ../generic/gemm_beta.c
+SHGEMMKERNEL    = ../generic/gemmkernel_2x2.c
+SHGEMMINCOPY    = ../generic/gemm_ncopy_2.c
+SHGEMMITCOPY    = ../generic/gemm_tcopy_2.c
+SHGEMMONCOPY    = ../generic/gemm_ncopy_2.c
+SHGEMMOTCOPY    = ../generic/gemm_tcopy_2.c
+SHGEMMINCOPYOBJ =  shgemm_incopy$(TSUFFIX).$(SUFFIX)
+SHGEMMITCOPYOBJ =  shgemm_itcopy$(TSUFFIX).$(SUFFIX)
+SHGEMMONCOPYOBJ =  shgemm_oncopy$(TSUFFIX).$(SUFFIX)
+SHGEMMOTCOPYOBJ =  shgemm_otcopy$(TSUFFIX).$(SUFFIX)
+
 SGEMMKERNEL    =  sgemm_kernel_power9.S
 SGEMMINCOPY    = ../generic/gemm_ncopy_16.c
 SGEMMITCOPY    = sgemm_tcopy_16_power8.S
index 3c71c77..12d0389 100644 (file)
@@ -54,6 +54,20 @@ gotoblas_t TABLE_NAME = {
   GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN,
 
   0, 0, 0,
+  SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
+#ifdef SHGEMM_DEFAULT_UNROLL_MN
+ SHGEMM_DEFAULT_UNROLL_MN,
+#else
+ MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
+#endif
+  shgemm_kernelTS, shgemm_betaTS,
+#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
+  shgemm_incopyTS, shgemm_itcopyTS,
+#else
+  shgemm_oncopyTS, shgemm_otcopyTS,
+#endif
+  shgemm_oncopyTS, shgemm_otcopyTS,
+  sgemm_kernelTS, sgemm_betaTS,
   SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
 #ifdef SGEMM_DEFAULT_UNROLL_MN
  SGEMM_DEFAULT_UNROLL_MN,
@@ -648,16 +662,19 @@ gotoblas_t TABLE_NAME = {
 
 #if defined(ARCH_ARM64)
 static void init_parameter(void) {
+  TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
   TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
   TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
   TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
   TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
 
+  TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
   TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
   TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
   TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
   TABLE_NAME.zgemm_q = ZGEMM_DEFAULT_Q;
 
+  TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
   TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
   TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
   TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
@@ -721,17 +738,20 @@ static void init_parameter(void) {
 #if defined(ARCH_POWER)
 static void init_parameter(void) {
 
+  TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
   TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
   TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
   TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
   TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
 
+  TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
   TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
   TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
   TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
   TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
 
 
+  TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
   TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
   TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
   TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
@@ -741,17 +761,20 @@ static void init_parameter(void) {
 
 #if defined(ARCH_ZARCH)
 static void init_parameter(void) {
+       TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
        TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
        TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
        TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
        TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
 
+       TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
        TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
        TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
        TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
        TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
 
 
+       TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
        TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
        TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
        TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
@@ -891,6 +914,7 @@ static void init_parameter(void) {
   (void) l2; /* dirty trick to suppress unused variable warning for targets */
              /* where the GEMM unrolling parameters do not depend on l2 */
   
+  TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
   TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
   TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
   TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
@@ -1261,6 +1285,7 @@ static void init_parameter(void) {
 
 
 
+  TABLE_NAME.shgemm_p = ((TABLE_NAME.shgemm_p + SHGEMM_DEFAULT_UNROLL_M - 1)/SHGEMM_DEFAULT_UNROLL_M) * SHGEMM_DEFAULT_UNROLL_M;
   TABLE_NAME.sgemm_p = ((TABLE_NAME.sgemm_p + SGEMM_DEFAULT_UNROLL_M - 1)/SGEMM_DEFAULT_UNROLL_M) * SGEMM_DEFAULT_UNROLL_M;
   TABLE_NAME.dgemm_p = ((TABLE_NAME.dgemm_p + DGEMM_DEFAULT_UNROLL_M - 1)/DGEMM_DEFAULT_UNROLL_M) * DGEMM_DEFAULT_UNROLL_M;
   TABLE_NAME.cgemm_p = ((TABLE_NAME.cgemm_p + CGEMM_DEFAULT_UNROLL_M - 1)/CGEMM_DEFAULT_UNROLL_M) * CGEMM_DEFAULT_UNROLL_M;
@@ -1288,6 +1313,11 @@ static void init_parameter(void) {
   fprintf(stderr, "L2 = %8d DGEMM_P  .. %d\n", l2, TABLE_NAME.dgemm_p);
 #endif
 
+  TABLE_NAME.shgemm_r = (((BUFFER_SIZE -
+                              ((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q *  4 + TABLE_NAME.offsetA
+                                + TABLE_NAME.align) & ~TABLE_NAME.align)
+                              ) / (TABLE_NAME.shgemm_q *  4) - 15) & ~15);
+
   TABLE_NAME.sgemm_r = (((BUFFER_SIZE -
                               ((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q *  4 + TABLE_NAME.offsetA
                                 + TABLE_NAME.align) & ~TABLE_NAME.align)
index c2fee6b..3125096 100644 (file)
@@ -380,6 +380,9 @@ static int thread_driver(blas_arg_t *args, FLOAT *sa, FLOAT *sb){
 #elif defined(DOUBLE)
   mode  =  BLAS_DOUBLE  | BLAS_REAL;
   mask  = MAX(DGEMM_UNROLL_M, DGEMM_UNROLL_N) - 1;
+#elif defined(HALF)
+  mode  =  BLAS_HALF  | BLAS_REAL;
+  mask  = MAX(SHGEMM_UNROLL_M, SHGEMM_UNROLL_N) - 1;
 #else
   mode  =  BLAS_SINGLE  | BLAS_REAL;
   mask  = MAX(SGEMM_UNROLL_M, SGEMM_UNROLL_N) - 1;
diff --git a/param.h b/param.h
index d6cbe54..a3eb29d 100644 (file)
--- a/param.h
+++ b/param.h
@@ -72,6 +72,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 #ifndef PARAM_H
 #define PARAM_H
 
+#define SHGEMM_DEFAULT_UNROLL_N 4
+#define SHGEMM_DEFAULT_UNROLL_M 8
+#define SHGEMM_DEFAULT_UNROLL_MN 32
+#define SHGEMM_DEFAULT_P 256
+#define SHGEMM_DEFAULT_R 256
+#define SHGEMM_DEFAULT_Q 256
 #ifdef OPTERON
 
 #define SNUMOPT                4
diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c
new file mode 100644 (file)
index 0000000..978972b
--- /dev/null
@@ -0,0 +1,95 @@
+/***************************************************************************
+Copyright (c) 2020, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+#include <stdio.h>
+#include <stdint.h>
+#include "common.h"
+#define SGEMM   BLASFUNC(sgemm)
+#define SHGEMM   BLASFUNC(shgemm)
+typedef union
+{
+  unsigned short v;
+  struct
+  {
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+    unsigned short s:1;
+    unsigned short e:8;
+    unsigned short m:7;
+#else
+    unsigned short m:7;
+    unsigned short e:8;
+    unsigned short s:1;
+#endif
+  } bits;
+} bfloat16_bits;
+
+int
+main (int argc, char *argv[])
+{
+  int m, n, k;
+  int i, j, l;
+  int ret = 0;
+  int loop = 20;
+  char transA = 'N', transB = 'N';
+  float alpha = 1.0, beta = 0.0;
+  char transa = 'N';
+  char transb = 'N';
+
+  for (int x = 0; x <= loop; x++)
+    {
+      m = k = n = x;
+      float A[m * k];
+      float B[k * n];
+      float C[m * n];
+      bfloat16_bits AA[m * k], BB[k * n];
+      float CC[m * n];
+
+      for (int j = 0; j < m; j++)
+       {
+         for (int i = 0; i < m; i++)
+           {
+             A[j * k + i] = j * 9.0;
+             B[j * k + i] = i * 2.0;
+             C[j * k + i] = 0;
+             AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
+             BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
+             CC[j * k + i] = 0;
+           }
+       }
+      SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
+             &m, B, &k, &beta, C, &m);
+      SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
+              &m, BB, &k, &beta, CC, &m);
+
+      for (i = 0; i < n; i++)
+       for (j = 0; j < m; j++)
+         for (l = 0; l < k; l++)
+           if (CC[i * m + j] != C[i * m + j])
+             ret++;
+    }
+  fprintf (stderr, "Return code: %d\n", ret);
+  return ret;
+}