1 /*********************************************************************/
2 /* Copyright 2009, 2010 The University of Texas at Austin. */
3 /* All rights reserved. */
5 /* Redistribution and use in source and binary forms, with or */
6 /* without modification, are permitted provided that the following */
7 /* conditions are met: */
9 /* 1. Redistributions of source code must retain the above */
10 /* copyright notice, this list of conditions and the following */
13 /* 2. Redistributions in binary form must reproduce the above */
14 /* copyright notice, this list of conditions and the following */
15 /* disclaimer in the documentation and/or other materials */
16 /* provided with the distribution. */
18 /* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
19 /* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
20 /* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
21 /* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
22 /* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
23 /* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
24 /* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
25 /* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
26 /* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
27 /* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
28 /* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
29 /* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
30 /* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
31 /* POSSIBILITY OF SUCH DAMAGE. */
33 /* The views and conclusions contained in the software and */
34 /* documentation are those of the authors and should not be */
35 /* interpreted as representing official policies, either expressed */
36 /* or implied, of The University of Texas at Austin. */
37 /*********************************************************************/
42 #ifdef FUNCTION_PROFILE
43 #include "functable.h"
47 #define SMP_THRESHOLD_MIN 65536.0
49 #define ERROR_NAME "QGEMM "
51 #define ERROR_NAME "DGEMM "
52 #elif defined(BFLOAT16)
53 #define ERROR_NAME "SBGEMM "
55 #define ERROR_NAME "SGEMM "
58 #define SMP_THRESHOLD_MIN 8192.0
61 #define ERROR_NAME "XGEMM "
63 #define ERROR_NAME "ZGEMM "
65 #define ERROR_NAME "CGEMM "
69 #define ERROR_NAME "XGEMM3M "
71 #define ERROR_NAME "ZGEMM3M "
73 #define ERROR_NAME "CGEMM3M "
78 #ifndef GEMM_MULTITHREAD_THRESHOLD
79 #define GEMM_MULTITHREAD_THRESHOLD 4
82 static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
84 GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
85 GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
86 GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
87 GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
88 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
89 GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
90 GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
91 GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
92 GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
95 GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
96 GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
97 GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
98 GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
99 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
100 GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
101 GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
102 GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
103 GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
108 #ifdef SMALL_MATRIX_OPT
109 //Only support s/dgemm small matrix optimiztion so far.
110 static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = {
113 GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL,
114 GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL,
119 static int (*gemm_small_kernel_a1b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = {
122 GEMM_SMALL_KERNEL_A1B0_NN, GEMM_SMALL_KERNEL_A1B0_TN, NULL, NULL,
123 GEMM_SMALL_KERNEL_A1B0_NT, GEMM_SMALL_KERNEL_A1B0_TT, NULL, NULL,
131 void NAME(char *TRANSA, char *TRANSB,
132 blasint *M, blasint *N, blasint *K,
134 IFLOAT *a, blasint *ldA,
135 IFLOAT *b, blasint *ldB,
137 FLOAT *c, blasint *ldC){
141 int transa, transb, nrowa, nrowb;
148 #if defined (SMP) || defined(SMALL_MATRIX_OPT)
150 #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
153 int mode = BLAS_XDOUBLE | BLAS_REAL;
154 #elif defined(DOUBLE)
155 int mode = BLAS_DOUBLE | BLAS_REAL;
157 int mode = BLAS_SINGLE | BLAS_REAL;
161 int mode = BLAS_XDOUBLE | BLAS_COMPLEX;
162 #elif defined(DOUBLE)
163 int mode = BLAS_DOUBLE | BLAS_COMPLEX;
165 int mode = BLAS_SINGLE | BLAS_COMPLEX;
171 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
189 args.alpha = (void *)alpha;
190 args.beta = (void *)beta;
201 if (transA == 'N') transa = 0;
202 if (transA == 'T') transa = 1;
204 if (transA == 'R') transa = 0;
205 if (transA == 'C') transa = 1;
207 if (transA == 'R') transa = 2;
208 if (transA == 'C') transa = 3;
211 if (transB == 'N') transb = 0;
212 if (transB == 'T') transb = 1;
214 if (transB == 'R') transb = 0;
215 if (transB == 'C') transb = 1;
217 if (transB == 'R') transb = 2;
218 if (transB == 'C') transb = 3;
222 if (transa & 1) nrowa = args.k;
224 if (transb & 1) nrowb = args.n;
228 if (args.ldc < args.m) info = 13;
229 if (args.ldb < nrowb) info = 10;
230 if (args.lda < nrowa) info = 8;
231 if (args.k < 0) info = 5;
232 if (args.n < 0) info = 4;
233 if (args.m < 0) info = 3;
234 if (transb < 0) info = 2;
235 if (transa < 0) info = 1;
238 BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
244 void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
245 blasint m, blasint n, blasint k,
248 FLOAT *a, blasint lda,
249 FLOAT *b, blasint ldb,
251 FLOAT *c, blasint ldc) {
254 void *va, blasint lda,
255 void *vb, blasint ldb,
257 void *vc, blasint ldc) {
258 FLOAT *alpha = (FLOAT*) valpha;
259 FLOAT *beta = (FLOAT*) vbeta;
260 FLOAT *a = (FLOAT*) va;
261 FLOAT *b = (FLOAT*) vb;
262 FLOAT *c = (FLOAT*) vc;
267 blasint nrowa, nrowb, info;
272 #if defined (SMP) || defined(SMALL_MATRIX_OPT)
277 #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
280 int mode = BLAS_XDOUBLE | BLAS_REAL;
281 #elif defined(DOUBLE)
282 int mode = BLAS_DOUBLE | BLAS_REAL;
284 int mode = BLAS_SINGLE | BLAS_REAL;
288 int mode = BLAS_XDOUBLE | BLAS_COMPLEX;
289 #elif defined(DOUBLE)
290 int mode = BLAS_DOUBLE | BLAS_COMPLEX;
292 int mode = BLAS_SINGLE | BLAS_COMPLEX;
298 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
304 #if !defined(COMPLEX) && !defined(DOUBLE) && defined(USE_SGEMM_KERNEL_DIRECT)
306 if (support_avx512() )
308 if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
309 SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
316 args.alpha = (void *)α
317 args.beta = (void *)β
319 args.alpha = (void *)alpha;
320 args.beta = (void *)beta;
327 if (order == CblasColMajor) {
340 if (TransA == CblasNoTrans) transa = 0;
341 if (TransA == CblasTrans) transa = 1;
343 if (TransA == CblasConjNoTrans) transa = 0;
344 if (TransA == CblasConjTrans) transa = 1;
346 if (TransA == CblasConjNoTrans) transa = 2;
347 if (TransA == CblasConjTrans) transa = 3;
349 if (TransB == CblasNoTrans) transb = 0;
350 if (TransB == CblasTrans) transb = 1;
352 if (TransB == CblasConjNoTrans) transb = 0;
353 if (TransB == CblasConjTrans) transb = 1;
355 if (TransB == CblasConjNoTrans) transb = 2;
356 if (TransB == CblasConjTrans) transb = 3;
360 if (transa & 1) nrowa = args.k;
362 if (transb & 1) nrowb = args.n;
366 if (args.ldc < args.m) info = 13;
367 if (args.ldb < nrowb) info = 10;
368 if (args.lda < nrowa) info = 8;
369 if (args.k < 0) info = 5;
370 if (args.n < 0) info = 4;
371 if (args.m < 0) info = 3;
372 if (transb < 0) info = 2;
373 if (transa < 0) info = 1;
376 if (order == CblasRowMajor) {
389 if (TransB == CblasNoTrans) transa = 0;
390 if (TransB == CblasTrans) transa = 1;
392 if (TransB == CblasConjNoTrans) transa = 0;
393 if (TransB == CblasConjTrans) transa = 1;
395 if (TransB == CblasConjNoTrans) transa = 2;
396 if (TransB == CblasConjTrans) transa = 3;
398 if (TransA == CblasNoTrans) transb = 0;
399 if (TransA == CblasTrans) transb = 1;
401 if (TransA == CblasConjNoTrans) transb = 0;
402 if (TransA == CblasConjTrans) transb = 1;
404 if (TransA == CblasConjNoTrans) transb = 2;
405 if (TransA == CblasConjTrans) transb = 3;
409 if (transa & 1) nrowa = args.k;
411 if (transb & 1) nrowb = args.n;
415 if (args.ldc < args.m) info = 13;
416 if (args.ldb < nrowb) info = 10;
417 if (args.lda < nrowa) info = 8;
418 if (args.k < 0) info = 5;
419 if (args.n < 0) info = 4;
420 if (args.m < 0) info = 3;
421 if (transb < 0) info = 2;
422 if (transa < 0) info = 1;
427 BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
433 if ((args.m == 0) || (args.n == 0)) return;
436 fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
437 args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
442 FUNCTION_PROFILE_START();
444 #if defined(SMP) || defined(SMALL_MATRIX_OPT)
445 MNK = (double) args.m * (double) args.n * (double) args.k;
448 #ifdef SMALL_MATRIX_OPT
449 #if !defined(COMPLEX)
450 //need to tune small matrices cases.
451 if(MNK <= 100.0*100.0*100.0){
453 if(*(FLOAT *)(args.alpha) == 1.0 && *(FLOAT *)(args.beta) == 0.0){
454 (gemm_small_kernel_a1b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda,args.b, args.ldb, args.c, args.ldc);
456 (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc);
465 buffer = (XFLOAT *)blas_memory_alloc(0);
467 sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
468 sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
471 #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
472 mode |= (transa << BLAS_TRANSA_SHIFT);
473 mode |= (transb << BLAS_TRANSB_SHIFT);
477 if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) )
480 args.nthreads = num_cpu_avail(3);
483 if (args.nthreads == 1) {
486 (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
492 #ifndef USE_SIMPLE_THREADED_LEVEL3
495 nodes = get_num_nodes();
497 if ((nodes > 1) && get_node_equal()) {
499 args.nthreads /= nodes;
501 gemm_thread_mn(mode, &args, NULL, NULL, gemm[16 | (transb << 2) | transa], sa, sb, nodes);
506 (gemm[16 | (transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
510 GEMM_THREAD(mode, &args, NULL, NULL, gemm[(transb << 2) | transa], sa, sb, args.nthreads);
514 #ifndef USE_SIMPLE_THREADED_LEVEL3
526 blas_memory_free(buffer);
528 FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, args.m * args.k + args.k * args.n + args.m * args.n, 2 * args.m * args.n * args.k);