5897b2c17ea8286e1e24eabcc0ba0cf9ace6fc94
[platform/upstream/openblas.git] / kernel / zarch / dgemv_n_4.c
1 /***************************************************************************
2 Copyright (c) 2017, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27
28
29 #include "common.h"
30
31 #define NBMAX 2048
32
33 #define HAVE_KERNEL_4x4_VEC 1
34 #define HAVE_KERNEL_4x2_VEC 1
35 #define HAVE_KERNEL_4x1_VEC 1
36
37 #if defined(HAVE_KERNEL_4x4_VEC) || defined(HAVE_KERNEL_4x2_VEC) || defined(HAVE_KERNEL_4x1_VEC)
38  #include <vecintrin.h>
39 #endif
40
41 #ifdef HAVE_KERNEL_4x4
42
43 #elif HAVE_KERNEL_4x4_VEC
44
45 static void dgemv_kernel_4x4(BLASLONG n, FLOAT **ap, FLOAT *xo, FLOAT *y, FLOAT *alpha)
46 {
47     BLASLONG i;
48     FLOAT x0,x1,x2,x3;
49     x0 = xo[0] * *alpha;
50     x1 = xo[1] * *alpha;
51     x2 = xo[2] * *alpha;
52     x3 = xo[3] * *alpha;
53     __vector double   v_x0 = {x0,x0};
54     __vector double   v_x1 = {x1,x1};
55     __vector double   v_x2 = {x2,x2};
56     __vector double   v_x3 = {x3,x3};
57     __vector double* v_y =(__vector double*)y;      
58     __vector double* va0 = (__vector double*)ap[0];
59     __vector double* va1 = (__vector double*)ap[1];
60     __vector double* va2 = (__vector double*)ap[2];
61     __vector double* va3 = (__vector double*)ap[3]; 
62
63     for ( i=0; i< n/2; i+=2 )
64     {
65         v_y[i]   += v_x0 * va0[i]   +  v_x1 * va1[i]   + v_x2 * va2[i]   + v_x3 * va3[i] ;
66         v_y[i+1] += v_x0 * va0[i+1] +  v_x1 * va1[i+1] + v_x2 * va2[i+1] + v_x3 * va3[i+1] ;        
67     }
68 }
69
70 #else
71
72 static void dgemv_kernel_4x4(BLASLONG n, FLOAT **ap, FLOAT *xo, FLOAT *y, FLOAT *alpha)
73 {
74     BLASLONG i;
75     FLOAT *a0,*a1,*a2,*a3;
76     FLOAT x[4]  __attribute__ ((aligned (16)));
77     a0 = ap[0];
78     a1 = ap[1];
79     a2 = ap[2];
80     a3 = ap[3];
81
82     for ( i=0; i<4; i++)
83         x[i] = xo[i] * *alpha;
84
85     for ( i=0; i< n; i+=4 )
86     {
87         y[i] += a0[i]*x[0] + a1[i]*x[1] + a2[i]*x[2] + a3[i]*x[3];        
88         y[i+1] += a0[i+1]*x[0] + a1[i+1]*x[1] + a2[i+1]*x[2] + a3[i+1]*x[3];        
89         y[i+2] += a0[i+2]*x[0] + a1[i+2]*x[1] + a2[i+2]*x[2] + a3[i+2]*x[3];        
90         y[i+3] += a0[i+3]*x[0] + a1[i+3]*x[1] + a2[i+3]*x[2] + a3[i+3]*x[3];        
91     }
92 }
93
94
95 #endif
96
97 #ifdef HAVE_KERNEL_4x2
98
99 #elif HAVE_KERNEL_4x2_VEC
100
101 static void dgemv_kernel_4x2(BLASLONG n, FLOAT **ap, FLOAT *xo, FLOAT *y, FLOAT *alpha)
102 {
103     BLASLONG i;
104     FLOAT x0,x1;
105     x0 = xo[0] * *alpha;
106     x1 = xo[1] * *alpha; 
107     __vector double   v_x0 = {x0,x0};
108     __vector double   v_x1 = {x1,x1}; 
109     __vector double* v_y =(__vector double*)y;      
110     __vector double* va0 = (__vector double*)ap[0];
111     __vector double* va1 = (__vector double*)ap[1]; 
112
113     for ( i=0; i< n/2; i+=2 )
114     {
115         v_y[i]   += v_x0 * va0[i] +  v_x1 * va1[i]   ;
116         v_y[i+1] += v_x0 * va0[i+1] +  v_x1 * va1[i+1]  ;        
117     } 
118 }
119 #else
120
121 static void dgemv_kernel_4x2(BLASLONG n, FLOAT **ap, FLOAT *xo, FLOAT *y, FLOAT *alpha)
122 {
123     BLASLONG i;
124     FLOAT *a0,*a1;
125     FLOAT x[4]  __attribute__ ((aligned (16)));
126     a0 = ap[0];
127     a1 = ap[1];
128
129     for ( i=0; i<2; i++)
130         x[i] = xo[i] * *alpha;
131
132     for ( i=0; i< n; i+=4 )
133     {
134         y[i] += a0[i]*x[0] + a1[i]*x[1];        
135         y[i+1] += a0[i+1]*x[0] + a1[i+1]*x[1];        
136         y[i+2] += a0[i+2]*x[0] + a1[i+2]*x[1];        
137         y[i+3] += a0[i+3]*x[0] + a1[i+3]*x[1];        
138     }
139 }
140
141
142 #endif
143
144 #ifdef HAVE_KERNEL_4x1
145
146 #elif HAVE_KERNEL_4x1_VEC
147 static void dgemv_kernel_4x1(BLASLONG n, FLOAT *ap, FLOAT *xo, FLOAT *y, FLOAT *alpha)
148 {
149     
150     BLASLONG i;
151     FLOAT x0;
152     x0 = xo[0] * *alpha;
153     __vector double   v_x0 = {x0,x0};
154     __vector double* v_y =(__vector double*)y;      
155     __vector double* va0 = (__vector double*)ap;
156
157     for ( i=0; i< n/2; i+=2 )
158     {
159         v_y[i] += v_x0 * va0[i]    ;
160         v_y[i+1] += v_x0 * va0[i+1]  ;        
161     }
162         
163  
164 }
165
166 #else
167 static void dgemv_kernel_4x1(BLASLONG n, FLOAT *ap, FLOAT *xo, FLOAT *y, FLOAT *alpha)
168 {
169     BLASLONG i;
170     FLOAT *a0;
171     FLOAT x[4]  __attribute__ ((aligned (16)));
172     a0 = ap;
173
174     for ( i=0; i<1; i++)
175         x[i] = xo[i] * *alpha;
176
177     for ( i=0; i< n; i+=4 )
178     {
179         y[i] += a0[i]*x[0];        
180         y[i+1] += a0[i+1]*x[0];        
181         y[i+2] += a0[i+2]*x[0];        
182         y[i+3] += a0[i+3]*x[0];        
183     }
184 }
185
186
187 #endif
188
189
190  
191 static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) __attribute__ ((noinline));
192
193 static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
194 {
195     BLASLONG i;
196         
197     for ( i=0; i<n; i++ ){
198             *dest += *src;
199             src++;
200             dest += inc_dest;
201     }
202     return;
203      
204 }
205
206 int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
207 {
208     BLASLONG i;
209     BLASLONG j;
210     FLOAT *a_ptr;
211     FLOAT *x_ptr;
212     FLOAT *y_ptr;
213     FLOAT *ap[4];
214     BLASLONG n1;
215     BLASLONG m1;
216     BLASLONG m2;
217     BLASLONG m3;
218     BLASLONG n2;
219     BLASLONG lda4 =  lda << 2;
220     FLOAT xbuffer[8],*ybuffer;
221
222     if ( m < 1 ) return(0);
223     if ( n < 1 ) return(0);
224
225     ybuffer = buffer;
226     
227     n1 = n >> 2 ;
228     n2 = n &  3 ;
229
230     m3 = m & 3  ;
231     m1 = m & -4 ;
232     m2 = (m & (NBMAX-1)) - m3 ;
233
234     y_ptr = y;
235
236     BLASLONG NB = NBMAX;
237
238     while ( NB == NBMAX )
239     {
240         
241         m1 -= NB;
242         if ( m1 < 0)
243         {
244             if ( m2 == 0 ) break;    
245             NB = m2;
246         }
247         
248         a_ptr = a;
249         x_ptr = x;
250         
251         ap[0] = a_ptr;
252         ap[1] = a_ptr + lda;
253         ap[2] = ap[1] + lda;
254         ap[3] = ap[2] + lda;
255
256         if ( inc_y != 1 )
257             memset(ybuffer,0,NB*8);
258         else
259             ybuffer = y_ptr;
260
261         if ( inc_x == 1 )
262         {
263
264
265             for( i = 0; i < n1 ; i++)
266             {
267                 dgemv_kernel_4x4(NB,ap,x_ptr,ybuffer,&alpha);
268                 ap[0] += lda4; 
269                 ap[1] += lda4; 
270                 ap[2] += lda4; 
271                 ap[3] += lda4; 
272                 a_ptr += lda4;
273                 x_ptr += 4;    
274             }
275
276             if ( n2 & 2 )
277             {
278                 dgemv_kernel_4x2(NB,ap,x_ptr,ybuffer,&alpha);
279                 a_ptr += lda*2;
280                 x_ptr += 2;    
281             }
282
283
284             if ( n2 & 1 )
285             {
286                 dgemv_kernel_4x1(NB,a_ptr,x_ptr,ybuffer,&alpha);
287                 a_ptr += lda;
288                 x_ptr += 1;    
289
290             }
291
292
293         }
294         else
295         {
296
297             for( i = 0; i < n1 ; i++)
298             {
299                 xbuffer[0] = x_ptr[0];
300                 x_ptr += inc_x;    
301                 xbuffer[1] =  x_ptr[0];
302                 x_ptr += inc_x;    
303                 xbuffer[2] =  x_ptr[0];
304                 x_ptr += inc_x;    
305                 xbuffer[3] = x_ptr[0];
306                 x_ptr += inc_x;    
307                 dgemv_kernel_4x4(NB,ap,xbuffer,ybuffer,&alpha);
308                 ap[0] += lda4; 
309                 ap[1] += lda4; 
310                 ap[2] += lda4; 
311                 ap[3] += lda4; 
312                 a_ptr += lda4;
313             }
314
315             for( i = 0; i < n2 ; i++)
316             {
317                 xbuffer[0] = x_ptr[0];
318                 x_ptr += inc_x;    
319                 dgemv_kernel_4x1(NB,a_ptr,xbuffer,ybuffer,&alpha);
320                 a_ptr += lda;
321
322             }
323
324         }
325
326         a     += NB;
327         if ( inc_y != 1 )
328         {
329             add_y(NB,ybuffer,y_ptr,inc_y);
330             y_ptr += NB * inc_y;
331         }
332         else
333             y_ptr += NB ;
334
335     }
336
337     if ( m3 == 0 ) return(0);
338
339     if ( m3 == 3 )
340     {
341         a_ptr = a;
342         x_ptr = x;
343         FLOAT temp0 = 0.0;
344         FLOAT temp1 = 0.0;
345         FLOAT temp2 = 0.0;
346         if ( lda == 3 && inc_x ==1 )
347         {
348
349             for( i = 0; i < ( n & -4 ); i+=4 )
350             {
351
352                 temp0 += a_ptr[0] * x_ptr[0] + a_ptr[3] * x_ptr[1];
353                 temp1 += a_ptr[1] * x_ptr[0] + a_ptr[4] * x_ptr[1];
354                 temp2 += a_ptr[2] * x_ptr[0] + a_ptr[5] * x_ptr[1];
355
356                 temp0 += a_ptr[6] * x_ptr[2] + a_ptr[9]  * x_ptr[3];
357                 temp1 += a_ptr[7] * x_ptr[2] + a_ptr[10] * x_ptr[3];
358                 temp2 += a_ptr[8] * x_ptr[2] + a_ptr[11] * x_ptr[3];
359
360                 a_ptr += 12;
361                 x_ptr += 4;
362             }
363
364             for( ; i < n; i++ )
365             {
366                 temp0 += a_ptr[0] * x_ptr[0];
367                 temp1 += a_ptr[1] * x_ptr[0];
368                 temp2 += a_ptr[2] * x_ptr[0];
369                 a_ptr += 3;
370                 x_ptr ++;
371             }
372
373         }
374         else
375         {
376
377             for( i = 0; i < n; i++ )
378             {
379                 temp0 += a_ptr[0] * x_ptr[0];
380                 temp1 += a_ptr[1] * x_ptr[0];
381                 temp2 += a_ptr[2] * x_ptr[0];
382                 a_ptr += lda;
383                 x_ptr += inc_x;
384
385
386             }
387
388         }
389         y_ptr[0] += alpha * temp0;
390         y_ptr += inc_y;
391         y_ptr[0] += alpha * temp1;
392         y_ptr += inc_y;
393         y_ptr[0] += alpha * temp2;
394         return(0);
395     }
396
397
398     if ( m3 == 2 )
399     {
400         a_ptr = a;
401         x_ptr = x;
402         FLOAT temp0 = 0.0;
403         FLOAT temp1 = 0.0;
404         if ( lda == 2 && inc_x ==1 )
405         {
406
407             for( i = 0; i < (n & -4) ; i+=4 )
408             {
409                 temp0 += a_ptr[0] * x_ptr[0] + a_ptr[2] * x_ptr[1];
410                 temp1 += a_ptr[1] * x_ptr[0] + a_ptr[3] * x_ptr[1];
411                 temp0 += a_ptr[4] * x_ptr[2] + a_ptr[6] * x_ptr[3];
412                 temp1 += a_ptr[5] * x_ptr[2] + a_ptr[7] * x_ptr[3];
413                 a_ptr += 8;
414                 x_ptr += 4;
415
416             }
417
418
419             for( ; i < n; i++ )
420             {
421                 temp0 += a_ptr[0]   * x_ptr[0];
422                 temp1 += a_ptr[1]   * x_ptr[0];
423                 a_ptr += 2;
424                 x_ptr ++;
425             }
426
427         }
428         else
429         {
430
431             for( i = 0; i < n; i++ )
432             {
433                 temp0 += a_ptr[0] * x_ptr[0];
434                 temp1 += a_ptr[1] * x_ptr[0];
435                 a_ptr += lda;
436                 x_ptr += inc_x;
437
438
439             }
440
441         }
442         y_ptr[0] += alpha * temp0;
443         y_ptr += inc_y;
444         y_ptr[0] += alpha * temp1;
445         return(0);
446     }
447
448     if ( m3 == 1 )
449     {
450         a_ptr = a;
451         x_ptr = x;
452         FLOAT temp = 0.0;
453         if ( lda == 1 && inc_x ==1 )
454         {
455
456             for( i = 0; i < (n & -4); i+=4 )
457             {
458                 temp += a_ptr[i] * x_ptr[i] + a_ptr[i+1] * x_ptr[i+1] + a_ptr[i+2] * x_ptr[i+2] + a_ptr[i+3] * x_ptr[i+3];
459     
460             }
461
462             for( ; i < n; i++ )
463             {
464                 temp += a_ptr[i] * x_ptr[i];
465             }
466
467         }
468         else
469         {
470
471             for( i = 0; i < n; i++ )
472             {
473                 temp += a_ptr[0] * x_ptr[0];
474                 a_ptr += lda;
475                 x_ptr += inc_x;
476             }
477
478         }
479         y_ptr[0] += alpha * temp;
480         return(0);
481     }
482
483
484     return(0);
485 }
486
487