1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
21 #include "mkldnn_types.h"
25 #include "jit_avx512_core_gemm_s8u8s32.hpp"
26 #include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
27 #include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
44 static inline void round_to_nearest(int32_t *rounded_val, double fp_val) {
47 if (fp_val > INT32_MAX) {
52 if (fp_val < INT32_MIN) {
56 *rounded_val = (int32_t) fp_val;
59 static inline void add_results(const dim_t m, const dim_t n, const dim_t k,
60 const float alpha, const float beta, const int32_t *c_partial_sum,
61 const dim_t ldcp, int32_t *c_data, const dim_t ldc,
62 const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao,
63 const int8_t bo, const int32_t *co, const int offsetc)
65 for (dim_t j = 0; j < n; ++j) {
66 for (dim_t i = 0; i < m; ++i) {
67 int32_t ctemp = c_partial_sum[i + j * ldcp];
71 c_data[i + j * ldc] = ctemp;
73 double c_float = (double) beta
74 * (double) c_data[i + j * ldc];
75 c_float += (double) ctemp;
76 round_to_nearest(&c_data[i + j * ldc], c_float);
78 } else if (alpha == -1.0f) {
80 c_data[i + j * ldc] = -ctemp;
82 double c_float = (double) beta
83 * (double) c_data[i + j * ldc];
84 c_float -= (double) ctemp;
85 round_to_nearest(&c_data[i + j * ldc], c_float);
89 double c_float = alpha * (double) ctemp;
90 round_to_nearest(&c_data[i + j * ldc], c_float);
92 double c_float = alpha * (double) ctemp +
93 beta * (double) c_data[i + j * ldc];
94 round_to_nearest(&c_data[i + j * ldc], c_float);
98 if (offsetc == FIX_OFFSET) {
99 c_data[i + j * ldc] += co[0];
100 } else if (offsetc == ROW_OFFSET) {
101 c_data[i + j * ldc] += co[j];
102 } else if (offsetc == COL_OFFSET) {
103 c_data[i + j * ldc] += co[i];
109 // TODO Find a better place for those functions.
110 static inline dim_t ld_padd(const dim_t x)
112 return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t)))
113 * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t));
116 void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k,
117 const int8_t *a, const uint8_t *b, float beta, int32_t *c,
118 const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum,
119 const int32_t *co, const int offsetc, const blas_t *arg)
123 int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0];
125 // Since m and n are limited by blocking, stack overflow may not happen;
127 #if !defined(_MSC_VER)
128 int32_t col_offset[m];
129 int32_t row_offset[n];
131 int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m);
132 int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n);
138 if ((bo != 0) || (offsetc == COL_OFFSET))
140 if ((ao != 0) || (offsetc == ROW_OFFSET))
143 // It needs one of colum or row offsets, but it doesn't need both
144 if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) {
145 if ((col_req == 0) && (row_req == 0)) {
155 for (dim_t i = 0; i < m; i++)
158 if (offsetc == COL_OFFSET) {
159 for (dim_t i = 0; i < m; i++)
160 col_offset[i] += co[i];
164 for (dim_t i = 0; i < m; i++)
165 col_offset[i] += bo * a_row_sum[i];
170 for (dim_t i = 0; i < n; i++)
173 if (offsetc == ROW_OFFSET) {
174 for (dim_t i = 0; i < n; i++)
175 row_offset[i] += co[i];
179 for (dim_t i = 0; i < n; i++)
180 row_offset[i] += ao * b_col_sum[i];
184 if ((offsetc == FIX_OFFSET) && (co_0 != 0)) {
186 for (dim_t i = 0; i < m; i++)
187 col_offset[i] += co_0;
189 for (dim_t i = 0; i < n; i++)
190 row_offset[i] += co_0;
194 if ((ao != 0) && (bo != 0)) {
196 for (dim_t i = 0; i < m; i++)
197 col_offset[i] += (int32_t) k * ao * bo;
199 for (dim_t i = 0; i < n; i++)
200 row_offset[i] += (int32_t) k * ao * bo;
207 arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
210 arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
215 arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
218 arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
225 arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
228 arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
233 arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
236 arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
243 static inline void *align(void *ptr, size_t alignment)
245 return (void *) utils::rnd_up((uintptr_t) ptr, alignment);
248 static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k,
249 const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co,
252 dim_t lda = arg->lda;
253 dim_t ldb = arg->ldb;
254 dim_t ldc = arg->ldc;
257 float alpha = *arg->alpha;
258 float beta = *arg->beta;
260 if (m <= 0 || n <= 0) {
264 // Padding along K dimension.
266 if (k <= arg->bk_traditional) {
267 k_padd = utils::rnd_up(k, arg->uk);
268 k_padd = nstl::max(128LL, k_padd);
269 } else if (k < 2 * arg->bk) {
270 k_padd = utils::rnd_up(k / 2, arg->uk);
275 // Padding along M dimension.
276 dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
279 // Padding along N dimension.
281 if (k < arg->blocking_small_k) {
282 n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
283 arg->bn_small_k), arg->un);
285 n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
289 // Padding for temporary buffer for C
290 dim_t ldc_buf = ld_padd(m_padd);
292 dim_t strideAm = (arg->transa == 0)? 1 : lda;
293 dim_t strideAn = (arg->transa != 0)? 1 : lda;
294 dim_t strideBm = (arg->transb == 0)? 1 : ldb;
295 dim_t strideBn = (arg->transb != 0)? 1 : ldb;
297 size_t a_buf_nelems = m_padd * k_padd;
298 size_t b_buf_nelems = k_padd * n_padd;
299 size_t a_row_sum_nelems = m_padd;
300 size_t b_col_sum_nelems = n_padd;
302 size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
303 + b_buf_nelems * sizeof(*b) + PAGE_4K
304 + a_row_sum_nelems * sizeof(*c) + PAGE_4K
305 + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
307 bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
309 size_t c_buf_nelems = ldc_buf * n_padd;
310 mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
313 char *mem = (char *) malloc(mem_size, 128);
319 int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
320 uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K);
321 int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
322 int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems,
325 int32_t *bufferC = NULL;
327 bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
330 float beta_saved = beta;
332 int a_block_copied = 0;
334 for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
340 for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
345 // Scale C blocks by beta only for the first time
351 // Apply C offset when to the last k-block of the partial sum.
352 int offsetc = NO_OFFSET;
354 offsetc = arg->offsetc;
357 for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
362 const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn;
363 arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL,
367 for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
369 if (sizeUM > arg->um)
373 * Use the whole A buffer only if we have multiple B blocks
374 * for k-dimension, otherwise we are wasting cache to store
381 const int8_t *a_block = a + (Bm + Um) * strideAm
383 if (!a_block_copied) {
384 arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL,
385 bufferA + Um_forA * sizeK, NULL, NULL,
386 a_row_sum + Um_forA);
389 int32_t *c_block = c + (Bm + Um) + Bn * ldc;
391 if (offsetc == FIX_OFFSET) {
393 } else if (offsetc == ROW_OFFSET) {
395 } else if (offsetc == COL_OFFSET) {
399 igemm_inner_kernel(sizeUM, sizeN, sizeK,
400 bufferA + Um_forA * sizeK, bufferB, 0.0f,
401 bufferC + Um, ldc_buf, a_row_sum + Um_forA,
402 b_col_sum, NULL, NO_OFFSET, arg);
404 // Finish the block adding the necessary alpha, beta
406 add_results(sizeUM, sizeN, sizeK, alpha, beta,
407 bufferC + Um, ldc_buf, c_block, ldc,
408 a_row_sum + Um_forA, b_col_sum, ao, bo,
409 co + co_stride, offsetc);
411 igemm_inner_kernel(sizeUM, sizeN, sizeK,
412 bufferA + Um_forA * sizeK, bufferB, beta,
413 c_block, ldc, a_row_sum + Um_forA, b_col_sum,
414 co + co_stride, offsetc, arg);
428 static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n,
429 const dim_t k, const int8_t *bufferA, const uint8_t *b,
430 const float beta, int32_t *c, const int offsetc, const int32_t *co,
431 const int32_t *a_row_sum, const blas_t *arg)
433 dim_t ldb = arg->ldb;
434 dim_t ldc = arg->ldc;
437 float alpha = *arg->alpha;
439 if (m <= 0 || n <= 0) {
443 // Padding along N dimension.
445 if (k < arg->blocking_small_k) {
446 n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
447 arg->bn_small_k), arg->un);
449 n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
453 // Padding for temporary buffer for C
454 dim_t ldc_buf = ld_padd(m);
456 dim_t strideBn = (arg->transb != 0)? 1 : ldb;
458 size_t b_buf_nelems = k * n_padd;
459 size_t b_col_sum_nelems = n_padd;
461 size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K
462 + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
464 bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
466 size_t c_buf_nelems = ldc_buf * n_padd;
467 mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
470 char *mem = (char *) malloc(mem_size, 128);
476 uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K);
477 int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
479 int32_t *bufferC = NULL;
481 bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
485 for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
490 // Implement the kernel here.
491 const uint8_t *b_block = b + Bn * strideBn;
492 arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL,
496 if (offsetc == FIX_OFFSET) {
498 } else if (offsetc == ROW_OFFSET) {
500 } else if (offsetc == COL_OFFSET) {
503 int32_t *c_block = c + Bn * ldc;
505 igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC,
506 ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg);
508 // Finish the block adding the necessary alpha, beta and offsets.
509 add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block,
510 ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride,
513 igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block,
514 ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
524 #define N2D_MAX_AVX512 384
525 #define M2D_MIN_AVX512 384
528 static inline void set_thread_opts_avx512(int *p_nthrs,
529 blas_thread_t *thread_info, const blas_t *arg)
531 int nthrs = *p_nthrs;
535 thread_info->nthrs_m = 0;
536 thread_info->nthrs_n = 0;
537 thread_info->copy_type = COPY_NONE; // By default don't do parallel copy.
539 int condition_2D_bsrc = -1;
540 if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) {
541 condition_2D_bsrc = 1;
543 condition_2D_bsrc = 0;
546 int condition_1D_copya = 0;
547 if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) {
548 condition_2D_bsrc = 0;
549 condition_1D_copya = 1;
552 // If offset is non-zero, we need to keep 1D_copya to reduce update overhead
553 if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0
554 || arg->offsetc != FIX_OFFSET) {
555 condition_2D_bsrc = 0;
556 condition_1D_copya = 1;
559 if (condition_2D_bsrc == 1) {
563 while ((nthrs_n % 2 == 0) &&
564 (n / nthrs > N2D_MAX_AVX512 ||
565 n / nthrs_n <= N2D_MAX_AVX512 / 2) &&
566 (m / nthrs_m >= 2 * M2D_MIN_AVX512) &&
572 thread_info->nthrs_m = nthrs_m;
573 thread_info->nthrs_n = nthrs_n;
574 thread_info->partition = PARTITION_2D;
576 // Reset the total number of threads that will be used.
577 *p_nthrs = nthrs_m * nthrs_n;
579 } else if (condition_1D_copya && mkldnn_thr_syncable()) {
580 // Use parallel copy A algorithm
581 thread_info->copy_type = COPY_A;
582 thread_info->partition = PARTITION_1D_COL;
584 if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) {
585 thread_info->partition = PARTITION_1D_ROW;
587 thread_info->partition = PARTITION_1D_COL;
591 #undef N2D_MAX_AVX512
592 #undef M2D_MIN_AVX512
596 static inline void partition_1d(const int ithr, const int nthrs, const dim_t n,
597 dim_t *t_offset, dim_t *t_block)
599 dim_t band = n / nthrs;
601 dim_t tail = n - (nthrs - 1) * band;
602 if (tail > (band + 1))
604 tail = n - (nthrs - 1) * band;
606 if (ithr < (nthrs - 1))
611 *t_offset = ithr * band;
613 if (*t_offset >= n) {
616 } else if ((*t_offset + *t_block) > n) {
617 *t_block = n - *t_offset;
621 static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i,
622 const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m,
623 const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp,
626 dim_t m_disp = 0, n_disp = 0;
627 dim_t m_band = 0, n_band = 0;
632 dim_t m_bandt = m / mdiv; /* size per thread */
633 dim_t n_bandt = n / ndiv; /* size per thread */
634 int firstmgroup = mdiv - 1;
635 int firstngroup = ndiv - 1;
636 dim_t firstmval = m_bandt;
637 dim_t firstnval = n_bandt;
639 int mthr_used = mdiv;
640 if (m - (mdiv - 1) * m_bandt > m_bandt + 1) {
641 if (m - (mdiv - 1) * m_bandt > mdiv)
644 firstmval = m_bandt + 1;
645 mthr_used = (int) (m / firstmval);
647 if (mthr_used * firstmval < m)
650 firstmgroup = mthr_used - 1;
653 int nthr_used = ndiv;
654 if (n - (ndiv - 1) * n_bandt > n_bandt + 1) {
655 firstnval = n_bandt + 1;
656 nthr_used = (int) (n / firstnval);
658 if (nthr_used * firstnval < n)
661 firstngroup = nthr_used - 1;
664 *nthrs = mthr_used * nthr_used;
667 if (ithr_i < firstmgroup) {
669 m_disp = ithr_i * firstmval;
670 } else if (ithr_i <= mthr_used - 2) {
672 m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt;
674 m_disp = firstmgroup * firstmval
675 + (mthr_used - 1 - firstmgroup) * m_bandt;
676 m_band = nstl::max(0LL, m - m_disp);
679 if (ithr_j < firstngroup) {
681 n_disp = ithr_j * firstnval;
682 } else if (ithr_j <= nthr_used - 2) {
684 n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt;
686 n_disp = firstngroup * firstnval
687 + (nthr_used - 1 - firstngroup) * n_bandt;
688 n_band = nstl::max(0LL, n - n_disp);
690 m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL);
691 n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL);
709 static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m,
710 dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c,
711 const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg)
713 dim_t strideAm = (arg->transa == 0)? 1 : arg->lda;
714 dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb;
715 int offsetc = arg->offsetc;
717 switch (thread_info->partition) {
718 case PARTITION_1D_ROW:
722 partition_1d(ithr, *nthrs, arg->m, &offset, &block);
729 *a = arg->a + offset * strideAm;
735 *c = arg->c + offset;
737 // Set offset vector for C matrix
739 if (offsetc == FIX_OFFSET) {
741 } else if (offsetc == ROW_OFFSET) {
743 } else if (offsetc == COL_OFFSET) {
746 *co = arg->co + co_stride;
750 case PARTITION_1D_COL:
754 partition_1d(ithr, *nthrs, arg->n, &offset, &block);
764 *b = arg->b + offset * strideBn;
767 *c = arg->c + offset * arg->ldc;
769 // Set offset vector for C matrix
771 if (offsetc == FIX_OFFSET) {
773 } else if (offsetc == ROW_OFFSET) {
775 } else if (offsetc == COL_OFFSET) {
778 *co = arg->co + co_stride;
782 case PARTITION_2D_COL_MAJOR:
784 int nthrs_m = thread_info->nthrs_m;
785 int nthrs_n = thread_info->nthrs_n;
786 int ithr_i = ithr % nthrs_m;
787 int ithr_j = ithr / nthrs_m;
794 partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n,
795 arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band);
802 *a = arg->a + m_disp * strideAm;
805 *b = arg->b + n_disp * strideBn;
808 *c = arg->c + m_disp + n_disp * arg->ldc;
810 // Set offset vector for C matrix
812 if (offsetc == FIX_OFFSET) {
814 } else if (offsetc == ROW_OFFSET) {
816 } else if (offsetc == COL_OFFSET) {
819 *co = arg->co + co_stride;
825 #define MULTIPLIER 10
826 static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m,
827 const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b,
828 int32_t *c, const int32_t *co, const blas_t *arg,
831 const dim_t lda = arg->lda;
832 const dim_t ldb = arg->ldb;
833 const dim_t strideAm = (arg->transa == 0)? 1 : lda;
834 const dim_t strideAn = (arg->transa != 0)? 1 : lda;
835 const dim_t strideBm = (arg->transb == 0)? 1 : ldb;
837 // Padding along M dimension.
838 dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
841 // Padding along K dimension.
843 if (k <= arg->bk_traditional) {
844 k_padd = utils::rnd_up(k, arg->uk);
845 k_padd = nstl::max(128LL, k_padd);
846 } else if (k < 2 * arg->bk) {
847 k_padd = utils::rnd_up(k / 2, arg->uk);
852 m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs;
854 m_padd = utils::rnd_up(m, arg->um);
857 size_t a_buf_nelems = m_padd * k_padd;
859 // Allocate shared memory for A and its row sum buffers in master thread.
860 if (ithr == 0) { // If thread master
861 size_t a_row_sum_nelems = m_padd;
863 size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K)
864 + a_row_sum_nelems * sizeof(*c) + PAGE_4K;
866 *p_shared_mem = (char *) malloc(mem_size, 128);
869 mkldnn_thr_barrier();
871 char *mem = *p_shared_mem;
872 int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
873 int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K);
879 int result = 0; // Return status
882 for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
887 // Scale C blocks by beta only for the first term of partial sum.
892 // Apply C offset for the last k-block of the partial sum.
893 int offsetc = NO_OFFSET;
895 offsetc = arg->offsetc;
898 for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
904 dim_t band = (sizeM + nthrs - 1) / nthrs;
905 band = utils::rnd_up(band, arg->um);
907 dim_t offset = band * ithr;
909 // If offset is too large don't use that thread for copying.
910 if (offset >= sizeM) {
915 // Handle the tail of the copy.
916 if (offset + band > sizeM) {
917 band = sizeM - offset;
921 const int8_t *a_block = a + (Bm + offset) * strideAm
923 arg->copyA(&sizeK, &band, a_block, &lda, NULL,
924 bufferA + offset * sizeK, NULL, NULL,
928 mkldnn_thr_barrier(); // Wait for finishing parallel copy.
930 const uint8_t *b_block = b + Bk * strideBm;
931 int32_t *c_block = c + Bm;
933 if (offsetc == FIX_OFFSET) {
935 } else if (offsetc == ROW_OFFSET) {
937 } else if (offsetc == COL_OFFSET) {
941 result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK,
942 bufferA, b_block, beta, c_block, offsetc, co + co_stride,
945 mkldnn_thr_barrier(); // Wait for kernel computations to finish.
949 // Free memory allocated in master thread
958 static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k,
959 double fp_per_cycle, int *nthrs)
961 double omp_overhead_small_core = 3.0e+3;
962 double omp_intercept_big_core = 4.0e+3;
963 double omp_slope_big_core = 5.0e+2;
965 double gemm_cycles = 8.0 * m * n * k / fp_per_cycle;
969 // Use a different model for omp overheads if nthrs is <= 4
970 if (*nthrs <= 4 && omp_overhead_small_core > 0) {
971 double omp_cycles = omp_overhead_small_core;
972 if (gemm_cycles < omp_cycles) {
977 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
982 if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
987 // adaptive decrement to march faster·
989 double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
990 if (omp_cycles * i < gemm_cycles * (i - 1))
1008 #define CACHE_LINE_SIZE 64
1009 static int gemm_threading_driver(blas_t *arg)
1011 if ((arg->m <= 0) || (arg->n <= 0))
1012 return mkldnn_success;
1014 if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) {
1015 return mkldnn_success;
1018 int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
1019 get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr);
1022 return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b,
1023 arg->c, arg->co, arg);
1026 int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE,
1033 for (int i = 0; i < nthr; i++) {
1034 results[i * CACHE_LINE_SIZE] = 0; // Initialize to success
1037 char *shared_mem = NULL;
1039 parallel(nthr, [&](const int ithr, const int nthr) {
1042 results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a,
1043 arg->b, arg->c, arg->co, arg);
1045 blas_thread_t thread_info;
1046 set_thread_opts_avx512(&nthrs, &thread_info, arg);
1048 const int8_t *a = NULL;
1049 const uint8_t *b = NULL;
1051 const int32_t *co = NULL;
1055 decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co,
1059 switch (thread_info.copy_type) {
1061 results[ithr * CACHE_LINE_SIZE] =
1062 parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg,
1068 results[ithr * CACHE_LINE_SIZE] =
1069 gemm_kernel_driver(m, n, k, a, b, c, co, arg);
1076 int result = 0; // Initialize to success
1077 for (int i = 0; i < nthr; i++) {
1078 if (results[i] != 0) {
1079 result = results[i * CACHE_LINE_SIZE];
1088 #undef CACHE_LINE_SIZE
1090 static jit_avx512_core_u8_copy_an_kern *copy_an;
1091 static jit_avx512_core_u8_copy_at_kern *copy_at;
1092 static jit_avx512_core_u8_copy_bn_kern *copy_bn;
1093 static jit_avx512_core_u8_copy_bt_kern *copy_bt;
1094 static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an;
1095 static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at;
1096 static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn;
1097 static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt;
1098 static jit_avx512_core_gemm_s8u8s32_kern *kernel;
1099 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b;
1100 static jit_avx512_core_gemm_s8u8s32_kern *kernel_r;
1101 static jit_avx512_core_gemm_s8u8s32_kern *kernel_c;
1102 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0;
1103 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b;
1104 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r;
1105 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c;
1106 static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel;
1107 static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel;
1109 static void jit_init(blas_t *arg)
1111 static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a,
1112 const dim_t *lda, const int8_t *alpha, int8_t *b,
1113 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1115 static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a,
1116 const dim_t *lda, const int8_t *alpha, int8_t *b,
1117 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1119 static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
1120 const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1121 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1123 static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
1124 const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1125 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1127 static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a,
1128 const dim_t *lda, const int8_t *alpha, int8_t *b,
1129 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1131 static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a,
1132 const dim_t *lda, const int8_t *alpha, int8_t *b,
1133 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1135 static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
1136 const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1137 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1139 static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
1140 const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1141 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1143 static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k,
1144 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1145 const dim_t ldc, const int32_t *col_offset,
1146 const int32_t *row_offset);
1148 static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k,
1149 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1150 const dim_t ldc, const int32_t *col_offset,
1151 const int32_t *row_offset);
1153 static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k,
1154 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1155 const dim_t ldc, const int32_t *col_offset,
1156 const int32_t *row_offset);
1158 static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k,
1159 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1160 const dim_t ldc, const int32_t *col_offset,
1161 const int32_t *row_offset);
1163 static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
1164 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1165 const dim_t ldc, const int32_t *col_offset,
1166 const int32_t *row_offset);
1168 static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
1169 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1170 const dim_t ldc, const int32_t *col_offset,
1171 const int32_t *row_offset);
1173 static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
1174 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1175 const dim_t ldc, const int32_t *col_offset,
1176 const int32_t *row_offset);
1178 static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
1179 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1180 const dim_t ldc, const int32_t *col_offset,
1181 const int32_t *row_offset);
1183 static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float,
1184 const int8_t*, const dim_t, const uint8_t*,
1185 const float, int32_t*);
1187 static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float,
1188 const uint8_t*, const dim_t, const int8_t*,
1189 const float, int32_t*);
1191 if (mayiuse(avx512_core_vnni)) {
1192 arg->um = AVX512_UNROLL_M;
1193 arg->un = AVX512_UNROLL_N;
1194 arg->uk = AVX512_UNROLL_K;
1195 arg->bm = AVX512_BM;
1196 arg->bn = AVX512_BN;
1197 arg->bk = AVX512_BK_VNNI;
1199 arg->bk_traditional = AVX512_BK_TRADITIONAL;
1200 arg->bn_small_k = AVX512_BN_SMALL_K;
1201 arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
1203 arg->um = AVX512_UNROLL_M;
1204 arg->un = AVX512_UNROLL_N;
1205 arg->uk = AVX512_UNROLL_K;
1206 arg->bm = AVX512_BM;
1207 arg->bn = AVX512_BN;
1208 arg->bk = AVX512_BK;
1210 arg->bk_traditional = AVX512_BK_TRADITIONAL;
1211 arg->bn_small_k = AVX512_BN_SMALL_K;
1212 arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
1215 static std::once_flag initialized;
1216 std::call_once(initialized, []{
1218 copy_an = new jit_avx512_core_u8_copy_an_kern();
1219 copy_at = new jit_avx512_core_u8_copy_at_kern();
1220 copy_bn = new jit_avx512_core_u8_copy_bn_kern();
1221 copy_bt = new jit_avx512_core_u8_copy_bt_kern();
1223 copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern();
1224 copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern();
1225 copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern();
1226 copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern();
1228 kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false);
1229 kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true);
1230 kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true);
1231 kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false);
1232 kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false);
1233 kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true);
1234 kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true);
1235 kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false);
1237 gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
1238 gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
1241 copyAn = copy_an->getCode<int (*)(const dim_t *, const dim_t *,
1242 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1243 const dim_t *, const dim_t *, int32_t *)>();
1245 copyAt = copy_at->getCode<int (*)(const dim_t *, const dim_t *,
1246 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1247 const dim_t *, const dim_t *, int32_t *)>();
1249 copyBn = copy_bn->getCode<int (*)(const dim_t *, const dim_t *,
1250 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1251 const dim_t *, const dim_t *, int32_t *)>();
1253 copyBt = copy_bt->getCode<int (*)(const dim_t *, const dim_t *,
1254 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1255 const dim_t *, const dim_t *, int32_t *)>();
1257 copySumAn = copy_sum_an->getCode<int (*)(const dim_t *, const dim_t *,
1258 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1259 const dim_t *, const dim_t *, int32_t *)>();
1261 copySumAt = copy_sum_at->getCode<int (*)(const dim_t *, const dim_t *,
1262 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1263 const dim_t *, const dim_t *, int32_t *)>();
1265 copySumBn = copy_sum_bn->getCode<int (*)(const dim_t *, const dim_t *,
1266 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1267 const dim_t *, const dim_t *, int32_t *)>();
1269 copySumBt = copy_sum_bt->getCode<int (*)(const dim_t *, const dim_t *,
1270 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1271 const dim_t *, const dim_t *, int32_t *)>();
1273 kern = kernel->getCode<int (*)(const dim_t *, const dim_t *,
1274 const dim_t *, const float *, const int8_t *, const uint8_t *,
1275 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1277 kern_b = kernel_b->getCode<int (*)(const dim_t *, const dim_t *,
1278 const dim_t *, const float *, const int8_t *, const uint8_t *,
1279 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1281 kern_r = kernel_r->getCode<int (*)(const dim_t *, const dim_t *,
1282 const dim_t *, const float *, const int8_t *, const uint8_t *,
1283 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1285 kern_c = kernel_c->getCode<int (*)(const dim_t *, const dim_t *,
1286 const dim_t *, const float *, const int8_t *, const uint8_t *,
1287 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1289 kern_b0 = kernel_b0->getCode<int (*)(const dim_t *, const dim_t *,
1290 const dim_t *, const float *, const int8_t *, const uint8_t *,
1291 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1293 kern_b0_b = kernel_b0_b->getCode<int (*)(const dim_t *, const dim_t *,
1294 const dim_t *, const float *, const int8_t *, const uint8_t *,
1295 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1297 kern_b0_r = kernel_b0_r->getCode<int (*)(const dim_t *, const dim_t *,
1298 const dim_t *, const float *, const int8_t *, const uint8_t *,
1299 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1301 kern_b0_c = kernel_b0_c->getCode<int (*)(const dim_t *, const dim_t *,
1302 const dim_t *, const float *, const int8_t *, const uint8_t *,
1303 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1306 gemv_s8u8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>
1307 (mayiuse(avx512_core_vnni));
1309 gemv_u8s8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>
1310 (mayiuse(avx512_core_vnni));
1313 if (arg->bo == 0) { // No need to compute A row sum if bo is zero
1314 if (arg->transa == 0) {
1315 arg->copyA = copyAn;
1317 arg->copyA = copyAt;
1320 if (arg->transa == 0) {
1321 arg->copyA = copySumAn;
1323 arg->copyA = copySumAt;
1327 if (arg->ao == 0) { // No need to compute B column sum if ao is zero
1328 if (arg->transb == 0) {
1329 arg->copyB = copyBn;
1331 arg->copyB = copyBt;
1334 if (arg->transb == 0) {
1335 arg->copyB = copySumBn;
1337 arg->copyB = copySumBt;
1342 arg->kernel_b = kern_b;
1343 arg->kernel_r = kern_r;
1344 arg->kernel_c = kern_c;
1345 arg->kernel_b0 = kern_b0;
1346 arg->kernel_b0_b = kern_b0_b;
1347 arg->kernel_b0_r = kern_b0_r;
1348 arg->kernel_b0_c = kern_b0_c;
1349 arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern;
1350 arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern;
1353 mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
1354 const char *transA, const char *transB, const char *offsetC,
1355 const int *m, const int *n, const int *k,
1356 const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
1357 const uint8_t *b, const int *ldb, const int8_t *ob,
1358 const float *beta, int32_t *c, const int *ldc, const int32_t *oc)
1360 char transa = *transA;
1361 char transb = *transB;
1362 char offsetc = *offsetC;
1366 // Initialize blas structure
1378 args.transa = (transa == 'N' || transa == 'n') ? 0 : 1;
1379 args.transb = (transb == 'N' || transb == 'n') ? 0 : 1;
1388 args.kernel_b0 = NULL;
1393 if (offsetc == 'F' || offsetc == 'f') {
1394 args.offsetc = FIX_OFFSET;
1395 } else if (offsetc == 'R' || offsetc == 'r') {
1396 args.offsetc = ROW_OFFSET;
1397 } else { // offsetc == 'C' || offsetc == 'c'
1398 args.offsetc = COL_OFFSET;
1402 int result = gemm_threading_driver(&args);
1404 return (result < 0) ? mkldnn_out_of_memory : mkldnn_success;