[WIP] Refactor the driver code for direct SGEMM (#2782)
authorMartin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Wed, 19 Aug 2020 12:51:09 +0000 (14:51 +0200)
committerGitHub <noreply@github.com>
Wed, 19 Aug 2020 12:51:09 +0000 (14:51 +0200)
Move "direct SGEMM" functionality out of the SkylakeX SGEMM kernel and make it available
(on x86_64 targets only for now) in DYNAMIC_ARCH builds
* Add  sgemm_direct targets in the kernel Makefile.L3 and CMakeLists.txt
* Add direct_sgemm functions to the gotoblas struct in common_param.h
* Move sgemm_direct_performant helper to separate file
* Update gemm.c  to macros for sgemm_direct to support dynamic_arch naming via common_s,h
* (Conditionally) add sgemm_direct functions in setparam-ref.c

common_level3.h
common_param.h
common_s.h
interface/gemm.c
kernel/CMakeLists.txt
kernel/Makefile.L3
kernel/setparam-ref.c
kernel/x86_64/sgemm_direct_performant.c [new file with mode: 0644]
kernel/x86_64/sgemm_direct_skylakex.c
kernel/x86_64/sgemm_kernel_16x4_skylakex_3.c

index 4e44a5e..671a7a0 100644 (file)
@@ -47,12 +47,12 @@ __global__ void cuda_dgemm_kernel(int, int, int, double *, double *, double *);
 extern "C" {
 #endif
 
-extern void sgemm_kernel_direct(BLASLONG M, BLASLONG N, BLASLONG K,
+void sgemm_direct(BLASLONG M, BLASLONG N, BLASLONG K,
        float * A, BLASLONG strideA,
        float * B, BLASLONG strideB,
        float * R, BLASLONG strideR);
 
-extern int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
+int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
 
 
 int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
index c92609a..0437482 100644 (file)
@@ -175,6 +175,11 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG);
   int    (*ssymv_L) (BLASLONG, BLASLONG, float,  float  *, BLASLONG, float  *, BLASLONG, float  *, BLASLONG, float *);
   int    (*ssymv_U) (BLASLONG, BLASLONG, float,  float  *, BLASLONG, float  *, BLASLONG, float  *, BLASLONG, float *);
 
+#ifdef ARCH_X86_64
+  void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG);
+  int  (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K);
+#endif
+  
   int    (*sgemm_kernel   )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG);
   int    (*sgemm_beta     )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float  *, BLASLONG);
 
index 23c432f..34903ec 100644 (file)
 #define SSYMV_THREAD_U         ssymv_thread_U
 #define SSYMV_THREAD_L         ssymv_thread_L
 
+
+#define SGEMM_DIRECT_PERFORMANT    sgemm_direct_performant
+#define SGEMM_DIRECT           sgemm_direct
+
 #define        SGEMM_ONCOPY            sgemm_oncopy
 #define        SGEMM_OTCOPY            sgemm_otcopy
 
 #define SSYMV_THREAD_U         ssymv_thread_U
 #define SSYMV_THREAD_L         ssymv_thread_L
 
+#ifdef ARCH_X86_64
+#define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant
+#define  SGEMM_DIRECT          gotoblas -> sgemm_direct
+#else
+#define SGEMM_DIRECT_PERFORMANT    sgemm_direct_performant
+#define  SGEMM_DIRECT          sgemm_direct
+#endif
+
 #define        SGEMM_ONCOPY            gotoblas -> sgemm_oncopy
 #define        SGEMM_OTCOPY            gotoblas -> sgemm_otcopy
 #define        SGEMM_INCOPY            gotoblas -> sgemm_incopy
index 99388e7..860e588 100644 (file)
@@ -275,8 +275,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
 #ifdef DYNAMIC_ARCH
  if (support_avx512() )
 #endif  
-  if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && sgemm_kernel_direct_performant(m,n,k)) {
-       sgemm_kernel_direct(m, n, k, a, lda, b, ldb, c, ldc);
+  if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
+       SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
        return;
   }
 
index d1349c5..d9fba6a 100644 (file)
@@ -134,6 +134,20 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
       set(USE_TRMM true)
     endif ()
 
+    set(USE_DIRECT_SGEMM false)
+    if (X86_64)
+       set(USE_DIRECT_SGEMM true)
+    endif()
+
+    if (USE_DIRECT_SGEMM)
+           #       if (NOT DEFINED SGEMMDIRECTKERNEL)
+         set (SGEMMDIRECTKERNEL sgemm_direct_skylakex.c)
+         set (SGEMMDIRECTPERFORMANT sgemm_direct_performant.c)
+         # endif()
+         GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTKERNEL}" "" "gemm_direct" false "" "" false SINGLE)
+         GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false  SINGLE)
+    endif()
+
     foreach (float_type SINGLE DOUBLE HALF)
       string(SUBSTRING ${float_type} 0 1 float_char)
       if (${float_type} STREQUAL "HALF")
index 8df306d..a176b47 100644 (file)
@@ -9,6 +9,10 @@ ifeq ($(ARCH), x86_64)
 USE_GEMM3M = 1
 endif
 
+ifeq ($(ARCH), x86_64)
+USE_DIRECT_SGEMM = 1
+endif
+
 ifeq ($(ARCH), ia64)
 USE_GEMM3M = 1
 endif
@@ -65,6 +69,13 @@ ifeq ($(CORE), Z14)
 USE_TRMM = 1
 endif
 
+ifdef USE_DIRECT_SGEMM
+ifndef SGEMMDIRECTKERNEL
+SGEMMDIRECTKERNEL = sgemm_direct_skylakex.c
+SGEMMDIRECTPERFORMANT = sgemm_direct_performant.c
+endif
+endif
+
 ifeq ($(BUILD_HALF), 1)
 ifndef SHGEMMKERNEL
 SHGEMM_BETA = ../generic/gemm_beta.c
@@ -90,6 +101,12 @@ SKERNELOBJS += \
        $(SGEMMINCOPYOBJ) $(SGEMMITCOPYOBJ) \
        $(SGEMMONCOPYOBJ) $(SGEMMOTCOPYOBJ)
 
+ifdef USE_DIRECT_SGEMM
+SKERNELOBJS += \
+       sgemm_direct$(TSUFFIX).$(SUFFIX) \
+       sgemm_direct_performant$(TSUFFIX).$(SUFFIX) 
+endif
+
 DKERNELOBJS    += \
        dgemm_kernel$(TSUFFIX).$(SUFFIX) \
        $(DGEMMINCOPYOBJ) $(DGEMMITCOPYOBJ) \
@@ -668,6 +685,13 @@ else
        $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
 endif
 
+ifdef USE_DIRECT_SGEMM
+$(KDIR)sgemm_direct_performant$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTPERFORMANT)
+       $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
+$(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL)
+       $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@
+endif
+
 ifeq ($(BUILD_HALF), 1)
 
 $(KDIR)shgemm_kernel$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SHGEMMKERNEL) $(SHGEMMDEPEND)
index d3aa030..d384500 100644 (file)
@@ -135,6 +135,11 @@ gotoblas_t TABLE_NAME = {
   sgemv_nTS,  sgemv_tTS, sger_kTS,
   ssymv_LTS, ssymv_UTS,
 
+#ifdef ARCH_X86_64
+  sgemm_directTS,
+  sgemm_direct_performantTS,   
+#endif
+       
   sgemm_kernelTS, sgemm_betaTS,
 #if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N
   sgemm_incopyTS, sgemm_itcopyTS,
diff --git a/kernel/x86_64/sgemm_direct_performant.c b/kernel/x86_64/sgemm_direct_performant.c
new file mode 100644 (file)
index 0000000..5a20ce3
--- /dev/null
@@ -0,0 +1,30 @@
+#include "common.h"
+/* helper for the direct sgemm code written by Arjan van der Ven */
+
+
+
+
+int CNAME(BLASLONG M, BLASLONG N, BLASLONG K)
+{
+       unsigned long long mnk = M * N * K;
+       /* large matrixes -> not performant */
+       if (mnk >= 28 * 512 * 512)
+               return 0;
+
+       /*
+        * if the B matrix is not a nice multiple if 4 we get many unaligned accesses,
+        * and the regular sgemm copy/realignment of data pays off much quicker
+        */
+       if ((N & 3) != 0 && (mnk >= 8 * 512 * 512))
+               return 0;
+
+#ifdef SMP
+       /* if we can run multithreaded, the threading changes the based threshold */
+       if (mnk > 2 * 350 * 512 && num_cpu_avail(3)> 1)
+               return 0;
+#endif
+
+       return 1;
+}
+
+
index 0e8f131..a7cddbb 100644 (file)
@@ -1,7 +1,7 @@
-
+#if defined(SKYLAKEX) || defined (COOPERLAKE)
 /* the direct sgemm code written by Arjan van der Ven */
-//#include <immintrin.h>
-
+#include <immintrin.h>
+#include "common.h"
 /*
  * "Direct sgemm" code. This code operates directly on the inputs and outputs
  * of the sgemm call, avoiding the copies, memory realignments and threading,
@@ -38,6 +38,7 @@
 #define MATMUL_SCALAR(N,M) result##N##M +=  Aval##M * Bval##N;
 #define STORE_SCALAR(N,M)  R[(i+M) * strideR + j + N] = result##N##M;
 
+#if 0
 int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K)
 {
        unsigned long long mnk = M * N * K;
@@ -61,9 +62,10 @@ int sgemm_kernel_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K)
        return 1;
 }
 
+#endif
 
-
-void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
+//void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
+void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
 {
        int i, j, k;
 
@@ -465,3 +467,8 @@ void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict
                }
        }
 }
+#else
+#include "common.h"
+void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR)
+{}
+#endif
index 3b1af33..f3d6142 100644 (file)
@@ -512,4 +512,4 @@ CNAME(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float * __restrict__ A, f
     return 0;
 }
 #include <immintrin.h>
-#include "sgemm_direct_skylakex.c"
+//#include "sgemm_direct_skylakex.c"