26 #ifdef FIXED_POINT_POSITION 28 #endif // FIXED_POINT_POSITION 30 #if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH) 33 #define DATA_TYPE uchar 34 #elif ELEMENT_SIZE == 2 35 #define DATA_TYPE ushort 36 #elif ELEMENT_SIZE == 4 37 #define DATA_TYPE uint 38 #else // ELEMENT_SIZE == 1 39 #error "Element size not supported" 40 #endif // ELEMENT_SIZE 67 uint x = get_global_id(0);
68 uint y = get_global_id(1);
69 uint z = get_global_id(2);
75 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W *
sizeof(
DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
76 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W *
sizeof(
DATA_TYPE);
79 dst_addr_in_bytes += z * dst_stride_z;
85 (b0, 0, (__global
DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
87 #endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH) 89 #if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE) 117 uint x = get_global_id(0);
118 uint y = get_global_id(1);
119 uint z = get_global_id(2);
125 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x *
sizeof(
DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
126 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 *
sizeof(
DATA_TYPE);
129 dst_addr_in_bytes += z * dst_stride_z;
131 __global uchar *input_ptr = src.
ptr;
135 a0 = vload4(0, (__global
DATA_TYPE *)(input_ptr + 0 * src_stride_y));
137 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
139 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
141 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
144 val0 = (
VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
145 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
147 val0 = (
VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
148 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
150 val0 = (
VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
151 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
153 val0 = (
VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
154 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
156 #endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE) 158 #if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT) 194 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
195 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
196 int z = get_global_id(2);
199 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
200 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
204 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
205 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
207 #if defined(MATRIX_B_DEPTH) 209 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
210 #else // defined(MATRIX_B_DEPTH) 211 src1_addr_in_bytes += z * src1_stride_z;
212 #endif // defined(MATRIX_B_DEPTH) 214 __global
float *src_addr_a = (__global
float *)(src0_ptr + src0_addr_in_bytes);
215 __global
float *src_addr_b = (__global
float *)(src1_ptr + src1_addr_in_bytes);
218 __global
float *src_end_addr_b = src_addr_b + COLS_B;
220 src_addr_a += offset_row_a;
221 src_addr_b += offset_row_b;
229 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
232 float4 a0 = vload4(0, src_addr_a);
233 float4 b0 = vload4(0, src_addr_b);
235 c00 += (float4)a0.s0 * b0;
236 c10 += (float4)a0.s1 * b0;
237 c20 += (float4)a0.s2 * b0;
238 c30 += (float4)a0.s3 * b0;
241 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
242 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
244 c00 += (float4)a0.s0 * b0;
245 c10 += (float4)a0.s1 * b0;
246 c20 += (float4)a0.s2 * b0;
247 c30 += (float4)a0.s3 * b0;
250 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
253 float4 a0 = vload4(0, src_addr_a);
254 float4 b0 = vload4(0, src_addr_b);
256 c00 += (float4)a0.s0 * b0;
257 c10 += (float4)a0.s1 * b0;
258 c20 += (float4)a0.s2 * b0;
259 c30 += (float4)a0.s3 * b0;
267 c00 = c00 * (float4)ALPHA;
268 c10 = c10 * (float4)ALPHA;
269 c20 = c20 * (float4)ALPHA;
270 c30 = c30 * (float4)ALPHA;
271 #endif // defined(ALPHA) 274 __global uchar *dst_addr =
offset(&dst, 0, 0);
277 dst_addr += z * dst_stride_z;
280 vstore4(c00, 0, (__global
float *)(dst_addr + 0 * dst_stride_y));
281 vstore4(c10, 0, (__global
float *)(dst_addr + 1 * dst_stride_y));
282 vstore4(c20, 0, (__global
float *)(dst_addr + 2 * dst_stride_y));
283 vstore4(c30, 0, (__global
float *)(dst_addr + 3 * dst_stride_y));
315 __kernel
void gemm_mm_interleaved_transposed_f32_bifrost(
IMAGE_DECLARATION(src0),
322 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
323 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
324 int z = get_global_id(2);
327 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
328 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
332 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
333 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
335 #if defined(MATRIX_B_DEPTH) 337 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
338 #else // defined(MATRIX_B_DEPTH) 339 src1_addr_in_bytes += z * src1_stride_z;
340 #endif // defined(MATRIX_B_DEPTH) 342 __global
float *src_addr_a = (__global
float *)(src0_ptr + src0_addr_in_bytes);
343 __global
float *src_addr_b = (__global
float *)(src1_ptr + src1_addr_in_bytes);
345 src_addr_a += offset_row_a;
346 src_addr_b += offset_row_b;
366 #define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH)) 369 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
372 float4 a0 = vload4(0, src_addr_a);
373 float4 b0 = vload4(0, src_addr_b);
375 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
376 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
378 c00 = fma(a0.s0, b0.s0, c00);
379 c01 = fma(a0.s0, b0.s1, c01);
380 c02 = fma(a0.s0, b0.s2, c02);
381 c03 = fma(a0.s0, b0.s3, c03);
383 c10 = fma(a0.s1, b0.s0, c10);
384 c11 = fma(a0.s1, b0.s1, c11);
385 c12 = fma(a0.s1, b0.s2, c12);
386 c13 = fma(a0.s1, b0.s3, c13);
388 c20 = fma(a0.s2, b0.s0, c20);
389 c21 = fma(a0.s2, b0.s1, c21);
390 c22 = fma(a0.s2, b0.s2, c22);
391 c23 = fma(a0.s2, b0.s3, c23);
393 c30 = fma(a0.s3, b0.s0, c30);
394 c31 = fma(a0.s3, b0.s1, c31);
395 c32 = fma(a0.s3, b0.s2, c32);
396 c33 = fma(a0.s3, b0.s3, c33);
399 a0 = vload4(0, src_addr_a);
400 b0 = vload4(0, src_addr_b);
402 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
403 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
405 c00 = fma(a0.s0, b0.s0, c00);
406 c01 = fma(a0.s0, b0.s1, c01);
407 c02 = fma(a0.s0, b0.s2, c02);
408 c03 = fma(a0.s0, b0.s3, c03);
410 c10 = fma(a0.s1, b0.s0, c10);
411 c11 = fma(a0.s1, b0.s1, c11);
412 c12 = fma(a0.s1, b0.s2, c12);
413 c13 = fma(a0.s1, b0.s3, c13);
415 c20 = fma(a0.s2, b0.s0, c20);
416 c21 = fma(a0.s2, b0.s1, c21);
417 c22 = fma(a0.s2, b0.s2, c22);
418 c23 = fma(a0.s2, b0.s3, c23);
420 c30 = fma(a0.s3, b0.s0, c30);
421 c31 = fma(a0.s3, b0.s1, c31);
422 c32 = fma(a0.s3, b0.s2, c32);
423 c33 = fma(a0.s3, b0.s3, c33);
426 a0 = vload4(0, src_addr_a);
427 b0 = vload4(0, src_addr_b);
429 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
430 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
432 c00 = fma(a0.s0, b0.s0, c00);
433 c01 = fma(a0.s0, b0.s1, c01);
434 c02 = fma(a0.s0, b0.s2, c02);
435 c03 = fma(a0.s0, b0.s3, c03);
437 c10 = fma(a0.s1, b0.s0, c10);
438 c11 = fma(a0.s1, b0.s1, c11);
439 c12 = fma(a0.s1, b0.s2, c12);
440 c13 = fma(a0.s1, b0.s3, c13);
442 c20 = fma(a0.s2, b0.s0, c20);
443 c21 = fma(a0.s2, b0.s1, c21);
444 c22 = fma(a0.s2, b0.s2, c22);
445 c23 = fma(a0.s2, b0.s3, c23);
447 c30 = fma(a0.s3, b0.s0, c30);
448 c31 = fma(a0.s3, b0.s1, c31);
449 c32 = fma(a0.s3, b0.s2, c32);
450 c33 = fma(a0.s3, b0.s3, c33);
453 a0 = vload4(0, src_addr_a);
454 b0 = vload4(0, src_addr_b);
456 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
457 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
459 c00 = fma(a0.s0, b0.s0, c00);
460 c01 = fma(a0.s0, b0.s1, c01);
461 c02 = fma(a0.s0, b0.s2, c02);
462 c03 = fma(a0.s0, b0.s3, c03);
464 c10 = fma(a0.s1, b0.s0, c10);
465 c11 = fma(a0.s1, b0.s1, c11);
466 c12 = fma(a0.s1, b0.s2, c12);
467 c13 = fma(a0.s1, b0.s3, c13);
469 c20 = fma(a0.s2, b0.s0, c20);
470 c21 = fma(a0.s2, b0.s1, c21);
471 c22 = fma(a0.s2, b0.s2, c22);
472 c23 = fma(a0.s2, b0.s3, c23);
474 c30 = fma(a0.s3, b0.s0, c30);
475 c31 = fma(a0.s3, b0.s1, c31);
476 c32 = fma(a0.s3, b0.s2, c32);
477 c33 = fma(a0.s3, b0.s3, c33);
480 for(; i < (int)(COLS_MTX_B); ++i)
483 float4 a0 = vload4(0, src_addr_a);
484 float4 b0 = vload4(0, src_addr_b);
486 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
487 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
489 c00 = fma(a0.s0, b0.s0, c00);
490 c01 = fma(a0.s0, b0.s1, c01);
491 c02 = fma(a0.s0, b0.s2, c02);
492 c03 = fma(a0.s0, b0.s3, c03);
494 c10 = fma(a0.s1, b0.s0, c10);
495 c11 = fma(a0.s1, b0.s1, c11);
496 c12 = fma(a0.s1, b0.s2, c12);
497 c13 = fma(a0.s1, b0.s3, c13);
499 c20 = fma(a0.s2, b0.s0, c20);
500 c21 = fma(a0.s2, b0.s1, c21);
501 c22 = fma(a0.s2, b0.s2, c22);
502 c23 = fma(a0.s2, b0.s3, c23);
504 c30 = fma(a0.s3, b0.s0, c30);
505 c31 = fma(a0.s3, b0.s1, c31);
506 c32 = fma(a0.s3, b0.s2, c32);
507 c33 = fma(a0.s3, b0.s3, c33);
531 #endif // defined(ALPHA) 534 __global uchar *dst_addr =
offset(&dst, 0, 0);
537 dst_addr += z * dst_stride_z;
540 vstore4((float4)(c00, c01, c02, c03), 0, (__global
float *)(dst_addr + 0 * dst_stride_y));
541 vstore4((float4)(c10, c11, c12, c13), 0, (__global
float *)(dst_addr + 1 * dst_stride_y));
542 vstore4((float4)(c20, c21, c22, c23), 0, (__global
float *)(dst_addr + 2 * dst_stride_y));
543 vstore4((float4)(c30, c31, c32, c33), 0, (__global
float *)(dst_addr + 3 * dst_stride_y));
549 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 585 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
586 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
587 int z = get_global_id(2);
590 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
591 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
595 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
596 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
598 #if defined(MATRIX_B_DEPTH) 600 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
601 #else // defined(MATRIX_B_DEPTH) 602 src1_addr_in_bytes += z * src1_stride_z;
603 #endif // defined(MATRIX_B_DEPTH) 605 __global
half *src_addr_a = (__global
half *)(src0_ptr + src0_addr_in_bytes);
606 __global
half *src_addr_b = (__global
half *)(src1_ptr + src1_addr_in_bytes);
609 __global
half *src_end_addr_b = src_addr_b + COLS_B;
611 src_addr_a += offset_row_a;
612 src_addr_b += offset_row_b;
620 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
623 half4 a0 = vload4(0, src_addr_a);
624 half8 b0 = vload8(0, src_addr_b);
626 c00 += (half8)a0.s0 * b0;
627 c10 += (half8)a0.s1 * b0;
628 c20 += (half8)a0.s2 * b0;
629 c30 += (half8)a0.s3 * b0;
632 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
633 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
635 c00 += (half8)a0.s0 * b0;
636 c10 += (half8)a0.s1 * b0;
637 c20 += (half8)a0.s2 * b0;
638 c30 += (half8)a0.s3 * b0;
641 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
644 half4 a0 = vload4(0, src_addr_a);
645 half8 b0 = vload8(0, src_addr_b);
647 c00 += (half8)a0.s0 * b0;
648 c10 += (half8)a0.s1 * b0;
649 c20 += (half8)a0.s2 * b0;
650 c30 += (half8)a0.s3 * b0;
658 c00 = c00 * (half8)ALPHA;
659 c10 = c10 * (half8)ALPHA;
660 c20 = c20 * (half8)ALPHA;
661 c30 = c30 * (half8)ALPHA;
662 #endif // defined(ALPHA) 665 __global uchar *dst_addr =
offset(&dst, 0, 0);
668 dst_addr += z * dst_stride_z;
671 vstore8(c00, 0, (__global
half *)(dst_addr + 0 * dst_stride_y));
672 vstore8(c10, 0, (__global
half *)(dst_addr + 1 * dst_stride_y));
673 vstore8(c20, 0, (__global
half *)(dst_addr + 2 * dst_stride_y));
674 vstore8(c30, 0, (__global
half *)(dst_addr + 3 * dst_stride_y));
705 __kernel
void gemm_mm_interleaved_transposed_f16_bifrost(
IMAGE_DECLARATION(src0),
712 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
713 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
714 int z = get_global_id(2);
717 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
718 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
722 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
723 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
725 #if defined(MATRIX_B_DEPTH) 727 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
728 #else // defined(MATRIX_B_DEPTH) 729 src1_addr_in_bytes += z * src1_stride_z;
730 #endif // defined(MATRIX_B_DEPTH) 732 __global
half *src_addr_a = (__global
half *)(src0_ptr + src0_addr_in_bytes);
733 __global
half *src_addr_b = (__global
half *)(src1_ptr + src1_addr_in_bytes);
736 __global
half *src_end_addr_b = src_addr_b + COLS_B;
738 src_addr_a += offset_row_a;
739 src_addr_b += offset_row_b;
747 #define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH)) 750 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
752 #if MULT_INTERLEAVE4X4_HEIGHT == 1 754 half8 a0 = vload8(0, src_addr_a);
755 half8 b0 = vload8(0, src_addr_b);
757 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
758 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
760 c00 = fma((half8)a0.s0, b0, c00);
761 c10 = fma((half8)a0.s1, b0, c10);
762 c20 = fma((half8)a0.s2, b0, c20);
763 c30 = fma((half8)a0.s3, b0, c30);
766 b0 = vload8(0, src_addr_b);
768 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
770 c00 = fma((half8)a0.s4, b0, c00);
771 c10 = fma((half8)a0.s5, b0, c10);
772 c20 = fma((half8)a0.s6, b0, c20);
773 c30 = fma((half8)a0.s7, b0, c30);
776 a0 = vload8(0, src_addr_a);
777 b0 = vload8(0, src_addr_b);
779 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
780 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
782 c00 = fma((half8)a0.s0, b0, c00);
783 c10 = fma((half8)a0.s1, b0, c10);
784 c20 = fma((half8)a0.s2, b0, c20);
785 c30 = fma((half8)a0.s3, b0, c30);
788 b0 = vload8(0, src_addr_b);
790 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
792 c00 = fma((half8)a0.s4, b0, c00);
793 c10 = fma((half8)a0.s5, b0, c10);
794 c20 = fma((half8)a0.s6, b0, c20);
795 c30 = fma((half8)a0.s7, b0, c30);
796 #else // MULT_INTERLEAVE4X4_HEIGHT == 1 798 half4 a0 = vload4(0, src_addr_a);
799 half8 b0 = vload8(0, src_addr_b);
801 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
802 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
804 c00 = fma((half8)a0.s0, b0, c00);
805 c10 = fma((half8)a0.s1, b0, c10);
806 c20 = fma((half8)a0.s2, b0, c20);
807 c30 = fma((half8)a0.s3, b0, c30);
810 a0 = vload4(0, src_addr_a);
811 b0 = vload8(0, src_addr_b);
813 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
814 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
816 c00 = fma((half8)a0.s0, b0, c00);
817 c10 = fma((half8)a0.s1, b0, c10);
818 c20 = fma((half8)a0.s2, b0, c20);
819 c30 = fma((half8)a0.s3, b0, c30);
822 a0 = vload4(0, src_addr_a);
823 b0 = vload8(0, src_addr_b);
825 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
826 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
828 c00 = fma((half8)a0.s0, b0, c00);
829 c10 = fma((half8)a0.s1, b0, c10);
830 c20 = fma((half8)a0.s2, b0, c20);
831 c30 = fma((half8)a0.s3, b0, c30);
834 a0 = vload4(0, src_addr_a);
835 b0 = vload8(0, src_addr_b);
837 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
838 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
840 c00 = fma((half8)a0.s0, b0, c00);
841 c10 = fma((half8)a0.s1, b0, c10);
842 c20 = fma((half8)a0.s2, b0, c20);
843 c30 = fma((half8)a0.s3, b0, c30);
844 #endif // MULT_INTERLEAVE4X4_HEIGHT == 1 847 for(; i < (int)(COLS_MTX_B); ++i)
850 half4 a0 = vload4(0, src_addr_a);
851 half8 b0 = vload8(0, src_addr_b);
853 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
854 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
856 c00 = fma((half8)a0.s0, b0, c00);
857 c10 = fma((half8)a0.s1, b0, c10);
858 c20 = fma((half8)a0.s2, b0, c20);
859 c30 = fma((half8)a0.s3, b0, c30);
867 c00 = c00 * (half8)ALPHA;
868 c10 = c10 * (half8)ALPHA;
869 c20 = c20 * (half8)ALPHA;
870 c30 = c30 * (half8)ALPHA;
871 #endif // defined(ALPHA) 874 __global uchar *dst_addr =
offset(&dst, 0, 0);
877 dst_addr += z * dst_stride_z;
880 vstore8(c00, 0, (__global
half *)(dst_addr + 0 * dst_stride_y));
881 vstore8(c10, 0, (__global
half *)(dst_addr + 1 * dst_stride_y));
882 vstore8(c20, 0, (__global
half *)(dst_addr + 2 * dst_stride_y));
883 vstore8(c30, 0, (__global
half *)(dst_addr + 3 * dst_stride_y));
889 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 891 #if defined(FIXED_POINT_POSITION) 928 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
929 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
930 int z = get_global_id(2);
933 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
934 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
938 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
939 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
941 #if defined(MATRIX_B_DEPTH) 943 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
944 #else // defined(MATRIX_B_DEPTH) 945 src1_addr_in_bytes += z * src1_stride_z;
946 #endif // defined(MATRIX_B_DEPTH) 948 __global
char *src_addr_a = (__global
char *)(src0_ptr + src0_addr_in_bytes);
949 __global
char *src_addr_b = (__global
char *)(src1_ptr + src1_addr_in_bytes);
952 __global
char *src_end_addr_b = src_addr_b + COLS_B;
954 src_addr_a += offset_row_a;
955 src_addr_b += offset_row_b;
968 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
971 char4 a0 = vload4(0, src_addr_a);
972 char16 b0 = vload16(0, src_addr_b);
974 c00 =
mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
975 c10 =
mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
976 c20 =
mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
977 c30 =
mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
979 c01 =
mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
980 c11 =
mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
981 c21 =
mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
982 c31 =
mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
989 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
990 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
991 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
992 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
995 c00_qs8 =
mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
996 c10_qs8 =
mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
997 c20_qs8 =
mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
998 c30_qs8 =
mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
999 #endif // defined(ALPHA) 1002 __global uchar *dst_addr =
offset(&dst, 0, 0);
1005 dst_addr += z * dst_stride_z;
1008 vstore16(c00_qs8, 0, (__global
char *)(dst_addr + 0 * dst_stride_y));
1009 vstore16(c10_qs8, 0, (__global
char *)(dst_addr + 1 * dst_stride_y));
1010 vstore16(c20_qs8, 0, (__global
char *)(dst_addr + 2 * dst_stride_y));
1011 vstore16(c30_qs8, 0, (__global
char *)(dst_addr + 3 * dst_stride_y));
1050 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1051 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
1052 int z = get_global_id(2);
1055 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1056 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
1060 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1061 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1063 #if defined(MATRIX_B_DEPTH) 1065 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1066 #else // defined(MATRIX_B_DEPTH) 1067 src1_addr_in_bytes += z * src1_stride_z;
1068 #endif // defined(MATRIX_B_DEPTH) 1070 __global
short *src_addr_a = (__global
short *)(src0_ptr + src0_addr_in_bytes);
1071 __global
short *src_addr_b = (__global
short *)(src1_ptr + src1_addr_in_bytes);
1074 __global
short *src_end_addr_b = src_addr_b + COLS_B;
1076 src_addr_a += offset_row_a;
1077 src_addr_b += offset_row_b;
1086 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
1089 short4 a0 = vload4(0, src_addr_a);
1090 short8 b0 = vload8(0, src_addr_b);
1102 short8 c00_qs16 = convert_short8_sat(c00);
1103 short8 c10_qs16 = convert_short8_sat(c10);
1104 short8 c20_qs16 = convert_short8_sat(c20);
1105 short8 c30_qs16 = convert_short8_sat(c30);
1108 c00_qs16 =
mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1109 c10_qs16 =
mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1110 c20_qs16 =
mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1111 c30_qs16 =
mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1112 #endif // defined(ALPHA) 1115 __global uchar *dst_addr =
offset(&dst, 0, 0);
1118 dst_addr += z * dst_stride_z;
1121 vstore8(c00_qs16, 0, (__global
short *)(dst_addr + 0 * dst_stride_y));
1122 vstore8(c10_qs16, 0, (__global
short *)(dst_addr + 1 * dst_stride_y));
1123 vstore8(c20_qs16, 0, (__global
short *)(dst_addr + 2 * dst_stride_y));
1124 vstore8(c30_qs16, 0, (__global
short *)(dst_addr + 3 * dst_stride_y));
1126 #endif // defined(FIXED_POINT_POSITION) 1127 #endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT) 1129 #if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y) 1130 #if defined(DATA_TYPE) 1131 #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X) 1167 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1170 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1173 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1179 src_addr.s0 += get_global_id(2) * src0_stride_z;
1181 #if defined(MATRIX_B_DEPTH) 1183 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1184 #else // defined(MATRIX_B_DEPTH) 1185 src_addr.s1 += get_global_id(2) * src1_stride_z;
1186 #endif // defined(MATRIX_B_DEPTH) 1188 int end_row_vec_a = src_addr.s0 + (COLS_A *
sizeof(
DATA_TYPE));
1190 VECTOR_TYPE acc0 = 0.0f;
1191 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1192 VECTOR_TYPE acc1 = 0.0f;
1193 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1194 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1195 VECTOR_TYPE acc2 = 0.0f;
1196 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1197 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1198 VECTOR_TYPE acc3 = 0.0f;
1199 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1201 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)
sizeof(DATA_TYPE)); src_addr += (int2)(2 *
sizeof(DATA_TYPE), 2 * src1_stride_y))
1205 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1206 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1208 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1209 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1210 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1212 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1213 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1214 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1216 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1217 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1219 VECTOR_TYPE b0 =
VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1220 VECTOR_TYPE b1 =
VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
1223 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1224 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1225 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1226 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1227 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1228 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1229 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1230 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1231 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1232 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1233 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1234 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1235 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1236 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1239 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(
sizeof(DATA_TYPE), src1_stride_y))
1242 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1243 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1244 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1245 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1246 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1247 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1248 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1249 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1250 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1251 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1253 VECTOR_TYPE b0 =
VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1256 acc0 += b0 * (VECTOR_TYPE)a0;
1257 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1258 acc1 += b0 * (VECTOR_TYPE)a1;
1259 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1260 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1261 acc2 += b0 * (VECTOR_TYPE)a2;
1262 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1263 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1264 acc3 += b0 * (VECTOR_TYPE)a3;
1265 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1272 __global uchar *dst_addr =
offset(&dst, 0, 0);
1275 dst_addr += get_global_id(2) * dst_stride_z;
1279 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
1280 #endif // defined(ALPHA) 1281 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1282 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
1283 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1285 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1286 #endif // defined(ALPHA) 1287 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1288 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
1289 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1290 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1292 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1293 #endif // defined(ALPHA) 1294 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1295 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
1296 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1297 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1299 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1300 #endif // defined(ALPHA) 1301 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1302 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
1303 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1305 #endif // defined(DATA_TYPE) 1343 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1346 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1349 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1352 src_addr.s1 += idx *
sizeof(float);
1355 src_addr.s0 += get_global_id(2) * src0_stride_z;
1357 #if defined(MATRIX_B_DEPTH) 1359 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1360 #else // defined(MATRIX_B_DEPTH) 1361 src_addr.s1 += get_global_id(2) * src1_stride_z;
1362 #endif // defined(MATRIX_B_DEPTH) 1370 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1375 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1377 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1382 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1384 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1389 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1393 for(; i <= ((int)COLS_A - 4); i += 4)
1396 float4 a0 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1397 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1398 float4 a1 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1399 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1400 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1401 float4 a2 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1402 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1403 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1404 float4 a3 = vload4(0, (__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1405 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1406 float4 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1407 src_addr.s1 += src1_stride_y;
1410 acc00 = fma(a0.s0, b0.s0, acc00);
1411 acc01 = fma(a0.s0, b0.s1, acc01);
1412 acc02 = fma(a0.s0, b0.s2, acc02);
1413 acc03 = fma(a0.s0, b0.s3, acc03);
1415 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1417 acc10 = fma(a1.s0, b0.s0, acc10);
1418 acc11 = fma(a1.s0, b0.s1, acc11);
1419 acc12 = fma(a1.s0, b0.s2, acc12);
1420 acc13 = fma(a1.s0, b0.s3, acc13);
1422 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1423 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1425 acc20 = fma(a2.s0, b0.s0, acc20);
1426 acc21 = fma(a2.s0, b0.s1, acc21);
1427 acc22 = fma(a2.s0, b0.s2, acc22);
1428 acc23 = fma(a2.s0, b0.s3, acc23);
1430 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1431 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1433 acc30 = fma(a3.s0, b0.s0, acc30);
1434 acc31 = fma(a3.s0, b0.s1, acc31);
1435 acc32 = fma(a3.s0, b0.s2, acc32);
1436 acc33 = fma(a3.s0, b0.s3, acc33);
1437 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1440 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1441 src_addr.s1 += src1_stride_y;
1444 acc00 = fma(a0.s1, b0.s0, acc00);
1445 acc01 = fma(a0.s1, b0.s1, acc01);
1446 acc02 = fma(a0.s1, b0.s2, acc02);
1447 acc03 = fma(a0.s1, b0.s3, acc03);
1449 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1451 acc10 = fma(a1.s1, b0.s0, acc10);
1452 acc11 = fma(a1.s1, b0.s1, acc11);
1453 acc12 = fma(a1.s1, b0.s2, acc12);
1454 acc13 = fma(a1.s1, b0.s3, acc13);
1456 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1457 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1459 acc20 = fma(a2.s1, b0.s0, acc20);
1460 acc21 = fma(a2.s1, b0.s1, acc21);
1461 acc22 = fma(a2.s1, b0.s2, acc22);
1462 acc23 = fma(a2.s1, b0.s3, acc23);
1464 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1465 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1467 acc30 = fma(a3.s1, b0.s0, acc30);
1468 acc31 = fma(a3.s1, b0.s1, acc31);
1469 acc32 = fma(a3.s1, b0.s2, acc32);
1470 acc33 = fma(a3.s1, b0.s3, acc33);
1471 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1474 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1475 src_addr.s1 += src1_stride_y;
1478 acc00 = fma(a0.s2, b0.s0, acc00);
1479 acc01 = fma(a0.s2, b0.s1, acc01);
1480 acc02 = fma(a0.s2, b0.s2, acc02);
1481 acc03 = fma(a0.s2, b0.s3, acc03);
1483 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1485 acc10 = fma(a1.s2, b0.s0, acc10);
1486 acc11 = fma(a1.s2, b0.s1, acc11);
1487 acc12 = fma(a1.s2, b0.s2, acc12);
1488 acc13 = fma(a1.s2, b0.s3, acc13);
1490 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1491 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1493 acc20 = fma(a2.s2, b0.s0, acc20);
1494 acc21 = fma(a2.s2, b0.s1, acc21);
1495 acc22 = fma(a2.s2, b0.s2, acc22);
1496 acc23 = fma(a2.s2, b0.s3, acc23);
1498 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1499 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1501 acc30 = fma(a3.s2, b0.s0, acc30);
1502 acc31 = fma(a3.s2, b0.s1, acc31);
1503 acc32 = fma(a3.s2, b0.s2, acc32);
1504 acc33 = fma(a3.s2, b0.s3, acc33);
1505 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1508 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1509 src_addr.s1 += src1_stride_y;
1512 acc00 = fma(a0.s3, b0.s0, acc00);
1513 acc01 = fma(a0.s3, b0.s1, acc01);
1514 acc02 = fma(a0.s3, b0.s2, acc02);
1515 acc03 = fma(a0.s3, b0.s3, acc03);
1517 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1519 acc10 = fma(a1.s3, b0.s0, acc10);
1520 acc11 = fma(a1.s3, b0.s1, acc11);
1521 acc12 = fma(a1.s3, b0.s2, acc12);
1522 acc13 = fma(a1.s3, b0.s3, acc13);
1524 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1525 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1527 acc20 = fma(a2.s3, b0.s0, acc20);
1528 acc21 = fma(a2.s3, b0.s1, acc21);
1529 acc22 = fma(a2.s3, b0.s2, acc22);
1530 acc23 = fma(a2.s3, b0.s3, acc23);
1532 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1533 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1535 acc30 = fma(a3.s3, b0.s0, acc30);
1536 acc31 = fma(a3.s3, b0.s1, acc31);
1537 acc32 = fma(a3.s3, b0.s2, acc32);
1538 acc33 = fma(a3.s3, b0.s3, acc33);
1539 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1541 src_addr.s0 += 4 *
sizeof(float);
1544 for(; i < (int)COLS_A; ++i)
1547 float a0 = *((__global
float *)(src0_ptr + src_addr.s0));
1548 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1549 float a1 = *((__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1550 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1551 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1552 float a2 = *((__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1553 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1554 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1555 float a3 = *((__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1556 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1558 float4 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
1559 src_addr.s1 += src1_stride_y;
1562 acc00 = fma(a0, b0.s0, acc00);
1563 acc01 = fma(a0, b0.s1, acc01);
1564 acc02 = fma(a0, b0.s2, acc02);
1565 acc03 = fma(a0, b0.s3, acc03);
1566 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1567 acc10 = fma(a1, b0.s0, acc10);
1568 acc11 = fma(a1, b0.s1, acc11);
1569 acc12 = fma(a1, b0.s2, acc12);
1570 acc13 = fma(a1, b0.s3, acc13);
1571 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1572 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1573 acc20 = fma(a2, b0.s0, acc20);
1574 acc21 = fma(a2, b0.s1, acc21);
1575 acc22 = fma(a2, b0.s2, acc22);
1576 acc23 = fma(a2, b0.s3, acc23);
1577 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1578 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1579 acc30 = fma(a3, b0.s0, acc30);
1580 acc31 = fma(a3, b0.s1, acc31);
1581 acc32 = fma(a3, b0.s2, acc32);
1582 acc33 = fma(a3, b0.s3, acc33);
1583 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1585 src_addr.s0 +=
sizeof(float);
1593 acc00 = acc00 * ALPHA;
1594 acc01 = acc01 * ALPHA;
1595 acc02 = acc02 * ALPHA;
1596 acc03 = acc03 * ALPHA;
1597 #endif // defined(ALPHA) 1600 __global uchar *dst_addr =
offset(&dst, 0, 0);
1603 dst_addr += get_global_id(2) * dst_stride_z;
1605 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
1606 vstore4(acc0, 0, (__global
float *)(dst_addr + 0 * dst_stride_y));
1608 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1610 acc10 = acc10 * ALPHA;
1611 acc11 = acc11 * ALPHA;
1612 acc12 = acc12 * ALPHA;
1613 acc13 = acc13 * ALPHA;
1614 #endif // defined(ALPHA) 1615 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
1616 vstore4(acc1, 0, (__global
float *)(dst_addr + 1 * dst_stride_y));
1617 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1618 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1620 acc20 = acc20 * ALPHA;
1621 acc21 = acc21 * ALPHA;
1622 acc22 = acc22 * ALPHA;
1623 acc23 = acc23 * ALPHA;
1624 #endif // defined(ALPHA) 1625 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
1626 vstore4(acc2, 0, (__global
float *)(dst_addr + 2 * dst_stride_y));
1627 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1628 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1630 acc30 = acc30 * ALPHA;
1631 acc31 = acc31 * ALPHA;
1632 acc32 = acc32 * ALPHA;
1633 acc33 = acc33 * ALPHA;
1634 #endif // defined(ALPHA) 1635 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
1636 vstore4(acc3, 0, (__global
float *)(dst_addr + 3 * dst_stride_y));
1637 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1678 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1681 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1684 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1687 src_addr.s1 += idx *
sizeof(float);
1690 src_addr.s0 += get_global_id(2) * src0_stride_z;
1692 #if defined(MATRIX_B_DEPTH) 1694 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1695 #else // defined(MATRIX_B_DEPTH) 1696 src_addr.s1 += get_global_id(2) * src1_stride_z;
1697 #endif // defined(MATRIX_B_DEPTH) 1703 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1706 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1707 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1710 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1711 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1714 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1718 for(; i <= ((int)COLS_A - 8); i += 8)
1721 float8 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0));
1724 float2 b0 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1725 src_addr.s1 += src1_stride_y;
1726 float2 b1 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1727 src_addr.s1 += src1_stride_y;
1728 float2 b2 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1729 src_addr.s1 += src1_stride_y;
1730 float2 b3 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1731 src_addr.s1 += src1_stride_y;
1732 float2 b4 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1733 src_addr.s1 += src1_stride_y;
1734 float2 b5 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1735 src_addr.s1 += src1_stride_y;
1736 float2 b6 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1737 src_addr.s1 += src1_stride_y;
1738 float2 b7 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1739 src_addr.s1 += src1_stride_y;
1742 acc00 = fma(a0.s0, b0.s0, acc00);
1743 acc00 = fma(a0.s1, b1.s0, acc00);
1744 acc00 = fma(a0.s2, b2.s0, acc00);
1745 acc00 = fma(a0.s3, b3.s0, acc00);
1746 acc00 = fma(a0.s4, b4.s0, acc00);
1747 acc00 = fma(a0.s5, b5.s0, acc00);
1748 acc00 = fma(a0.s6, b6.s0, acc00);
1749 acc00 = fma(a0.s7, b7.s0, acc00);
1751 acc01 = fma(a0.s0, b0.s1, acc01);
1752 acc01 = fma(a0.s1, b1.s1, acc01);
1753 acc01 = fma(a0.s2, b2.s1, acc01);
1754 acc01 = fma(a0.s3, b3.s1, acc01);
1755 acc01 = fma(a0.s4, b4.s1, acc01);
1756 acc01 = fma(a0.s5, b5.s1, acc01);
1757 acc01 = fma(a0.s6, b6.s1, acc01);
1758 acc01 = fma(a0.s7, b7.s1, acc01);
1760 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1761 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1762 acc10 = fma(a0.s0, b0.s0, acc10);
1763 acc10 = fma(a0.s1, b1.s0, acc10);
1764 acc10 = fma(a0.s2, b2.s0, acc10);
1765 acc10 = fma(a0.s3, b3.s0, acc10);
1766 acc10 = fma(a0.s4, b4.s0, acc10);
1767 acc10 = fma(a0.s5, b5.s0, acc10);
1768 acc10 = fma(a0.s6, b6.s0, acc10);
1769 acc10 = fma(a0.s7, b7.s0, acc10);
1771 acc11 = fma(a0.s0, b0.s1, acc11);
1772 acc11 = fma(a0.s1, b1.s1, acc11);
1773 acc11 = fma(a0.s2, b2.s1, acc11);
1774 acc11 = fma(a0.s3, b3.s1, acc11);
1775 acc11 = fma(a0.s4, b4.s1, acc11);
1776 acc11 = fma(a0.s5, b5.s1, acc11);
1777 acc11 = fma(a0.s6, b6.s1, acc11);
1778 acc11 = fma(a0.s7, b7.s1, acc11);
1779 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1780 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1781 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1782 acc20 = fma(a0.s0, b0.s0, acc20);
1783 acc20 = fma(a0.s1, b1.s0, acc20);
1784 acc20 = fma(a0.s2, b2.s0, acc20);
1785 acc20 = fma(a0.s3, b3.s0, acc20);
1786 acc20 = fma(a0.s4, b4.s0, acc20);
1787 acc20 = fma(a0.s5, b5.s0, acc20);
1788 acc20 = fma(a0.s6, b6.s0, acc20);
1789 acc20 = fma(a0.s7, b7.s0, acc20);
1791 acc21 = fma(a0.s0, b0.s1, acc21);
1792 acc21 = fma(a0.s1, b1.s1, acc21);
1793 acc21 = fma(a0.s2, b2.s1, acc21);
1794 acc21 = fma(a0.s3, b3.s1, acc21);
1795 acc21 = fma(a0.s4, b4.s1, acc21);
1796 acc21 = fma(a0.s5, b5.s1, acc21);
1797 acc21 = fma(a0.s6, b6.s1, acc21);
1798 acc21 = fma(a0.s7, b7.s1, acc21);
1799 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1800 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1801 a0 = vload8(0, (__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1802 acc30 = fma(a0.s0, b0.s0, acc30);
1803 acc30 = fma(a0.s1, b1.s0, acc30);
1804 acc30 = fma(a0.s2, b2.s0, acc30);
1805 acc30 = fma(a0.s3, b3.s0, acc30);
1806 acc30 = fma(a0.s4, b4.s0, acc30);
1807 acc30 = fma(a0.s5, b5.s0, acc30);
1808 acc30 = fma(a0.s6, b6.s0, acc30);
1809 acc30 = fma(a0.s7, b7.s0, acc30);
1811 acc31 = fma(a0.s0, b0.s1, acc31);
1812 acc31 = fma(a0.s1, b1.s1, acc31);
1813 acc31 = fma(a0.s2, b2.s1, acc31);
1814 acc31 = fma(a0.s3, b3.s1, acc31);
1815 acc31 = fma(a0.s4, b4.s1, acc31);
1816 acc31 = fma(a0.s5, b5.s1, acc31);
1817 acc31 = fma(a0.s6, b6.s1, acc31);
1818 acc31 = fma(a0.s7, b7.s1, acc31);
1819 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1821 src_addr.s0 +=
sizeof(float) * 8;
1824 for(; i < (int)COLS_A; ++i)
1827 float a0 = *((__global
float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1828 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1829 float a1 = *((__global
float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1830 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1831 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1832 float a2 = *((__global
float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1833 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1834 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1835 float a3 = *((__global
float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1836 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1838 float2 b0 = vload2(0, (__global
float *)(src1_ptr + src_addr.s1));
1839 src_addr.s1 += src1_stride_y;
1842 acc00 = fma(a0, b0.s0, acc00);
1843 acc01 = fma(a0, b0.s1, acc01);
1844 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1845 acc10 = fma(a1, b0.s0, acc10);
1846 acc11 = fma(a1, b0.s1, acc11);
1847 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1848 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1849 acc20 = fma(a2, b0.s0, acc20);
1850 acc21 = fma(a2, b0.s1, acc21);
1851 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1852 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1853 acc30 = fma(a3, b0.s0, acc30);
1854 acc31 = fma(a3, b0.s1, acc31);
1855 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1857 src_addr.s0 +=
sizeof(float);
1864 __global uchar *dst_addr =
offset(&dst, 0, 0);
1867 dst_addr += get_global_id(2) * dst_stride_z;
1871 acc00 = acc00 * ALPHA;
1872 acc01 = acc01 * ALPHA;
1873 #endif // defined(ALPHA) 1874 float2 acc0 = ((float2)(acc00, acc01));
1875 vstore2(acc0, 0, (__global
float *)(dst_addr + 0 * dst_stride_y));
1876 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1878 acc10 = acc10 * ALPHA;
1879 acc11 = acc11 * ALPHA;
1880 #endif // defined(ALPHA) 1881 float2 acc1 = ((float2)(acc10, acc11));
1882 vstore2(acc1, 0, (__global
float *)(dst_addr + 1 * dst_stride_y));
1883 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1884 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1886 acc20 = acc20 * ALPHA;
1887 acc21 = acc21 * ALPHA;
1888 #endif // defined(ALPHA) 1889 float2 acc2 = ((float2)(acc20, acc21));
1890 vstore2(acc2, 0, (__global
float *)(dst_addr + 2 * dst_stride_y));
1891 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1892 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1894 acc30 = acc30 * ALPHA;
1895 acc31 = acc31 * ALPHA;
1896 #endif // defined(ALPHA) 1897 float2 acc3 = (float2)(acc30, acc31);
1898 vstore2(acc3, 0, (__global
float *)(dst_addr + 3 * dst_stride_y));
1899 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1938 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1941 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1944 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1947 src_addr.s1 += idx *
sizeof(
half);
1950 src_addr.s0 += get_global_id(2) * src0_stride_z;
1952 #if defined(MATRIX_B_DEPTH) 1954 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1955 #else // defined(MATRIX_B_DEPTH) 1956 src_addr.s1 += get_global_id(2) * src1_stride_z;
1957 #endif // defined(MATRIX_B_DEPTH) 1960 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1962 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1963 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1965 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1966 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1968 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1971 for(; i <= ((int)COLS_A - 4); i += 4)
1974 half4 a0 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1975 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1976 half4 a1 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1977 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1978 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1979 half4 a2 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1980 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1981 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1982 half4 a3 = vload4(0, (__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1983 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1985 half8 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
1986 src_addr.s1 += src1_stride_y;
1989 acc0 = fma(b0, (half8)a0.s0, acc0);
1990 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1991 acc1 = fma(b0, (half8)a1.s0, acc1);
1992 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 1993 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1994 acc2 = fma(b0, (half8)a2.s0, acc2);
1995 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 1996 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 1997 acc3 = fma(b0, (half8)a3.s0, acc3);
1998 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2000 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
2001 src_addr.s1 += src1_stride_y;
2002 acc0 = fma(b0, (half8)a0.s1, acc0);
2003 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2004 acc1 = fma(b0, (half8)a1.s1, acc1);
2005 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2006 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2007 acc2 = fma(b0, (half8)a2.s1, acc2);
2008 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2009 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2010 acc3 = fma(b0, (half8)a3.s1, acc3);
2011 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2013 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
2014 src_addr.s1 += src1_stride_y;
2015 acc0 = fma(b0, (half8)a0.s2, acc0);
2016 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2017 acc1 = fma(b0, (half8)a1.s2, acc1);
2018 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2019 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2020 acc2 = fma(b0, (half8)a2.s2, acc2);
2021 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2022 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2023 acc3 = fma(b0, (half8)a3.s2, acc3);
2024 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2026 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
2027 src_addr.s1 += src1_stride_y;
2028 acc0 = fma(b0, (half8)a0.s3, acc0);
2029 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2030 acc1 = fma(b0, (half8)a1.s3, acc1);
2031 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2032 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2033 acc2 = fma(b0, (half8)a2.s3, acc2);
2034 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2035 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2036 acc3 = fma(b0, (half8)a3.s3, acc3);
2037 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2039 src_addr.s0 += 4 *
sizeof(
half);
2042 for(; i < (int)COLS_A; ++i)
2045 half a0 = *((__global
half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2046 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2047 half a1 = *((__global
half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2048 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2049 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2050 half a2 = *((__global
half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2051 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2052 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2053 half a3 = *((__global
half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2054 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2056 half8 b0 = vload8(0, (__global
half *)(src1_ptr + src_addr.s1));
2058 src_addr += (int2)(
sizeof(
half), src1_stride_y);
2061 acc0 = fma(b0, (half8)a0, acc0);
2062 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2063 acc1 = fma(b0, (half8)a1, acc1);
2064 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2065 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2066 acc2 = fma(b0, (half8)a2, acc2);
2067 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2068 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2069 acc3 = fma(b0, (half8)a3, acc3);
2070 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2077 __global uchar *dst_addr =
offset(&dst, 0, 0);
2080 dst_addr += get_global_id(2) * dst_stride_z;
2084 acc0 = acc0 * (half8)ALPHA;
2085 #endif // defined(ALPHA) 2086 vstore8(acc0, 0, (__global
half *)(dst_addr + 0 * dst_stride_y));
2087 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2089 acc1 = acc1 * (half8)ALPHA;
2090 #endif // defined(ALPHA) 2091 vstore8(acc1, 0, (__global
half *)(dst_addr + 1 * dst_stride_y));
2092 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2093 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2095 acc2 = acc2 * (half8)ALPHA;
2096 #endif // defined(ALPHA) 2097 vstore8(acc2, 0, (__global
half *)(dst_addr + 2 * dst_stride_y));
2098 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2099 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2101 acc3 = acc3 * (half8)ALPHA;
2102 #endif // defined(ALPHA) 2103 vstore8(acc3, 0, (__global
half *)(dst_addr + 3 * dst_stride_y));
2104 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2107 #if defined(FIXED_POINT_POSITION) 2144 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2147 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2150 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2153 src_addr.s1 += idx *
sizeof(char);
2156 src_addr.s0 += get_global_id(2) * src0_stride_z;
2158 #if defined(MATRIX_B_DEPTH) 2160 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2161 #else // defined(MATRIX_B_DEPTH) 2162 src_addr.s1 += get_global_id(2) * src1_stride_z;
2163 #endif // defined(MATRIX_B_DEPTH) 2165 int end_row_vec_a = src_addr.s0 + (COLS_A *
sizeof(char));
2169 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2172 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2173 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2176 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2177 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2180 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2183 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
2185 char2 a0 = vload2(0, (__global
char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2186 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2187 char2 a1 = vload2(0, (__global
char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2188 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2189 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2190 char2 a2 = vload2(0, (__global
char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2191 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2192 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2193 char2 a3 = vload2(0, (__global
char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2194 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2195 char16 b0 = vload16(0, (__global
char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2196 char16 b1 = vload16(0, (__global
char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
2198 acc00 =
mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
2199 acc00 =
mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
2200 acc01 =
mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2201 acc01 =
mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2202 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2203 acc10 =
mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
2204 acc10 =
mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
2205 acc11 =
mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2206 acc11 =
mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2207 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2208 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2209 acc20 =
mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
2210 acc20 =
mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
2211 acc21 =
mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2212 acc21 =
mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2213 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2214 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2215 acc30 =
mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
2216 acc30 =
mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
2217 acc31 =
mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2218 acc31 =
mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2219 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2223 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
2225 char a0 = *((__global
char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2226 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2227 char a1 = *((__global
char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2228 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2229 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2230 char a2 = *((__global
char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2231 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2232 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2233 char a3 = *((__global
char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2234 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2235 char16 b0 = vload16(0, (__global
char *)(src1_ptr + src_addr.s1));
2237 acc00 =
mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
2238 acc01 =
mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2239 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2240 acc10 =
mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
2241 acc11 =
mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
2242 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2243 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2244 acc20 =
mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
2245 acc21 =
mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
2246 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2247 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2248 acc30 =
mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
2249 acc31 =
mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
2250 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2257 __global uchar *dst_addr =
offset(&dst, 0, 0);
2260 dst_addr += get_global_id(2) * dst_stride_z;
2264 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
2266 acc_qs8 =
mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2267 #endif // defined(ALPHA) 2268 vstore16(acc_qs8, 0, (__global
char *)(dst_addr + 0 * dst_stride_y));
2269 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2270 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
2272 acc_qs8 =
mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2273 #endif // defined(ALPHA) 2274 vstore16(acc_qs8, 0, (__global
char *)(dst_addr + 1 * dst_stride_y));
2275 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2276 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2277 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
2279 acc_qs8 =
mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2280 #endif // defined(ALPHA) 2281 vstore16(acc_qs8, 0, (__global
char *)(dst_addr + 2 * dst_stride_y));
2282 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2283 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2284 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
2286 acc_qs8 =
mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2287 #endif // defined(ALPHA) 2288 vstore16(acc_qs8, 0, (__global
char *)(dst_addr + 3 * dst_stride_y));
2289 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2328 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2331 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2334 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2337 src_addr.s1 += idx *
sizeof(short);
2340 src_addr.s0 += get_global_id(2) * src0_stride_z;
2342 #if defined(MATRIX_B_DEPTH) 2344 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2345 #else // defined(MATRIX_B_DEPTH) 2346 src_addr.s1 += get_global_id(2) * src1_stride_z;
2347 #endif // defined(MATRIX_B_DEPTH) 2349 int end_row_vec_a = src_addr.s0 + (COLS_A *
sizeof(short));
2352 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2354 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2355 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2357 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2358 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2360 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2363 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)
sizeof(
short)); src_addr += (int2)(2 *
sizeof(
short), 2 * src1_stride_y))
2365 short2 a0 = vload2(0, (__global
short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2366 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2367 short2 a1 = vload2(0, (__global
short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2368 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2369 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2370 short2 a2 = vload2(0, (__global
short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2371 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2372 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2373 short2 a3 = vload2(0, (__global
short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2374 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2375 short8 b0 = vload8(0, (__global
short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2376 short8 b1 = vload8(0, (__global
short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
2378 acc0 =
mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
2379 acc0 =
mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
2380 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2381 acc1 =
mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
2382 acc1 =
mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
2383 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2384 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2385 acc2 =
mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
2386 acc2 =
mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
2387 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2388 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2389 acc3 =
mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
2390 acc3 =
mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
2391 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2395 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(
sizeof(
short), src1_stride_y))
2397 short a0 = *((__global
short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2398 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2399 short a1 = *((__global
short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2400 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2401 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2402 short a2 = *((__global
short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2403 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2404 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2405 short a3 = *((__global
short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2406 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2407 short8 b0 = vload8(0, (__global
short *)(src1_ptr + src_addr.s1));
2410 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2412 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2413 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2415 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2416 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2418 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2425 __global uchar *dst_addr =
offset(&dst, 0, 0);
2428 dst_addr += get_global_id(2) * dst_stride_z;
2432 acc_qs16 = convert_short8_sat(acc0);
2434 acc_qs16 =
mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2435 #endif // defined(ALPHA) 2436 vstore8(acc_qs16, 0, (__global
short *)(dst_addr + 0 * dst_stride_y));
2437 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2438 acc_qs16 = convert_short8_sat(acc1);
2440 acc_qs16 =
mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2441 #endif // defined(ALPHA) 2442 vstore8(acc_qs16, 0, (__global
short *)(dst_addr + 1 * dst_stride_y));
2443 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 2444 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2445 acc_qs16 = convert_short8_sat(acc2);
2447 acc_qs16 =
mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2448 #endif // defined(ALPHA) 2449 vstore8(acc_qs16, 0, (__global
short *)(dst_addr + 2 * dst_stride_y));
2450 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 2451 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2452 acc_qs16 = convert_short8_sat(acc3);
2454 acc_qs16 =
mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2455 #endif // defined(ALPHA) 2456 vstore8(acc_qs16, 0, (__global
short *)(dst_addr + 3 * dst_stride_y));
2457 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 2459 #endif // defined(FIXED_POINT_POSITION) 2460 #endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y) 2488 float4 alpha_ab = vload4(0, (__global
float *)dst.
ptr);
2491 float4
c = vload4(0, (__global
float *)src.
ptr);
2494 float4 out = alpha_ab + (float4)BETA * c;
2497 vstore4(out, 0, (__global
float *)dst.
ptr);
2525 half8 alpha_ab = vload8(0, (__global
half *)dst.
ptr);
2528 half8 c = vload8(0, (__global
half *)src.
ptr);
2531 half8 out = alpha_ab + (half8)BETA * c;
2534 vstore8(out, 0, (__global
half *)dst.
ptr);
2537 #if defined(FIXED_POINT_POSITION) 2565 char16 alpha_ab = vload16(0, (__global
char *)dst.
ptr);
2568 char16 c = vload16(0, (__global
char *)src.
ptr);
2571 char16 out =
mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
2574 vstore16(out, 0, (__global
char *)dst.
ptr);
2604 short8 alpha_ab = vload8(0, (__global
short *)dst.
ptr);
2607 short8 c = vload8(0, (__global
short *)src.
ptr);
2610 short8 out =
mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
2613 vstore8(out, 0, (__global
short *)dst.
ptr);
2615 #endif // defined(FIXED_POINT_POSITION) 2616 #endif // defined(BETA) 2618 #if defined(WIDTH_VECTOR_A) 2650 int idx = get_global_id(0) * 4;
2651 int idy = get_global_id(1);
2654 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2655 src_addr.s1 += idx *
sizeof(float);
2657 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A *
sizeof(float));
2661 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)
sizeof(
float)); src_addr += (int2)(2 *
sizeof(
float), 2 * src1_stride_y))
2663 float2 a0 = vload2(0, (__global
float *)(src0_ptr + src_addr.s0));
2664 float4 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
2665 float4 b1 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2667 acc += b0 * (float4)a0.s0;
2668 acc += b1 * (float4)a0.s1;
2671 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(
sizeof(
float), src1_stride_y))
2673 float a0 = *((__global
float *)(src0_ptr + src_addr.s0));
2674 float4 b0 = vload4(0, (__global
float *)(src1_ptr + src_addr.s1));
2676 acc += b0 * (float4)a0;
2682 vstore4(acc, 0, (__global
float *)(
offset(&dst, 0, 0)));
2684 #endif // defined(WIDTH_VECTOR_A) 2702 #if defined(DATA_TYPE) && defined(VECTOR_SIZE) 2703 __kernel
void gemm_accumulate_biases(
2714 biases_value =
VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
2715 #ifdef FIXED_POINT_POSITION 2716 accum_value =
ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
2717 #else // FIXED_POINT_POSITION 2718 accum_value = biases_value + accum_value;
2719 #endif // FIXED_POINT_POSITION 2722 (accum_value, 0, (__global DATA_TYPE *)accum.
ptr);
2724 #endif // defined(DATA_TYPE) && defined(VECTOR_SIZE) Structure to hold Vector information.
qs8x16 mul_sat_qs8x16(qs8x16 VopA, qs8x16 VopB, int fixed_point_position)
qs16x8 mul_sat_qs16x8(qs16x8 VopA, qs16x8 VopB, int fixed_point_position)
#define CONVERT_TO_TENSOR3D_STRUCT(name)
#define CONVERT_TO_VECTOR_STRUCT(name)
half_float::half half
16-bit floating point type
#define IMAGE_DECLARATION(name)
Structure to hold 3D tensor information.
qs8x16 mla_sat_qs8x16(qs8x16 VopA, qs8x16 VopB, qs8x16 VopC, int fixed_point_position)
__global uchar * offset(const Image *img, int x, int y)
Get the pointer position of a Image.
qs16x8 mla_sat_qs16x8(qs16x8 VopA, qs16x8 VopB, qs16x8 VopC, int fixed_point_position)
#define CONVERT_TO_IMAGE_STRUCT(name)
qs32x8 mlal_sat_qs16x8(qs32x8 VopA, qs16x8 VopB, qs16x8 VopC, int fixed_point_position)
#define VECTOR_DECLARATION(name)
Structure to hold Image information.
#define TENSOR3D_DECLARATION(name)
__global uchar * ptr
Pointer to the starting postion of the buffer.
#define VEC_DATA_TYPE(type, size)
qs16x8 mlal_sat_qs8x8(qs16x8 VopA, qs8x8 VopB, qs8x8 VopC, int fixed_point_position)
__global uchar * ptr
Pointer to the starting postion of the buffer.
#define ADD_SAT_OP_EXPAND(a, b, type, size)
convolution configure & src