1 /*********************************************************************/
2 /* Copyright 2009, 2010 The University of Texas at Austin. */
3 /* All rights reserved. */
5 /* Redistribution and use in source and binary forms, with or */
6 /* without modification, are permitted provided that the following */
7 /* conditions are met: */
9 /* 1. Redistributions of source code must retain the above */
10 /* copyright notice, this list of conditions and the following */
13 /* 2. Redistributions in binary form must reproduce the above */
14 /* copyright notice, this list of conditions and the following */
15 /* disclaimer in the documentation and/or other materials */
16 /* provided with the distribution. */
18 /* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
19 /* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
20 /* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
21 /* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
22 /* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
23 /* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
24 /* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
25 /* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
26 /* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
27 /* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
28 /* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
29 /* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
30 /* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
31 /* POSSIBILITY OF SUCH DAMAGE. */
33 /* The views and conclusions contained in the software and */
34 /* documentation are those of the authors and should not be */
35 /* interpreted as representing official policies, either expressed */
36 /* or implied, of The University of Texas at Austin. */
37 /*********************************************************************/
42 #ifdef FUNCTION_PROFILE
43 #include "functable.h"
47 #define SMP_THRESHOLD_MIN 65536.0
49 #define ERROR_NAME "QGEMM "
51 #define ERROR_NAME "DGEMM "
53 #define ERROR_NAME "SGEMM "
56 #define SMP_THRESHOLD_MIN 8192.0
59 #define ERROR_NAME "XGEMM "
61 #define ERROR_NAME "ZGEMM "
63 #define ERROR_NAME "CGEMM "
67 #define ERROR_NAME "XGEMM3M "
69 #define ERROR_NAME "ZGEMM3M "
71 #define ERROR_NAME "CGEMM3M "
76 #ifndef GEMM_MULTITHREAD_THRESHOLD
77 #define GEMM_MULTITHREAD_THRESHOLD 4
80 static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
82 GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
83 GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
84 GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
85 GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
86 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
87 GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
88 GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
89 GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
90 GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
93 GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
94 GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
95 GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
96 GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
97 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
98 GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
99 GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
100 GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
101 GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
108 void NAME(char *TRANSA, char *TRANSB,
109 blasint *M, blasint *N, blasint *K,
111 FLOAT *a, blasint *ldA,
112 FLOAT *b, blasint *ldB,
114 FLOAT *c, blasint *ldC){
118 int transa, transb, nrowa, nrowb;
129 int mode = BLAS_XDOUBLE | BLAS_REAL;
130 #elif defined(DOUBLE)
131 int mode = BLAS_DOUBLE | BLAS_REAL;
133 int mode = BLAS_SINGLE | BLAS_REAL;
137 int mode = BLAS_XDOUBLE | BLAS_COMPLEX;
138 #elif defined(DOUBLE)
139 int mode = BLAS_DOUBLE | BLAS_COMPLEX;
141 int mode = BLAS_SINGLE | BLAS_COMPLEX;
146 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
164 args.alpha = (void *)alpha;
165 args.beta = (void *)beta;
176 if (transA == 'N') transa = 0;
177 if (transA == 'T') transa = 1;
179 if (transA == 'R') transa = 0;
180 if (transA == 'C') transa = 1;
182 if (transA == 'R') transa = 2;
183 if (transA == 'C') transa = 3;
186 if (transB == 'N') transb = 0;
187 if (transB == 'T') transb = 1;
189 if (transB == 'R') transb = 0;
190 if (transB == 'C') transb = 1;
192 if (transB == 'R') transb = 2;
193 if (transB == 'C') transb = 3;
197 if (transa & 1) nrowa = args.k;
199 if (transb & 1) nrowb = args.n;
203 if (args.ldc < args.m) info = 13;
204 if (args.ldb < nrowb) info = 10;
205 if (args.lda < nrowa) info = 8;
206 if (args.k < 0) info = 5;
207 if (args.n < 0) info = 4;
208 if (args.m < 0) info = 3;
209 if (transb < 0) info = 2;
210 if (transa < 0) info = 1;
213 BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
219 void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
220 blasint m, blasint n, blasint k,
223 FLOAT *a, blasint lda,
224 FLOAT *b, blasint ldb,
226 FLOAT *c, blasint ldc) {
229 void *va, blasint lda,
230 void *vb, blasint ldb,
232 void *vc, blasint ldc) {
233 FLOAT *alpha = (FLOAT*) valpha;
234 FLOAT *beta = (FLOAT*) vbeta;
235 FLOAT *a = (FLOAT*) va;
236 FLOAT *b = (FLOAT*) vb;
237 FLOAT *c = (FLOAT*) vc;
242 blasint nrowa, nrowb, info;
251 int mode = BLAS_XDOUBLE | BLAS_REAL;
252 #elif defined(DOUBLE)
253 int mode = BLAS_DOUBLE | BLAS_REAL;
255 int mode = BLAS_SINGLE | BLAS_REAL;
259 int mode = BLAS_XDOUBLE | BLAS_COMPLEX;
260 #elif defined(DOUBLE)
261 int mode = BLAS_DOUBLE | BLAS_COMPLEX;
263 int mode = BLAS_SINGLE | BLAS_COMPLEX;
268 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
274 #if !defined(COMPLEX) && !defined(DOUBLE) && defined(USE_SGEMM_KERNEL_DIRECT)
275 if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && sgemm_kernel_direct_performant(m,n,k)) {
276 sgemm_kernel_direct(m, n, k, a, lda, b, ldb, c, ldc);
283 args.alpha = (void *)α
284 args.beta = (void *)β
286 args.alpha = (void *)alpha;
287 args.beta = (void *)beta;
294 if (order == CblasColMajor) {
307 if (TransA == CblasNoTrans) transa = 0;
308 if (TransA == CblasTrans) transa = 1;
310 if (TransA == CblasConjNoTrans) transa = 0;
311 if (TransA == CblasConjTrans) transa = 1;
313 if (TransA == CblasConjNoTrans) transa = 2;
314 if (TransA == CblasConjTrans) transa = 3;
316 if (TransB == CblasNoTrans) transb = 0;
317 if (TransB == CblasTrans) transb = 1;
319 if (TransB == CblasConjNoTrans) transb = 0;
320 if (TransB == CblasConjTrans) transb = 1;
322 if (TransB == CblasConjNoTrans) transb = 2;
323 if (TransB == CblasConjTrans) transb = 3;
327 if (transa & 1) nrowa = args.k;
329 if (transb & 1) nrowb = args.n;
333 if (args.ldc < args.m) info = 13;
334 if (args.ldb < nrowb) info = 10;
335 if (args.lda < nrowa) info = 8;
336 if (args.k < 0) info = 5;
337 if (args.n < 0) info = 4;
338 if (args.m < 0) info = 3;
339 if (transb < 0) info = 2;
340 if (transa < 0) info = 1;
343 if (order == CblasRowMajor) {
356 if (TransB == CblasNoTrans) transa = 0;
357 if (TransB == CblasTrans) transa = 1;
359 if (TransB == CblasConjNoTrans) transa = 0;
360 if (TransB == CblasConjTrans) transa = 1;
362 if (TransB == CblasConjNoTrans) transa = 2;
363 if (TransB == CblasConjTrans) transa = 3;
365 if (TransA == CblasNoTrans) transb = 0;
366 if (TransA == CblasTrans) transb = 1;
368 if (TransA == CblasConjNoTrans) transb = 0;
369 if (TransA == CblasConjTrans) transb = 1;
371 if (TransA == CblasConjNoTrans) transb = 2;
372 if (TransA == CblasConjTrans) transb = 3;
376 if (transa & 1) nrowa = args.k;
378 if (transb & 1) nrowb = args.n;
382 if (args.ldc < args.m) info = 13;
383 if (args.ldb < nrowb) info = 10;
384 if (args.lda < nrowa) info = 8;
385 if (args.k < 0) info = 5;
386 if (args.n < 0) info = 4;
387 if (args.m < 0) info = 3;
388 if (transb < 0) info = 2;
389 if (transa < 0) info = 1;
394 BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
400 if ((args.m == 0) || (args.n == 0)) return;
403 fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
404 args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
409 FUNCTION_PROFILE_START();
411 buffer = (XFLOAT *)blas_memory_alloc(0);
413 sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
414 sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
417 mode |= (transa << BLAS_TRANSA_SHIFT);
418 mode |= (transb << BLAS_TRANSB_SHIFT);
420 MNK = (double) args.m * (double) args.n * (double) args.k;
421 if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) )
424 args.nthreads = num_cpu_avail(3);
427 if (args.nthreads == 1) {
430 (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
436 #ifndef USE_SIMPLE_THREADED_LEVEL3
439 nodes = get_num_nodes();
441 if ((nodes > 1) && get_node_equal()) {
443 args.nthreads /= nodes;
445 gemm_thread_mn(mode, &args, NULL, NULL, gemm[16 | (transb << 2) | transa], sa, sb, nodes);
450 (gemm[16 | (transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
454 GEMM_THREAD(mode, &args, NULL, NULL, gemm[(transb << 2) | transa], sa, sb, args.nthreads);
458 #ifndef USE_SIMPLE_THREADED_LEVEL3
470 blas_memory_free(buffer);
472 FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, args.m * args.k + args.k * args.n + args.m * args.n, 2 * args.m * args.n * args.k);