This patch adds support for bfloat16 data type matrix multiplication kernel.
For architectures that don't support bfloat16, it is defined as unsigned short
(2 bytes). Default unroll sizes can be changed as per architecture as done for
SGEMM and for now 8 and 4 are used for M and N. Size of ncopy/tcopy can be
changed as per architecture requirement and for now, size 2 is used.
Added shgemm in kernel/power/KERNEL.POWER9 and tested in powerpc64le and
powerpc64. For reference, added a small test compare_sgemm_shgemm.c to compare
sgemm and shgemm output.
This patch does not cover OpenBLAS test, benchmark and lapack tests for shgemm.
Complex type implementation can be discussed and added once this is approved.
export TARGET_CORE
export NO_AVX512
+export SHGEMM_UNROLL_M
+export SHGEMM_UNROLL_N
export SGEMM_UNROLL_M
export SGEMM_UNROLL_N
export DGEMM_UNROLL_M
+SHBLASOBJS_P = $(SHBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
SBLASOBJS_P = $(SBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
DBLASOBJS_P = $(DBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
QBLASOBJS_P = $(QBLASOBJS:.$(SUFFIX)=.$(PSUFFIX))
HPLOBJS_P = $(HPLOBJS:.$(SUFFIX)=.$(PSUFFIX))
-BLASOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
-BLASOBJS_P = $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
+BLASOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
+BLASOBJS_P = $(SHBLASOBJS_P) $(SBLASOBJS_P) $(DBLASOBJS_P) $(CBLASOBJS_P) $(ZBLASOBJS_P)
ifdef EXPRECISION
BLASOBJS += $(QBLASOBJS) $(XBLASOBJS)
BLASOBJS_P += $(QBLASOBJS_P) $(XBLASOBJS_P)
endif
+$(SHBLASOBJS) $(SHBLASOBJS_P) : override CFLAGS += -DHALF -UDOUBLE -UCOMPLEX
$(SBLASOBJS) $(SBLASOBJS_P) : override CFLAGS += -UDOUBLE -UCOMPLEX
$(DBLASOBJS) $(DBLASOBJS_P) : override CFLAGS += -DDOUBLE -UCOMPLEX
$(QBLASOBJS) $(QBLASOBJS_P) : override CFLAGS += -DXDOUBLE -UCOMPLEX
$(ZBLASOBJS) $(ZBLASOBJS_P) : override CFLAGS += -DDOUBLE -DCOMPLEX
$(XBLASOBJS) $(XBLASOBJS_P) : override CFLAGS += -DXDOUBLE -DCOMPLEX
+$(SHBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(SBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(DBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
$(QBLASOBJS_P) : override CFLAGS += -DPROFILE $(COMMON_PROF)
# HAVE_SSE2
# HAVE_SSE3
# MAKE
+# SHGEMM_UNROLL_M
+# SHGEMM_UNROLL_N
# SGEMM_UNROLL_M
# SGEMM_UNROLL_N
# DGEMM_UNROLL_M
set(ZGEMM_UNROLL_N 2)
set(SYMV_P 8)
endif()
+ set(SHGEMM_UNROLL_M 8)
+ set(SHGEMM_UNROLL_N 4)
# Or should this actually be NUM_CORES?
if (${NUM_THREADS} GREATER 0)
#export FUNCTION_PROFILE
#export TARGET_CORE
#
+#export SHGEMM_UNROLL_M
+#export SHGEMM_UNROLL_N
#export SGEMM_UNROLL_M
#export SGEMM_UNROLL_N
#export DGEMM_UNROLL_M
#define SIZE 8
#define BASE_SHIFT 3
#define ZBASE_SHIFT 4
+#elif defined(HALF)
+#ifndef BFLOAT16
+typedef unsigned short bfloat16;
+#define HALFCONVERSION 1
+#endif
+#define IFLOAT bfloat16
+#define XFLOAT IFLOAT
+#define FLOAT float
+#define SIZE 2
+#define BASE_SHIFT 1
+#define ZBASE_SHIFT 2
#else
#define FLOAT float
#define SIZE 4
#define XFLOAT FLOAT
#endif
+#ifndef IFLOAT
+#define IFLOAT FLOAT
+#endif
+
#ifndef COMPLEX
#define COMPSIZE 1
#else
/*********************************************************************/
#ifndef ASSEMBLER
+#ifndef BFLOAT16
+typedef unsigned short bfloat16;
+#endif
#ifdef __cplusplus
extern "C" {
/* Level 3 routines */
+void BLASFUNC(shgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
+ bfloat16 *, blasint *, bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemm)(char *, char *, blasint *, blasint *, blasint *, float *,
float *, blasint *, float *, blasint *, float *, float *, blasint *);
void BLASFUNC(dgemm)(char *, char *, blasint *, blasint *, blasint *, double *,
/*********************************************************************/
#ifndef ASSEMBLER
+#ifndef BFLOAT16
+typedef unsigned short bfloat16;
+#endif
#ifdef __CUDACC__
__global__ void cuda_sgemm_kernel(int, int, int, float *, float *, float *);
extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
+int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
+ bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
int sgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int dgemm_beta(BLASLONG, BLASLONG, BLASLONG, double,
xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG);
#endif
+int shgemm_incopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
+int shgemm_itcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
+int shgemm_oncopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
+int shgemm_otcopy(BLASLONG m, BLASLONG n, bfloat16 *a, BLASLONG lda, bfloat16 *b);
int sgemm_incopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
int sgemm_itcopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
int sgemm_oncopy(BLASLONG m, BLASLONG n, float *a, BLASLONG lda, float *b);
int xher2k_kernel_LN(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
int xher2k_kernel_LC(BLASLONG m, BLASLONG n, BLASLONG k, xdouble alpha_r, xdouble alpha_i, xdouble *a, xdouble *b, xdouble *c, BLASLONG ldc, BLASLONG offset, int flag);
+int shgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
int sgemm_kernel(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int dgemm_kernel(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG);
int zgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG);
int xgemm3m_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG);
+int shgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+
int sgemm_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int xgemm_cc(blas_arg_t *, BLASLONG *, BLASLONG *, xdouble *, xdouble *, BLASLONG);
#endif
+int shgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+int shgemm_thread_tt(blas_arg_t *, BLASLONG *, BLASLONG *, bfloat16 *, bfloat16 *, BLASLONG);
+
int sgemm_thread_nn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_thread_nt(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
int sgemm_thread_tn(blas_arg_t *, BLASLONG *, BLASLONG *, float *, float *, BLASLONG);
#ifndef COMMON_MACRO
#define COMMON_MACRO
+#include "common_sh.h"
#include "common_s.h"
#include "common_d.h"
#include "common_q.h"
#define IMATCOPY_K_RT DIMATCOPY_K_RT
#define GEADD_K DGEADD_K
+
+#elif defined(HALF)
+
+#define GEMM_BETA SHGEMM_BETA
+#define GEMM_KERNEL_N SHGEMM_KERNEL
+#define GEMM_KERNEL_L SHGEMM_KERNEL
+#define GEMM_KERNEL_R SHGEMM_KERNEL
+#define GEMM_KERNEL_B SHGEMM_KERNEL
+
+#define GEMM_NN SHGEMM_NN
+#define GEMM_CN SHGEMM_TN
+#define GEMM_TN SHGEMM_TN
+#define GEMM_NC SHGEMM_NT
+#define GEMM_NT SHGEMM_NT
+#define GEMM_CC SHGEMM_TT
+#define GEMM_CT SHGEMM_TT
+#define GEMM_TC SHGEMM_TT
+#define GEMM_TT SHGEMM_TT
+#define GEMM_NR SHGEMM_NN
+#define GEMM_TR SHGEMM_TN
+#define GEMM_CR SHGEMM_TN
+#define GEMM_RN SHGEMM_NN
+#define GEMM_RT SHGEMM_NT
+#define GEMM_RC SHGEMM_NT
+#define GEMM_RR SHGEMM_NN
+#define GEMM_ONCOPY SHGEMM_ONCOPY
+#define GEMM_OTCOPY SHGEMM_OTCOPY
+#define GEMM_INCOPY SHGEMM_INCOPY
+#define GEMM_ITCOPY SHGEMM_ITCOPY
+
+#define GEMM_THREAD_NN SHGEMM_THREAD_NN
+#define GEMM_THREAD_CN SHGEMM_THREAD_TN
+#define GEMM_THREAD_TN SHGEMM_THREAD_TN
+#define GEMM_THREAD_NC SHGEMM_THREAD_NT
+#define GEMM_THREAD_NT SHGEMM_THREAD_NT
+#define GEMM_THREAD_CC SHGEMM_THREAD_TT
+#define GEMM_THREAD_CT SHGEMM_THREAD_TT
+#define GEMM_THREAD_TC SHGEMM_THREAD_TT
+#define GEMM_THREAD_TT SHGEMM_THREAD_TT
+#define GEMM_THREAD_NR SHGEMM_THREAD_NN
+#define GEMM_THREAD_TR SHGEMM_THREAD_TN
+#define GEMM_THREAD_CR SHGEMM_THREAD_TN
+#define GEMM_THREAD_RN SHGEMM_THREAD_NN
+#define GEMM_THREAD_RT SHGEMM_THREAD_NT
+#define GEMM_THREAD_RC SHGEMM_THREAD_NT
+#define GEMM_THREAD_RR SHGEMM_THREAD_NN
+
#else
#define AMAX_K SAMAX_K
#if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64)
extern BLASLONG gemm_offset_a;
extern BLASLONG gemm_offset_b;
+extern BLASLONG shgemm_p;
+extern BLASLONG shgemm_q;
+extern BLASLONG shgemm_r;
extern BLASLONG sgemm_p;
extern BLASLONG sgemm_q;
extern BLASLONG sgemm_r;
int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
+ int shgemm_p, shgemm_q, shgemm_r;
+ int shgemm_unroll_m, shgemm_unroll_n, shgemm_unroll_mn;
+ int (*shgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, bfloat16 *, float *, BLASLONG);
+ int (*shgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float *, BLASLONG);
+
+ int (*shgemm_incopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+ int (*shgemm_itcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+ int (*shgemm_oncopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+ int (*shgemm_otcopy )(BLASLONG, BLASLONG, bfloat16 *, BLASLONG, bfloat16 *);
+
int (*sgemm_incopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sgemm_itcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
int (*sgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *);
#define HAVE_EX_L2 gotoblas -> exclusive_cache
+#define SHGEMM_P gotoblas -> shgemm_p
+#define SHGEMM_Q gotoblas -> shgemm_q
+#define SHGEMM_R gotoblas -> shgemm_r
+#define SHGEMM_UNROLL_M gotoblas -> shgemm_unroll_m
+#define SHGEMM_UNROLL_N gotoblas -> shgemm_unroll_n
+#define SHGEMM_UNROLL_MN gotoblas -> shgemm_unroll_mn
+
#define SGEMM_P gotoblas -> sgemm_p
#define SGEMM_Q gotoblas -> sgemm_q
#define SGEMM_R gotoblas -> sgemm_r
#define HAVE_EX_L2 0
#endif
+#define SHGEMM_P SHGEMM_DEFAULT_P
+#define SHGEMM_Q SHGEMM_DEFAULT_Q
+#define SHGEMM_R SHGEMM_DEFAULT_R
+#define SHGEMM_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
+#define SHGEMM_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
+#ifdef SHGEMM_DEFAULT_UNROLL_MN
+#define SHGEMM_UNROLL_MN SHGEMM_DEFAULT_UNROLL_MN
+#else
+#define SHGEMM_UNROLL_MN MAX((SHGEMM_UNROLL_M), (SHGEMM_UNROLL_N))
+#endif
+
#define SGEMM_P SGEMM_DEFAULT_P
#define SGEMM_Q SGEMM_DEFAULT_Q
#define SGEMM_R SGEMM_DEFAULT_R
#define GEMM_DEFAULT_R DGEMM_DEFAULT_R
#define GEMM_DEFAULT_UNROLL_M DGEMM_DEFAULT_UNROLL_M
#define GEMM_DEFAULT_UNROLL_N DGEMM_DEFAULT_UNROLL_N
+#elif defined(HALF)
+#define GEMM_P SHGEMM_P
+#define GEMM_Q SHGEMM_Q
+#define GEMM_R SHGEMM_R
+#define GEMM_UNROLL_M SHGEMM_UNROLL_M
+#define GEMM_UNROLL_N SHGEMM_UNROLL_N
+#define GEMM_UNROLL_MN SHGEMM_UNROLL_MN
+#define GEMM_DEFAULT_P SHGEMM_DEFAULT_P
+#define GEMM_DEFAULT_Q SHGEMM_DEFAULT_Q
+#define GEMM_DEFAULT_R SHGEMM_DEFAULT_R
+#define GEMM_DEFAULT_UNROLL_M SHGEMM_DEFAULT_UNROLL_M
+#define GEMM_DEFAULT_UNROLL_N SHGEMM_DEFAULT_UNROLL_N
#else
#define GEMM_P SGEMM_P
#define GEMM_Q SGEMM_Q
#define GEMM_THREAD gemm_thread_n
#endif
+#ifndef SHGEMM_DEFAULT_R
+#define SHGEMM_DEFAULT_R (((BUFFER_SIZE - ((SHGEMM_DEFAULT_P * SHGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SHGEMM_DEFAULT_Q * 4) - 15) & ~15)
+#endif
+
#ifndef SGEMM_DEFAULT_R
#define SGEMM_DEFAULT_R (((BUFFER_SIZE - ((SGEMM_DEFAULT_P * SGEMM_DEFAULT_Q * 4 + GEMM_DEFAULT_OFFSET_A + GEMM_DEFAULT_ALIGN) & ~GEMM_DEFAULT_ALIGN)) / (SGEMM_DEFAULT_Q * 4) - 15) & ~15)
#endif
--- /dev/null
+#ifndef COMMON_H_H
+#define COMMON_H_H
+
+#ifndef DYNAMIC_ARCH
+
+#define SHGEMM_ONCOPY shgemm_oncopy
+#define SHGEMM_OTCOPY shgemm_otcopy
+
+#if SHGEMM_DEFAULT_UNROLL_M == SHGEMM_DEFAULT_UNROLL_N
+#define SHGEMM_INCOPY shgemm_oncopy
+#define SHGEMM_ITCOPY shgemm_otcopy
+#else
+#define SHGEMM_INCOPY shgemm_incopy
+#define SHGEMM_ITCOPY shgemm_itcopy
+#endif
+#define SHGEMM_BETA shgemm_beta
+#define SHGEMM_KERNEL shgemm_kernel
+
+#else
+
+#define SHGEMM_ONCOPY gotoblas -> shgemm_oncopy
+#define SHGEMM_OTCOPY gotoblas -> shgemm_otcopy
+#define SHGEMM_INCOPY gotoblas -> shgemm_incopy
+#define SHGEMM_ITCOPY gotoblas -> shgemm_itcopy
+#define SHGEMM_BETA gotoblas -> shgemm_beta
+#define SHGEMM_KERNEL gotoblas -> shgemm_kernel
+
+#endif
+
+#define SHGEMM_NN shgemm_nn
+#define SHGEMM_CN shgemm_tn
+#define SHGEMM_TN shgemm_tn
+#define SHGEMM_NC shgemm_nt
+#define SHGEMM_NT shgemm_nt
+#define SHGEMM_CC shgemm_tt
+#define SHGEMM_CT shgemm_tt
+#define SHGEMM_TC shgemm_tt
+#define SHGEMM_TT shgemm_tt
+#define SHGEMM_NR shgemm_nn
+#define SHGEMM_TR shgemm_tn
+#define SHGEMM_CR shgemm_tn
+#define SHGEMM_RN shgemm_nn
+#define SHGEMM_RT shgemm_nt
+#define SHGEMM_RC shgemm_nt
+#define SHGEMM_RR shgemm_nn
+
+#define SHGEMM_THREAD_NN shgemm_thread_nn
+#define SHGEMM_THREAD_CN shgemm_thread_tn
+#define SHGEMM_THREAD_TN shgemm_thread_tn
+#define SHGEMM_THREAD_NC shgemm_thread_nt
+#define SHGEMM_THREAD_NT shgemm_thread_nt
+#define SHGEMM_THREAD_CC shgemm_thread_tt
+#define SHGEMM_THREAD_CT shgemm_thread_tt
+#define SHGEMM_THREAD_TC shgemm_thread_tt
+#define SHGEMM_THREAD_TT shgemm_thread_tt
+#define SHGEMM_THREAD_NR shgemm_thread_nn
+#define SHGEMM_THREAD_TR shgemm_thread_tn
+#define SHGEMM_THREAD_CR shgemm_thread_tn
+#define SHGEMM_THREAD_RN shgemm_thread_nn
+#define SHGEMM_THREAD_RT shgemm_thread_nt
+#define SHGEMM_THREAD_RC shgemm_thread_nt
+#define SHGEMM_THREAD_RR shgemm_thread_nn
+
+#endif
+
USE_GEMM3M = 1
endif
+SHBLASOBJS += shgemm_nn.$(SUFFIX) shgemm_nt.$(SUFFIX) shgemm_tn.$(SUFFIX) shgemm_tt.$(SUFFIX)
SBLASOBJS += \
sgemm_nn.$(SUFFIX) sgemm_nt.$(SUFFIX) sgemm_tn.$(SUFFIX) sgemm_tt.$(SUFFIX) \
strmm_LNUU.$(SUFFIX) strmm_LNUN.$(SUFFIX) strmm_LNLU.$(SUFFIX) strmm_LNLN.$(SUFFIX) \
ifndef USE_SIMPLE_THREADED_LEVEL3
+SHBLASOBJS += shgemm_thread_nn.$(SUFFIX) shgemm_thread_nt.$(SUFFIX) shgemm_thread_tn.$(SUFFIX) shgemm_thread_tt.$(SUFFIX)
SBLASOBJS += sgemm_thread_nn.$(SUFFIX) sgemm_thread_nt.$(SUFFIX) sgemm_thread_tn.$(SUFFIX) sgemm_thread_tt.$(SUFFIX)
DBLASOBJS += dgemm_thread_nn.$(SUFFIX) dgemm_thread_nt.$(SUFFIX) dgemm_thread_tn.$(SUFFIX) dgemm_thread_tt.$(SUFFIX)
QBLASOBJS += qgemm_thread_nn.$(SUFFIX) qgemm_thread_nt.$(SUFFIX) qgemm_thread_tn.$(SUFFIX) qgemm_thread_tt.$(SUFFIX)
all ::
+shgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_nt.$(SUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_tn.$(SUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_tt.$(SUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
+
sgemm_nn.$(SUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
beta_thread.$(SUFFIX) : beta_thread.c ../../common.h
$(CC) -c $(CFLAGS) $< -o $(@F)
+shgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_thread_nt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_thread_tn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_thread_tt.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
sgemm_thread_nn.$(SUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(CFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
xtrsm_RCLN.$(SUFFIX) : trsm_R.c
$(CC) -c $(CFLAGS) -DCOMPLEX -DXDOUBLE -DTRANSA -UUPPER -UUNIT -DCONJ $< -o $(@F)
+shgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_nt.$(PSUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_tn.$(PSUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_tt.$(PSUFFIX) : gemm.c level3.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
+
sgemm_nn.$(PSUFFIX) : gemm.c level3.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
$(CC) -c $(PFLAGS) $< -o $(@F)
+shgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
+
+shgemm_thread_nt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DNT $< -o $(@F)
+
+shgemm_thread_tn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTN $< -o $(@F)
+
+shgemm_thread_tt.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
+ $(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -DHALF -UDOUBLE -UCOMPLEX -DTT $< -o $(@F)
+
sgemm_thread_nn.$(PSUFFIX) : gemm.c level3_thread.c ../../param.h
$(CC) $(PFLAGS) $(BLOCKS) -c -DTHREADED_LEVEL3 -UDOUBLE -UCOMPLEX -DNN $< -o $(@F)
#ifndef ICOPY_OPERATION
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
defined(RN) || defined(RT) || defined(RC) || defined(RR)
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
#ifndef OCOPY_OPERATION
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
defined(NR) || defined(TR) || defined(CR) || defined(RR)
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
BLASLONG k, lda, ldb, ldc;
FLOAT *alpha, *beta;
- FLOAT *a, *b, *c;
+ IFLOAT *a, *b;
+ FLOAT *c;
BLASLONG m_from, m_to, n_from, n_to;
BLASLONG ls, is, js;
k = K;
- a = (FLOAT *)A;
- b = (FLOAT *)B;
+ a = (IFLOAT *)A;
+ b = (IFLOAT *)B;
c = (FLOAT *)C;
lda = LDA;
#ifndef ICOPY_OPERATION
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
defined(RN) || defined(RT) || defined(RC) || defined(RR)
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
-#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
#ifndef OCOPY_OPERATION
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
defined(NR) || defined(TR) || defined(CR) || defined(RR)
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (IFLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
-#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
+#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (IFLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
#define STOP_RPCC(COUNTER)
#endif
-static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
+static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
- FLOAT *buffer[DIVIDE_RATE];
+ IFLOAT *buffer[DIVIDE_RATE];
BLASLONG k, lda, ldb, ldc;
BLASLONG m_from, m_to, n_from, n_to;
FLOAT *alpha, *beta;
- FLOAT *a, *b, *c;
+ IFLOAT *a, *b;
+ FLOAT *c;
job_t *job = (job_t *)args -> common;
BLASLONG nthreads_m;
k = K;
- a = (FLOAT *)A;
- b = (FLOAT *)B;
+ a = (IFLOAT *)A;
+ b = (IFLOAT *)B;
c = (FLOAT *)C;
lda = LDA;
/* Apply kernel with local region of A and part of other region of B */
START_RPCC();
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
- sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
+ sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
c, ldc, m_from, js);
STOP_RPCC(kernel);
/* Apply kernel with local region of A and part of region of B */
START_RPCC();
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - js, div_n), min_l, alpha,
- sa, (FLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
+ sa, (IFLOAT *)job[current].working[mypos][CACHE_LINE_SIZE * bufferside],
c, ldc, is, js);
STOP_RPCC(kernel);
static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
- *range_n, FLOAT *sa, FLOAT *sb,
+ *range_n, IFLOAT *sa, IFLOAT *sb,
BLASLONG nthreads_m, BLASLONG nthreads_n) {
#ifndef USE_OPENMP
return 0;
}
-int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
+int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, IFLOAT *sa, IFLOAT *sb, BLASLONG mypos){
BLASLONG m = args -> m;
BLASLONG n = args -> n;
BLASLONG gemm_offset_b = GEMM_OFFSET_B;
#endif
+#if SHGEMM_P == shgemm_p
+BLASLONG shgemm_p = DEFAULT_GEMM_P;
+#else
+BLASLONG shgemm_p = SHGEMM_P;
+#endif
#if SGEMM_P == sgemm_p
BLASLONG sgemm_p = DEFAULT_GEMM_P;
#else
BLASLONG zgemm_p = ZGEMM_P;
#endif
+#if SHGEMM_Q == shgemm_q
+BLASLONG shgemm_q = DEFAULT_GEMM_Q;
+#else
+BLASLONG shgemm_q = SHGEMM_Q;
+#endif
#if SGEMM_Q == sgemm_q
BLASLONG sgemm_q = DEFAULT_GEMM_Q;
#else
BLASLONG zgemm_q = ZGEMM_Q;
#endif
+#if SHGEMM_R == shgemm_r
+BLASLONG shgemm_r = DEFAULT_GEMM_R;
+#else
+BLASLONG shgemm_r = SHGEMM_R;
+#endif
#if SGEMM_R == sgemm_r
BLASLONG sgemm_r = DEFAULT_GEMM_R;
#else
size = BITMASK(cpuid3, 16, 0xff);
+ shgemm_p = 192 * (size + 1);
sgemm_p = 192 * (size + 1);
dgemm_p = 96 * (size + 1);
cgemm_p = 96 * (size + 1);
xgemm_p = 16 * (size + 1);
#endif
+ shgemm_r = (((BUFFER_SIZE - ((SHGEMM_P * SHGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SHGEMM_Q * 4)) - 15) & ~15;
sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15;
dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15;
cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15;
int main(int argc, char **argv) {
if ( (argc <= 1) || ((argc >= 2) && (*argv[1] == '0'))) {
+ printf("SHGEMM_UNROLL_M=%d\n", SHGEMM_DEFAULT_UNROLL_M);
+ printf("SHGEMM_UNROLL_N=%d\n", SHGEMM_DEFAULT_UNROLL_N);
printf("SGEMM_UNROLL_M=%d\n", SGEMM_DEFAULT_UNROLL_M);
printf("SGEMM_UNROLL_N=%d\n", SGEMM_DEFAULT_UNROLL_N);
printf("DGEMM_UNROLL_M=%d\n", DGEMM_DEFAULT_UNROLL_M);
somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\
sgeadd.$(SUFFIX)
+SHBLAS3OBJS = shgemm.$(SUFFIX)
DBLAS1OBJS = \
daxpy.$(SUFFIX) dswap.$(SUFFIX) \
cblas_ssyrk.$(SUFFIX) cblas_ssyr2k.$(SUFFIX) cblas_somatcopy.$(SUFFIX) cblas_simatcopy.$(SUFFIX)\
cblas_sgeadd.$(SUFFIX)
+CSHBLAS3OBJS = cblas_shgemm.$(SUFFIX)
+
CDBLAS1OBJS = \
cblas_idamax.$(SUFFIX) cblas_idamin.$(SUFFIX) cblas_dasum.$(SUFFIX) cblas_daxpy.$(SUFFIX) \
cblas_dcopy.$(SUFFIX) cblas_ddot.$(SUFFIX) \
SBLAS1OBJS += $(CSBLAS1OBJS)
SBLAS2OBJS += $(CSBLAS2OBJS)
SBLAS3OBJS += $(CSBLAS3OBJS)
+SHBLAS3OBJS += $(CSHBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS)
DBLAS3OBJS += $(CDBLAS3OBJS)
endif
SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
+SHBLASOBJS = $(SHBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
endif
-FUNCOBJS = $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
+FUNCOBJS = $(SHBLASOBJS) $(SBLASOBJS) $(DBLASOBJS) $(CBLASOBJS) $(ZBLASOBJS)
ifdef EXPRECISION
FUNCOBJS += $(QBLASOBJS) $(XBLASOBJS)
level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
-level3 : $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
+level3 : $(SHBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^
-$(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \
+$(CSHBLASOBJS) $(CSHBLASOBJS_P) $(CSBLASOBJS) $(CSBLASOBJS_P) $(CDBLASOBJS) $(CDBLASOBJS_P) $(CQBLASOBJS) $(CQBLASOBJS_P) \
$(CCBLASOBJS) $(CCBLASOBJS_P) $(CZBLASOBJS) $(CZBLASOBJS_P) $(CXBLASOBJS) $(CXBLASOBJS_P) : override CFLAGS += -DCBLAS
srot.$(SUFFIX) srot.$(PSUFFIX) : rot.c
xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c
$(CC) -c $(CFLAGS) $< -o $(@F)
+shgemm.$(SUFFIX) shgemm.$(PSUFFIX) : gemm.c ../param.h
+ $(CC) -c $(CFLAGS) $< -o $(@F)
+
sgemm.$(SUFFIX) sgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -c $(CFLAGS) $< -o $(@F)
cblas_sgemm.$(SUFFIX) cblas_sgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
+cblas_shgemm.$(SUFFIX) cblas_shgemm.$(PSUFFIX) : gemm.c ../param.h
+ $(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
+
cblas_dgemm.$(SUFFIX) cblas_dgemm.$(PSUFFIX) : gemm.c ../param.h
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
#define GEMM_MULTITHREAD_THRESHOLD 4
#endif
-static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
+static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
#ifndef GEMM3M
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
void NAME(char *TRANSA, char *TRANSB,
blasint *M, blasint *N, blasint *K,
FLOAT *alpha,
- FLOAT *a, blasint *ldA,
- FLOAT *b, blasint *ldB,
+ IFLOAT *a, blasint *ldA,
+ IFLOAT *b, blasint *ldB,
FLOAT *beta,
FLOAT *c, blasint *ldC){
blasint info;
char transA, transB;
- FLOAT *buffer;
- FLOAT *sa, *sb;
+ IFLOAT *buffer;
+ IFLOAT *sa, *sb;
#ifdef SMP
double MNK;
USE_TRMM = 1
endif
+SHKERNELOBJS += \
+ shgemm_kernel$(TSUFFIX).$(SUFFIX) \
+ $(SHGEMMINCOPYOBJ) $(SHGEMMITCOPYOBJ) \
+ $(SHGEMMONCOPYOBJ) $(SHGEMMOTCOPYOBJ)
SKERNELOBJS += \
sgemm_kernel$(TSUFFIX).$(SUFFIX) \
$(XGEMMINCOPYOBJ) $(XGEMMITCOPYOBJ) \
$(XGEMMONCOPYOBJ) $(XGEMMOTCOPYOBJ)
+SHBLASOBJS += $(SHKERNELOBJS)
SBLASOBJS += $(SKERNELOBJS)
DBLASOBJS += $(DKERNELOBJS)
QBLASOBJS += $(QKERNELOBJS)
ZBLASOBJS += $(ZKERNELOBJS)
XBLASOBJS += $(XKERNELOBJS)
+SHBLASOBJS += shgemm_beta$(TSUFFIX).$(SUFFIX)
SBLASOBJS += \
sgemm_beta$(TSUFFIX).$(SUFFIX) \
strmm_kernel_LN$(TSUFFIX).$(SUFFIX) strmm_kernel_LT$(TSUFFIX).$(SUFFIX) \
zgeadd_k$(TSUFFIX).$(SUFFIX)
+SHGEMMINCOPYOBJ_P = $(SHGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+SHGEMMITCOPYOBJ_P = $(SHGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+SHGEMMONCOPYOBJ_P = $(SHGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+SHGEMMOTCOPYOBJ_P = $(SHGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMINCOPYOBJ_P = $(SGEMMINCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMITCOPYOBJ_P = $(SGEMMITCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
SGEMMONCOPYOBJ_P = $(SGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
XGEMMONCOPYOBJ_P = $(XGEMMONCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
XGEMMOTCOPYOBJ_P = $(XGEMMOTCOPYOBJ:.$(SUFFIX)=.$(PSUFFIX))
+$(KDIR)shgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
$(KDIR)sgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
$(KDIR)xgemm_beta$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
+$(KDIR)$(SHGEMMONCOPYOBJ) : $(KERNELDIR)/$(SHGEMMONCOPY)
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(KDIR)$(SHGEMMOTCOPYOBJ) : $(KERNELDIR)/$(SHGEMMOTCOPY)
+ifeq ($(OS), AIX)
+ $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmotcopy.s
+ m4 shgemmotcopy.s > shgemmotcopy_nomacros.s
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmotcopy_nomacros.s -o $@
+ rm shgemmotcopy.s shgemmotcopy_nomacros.s
+else
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
+ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
+
+$(KDIR)$(SHGEMMINCOPYOBJ) : $(KERNELDIR)/$(SHGEMMINCOPY)
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(KDIR)$(SHGEMMITCOPYOBJ) : $(KERNELDIR)/$(SHGEMMITCOPY)
+ifeq ($(OS), AIX)
+ $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemmitcopy.s
+ m4 shgemmitcopy.s > shgemmitcopy_nomacros.s
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemmitcopy_nomacros.s -o $@
+ rm shgemmitcopy.s shgemmitcopy_nomacros.s
+else
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
+endif
+
$(KDIR)$(SGEMMONCOPYOBJ) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
endif
+$(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
+ifeq ($(OS), AIX)
+ $(CC) $(CFLAGS) -E -DHALF -UDOUBLE -UCOMPLEX $< -o shgemm_kernel$(TSUFFIX).s
+ m4 shgemm_kernel$(TSUFFIX).s > shgemm_kernel$(TSUFFIX)_nomacros.s
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX shgemm_kernel$(TSUFFIX)_nomacros.s -o $@
+ rm shgemm_kernel$(TSUFFIX).s shgemm_kernel$(TSUFFIX)_nomacros.s
+else
+ $(CC) $(CFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
$(KDIR)dgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMMKERNEL) $(DGEMMDEPEND)
ifeq ($(OS), AIX)
$(CC) $(CFLAGS) -E -DDOUBLE -UCOMPLEX $< -o dgemm_kernel$(TSUFFIX).s
$(KDIR)sgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMM_BETA)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
+$(KDIR)shgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMM_BETA)
+ $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
$(KDIR)dgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMM_BETA)
$(CC) $(PFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@
$(KDIR)xgemm_beta$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(XGEMM_BETA)
$(CC) $(PFLAGS) -c -DXDOUBLE -DCOMPLEX $< -o $@
+$(SHGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMONCOPY)
+ $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(SHGEMMOTCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMOTCOPY)
+ $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+ifneq ($(SHGEMM_UNROLL_M), $(SHGEMM_UNROLL_N))
+$(SHGEMMINCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMINCOPY)
+ $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+$(SHGEMMITCOPYOBJ_P) : $(KERNELDIR)/$(SHGEMMITCOPY)
+ $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
+endif
$(SGEMMONCOPYOBJ_P) : $(KERNELDIR)/$(SGEMMONCOPY)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
endif
+$(KDIR)shgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
+ $(CC) $(PFLAGS) -c -DHALF -UDOUBLE -UCOMPLEX $< -o $@
+
$(KDIR)sgemm_kernel$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(SGEMMKERNEL) $(SGEMMDEPEND)
$(CC) $(PFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
#include "common.h"
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta,
- FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5,
+ IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5,
FLOAT *c, BLASLONG ldc){
#include <stdio.h>
#include "common.h"
-int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
+int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
BLASLONG i, j;
- FLOAT *a_offset, *a_offset1, *a_offset2;
- FLOAT *b_offset;
+ IFLOAT *a_offset, *a_offset1, *a_offset2;
+ IFLOAT *b_offset;
a_offset = a;
b_offset = b;
#include <stdio.h>
#include "common.h"
-int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b){
+int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
BLASLONG i, j;
- FLOAT *a_offset, *a_offset1, *a_offset2;
- FLOAT *b_offset, *b_offset1, *b_offset2;
+ IFLOAT *a_offset, *a_offset1, *a_offset2;
+ IFLOAT *b_offset, *b_offset1, *b_offset2;
a_offset = a;
b_offset = b;
#include "common.h"
-int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FLOAT* C,BLASLONG ldc
+#if defined(HALF) && defined(HALFCONVERSION)
+float
+bfloat16tof32 (bfloat16 f16)
+{
+ float result = 0;
+ unsigned short* q = (unsigned short*)(&result);
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ q[0] = f16;
+#else
+ q[1] = f16;
+#endif
+ return result;
+}
+#define BF16TOF32(x) (bfloat16tof32(x))
+#else
+#define BF16TOF32(x) x
+#endif
+int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,IFLOAT* ba,IFLOAT* bb,FLOAT* C,BLASLONG ldc
#ifdef TRMMKERNEL
,BLASLONG offset
#endif
)
{
BLASLONG i,j,k;
- FLOAT *C0,*C1,*ptrba,*ptrbb;
- FLOAT res0,res1,res2,res3,load0,load1,load2,load3,load4,load5,load6,load7;
+ FLOAT *C0,*C1;
+ IFLOAT *ptrba,*ptrbb;
+ FLOAT res0,res1,res2,res3;
+ IFLOAT load0,load1,load2,load3,load4,load5,load6,load7;
for (j=0; j<bn/2; j+=1)
{
C0 = C;
{
load0 = ptrba[2*0+0];
load1 = ptrbb[2*0+0];
- res0 = res0+load0*load1;
+ res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*0+1];
- res1 = res1+load2*load1;
+ res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
load3 = ptrbb[2*0+1];
- res2 = res2+load0*load3;
- res3 = res3+load2*load3;
+ res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
+ res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
load4 = ptrba[2*1+0];
load5 = ptrbb[2*1+0];
- res0 = res0+load4*load5;
+ res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
load6 = ptrba[2*1+1];
- res1 = res1+load6*load5;
+ res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
load7 = ptrbb[2*1+1];
- res2 = res2+load4*load7;
- res3 = res3+load6*load7;
+ res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
+ res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
load0 = ptrba[2*2+0];
load1 = ptrbb[2*2+0];
- res0 = res0+load0*load1;
+ res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*2+1];
- res1 = res1+load2*load1;
+ res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
load3 = ptrbb[2*2+1];
- res2 = res2+load0*load3;
- res3 = res3+load2*load3;
+ res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
+ res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
load4 = ptrba[2*3+0];
load5 = ptrbb[2*3+0];
- res0 = res0+load4*load5;
+ res0 = res0+BF16TOF32(load4)*BF16TOF32(load5);
load6 = ptrba[2*3+1];
- res1 = res1+load6*load5;
+ res1 = res1+BF16TOF32(load6)*BF16TOF32(load5);
load7 = ptrbb[2*3+1];
- res2 = res2+load4*load7;
- res3 = res3+load6*load7;
+ res2 = res2+BF16TOF32(load4)*BF16TOF32(load7);
+ res3 = res3+BF16TOF32(load6)*BF16TOF32(load7);
ptrba = ptrba+8;
ptrbb = ptrbb+8;
}
{
load0 = ptrba[2*0+0];
load1 = ptrbb[2*0+0];
- res0 = res0+load0*load1;
+ res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*0+1];
- res1 = res1+load2*load1;
+ res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
load3 = ptrbb[2*0+1];
- res2 = res2+load0*load3;
- res3 = res3+load2*load3;
+ res2 = res2+BF16TOF32(load0)*BF16TOF32(load3);
+ res3 = res3+BF16TOF32(load2)*BF16TOF32(load3);
ptrba = ptrba+2;
ptrbb = ptrbb+2;
}
{
load0 = ptrba[0+0];
load1 = ptrbb[2*0+0];
- res0 = res0+load0*load1;
+ res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrbb[2*0+1];
- res1 = res1+load0*load2;
+ res1 = res1+BF16TOF32(load0)*BF16TOF32(load2);
ptrba = ptrba+1;
ptrbb = ptrbb+2;
}
{
load0 = ptrba[2*0+0];
load1 = ptrbb[0+0];
- res0 = res0+load0*load1;
+ res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
load2 = ptrba[2*0+1];
- res1 = res1+load2*load1;
+ res1 = res1+BF16TOF32(load2)*BF16TOF32(load1);
ptrba = ptrba+2;
ptrbb = ptrbb+1;
}
{
load0 = ptrba[0+0];
load1 = ptrbb[0+0];
- res0 = res0+load0*load1;
+ res0 = res0+BF16TOF32(load0)*BF16TOF32(load1);
ptrba = ptrba+1;
ptrbb = ptrbb+1;
}
CTRMMKERNEL = cgemm_kernel_power9.S
ZTRMMKERNEL = zgemm_kernel_power9.S
+SHGEMM_BETA = ../generic/gemm_beta.c
+SHGEMMKERNEL = ../generic/gemmkernel_2x2.c
+SHGEMMINCOPY = ../generic/gemm_ncopy_2.c
+SHGEMMITCOPY = ../generic/gemm_tcopy_2.c
+SHGEMMONCOPY = ../generic/gemm_ncopy_2.c
+SHGEMMOTCOPY = ../generic/gemm_tcopy_2.c
+SHGEMMINCOPYOBJ = shgemm_incopy$(TSUFFIX).$(SUFFIX)
+SHGEMMITCOPYOBJ = shgemm_itcopy$(TSUFFIX).$(SUFFIX)
+SHGEMMONCOPYOBJ = shgemm_oncopy$(TSUFFIX).$(SUFFIX)
+SHGEMMOTCOPYOBJ = shgemm_otcopy$(TSUFFIX).$(SUFFIX)
+
SGEMMKERNEL = sgemm_kernel_power9.S
SGEMMINCOPY = ../generic/gemm_ncopy_16.c
SGEMMITCOPY = sgemm_tcopy_16_power8.S
GEMM_DEFAULT_OFFSET_A, GEMM_DEFAULT_OFFSET_B, GEMM_DEFAULT_ALIGN,
0, 0, 0,
+ SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N,
+#ifdef SHGEMM_DEFAULT_UNROLL_MN
+ SHGEMM_DEFAULT_UNROLL_MN,
+#else
+ MAX(SHGEMM_DEFAULT_UNROLL_M, SHGEMM_DEFAULT_UNROLL_N),
+#endif
+ shgemm_kernelTS, shgemm_betaTS,
+#if SHGEMM_DEFAULT_UNROLL_M != SHGEMM_DEFAULT_UNROLL_N
+ shgemm_incopyTS, shgemm_itcopyTS,
+#else
+ shgemm_oncopyTS, shgemm_otcopyTS,
+#endif
+ shgemm_oncopyTS, shgemm_otcopyTS,
+ sgemm_kernelTS, sgemm_betaTS,
SGEMM_DEFAULT_UNROLL_M, SGEMM_DEFAULT_UNROLL_N,
#ifdef SGEMM_DEFAULT_UNROLL_MN
SGEMM_DEFAULT_UNROLL_MN,
#if defined(ARCH_ARM64)
static void init_parameter(void) {
+ TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
+ TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
TABLE_NAME.zgemm_q = ZGEMM_DEFAULT_Q;
+ TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
#if defined(ARCH_POWER)
static void init_parameter(void) {
+ TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
+ TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
+ TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
#if defined(ARCH_ZARCH)
static void init_parameter(void) {
+ TABLE_NAME.shgemm_p = SHGEMM_DEFAULT_P;
TABLE_NAME.sgemm_p = SGEMM_DEFAULT_P;
TABLE_NAME.dgemm_p = DGEMM_DEFAULT_P;
TABLE_NAME.cgemm_p = CGEMM_DEFAULT_P;
TABLE_NAME.zgemm_p = ZGEMM_DEFAULT_P;
+ TABLE_NAME.shgemm_r = SHGEMM_DEFAULT_R;
TABLE_NAME.sgemm_r = SGEMM_DEFAULT_R;
TABLE_NAME.dgemm_r = DGEMM_DEFAULT_R;
TABLE_NAME.cgemm_r = CGEMM_DEFAULT_R;
TABLE_NAME.zgemm_r = ZGEMM_DEFAULT_R;
+ TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
(void) l2; /* dirty trick to suppress unused variable warning for targets */
/* where the GEMM unrolling parameters do not depend on l2 */
+ TABLE_NAME.shgemm_q = SHGEMM_DEFAULT_Q;
TABLE_NAME.sgemm_q = SGEMM_DEFAULT_Q;
TABLE_NAME.dgemm_q = DGEMM_DEFAULT_Q;
TABLE_NAME.cgemm_q = CGEMM_DEFAULT_Q;
+ TABLE_NAME.shgemm_p = ((TABLE_NAME.shgemm_p + SHGEMM_DEFAULT_UNROLL_M - 1)/SHGEMM_DEFAULT_UNROLL_M) * SHGEMM_DEFAULT_UNROLL_M;
TABLE_NAME.sgemm_p = ((TABLE_NAME.sgemm_p + SGEMM_DEFAULT_UNROLL_M - 1)/SGEMM_DEFAULT_UNROLL_M) * SGEMM_DEFAULT_UNROLL_M;
TABLE_NAME.dgemm_p = ((TABLE_NAME.dgemm_p + DGEMM_DEFAULT_UNROLL_M - 1)/DGEMM_DEFAULT_UNROLL_M) * DGEMM_DEFAULT_UNROLL_M;
TABLE_NAME.cgemm_p = ((TABLE_NAME.cgemm_p + CGEMM_DEFAULT_UNROLL_M - 1)/CGEMM_DEFAULT_UNROLL_M) * CGEMM_DEFAULT_UNROLL_M;
fprintf(stderr, "L2 = %8d DGEMM_P .. %d\n", l2, TABLE_NAME.dgemm_p);
#endif
+ TABLE_NAME.shgemm_r = (((BUFFER_SIZE -
+ ((TABLE_NAME.shgemm_p * TABLE_NAME.shgemm_q * 4 + TABLE_NAME.offsetA
+ + TABLE_NAME.align) & ~TABLE_NAME.align)
+ ) / (TABLE_NAME.shgemm_q * 4) - 15) & ~15);
+
TABLE_NAME.sgemm_r = (((BUFFER_SIZE -
((TABLE_NAME.sgemm_p * TABLE_NAME.sgemm_q * 4 + TABLE_NAME.offsetA
+ TABLE_NAME.align) & ~TABLE_NAME.align)
#elif defined(DOUBLE)
mode = BLAS_DOUBLE | BLAS_REAL;
mask = MAX(DGEMM_UNROLL_M, DGEMM_UNROLL_N) - 1;
+#elif defined(HALF)
+ mode = BLAS_HALF | BLAS_REAL;
+ mask = MAX(SHGEMM_UNROLL_M, SHGEMM_UNROLL_N) - 1;
#else
mode = BLAS_SINGLE | BLAS_REAL;
mask = MAX(SGEMM_UNROLL_M, SGEMM_UNROLL_N) - 1;
#ifndef PARAM_H
#define PARAM_H
+#define SHGEMM_DEFAULT_UNROLL_N 4
+#define SHGEMM_DEFAULT_UNROLL_M 8
+#define SHGEMM_DEFAULT_UNROLL_MN 32
+#define SHGEMM_DEFAULT_P 256
+#define SHGEMM_DEFAULT_R 256
+#define SHGEMM_DEFAULT_Q 256
#ifdef OPTERON
#define SNUMOPT 4
--- /dev/null
+/***************************************************************************
+Copyright (c) 2020, The OpenBLAS Project
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+1. Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in
+the documentation and/or other materials provided with the
+distribution.
+3. Neither the name of the OpenBLAS project nor the names of
+its contributors may be used to endorse or promote products
+derived from this software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*****************************************************************************/
+#include <stdio.h>
+#include <stdint.h>
+#include "common.h"
+#define SGEMM BLASFUNC(sgemm)
+#define SHGEMM BLASFUNC(shgemm)
+typedef union
+{
+ unsigned short v;
+ struct
+ {
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+ unsigned short s:1;
+ unsigned short e:8;
+ unsigned short m:7;
+#else
+ unsigned short m:7;
+ unsigned short e:8;
+ unsigned short s:1;
+#endif
+ } bits;
+} bfloat16_bits;
+
+int
+main (int argc, char *argv[])
+{
+ int m, n, k;
+ int i, j, l;
+ int ret = 0;
+ int loop = 20;
+ char transA = 'N', transB = 'N';
+ float alpha = 1.0, beta = 0.0;
+ char transa = 'N';
+ char transb = 'N';
+
+ for (int x = 0; x <= loop; x++)
+ {
+ m = k = n = x;
+ float A[m * k];
+ float B[k * n];
+ float C[m * n];
+ bfloat16_bits AA[m * k], BB[k * n];
+ float CC[m * n];
+
+ for (int j = 0; j < m; j++)
+ {
+ for (int i = 0; i < m; i++)
+ {
+ A[j * k + i] = j * 9.0;
+ B[j * k + i] = i * 2.0;
+ C[j * k + i] = 0;
+ AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
+ BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
+ CC[j * k + i] = 0;
+ }
+ }
+ SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
+ &m, B, &k, &beta, C, &m);
+ SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
+ &m, BB, &k, &beta, CC, &m);
+
+ for (i = 0; i < n; i++)
+ for (j = 0; j < m; j++)
+ for (l = 0; l < k; l++)
+ if (CC[i * m + j] != C[i * m + j])
+ ret++;
+ }
+ fprintf (stderr, "Return code: %d\n", ret);
+ return ret;
+}