Change a1b0 gemm to b0 gemm.
[platform/upstream/openblas.git] / interface / gemm.c
1 /*********************************************************************/
2 /* Copyright 2009, 2010 The University of Texas at Austin.           */
3 /* All rights reserved.                                              */
4 /*                                                                   */
5 /* Redistribution and use in source and binary forms, with or        */
6 /* without modification, are permitted provided that the following   */
7 /* conditions are met:                                               */
8 /*                                                                   */
9 /*   1. Redistributions of source code must retain the above         */
10 /*      copyright notice, this list of conditions and the following  */
11 /*      disclaimer.                                                  */
12 /*                                                                   */
13 /*   2. Redistributions in binary form must reproduce the above      */
14 /*      copyright notice, this list of conditions and the following  */
15 /*      disclaimer in the documentation and/or other materials       */
16 /*      provided with the distribution.                              */
17 /*                                                                   */
18 /*    THIS  SOFTWARE IS PROVIDED  BY THE  UNIVERSITY OF  TEXAS AT    */
19 /*    AUSTIN  ``AS IS''  AND ANY  EXPRESS OR  IMPLIED WARRANTIES,    */
20 /*    INCLUDING, BUT  NOT LIMITED  TO, THE IMPLIED  WARRANTIES OF    */
21 /*    MERCHANTABILITY  AND FITNESS FOR  A PARTICULAR  PURPOSE ARE    */
22 /*    DISCLAIMED.  IN  NO EVENT SHALL THE UNIVERSITY  OF TEXAS AT    */
23 /*    AUSTIN OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT,    */
24 /*    INCIDENTAL,  SPECIAL, EXEMPLARY,  OR  CONSEQUENTIAL DAMAGES    */
25 /*    (INCLUDING, BUT  NOT LIMITED TO,  PROCUREMENT OF SUBSTITUTE    */
26 /*    GOODS  OR  SERVICES; LOSS  OF  USE,  DATA,  OR PROFITS;  OR    */
27 /*    BUSINESS INTERRUPTION) HOWEVER CAUSED  AND ON ANY THEORY OF    */
28 /*    LIABILITY, WHETHER  IN CONTRACT, STRICT  LIABILITY, OR TORT    */
29 /*    (INCLUDING NEGLIGENCE OR OTHERWISE)  ARISING IN ANY WAY OUT    */
30 /*    OF  THE  USE OF  THIS  SOFTWARE,  EVEN  IF ADVISED  OF  THE    */
31 /*    POSSIBILITY OF SUCH DAMAGE.                                    */
32 /*                                                                   */
33 /* The views and conclusions contained in the software and           */
34 /* documentation are those of the authors and should not be          */
35 /* interpreted as representing official policies, either expressed   */
36 /* or implied, of The University of Texas at Austin.                 */
37 /*********************************************************************/
38
39 #include <stdio.h>
40 #include <stdlib.h>
41 #include "common.h"
42 #ifdef FUNCTION_PROFILE
43 #include "functable.h"
44 #endif
45
46 #ifndef COMPLEX
47 #define SMP_THRESHOLD_MIN 65536.0
48 #ifdef XDOUBLE
49 #define ERROR_NAME "QGEMM "
50 #elif defined(DOUBLE)
51 #define ERROR_NAME "DGEMM "
52 #elif defined(BFLOAT16)
53 #define ERROR_NAME "SBGEMM "
54 #else
55 #define ERROR_NAME "SGEMM "
56 #endif
57 #else
58 #define SMP_THRESHOLD_MIN 8192.0
59 #ifndef GEMM3M
60 #ifdef XDOUBLE
61 #define ERROR_NAME "XGEMM "
62 #elif defined(DOUBLE)
63 #define ERROR_NAME "ZGEMM "
64 #else
65 #define ERROR_NAME "CGEMM "
66 #endif
67 #else
68 #ifdef XDOUBLE
69 #define ERROR_NAME "XGEMM3M "
70 #elif defined(DOUBLE)
71 #define ERROR_NAME "ZGEMM3M "
72 #else
73 #define ERROR_NAME "CGEMM3M "
74 #endif
75 #endif
76 #endif
77
78 #ifndef GEMM_MULTITHREAD_THRESHOLD
79 #define GEMM_MULTITHREAD_THRESHOLD 4
80 #endif
81
82 static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = {
83 #ifndef GEMM3M
84   GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
85   GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
86   GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
87   GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
88 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
89   GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
90   GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
91   GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
92   GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
93 #endif
94 #else
95   GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
96   GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
97   GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
98   GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
99 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
100   GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
101   GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
102   GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
103   GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
104 #endif
105 #endif
106 };
107
108 #ifdef SMALL_MATRIX_OPT
109 //Only support s/dgemm small matrix optimiztion so far.
110 static int (*gemm_small_kernel[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT ,FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = {
111 #ifndef GEMM3M
112 #ifndef COMPLEX
113         GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, NULL, NULL,
114         GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, NULL, NULL,
115 #endif
116 #endif
117 };
118
119 static int (*gemm_small_kernel_b0[])(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG) = {
120 #ifndef GEMM3M
121 #ifndef COMPLEX
122         GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, NULL, NULL,
123         GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, NULL, NULL,
124 #endif
125 #endif
126 };
127 #endif
128
129 #ifndef CBLAS
130
131 void NAME(char *TRANSA, char *TRANSB,
132           blasint *M, blasint *N, blasint *K,
133           FLOAT *alpha,
134           IFLOAT *a, blasint *ldA,
135           IFLOAT *b, blasint *ldB,
136           FLOAT *beta,
137           FLOAT *c, blasint *ldC){
138
139   blas_arg_t args;
140
141   int transa, transb, nrowa, nrowb;
142   blasint info;
143
144   char transA, transB;
145   IFLOAT *buffer;
146   IFLOAT *sa, *sb;
147
148 #if defined (SMP) || defined(SMALL_MATRIX_OPT)
149   double MNK;
150 #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
151 #ifndef COMPLEX
152 #ifdef XDOUBLE
153   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
154 #elif defined(DOUBLE)
155   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
156 #else
157   int mode  =  BLAS_SINGLE  | BLAS_REAL;
158 #endif
159 #else
160 #ifdef XDOUBLE
161   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
162 #elif defined(DOUBLE)
163   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
164 #else
165   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
166 #endif
167 #endif
168 #endif
169 #endif
170
171 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
172   int nodes;
173 #endif
174
175   PRINT_DEBUG_NAME;
176
177   args.m = *M;
178   args.n = *N;
179   args.k = *K;
180
181   args.a = (void *)a;
182   args.b = (void *)b;
183   args.c = (void *)c;
184
185   args.lda = *ldA;
186   args.ldb = *ldB;
187   args.ldc = *ldC;
188
189   args.alpha = (void *)alpha;
190   args.beta  = (void *)beta;
191
192   transA = *TRANSA;
193   transB = *TRANSB;
194
195   TOUPPER(transA);
196   TOUPPER(transB);
197
198   transa = -1;
199   transb = -1;
200
201   if (transA == 'N') transa = 0;
202   if (transA == 'T') transa = 1;
203 #ifndef COMPLEX
204   if (transA == 'R') transa = 0;
205   if (transA == 'C') transa = 1;
206 #else
207   if (transA == 'R') transa = 2;
208   if (transA == 'C') transa = 3;
209 #endif
210
211   if (transB == 'N') transb = 0;
212   if (transB == 'T') transb = 1;
213 #ifndef COMPLEX
214   if (transB == 'R') transb = 0;
215   if (transB == 'C') transb = 1;
216 #else
217   if (transB == 'R') transb = 2;
218   if (transB == 'C') transb = 3;
219 #endif
220
221   nrowa = args.m;
222   if (transa & 1) nrowa = args.k;
223   nrowb = args.k;
224   if (transb & 1) nrowb = args.n;
225
226   info = 0;
227
228   if (args.ldc < args.m) info = 13;
229   if (args.ldb < nrowb)  info = 10;
230   if (args.lda < nrowa)  info =  8;
231   if (args.k < 0)        info =  5;
232   if (args.n < 0)        info =  4;
233   if (args.m < 0)        info =  3;
234   if (transb < 0)        info =  2;
235   if (transa < 0)        info =  1;
236
237   if (info){
238     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
239     return;
240   }
241
242 #else
243
244 void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
245            blasint m, blasint n, blasint k,
246 #ifndef COMPLEX
247            FLOAT alpha,
248            FLOAT *a, blasint lda,
249            FLOAT *b, blasint ldb,
250            FLOAT beta,
251            FLOAT *c, blasint ldc) {
252 #else
253            void *valpha,
254            void *va, blasint lda,
255            void *vb, blasint ldb,
256            void *vbeta,
257            void *vc, blasint ldc) {
258   FLOAT *alpha = (FLOAT*) valpha;
259   FLOAT *beta  = (FLOAT*) vbeta;
260   FLOAT *a = (FLOAT*) va;
261   FLOAT *b = (FLOAT*) vb;
262   FLOAT *c = (FLOAT*) vc;          
263 #endif
264
265   blas_arg_t args;
266   int transa, transb;
267   blasint nrowa, nrowb, info;
268
269   XFLOAT *buffer;
270   XFLOAT *sa, *sb;
271
272 #if defined (SMP) || defined(SMALL_MATRIX_OPT)
273   double MNK;
274 #endif
275
276 #ifdef SMP
277 #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
278 #ifndef COMPLEX
279 #ifdef XDOUBLE
280   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
281 #elif defined(DOUBLE)
282   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
283 #else
284   int mode  =  BLAS_SINGLE  | BLAS_REAL;
285 #endif
286 #else
287 #ifdef XDOUBLE
288   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
289 #elif defined(DOUBLE)
290   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
291 #else
292   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
293 #endif
294 #endif
295 #endif
296 #endif
297
298 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
299   int nodes;
300 #endif
301
302   PRINT_DEBUG_CNAME;
303
304 #if !defined(COMPLEX) && !defined(DOUBLE) && defined(USE_SGEMM_KERNEL_DIRECT)
305 #ifdef DYNAMIC_ARCH
306  if (support_avx512() )
307 #endif  
308   if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
309         SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
310         return;
311   }
312
313 #endif
314
315 #ifndef COMPLEX
316   args.alpha = (void *)&alpha;
317   args.beta  = (void *)&beta;
318 #else
319   args.alpha = (void *)alpha;
320   args.beta  = (void *)beta;
321 #endif
322
323   transa = -1;
324   transb = -1;
325   info   =  0;
326
327   if (order == CblasColMajor) {
328     args.m = m;
329     args.n = n;
330     args.k = k;
331
332     args.a = (void *)a;
333     args.b = (void *)b;
334     args.c = (void *)c;
335
336     args.lda = lda;
337     args.ldb = ldb;
338     args.ldc = ldc;
339
340     if (TransA == CblasNoTrans)     transa = 0;
341     if (TransA == CblasTrans)       transa = 1;
342 #ifndef COMPLEX
343     if (TransA == CblasConjNoTrans) transa = 0;
344     if (TransA == CblasConjTrans)   transa = 1;
345 #else
346     if (TransA == CblasConjNoTrans) transa = 2;
347     if (TransA == CblasConjTrans)   transa = 3;
348 #endif
349     if (TransB == CblasNoTrans)     transb = 0;
350     if (TransB == CblasTrans)       transb = 1;
351 #ifndef COMPLEX
352     if (TransB == CblasConjNoTrans) transb = 0;
353     if (TransB == CblasConjTrans)   transb = 1;
354 #else
355     if (TransB == CblasConjNoTrans) transb = 2;
356     if (TransB == CblasConjTrans)   transb = 3;
357 #endif
358
359     nrowa = args.m;
360     if (transa & 1) nrowa = args.k;
361     nrowb = args.k;
362     if (transb & 1) nrowb = args.n;
363
364     info = -1;
365
366     if (args.ldc < args.m) info = 13;
367     if (args.ldb < nrowb)  info = 10;
368     if (args.lda < nrowa)  info =  8;
369     if (args.k < 0)        info =  5;
370     if (args.n < 0)        info =  4;
371     if (args.m < 0)        info =  3;
372     if (transb < 0)        info =  2;
373     if (transa < 0)        info =  1;
374   }
375
376   if (order == CblasRowMajor) {
377     args.m = n;
378     args.n = m;
379     args.k = k;
380
381     args.a = (void *)b;
382     args.b = (void *)a;
383     args.c = (void *)c;
384
385     args.lda = ldb;
386     args.ldb = lda;
387     args.ldc = ldc;
388
389     if (TransB == CblasNoTrans)     transa = 0;
390     if (TransB == CblasTrans)       transa = 1;
391 #ifndef COMPLEX
392     if (TransB == CblasConjNoTrans) transa = 0;
393     if (TransB == CblasConjTrans)   transa = 1;
394 #else
395     if (TransB == CblasConjNoTrans) transa = 2;
396     if (TransB == CblasConjTrans)   transa = 3;
397 #endif
398     if (TransA == CblasNoTrans)     transb = 0;
399     if (TransA == CblasTrans)       transb = 1;
400 #ifndef COMPLEX
401     if (TransA == CblasConjNoTrans) transb = 0;
402     if (TransA == CblasConjTrans)   transb = 1;
403 #else
404     if (TransA == CblasConjNoTrans) transb = 2;
405     if (TransA == CblasConjTrans)   transb = 3;
406 #endif
407
408     nrowa = args.m;
409     if (transa & 1) nrowa = args.k;
410     nrowb = args.k;
411     if (transb & 1) nrowb = args.n;
412
413     info = -1;
414
415     if (args.ldc < args.m) info = 13;
416     if (args.ldb < nrowb)  info = 10;
417     if (args.lda < nrowa)  info =  8;
418     if (args.k < 0)        info =  5;
419     if (args.n < 0)        info =  4;
420     if (args.m < 0)        info =  3;
421     if (transb < 0)        info =  2;
422     if (transa < 0)        info =  1;
423
424   }
425
426   if (info >= 0) {
427     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
428     return;
429   }
430
431 #endif
432
433   if ((args.m == 0) || (args.n == 0)) return;
434
435 #if 0
436   fprintf(stderr, "m = %4d  n = %d  k = %d  lda = %4d  ldb = %4d  ldc = %4d\n",
437          args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
438 #endif
439
440   IDEBUG_START;
441
442   FUNCTION_PROFILE_START();
443
444 #if defined(SMP) || defined(SMALL_MATRIX_OPT)
445   MNK = (double) args.m * (double) args.n * (double) args.k;
446 #endif
447
448 #ifdef SMALL_MATRIX_OPT
449 #if !defined(COMPLEX)
450   //need to tune small matrices cases.
451   if(MNK <= 100.0*100.0*100.0){
452
453           if(*(FLOAT *)(args.beta) == 0.0){
454                   (gemm_small_kernel_b0[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc);
455           }else{
456                   (gemm_small_kernel[(transb << 2) | transa])(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc);
457           }
458           
459           return;
460   }
461 #endif
462 #endif
463   
464
465   buffer = (XFLOAT *)blas_memory_alloc(0);
466
467   sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
468   sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
469
470 #ifdef SMP
471 #if defined(USE_SIMPLE_THREADED_LEVEL3) || !defined(NO_AFFINITY)
472   mode |= (transa << BLAS_TRANSA_SHIFT);
473   mode |= (transb << BLAS_TRANSB_SHIFT);
474 #endif
475
476
477   if ( MNK <= (SMP_THRESHOLD_MIN  * (double) GEMM_MULTITHREAD_THRESHOLD)  )
478         args.nthreads = 1;
479   else
480         args.nthreads = num_cpu_avail(3);
481   args.common = NULL;
482
483  if (args.nthreads == 1) {
484 #endif
485
486     (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
487
488 #ifdef SMP
489
490   } else {
491
492 #ifndef USE_SIMPLE_THREADED_LEVEL3
493
494 #ifndef NO_AFFINITY
495       nodes = get_num_nodes();
496
497       if ((nodes > 1) && get_node_equal()) {
498
499         args.nthreads /= nodes;
500
501         gemm_thread_mn(mode, &args, NULL, NULL, gemm[16 | (transb << 2) | transa], sa, sb, nodes);
502
503       } else {
504 #endif
505
506         (gemm[16 | (transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
507
508 #else
509
510         GEMM_THREAD(mode, &args, NULL, NULL, gemm[(transb << 2) | transa], sa, sb, args.nthreads);
511
512 #endif
513
514 #ifndef USE_SIMPLE_THREADED_LEVEL3
515 #ifndef NO_AFFINITY
516       }
517 #endif
518 #endif
519
520 #endif
521
522 #ifdef SMP
523   }
524 #endif
525
526  blas_memory_free(buffer);
527
528   FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, args.m * args.k + args.k * args.n + args.m * args.n, 2 * args.m * args.n * args.k);
529
530   IDEBUG_END;
531
532   return;
533 }