Add CPUID identification of Intel Ice Lake
[platform/upstream/openblas.git] / interface / trsm.c
1 /*********************************************************************/
2 /* Copyright 2009, 2010 The University of Texas at Austin.           */
3 /* All rights reserved.                                              */
4 /*                                                                   */
5 /* Redistribution and use in source and binary forms, with or        */
6 /* without modification, are permitted provided that the following   */
7 /* conditions are met:                                               */
8 /*                                                                   */
9 /*   1. Redistributions of source code must retain the above         */
10 /*      copyright notice, this list of conditions and the following  */
11 /*      disclaimer.                                                  */
12 /*                                                                   */
13 /*   2. Redistributions in binary form must reproduce the above      */
14 /*      copyright notice, this list of conditions and the following  */
15 /*      disclaimer in the documentation and/or other materials       */
16 /*      provided with the distribution.                              */
17 /*                                                                   */
18 /*    THIS  SOFTWARE IS PROVIDED  BY THE  UNIVERSITY OF  TEXAS AT    */
19 /*    AUSTIN  ``AS IS''  AND ANY  EXPRESS OR  IMPLIED WARRANTIES,    */
20 /*    INCLUDING, BUT  NOT LIMITED  TO, THE IMPLIED  WARRANTIES OF    */
21 /*    MERCHANTABILITY  AND FITNESS FOR  A PARTICULAR  PURPOSE ARE    */
22 /*    DISCLAIMED.  IN  NO EVENT SHALL THE UNIVERSITY  OF TEXAS AT    */
23 /*    AUSTIN OR CONTRIBUTORS BE  LIABLE FOR ANY DIRECT, INDIRECT,    */
24 /*    INCIDENTAL,  SPECIAL, EXEMPLARY,  OR  CONSEQUENTIAL DAMAGES    */
25 /*    (INCLUDING, BUT  NOT LIMITED TO,  PROCUREMENT OF SUBSTITUTE    */
26 /*    GOODS  OR  SERVICES; LOSS  OF  USE,  DATA,  OR PROFITS;  OR    */
27 /*    BUSINESS INTERRUPTION) HOWEVER CAUSED  AND ON ANY THEORY OF    */
28 /*    LIABILITY, WHETHER  IN CONTRACT, STRICT  LIABILITY, OR TORT    */
29 /*    (INCLUDING NEGLIGENCE OR OTHERWISE)  ARISING IN ANY WAY OUT    */
30 /*    OF  THE  USE OF  THIS  SOFTWARE,  EVEN  IF ADVISED  OF  THE    */
31 /*    POSSIBILITY OF SUCH DAMAGE.                                    */
32 /*                                                                   */
33 /* The views and conclusions contained in the software and           */
34 /* documentation are those of the authors and should not be          */
35 /* interpreted as representing official policies, either expressed   */
36 /* or implied, of The University of Texas at Austin.                 */
37 /*********************************************************************/
38
39 #include <stdio.h>
40 #include <ctype.h>
41 #include "common.h"
42 #ifdef FUNCTION_PROFILE
43 #include "functable.h"
44 #endif
45
46 #ifndef TRMM
47 #ifndef COMPLEX
48 #ifdef XDOUBLE
49 #define ERROR_NAME "QTRSM "
50 #elif defined(DOUBLE)
51 #define ERROR_NAME "DTRSM "
52 #else
53 #define ERROR_NAME "STRSM "
54 #endif
55 #else
56 #ifdef XDOUBLE
57 #define ERROR_NAME "XTRSM "
58 #elif defined(DOUBLE)
59 #define ERROR_NAME "ZTRSM "
60 #else
61 #define ERROR_NAME "CTRSM "
62 #endif
63 #endif
64 #else
65 #ifndef COMPLEX
66 #ifdef XDOUBLE
67 #define ERROR_NAME "QTRMM "
68 #elif defined(DOUBLE)
69 #define ERROR_NAME "DTRMM "
70 #else
71 #define ERROR_NAME "STRMM "
72 #endif
73 #else
74 #ifdef XDOUBLE
75 #define ERROR_NAME "XTRMM "
76 #elif defined(DOUBLE)
77 #define ERROR_NAME "ZTRMM "
78 #else
79 #define ERROR_NAME "CTRMM "
80 #endif
81 #endif
82 #endif
83
84 #ifndef COMPLEX
85 #define SMP_FACTOR 256
86 #else
87 #define SMP_FACTOR 128
88 #endif
89
90 static int (*trsm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
91 #ifndef TRMM
92   TRSM_LNUU, TRSM_LNUN, TRSM_LNLU, TRSM_LNLN,
93   TRSM_LTUU, TRSM_LTUN, TRSM_LTLU, TRSM_LTLN,
94   TRSM_LRUU, TRSM_LRUN, TRSM_LRLU, TRSM_LRLN,
95   TRSM_LCUU, TRSM_LCUN, TRSM_LCLU, TRSM_LCLN,
96   TRSM_RNUU, TRSM_RNUN, TRSM_RNLU, TRSM_RNLN,
97   TRSM_RTUU, TRSM_RTUN, TRSM_RTLU, TRSM_RTLN,
98   TRSM_RRUU, TRSM_RRUN, TRSM_RRLU, TRSM_RRLN,
99   TRSM_RCUU, TRSM_RCUN, TRSM_RCLU, TRSM_RCLN,
100 #else
101   TRMM_LNUU, TRMM_LNUN, TRMM_LNLU, TRMM_LNLN,
102   TRMM_LTUU, TRMM_LTUN, TRMM_LTLU, TRMM_LTLN,
103   TRMM_LRUU, TRMM_LRUN, TRMM_LRLU, TRMM_LRLN,
104   TRMM_LCUU, TRMM_LCUN, TRMM_LCLU, TRMM_LCLN,
105   TRMM_RNUU, TRMM_RNUN, TRMM_RNLU, TRMM_RNLN,
106   TRMM_RTUU, TRMM_RTUN, TRMM_RTLU, TRMM_RTLN,
107   TRMM_RRUU, TRMM_RRUN, TRMM_RRLU, TRMM_RRLN,
108   TRMM_RCUU, TRMM_RCUN, TRMM_RCLU, TRMM_RCLN,
109 #endif
110 };
111
112 #ifndef CBLAS
113
114 void NAME(char *SIDE, char *UPLO, char *TRANS, char *DIAG,
115            blasint *M, blasint *N, FLOAT *alpha,
116            FLOAT *a, blasint *ldA, FLOAT *b, blasint *ldB){
117
118   char side_arg  = *SIDE;
119   char uplo_arg  = *UPLO;
120   char trans_arg = *TRANS;
121   char diag_arg  = *DIAG;
122
123   blas_arg_t args;
124
125   FLOAT *buffer;
126   FLOAT *sa, *sb;
127
128 #ifdef SMP
129 #ifndef COMPLEX
130 #ifdef XDOUBLE
131   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
132 #elif defined(DOUBLE)
133   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
134 #else
135   int mode  =  BLAS_SINGLE  | BLAS_REAL;
136 #endif
137 #else
138 #ifdef XDOUBLE
139   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
140 #elif defined(DOUBLE)
141   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
142 #else
143   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
144 #endif
145 #endif
146 #endif
147
148   blasint info;
149   int side;
150   int uplo;
151   int unit;
152   int trans;
153   int nrowa;
154
155   PRINT_DEBUG_NAME;
156
157   args.m = *M;
158   args.n = *N;
159
160   args.a = (void *)a;
161   args.b = (void *)b;
162
163   args.lda = *ldA;
164   args.ldb = *ldB;
165
166   args.beta = (void *)alpha;
167
168   TOUPPER(side_arg);
169   TOUPPER(uplo_arg);
170   TOUPPER(trans_arg);
171   TOUPPER(diag_arg);
172
173   side  = -1;
174   trans = -1;
175   unit  = -1;
176   uplo  = -1;
177
178   if (side_arg  == 'L') side  = 0;
179   if (side_arg  == 'R') side  = 1;
180
181   if (trans_arg == 'N') trans = 0;
182   if (trans_arg == 'T') trans = 1;
183   if (trans_arg == 'R') trans = 2;
184   if (trans_arg == 'C') trans = 3;
185
186   if (diag_arg  == 'U') unit  = 0;
187   if (diag_arg  == 'N') unit  = 1;
188
189   if (uplo_arg  == 'U') uplo  = 0;
190   if (uplo_arg  == 'L') uplo  = 1;
191
192   nrowa = args.m;
193   if (side & 1) nrowa = args.n;
194
195   info = 0;
196
197   if (args.ldb < MAX(1,args.m)) info = 11;
198   if (args.lda < MAX(1,nrowa))  info =  9;
199   if (args.n < 0)               info =  6;
200   if (args.m < 0)               info =  5;
201   if (unit < 0)                 info =  4;
202   if (trans < 0)                info =  3;
203   if (uplo  < 0)                info =  2;
204   if (side  < 0)                info =  1;
205
206   if (info != 0) {
207     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)-1);
208     return;
209   }
210
211 #else
212
213 void CNAME(enum CBLAS_ORDER order,
214            enum CBLAS_SIDE Side,  enum CBLAS_UPLO Uplo,
215            enum CBLAS_TRANSPOSE Trans, enum CBLAS_DIAG Diag,
216            blasint m, blasint n,
217 #ifndef COMPLEX
218            FLOAT alpha,
219            FLOAT *a, blasint lda,
220            FLOAT *b, blasint ldb) {
221 #else
222            void *valpha,
223            void *va, blasint lda,
224            void *vb, blasint ldb) {
225   FLOAT *alpha = (FLOAT*) valpha;
226   FLOAT *a = (FLOAT*) va;
227   FLOAT *b = (FLOAT*) vb;
228 #endif
229
230   blas_arg_t args;
231   int side, uplo, trans, unit;
232   blasint info, nrowa;
233
234   XFLOAT *buffer;
235   XFLOAT *sa, *sb;
236
237 #ifdef SMP
238 #ifndef COMPLEX
239 #ifdef XDOUBLE
240   int mode  =  BLAS_XDOUBLE | BLAS_REAL;
241 #elif defined(DOUBLE)
242   int mode  =  BLAS_DOUBLE  | BLAS_REAL;
243 #else
244   int mode  =  BLAS_SINGLE  | BLAS_REAL;
245 #endif
246 #else
247 #ifdef XDOUBLE
248   int mode  =  BLAS_XDOUBLE | BLAS_COMPLEX;
249 #elif defined(DOUBLE)
250   int mode  =  BLAS_DOUBLE  | BLAS_COMPLEX;
251 #else
252   int mode  =  BLAS_SINGLE  | BLAS_COMPLEX;
253 #endif
254 #endif
255 #endif
256
257   PRINT_DEBUG_CNAME;
258
259   args.a = (void *)a;
260   args.b = (void *)b;
261
262   args.lda = lda;
263   args.ldb = ldb;
264
265 #ifndef COMPLEX
266   args.beta = (void *)&alpha;
267 #else
268   args.beta = (void *)alpha;
269 #endif
270
271   side   = -1;
272   uplo   = -1;
273   trans  = -1;
274   unit   = -1;
275   info   =  0;
276
277   if (order == CblasColMajor) {
278     args.m = m;
279     args.n = n;
280
281     if (Side == CblasLeft)         side  = 0;
282     if (Side == CblasRight)        side  = 1;
283
284     if (Uplo == CblasUpper)        uplo  = 0;
285     if (Uplo == CblasLower)        uplo  = 1;
286
287     if (Trans == CblasNoTrans)     trans = 0;
288     if (Trans == CblasTrans)       trans = 1;
289 #ifndef COMPLEX
290     if (Trans == CblasConjNoTrans) trans = 0;
291     if (Trans == CblasConjTrans)   trans = 1;
292 #else
293     if (Trans == CblasConjNoTrans) trans = 2;
294     if (Trans == CblasConjTrans)   trans = 3;
295 #endif
296
297     if (Diag == CblasUnit)          unit  = 0;
298     if (Diag == CblasNonUnit)       unit  = 1;
299
300     info = -1;
301
302     nrowa = args.m;
303     if (side & 1) nrowa = args.n;
304
305     if (args.ldb < MAX(1,args.m)) info = 11;
306     if (args.lda < MAX(1,nrowa))  info =  9;
307     if (args.n < 0)               info =  6;
308     if (args.m < 0)               info =  5;
309     if (unit < 0)                 info =  4;
310     if (trans < 0)                info =  3;
311     if (uplo  < 0)                info =  2;
312     if (side  < 0)                info =  1;
313   }
314
315   if (order == CblasRowMajor) {
316     args.m = n;
317     args.n = m;
318
319     if (Side == CblasLeft)         side  = 1;
320     if (Side == CblasRight)        side  = 0;
321
322     if (Uplo == CblasUpper)        uplo  = 1;
323     if (Uplo == CblasLower)        uplo  = 0;
324
325     if (Trans == CblasNoTrans)     trans = 0;
326     if (Trans == CblasTrans)       trans = 1;
327 #ifndef COMPLEX
328     if (Trans == CblasConjNoTrans) trans = 0;
329     if (Trans == CblasConjTrans)   trans = 1;
330 #else
331     if (Trans == CblasConjNoTrans) trans = 2;
332     if (Trans == CblasConjTrans)   trans = 3;
333 #endif
334
335     if (Diag == CblasUnit)         unit  = 0;
336     if (Diag == CblasNonUnit)      unit  = 1;
337
338     info = -1;
339
340     nrowa = args.m;
341     if (side & 1) nrowa = args.n;
342
343     if (args.ldb < MAX(1,args.m)) info = 11;
344     if (args.lda < MAX(1,nrowa))  info =  9;
345     if (args.n < 0)               info =  6;
346     if (args.m < 0)               info =  5;
347     if (unit < 0)                 info =  4;
348     if (trans < 0)                info =  3;
349     if (uplo  < 0)                info =  2;
350     if (side  < 0)                info =  1;
351   }
352
353   if (info >= 0) {
354     BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
355     return;
356   }
357
358 #endif
359
360   if ((args.m == 0) || (args.n == 0)) return;
361
362   IDEBUG_START;
363
364   FUNCTION_PROFILE_START();
365
366   buffer = (FLOAT *)blas_memory_alloc(0);
367
368   sa = (FLOAT *)((BLASLONG)buffer + GEMM_OFFSET_A);
369   sb = (FLOAT *)(((BLASLONG)sa + ((GEMM_P * GEMM_Q * COMPSIZE * SIZE + GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
370
371 #ifdef SMP
372   mode |= (trans << BLAS_TRANSA_SHIFT);
373   mode |= (side  << BLAS_RSIDE_SHIFT);
374
375 /*
376   if ( args.m < 2 * GEMM_MULTITHREAD_THRESHOLD )
377         args.nthreads = 1;
378   else
379         if ( args.n < 2 * GEMM_MULTITHREAD_THRESHOLD )
380                 args.nthreads = 1;
381 */
382   if ( args.m * args.n < SMP_FACTOR * GEMM_MULTITHREAD_THRESHOLD)
383         args.nthreads = 1;
384   else
385         args.nthreads = num_cpu_avail(3);
386                 
387
388   if (args.nthreads == 1) {
389 #endif
390
391     (trsm[(side<<4) | (trans<<2) | (uplo<<1) | unit])(&args, NULL, NULL, sa, sb, 0);
392
393 #ifdef SMP
394   } else {
395     if (!side) {
396       gemm_thread_n(mode, &args, NULL, NULL, trsm[(side<<4) | (trans<<2) | (uplo<<1) | unit], sa, sb, args.nthreads);
397     } else {
398       gemm_thread_m(mode, &args, NULL, NULL, trsm[(side<<4) | (trans<<2) | (uplo<<1) | unit], sa, sb, args.nthreads);
399     }
400   }
401 #endif
402
403   blas_memory_free(buffer);
404
405   FUNCTION_PROFILE_END(COMPSIZE * COMPSIZE,
406                        (!side) ? args.m * (args.m + args.n) : args.n * (args.m + args.n),
407                        (!side) ? args.m * args.m * args.n : args.m * args.n * args.n);
408
409   IDEBUG_END;
410
411   return;
412 }
413