Add CPUID identification of Intel Ice Lake
[platform/upstream/openblas.git] / interface / gemm.c
1 /*********************************************************************/
2 /* Copyright 2009, 2010 The University of Texas at Austin.           */
3 /* All rights reserved.                                              */
4 /*                                                                   */
5 /* Redistribution and use in source and binary forms, with or        */
6 /* without modification, are permitted provided that the following   */
7 /* conditions are met:                                               */
8 /*                                                                   */
9 /*   1. Redistributions of source code must retain the above         */
10 /*      copyright notice, this list of conditions and the following  */
11 /*      disclaimer.                                                  */
12 /*                                                                   */
13 /*   2. Redistributions in binary form must reproduce the above      */
14 /*      copyright notice, this list of conditions and the following  */
15 /*      disclaimer in the documentation and/or other materials       */
16 /*      provided with the distribution.                              */
17 /*                                                                   */
18 /*    THIS  SOFTWARE IS PROVIDED  BY THE  UNIVERSITY OF  TEXAS AT    */
19 /*    AUSTIN  ``AS IS''  AND ANY  EXPRESS OR  IMPLIED WARRANTIES,    */
20 /*    INCLUDING, BUT  NOT LIMITED  TO, THE IMPLIED  WARRANTIES OF    */
21 /*    MERCHANTABILITY  AND FITNESS FOR  A PARTICULAR  PURPOSE ARE    */
22 /*    DISCLAIMED.  IN  NO EVENT SHALL THE UNIVERSITY  OF TEXAS AT    */
23 /*    AUSTIN OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT,    */
24 /*    INCIDENTAL,  SPECIAL, EXEMPLARY,  OR  CONSEQUENTIAL DAMAGES    */
25 /*    (INCLUDING, BUT  NOT LIMITED TO,  PROCUREMENT OF SUBSTITUTE    */
26 /*    GOODS  OR  SERVICES; LOSS  OF  USE,  DATA,  OR PROFITS;  OR    */
27 /*    BUSINESS INTERRUPTION) HOWEVER CAUSED  AND ON ANY THEORY OF    */
28 /*    LIABILITY, WHETHER  IN CONTRACT, STRICT  LIABILITY, OR TORT    */
29 /*    (INCLUDING NEGLIGENCE OR OTHERWISE)  ARISING IN ANY WAY OUT    */
30 /*    OF  THE  USE OF  THIS  SOFTWARE,  EVEN  IF ADVISED  OF  THE    */
31 /*    POSSIBILITY OF SUCH DAMAGE.                                    */
32 /*                                                                   */
33 /* The views and conclusions contained in the software and           */
34 /* documentation are those of the authors and should not be          */
35 /* interpreted as representing official policies, either expressed   */
36 /* or implied, of The University of Texas at Austin.                 */
37 /*********************************************************************/
38
39 #include <stdio.h>
40 #include <stdlib.h>
41 #include "common.h"
42 #ifdef FUNCTION_PROFILE
43 #include "functable.h"
44 #endif
45
46 #ifndef COMPLEX
47 #define SMP_THRESHOLD_MIN 65536.0
48 #ifdef XDOUBLE
49 #define ERROR_NAME "QGEMM "
50 #elif defined(DOUBLE)
51 #define ERROR_NAME "DGEMM "
52 #else
53 #define ERROR_NAME "SGEMM "
54 #endif
55 #else
56 #define SMP_THRESHOLD_MIN 8192.0
57 #ifndef GEMM3M
58 #ifdef XDOUBLE
59 #define ERROR_NAME "XGEMM "
60 #elif defined(DOUBLE)
61 #define ERROR_NAME "ZGEMM "
62 #else
63 #define ERROR_NAME "CGEMM "
64 #endif
65 #else
66 #ifdef XDOUBLE
67 #define ERROR_NAME "XGEMM3M "
68 #elif defined(DOUBLE)
69 #define ERROR_NAME "ZGEMM3M "
70 #else
71 #define ERROR_NAME "CGEMM3M "
72 #endif
73 #endif
74 #endif
75
76 #ifndef GEMM_MULTITHREAD_THRESHOLD
77 #define GEMM_MULTITHREAD_THRESHOLD 4
78 #endif
79
80 static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
81 #ifndef GEMM3M
82   GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
83   GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
84   GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
85   GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
86 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
87   GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
88   GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
89   GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
90   GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
91 #endif
92 #else
93   GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
94   GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
95   GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
96   GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
97 #if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
98   GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
99   GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
100   GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
101   GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
102 #endif
103 #endif
104 };
105
106 #ifndef CBLAS
107
108 void NAME(char *TRANSA, char *TRANSB,
109           blasint *M, blasint *N, blasint *K,
110           FLOAT *alpha,
111           FLOAT *a, blasint *ldA,
112           FLOAT *b, blasint *ldB,
113           FLOAT *beta,
114           FLOAT *c, blasint *ldC){
115
116   blas_arg_t args;
117
118   int transa, transb, nrowa, nrowb;
119   blasint info;
120
121   char transA, transB;
122   FLOAT *buffer;
123   FLOAT *sa, *sb;
124
125 #ifdef SMP
126   double MNK;
127 #ifndef COMPLEX
128 #ifdef XDOUBLE
129   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
130 #elif defined(DOUBLE)
131   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
132 #else
133   int mode  =  BLAS_SINGLE  | BLAS_REAL;
134 #endif
135 #else
136 #ifdef XDOUBLE
137   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
138 #elif defined(DOUBLE)
139   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
140 #else
141   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
142 #endif
143 #endif
144 #endif
145
146 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
147   int nodes;
148 #endif
149
150   PRINT_DEBUG_NAME;
151
152   args.m = *M;
153   args.n = *N;
154   args.k = *K;
155
156   args.a = (void *)a;
157   args.b = (void *)b;
158   args.c = (void *)c;
159
160   args.lda = *ldA;
161   args.ldb = *ldB;
162   args.ldc = *ldC;
163
164   args.alpha = (void *)alpha;
165   args.beta  = (void *)beta;
166
167   transA = *TRANSA;
168   transB = *TRANSB;
169
170   TOUPPER(transA);
171   TOUPPER(transB);
172
173   transa = -1;
174   transb = -1;
175
176   if (transA == 'N') transa = 0;
177   if (transA == 'T') transa = 1;
178 #ifndef COMPLEX
179   if (transA == 'R') transa = 0;
180   if (transA == 'C') transa = 1;
181 #else
182   if (transA == 'R') transa = 2;
183   if (transA == 'C') transa = 3;
184 #endif
185
186   if (transB == 'N') transb = 0;
187   if (transB == 'T') transb = 1;
188 #ifndef COMPLEX
189   if (transB == 'R') transb = 0;
190   if (transB == 'C') transb = 1;
191 #else
192   if (transB == 'R') transb = 2;
193   if (transB == 'C') transb = 3;
194 #endif
195
196   nrowa = args.m;
197   if (transa & 1) nrowa = args.k;
198   nrowb = args.k;
199   if (transb & 1) nrowb = args.n;
200
201   info = 0;
202
203   if (args.ldc < args.m) info = 13;
204   if (args.ldb < nrowb)  info = 10;
205   if (args.lda < nrowa)  info =  8;
206   if (args.k < 0)        info =  5;
207   if (args.n < 0)        info =  4;
208   if (args.m < 0)        info =  3;
209   if (transb < 0)        info =  2;
210   if (transa < 0)        info =  1;
211
212   if (info){
213     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
214     return;
215   }
216
217 #else
218
219 void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
220            blasint m, blasint n, blasint k,
221 #ifndef COMPLEX
222            FLOAT alpha,
223            FLOAT *a, blasint lda,
224            FLOAT *b, blasint ldb,
225            FLOAT beta,
226            FLOAT *c, blasint ldc) {
227 #else
228            void *valpha,
229            void *va, blasint lda,
230            void *vb, blasint ldb,
231            void *vbeta,
232            void *vc, blasint ldc) {
233   FLOAT *alpha = (FLOAT*) valpha;
234   FLOAT *beta  = (FLOAT*) vbeta;
235   FLOAT *a = (FLOAT*) va;
236   FLOAT *b = (FLOAT*) vb;
237   FLOAT *c = (FLOAT*) vc;          
238 #endif
239
240   blas_arg_t args;
241   int transa, transb;
242   blasint nrowa, nrowb, info;
243
244   XFLOAT *buffer;
245   XFLOAT *sa, *sb;
246
247 #ifdef SMP
248   double MNK;
249 #ifndef COMPLEX
250 #ifdef XDOUBLE
251   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
252 #elif defined(DOUBLE)
253   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
254 #else
255   int mode  =  BLAS_SINGLE  | BLAS_REAL;
256 #endif
257 #else
258 #ifdef XDOUBLE
259   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
260 #elif defined(DOUBLE)
261   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
262 #else
263   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
264 #endif
265 #endif
266 #endif
267
268 #if defined(SMP) && !defined(NO_AFFINITY) && !defined(USE_SIMPLE_THREADED_LEVEL3)
269   int nodes;
270 #endif
271
272   PRINT_DEBUG_CNAME;
273
274 #if !defined(COMPLEX) && !defined(DOUBLE) && defined(USE_SGEMM_KERNEL_DIRECT)
275   if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && sgemm_kernel_direct_performant(m,n,k)) {
276         sgemm_kernel_direct(m, n, k, a, lda, b, ldb, c, ldc);
277         return;
278   }
279
280 #endif
281
282 #ifndef COMPLEX
283   args.alpha = (void *)&alpha;
284   args.beta  = (void *)&beta;
285 #else
286   args.alpha = (void *)alpha;
287   args.beta  = (void *)beta;
288 #endif
289
290   transa = -1;
291   transb = -1;
292   info   =  0;
293
294   if (order == CblasColMajor) {
295     args.m = m;
296     args.n = n;
297     args.k = k;
298
299     args.a = (void *)a;
300     args.b = (void *)b;
301     args.c = (void *)c;
302
303     args.lda = lda;
304     args.ldb = ldb;
305     args.ldc = ldc;
306
307     if (TransA == CblasNoTrans)     transa = 0;
308     if (TransA == CblasTrans)       transa = 1;
309 #ifndef COMPLEX
310     if (TransA == CblasConjNoTrans) transa = 0;
311     if (TransA == CblasConjTrans)   transa = 1;
312 #else
313     if (TransA == CblasConjNoTrans) transa = 2;
314     if (TransA == CblasConjTrans)   transa = 3;
315 #endif
316     if (TransB == CblasNoTrans)     transb = 0;
317     if (TransB == CblasTrans)       transb = 1;
318 #ifndef COMPLEX
319     if (TransB == CblasConjNoTrans) transb = 0;
320     if (TransB == CblasConjTrans)   transb = 1;
321 #else
322     if (TransB == CblasConjNoTrans) transb = 2;
323     if (TransB == CblasConjTrans)   transb = 3;
324 #endif
325
326     nrowa = args.m;
327     if (transa & 1) nrowa = args.k;
328     nrowb = args.k;
329     if (transb & 1) nrowb = args.n;
330
331     info = -1;
332
333     if (args.ldc < args.m) info = 13;
334     if (args.ldb < nrowb)  info = 10;
335     if (args.lda < nrowa)  info =  8;
336     if (args.k < 0)        info =  5;
337     if (args.n < 0)        info =  4;
338     if (args.m < 0)        info =  3;
339     if (transb < 0)        info =  2;
340     if (transa < 0)        info =  1;
341   }
342
343   if (order == CblasRowMajor) {
344     args.m = n;
345     args.n = m;
346     args.k = k;
347
348     args.a = (void *)b;
349     args.b = (void *)a;
350     args.c = (void *)c;
351
352     args.lda = ldb;
353     args.ldb = lda;
354     args.ldc = ldc;
355
356     if (TransB == CblasNoTrans)     transa = 0;
357     if (TransB == CblasTrans)       transa = 1;
358 #ifndef COMPLEX
359     if (TransB == CblasConjNoTrans) transa = 0;
360     if (TransB == CblasConjTrans)   transa = 1;
361 #else
362     if (TransB == CblasConjNoTrans) transa = 2;
363     if (TransB == CblasConjTrans)   transa = 3;
364 #endif
365     if (TransA == CblasNoTrans)     transb = 0;
366     if (TransA == CblasTrans)       transb = 1;
367 #ifndef COMPLEX
368     if (TransA == CblasConjNoTrans) transb = 0;
369     if (TransA == CblasConjTrans)   transb = 1;
370 #else
371     if (TransA == CblasConjNoTrans) transb = 2;
372     if (TransA == CblasConjTrans)   transb = 3;
373 #endif
374
375     nrowa = args.m;
376     if (transa & 1) nrowa = args.k;
377     nrowb = args.k;
378     if (transb & 1) nrowb = args.n;
379
380     info = -1;
381
382     if (args.ldc < args.m) info = 13;
383     if (args.ldb < nrowb)  info = 10;
384     if (args.lda < nrowa)  info =  8;
385     if (args.k < 0)        info =  5;
386     if (args.n < 0)        info =  4;
387     if (args.m < 0)        info =  3;
388     if (transb < 0)        info =  2;
389     if (transa < 0)        info =  1;
390
391   }
392
393   if (info >= 0) {
394     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
395     return;
396   }
397
398 #endif
399
400   if ((args.m == 0) || (args.n == 0)) return;
401
402 #if 0
403   fprintf(stderr, "m = %4d  n = %d  k = %d  lda = %4d  ldb = %4d  ldc = %4d\n",
404          args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
405 #endif
406
407   IDEBUG_START;
408
409   FUNCTION_PROFILE_START();
410
411   buffer = (XFLOAT *)blas_memory_alloc(0);
412
413   sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
414   sb = (XFLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
415
416 #ifdef SMP
417   mode |= (transa << BLAS_TRANSA_SHIFT);
418   mode |= (transb << BLAS_TRANSB_SHIFT);
419
420   MNK = (double) args.m * (double) args.n * (double) args.k;
421   if ( MNK <= (SMP_THRESHOLD_MIN  * (double) GEMM_MULTITHREAD_THRESHOLD)  )
422         args.nthreads = 1;
423   else
424         args.nthreads = num_cpu_avail(3);
425   args.common = NULL;
426
427  if (args.nthreads == 1) {
428 #endif
429
430     (gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
431
432 #ifdef SMP
433
434   } else {
435
436 #ifndef USE_SIMPLE_THREADED_LEVEL3
437
438 #ifndef NO_AFFINITY
439       nodes = get_num_nodes();
440
441       if ((nodes > 1) && get_node_equal()) {
442
443         args.nthreads /= nodes;
444
445         gemm_thread_mn(mode, &args, NULL, NULL, gemm[16 | (transb << 2) | transa], sa, sb, nodes);
446
447       } else {
448 #endif
449
450         (gemm[16 | (transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
451
452 #else
453
454         GEMM_THREAD(mode, &args, NULL, NULL, gemm[(transb << 2) | transa], sa, sb, args.nthreads);
455
456 #endif
457
458 #ifndef USE_SIMPLE_THREADED_LEVEL3
459 #ifndef NO_AFFINITY
460       }
461 #endif
462 #endif
463
464 #endif
465
466 #ifdef SMP
467   }
468 #endif
469
470  blas_memory_free(buffer);
471
472   FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE, args.m * args.k + args.k * args.n + args.m * args.n, 2 * args.m * args.n * args.k);
473
474   IDEBUG_END;
475
476   return;
477 }