Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_gemm_s8u8s32.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstdint>
18 #include <mutex>
19
20 #include "common.hpp"
21 #include "mkldnn_types.h"
22 #include "nstl.hpp"
23 #include "utils.hpp"
24
25 #include "jit_avx512_core_gemm_s8u8s32.hpp"
26 #include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
27 #include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
28 #include "gemv.hpp"
29
30 #if defined(_MSC_VER)
31 #include <malloc.h>
32 #endif
33
34 namespace mkldnn {
35 namespace impl {
36 namespace cpu {
37
38 typedef struct {
39     int nthrs_m, nthrs_n;
40     int partition;
41     int copy_type;
42 } blas_thread_t;
43
44 static inline void round_to_nearest(int32_t *rounded_val, double fp_val) {
45     if (fp_val >= 0.) {
46         fp_val += 0.5;
47         if (fp_val > INT32_MAX) {
48             fp_val = INT32_MAX;
49         }
50     } else {
51         fp_val -= 0.5;
52         if (fp_val < INT32_MIN) {
53             fp_val = INT32_MIN;
54         }
55     }
56     *rounded_val = (int32_t) fp_val;
57 }
58
59 static inline void add_results(const dim_t m, const dim_t n, const dim_t k,
60         const float alpha, const float beta, const int32_t *c_partial_sum,
61         const dim_t ldcp, int32_t *c_data, const dim_t ldc,
62         const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao,
63         const int8_t bo, const int32_t *co, const int offsetc)
64 {
65     for (dim_t j = 0; j < n; ++j) {
66         for (dim_t i = 0; i < m; ++i) {
67             int32_t ctemp = c_partial_sum[i + j * ldcp];
68
69             if (alpha == 1.0f) {
70                 if (beta == 0.0f) {
71                     c_data[i + j * ldc] = ctemp;
72                 } else {
73                     double c_float = (double) beta
74                         * (double) c_data[i + j * ldc];
75                     c_float += (double) ctemp;
76                     round_to_nearest(&c_data[i + j * ldc], c_float);
77                 }
78             } else if (alpha == -1.0f) {
79                 if (beta == 0.0f) {
80                     c_data[i + j * ldc] = -ctemp;
81                 } else {
82                     double c_float = (double) beta
83                         * (double) c_data[i + j * ldc];
84                     c_float -= (double) ctemp;
85                     round_to_nearest(&c_data[i + j * ldc], c_float);
86                 }
87             } else {
88                 if (beta == 0.0f) {
89                     double c_float = alpha * (double) ctemp;
90                     round_to_nearest(&c_data[i + j * ldc], c_float);
91                 } else {
92                     double c_float = alpha * (double) ctemp +
93                         beta * (double) c_data[i + j * ldc];
94                     round_to_nearest(&c_data[i + j * ldc], c_float);
95                 }
96             }
97
98             if (offsetc == FIX_OFFSET) {
99                 c_data[i + j * ldc] += co[0];
100             } else if (offsetc == ROW_OFFSET) {
101                 c_data[i + j * ldc] += co[j];
102             } else if (offsetc == COL_OFFSET) {
103                 c_data[i + j * ldc] += co[i];
104             }
105         }
106     }
107 }
108
109 // TODO Find a better place for those functions.
110 static inline dim_t ld_padd(const dim_t x)
111 {
112     return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t)))
113         * (2048 / sizeof(int32_t)) +  (64 / sizeof(int32_t));
114 }
115
116 void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k,
117         const int8_t *a, const uint8_t *b, float beta, int32_t *c,
118         const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum,
119         const int32_t *co, const int offsetc, const blas_t *arg)
120 {
121     int8_t ao = arg->ao;
122     int8_t bo = arg->bo;
123     int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0];
124
125     // Since m and n are limited by blocking, stack overflow may not happen;
126     // it's up to 32kB
127 #if !defined(_MSC_VER)
128     int32_t col_offset[m];
129     int32_t row_offset[n];
130 #else
131     int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m);
132     int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n);
133 #endif
134
135     int col_req = 0;
136     int row_req = 0;
137
138     if ((bo != 0) || (offsetc == COL_OFFSET))
139         col_req = 1;
140     if ((ao != 0) || (offsetc == ROW_OFFSET))
141         row_req = 1;
142
143     // It needs one of colum or row offsets, but it doesn't need both
144     if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) {
145         if ((col_req == 0) && (row_req == 0)) {
146             if (m <= n) {
147                 col_req = 1;
148             } else {
149                 row_req = 1;
150             }
151         }
152     }
153
154     if (col_req) {
155         for (dim_t i = 0; i < m; i++)
156             col_offset[i] = 0;
157
158         if (offsetc == COL_OFFSET) {
159             for (dim_t i = 0; i < m; i++)
160                 col_offset[i] += co[i];
161         }
162
163         if (bo != 0) {
164             for (dim_t i = 0; i < m; i++)
165                 col_offset[i] += bo * a_row_sum[i];
166         }
167     }
168
169     if (row_req) {
170         for (dim_t i = 0; i < n; i++)
171             row_offset[i] = 0;
172
173         if (offsetc == ROW_OFFSET) {
174             for (dim_t i = 0; i < n; i++)
175                 row_offset[i] += co[i];
176         }
177
178         if (ao != 0) {
179             for (dim_t i = 0; i < n; i++)
180                 row_offset[i] += ao * b_col_sum[i];
181         }
182     }
183
184     if ((offsetc == FIX_OFFSET) && (co_0 != 0)) {
185         if (col_req) {
186             for (dim_t i = 0; i < m; i++)
187                 col_offset[i] += co_0;
188         } else {
189             for (dim_t i = 0; i < n; i++)
190                 row_offset[i] += co_0;
191         }
192     }
193
194     if ((ao != 0) && (bo != 0)) {
195         if (col_req) {
196             for (dim_t i = 0; i < m; i++)
197                 col_offset[i] += (int32_t) k * ao * bo;
198         } else {
199             for (dim_t i = 0; i < n; i++)
200                 row_offset[i] += (int32_t) k * ao * bo;
201         }
202     }
203
204     if (col_req == 0) {
205         if (row_req == 0) {
206             if (beta == 0.0) {
207                 arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
208                         row_offset);
209             } else {
210                 arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
211                         row_offset);
212             }
213         } else {
214             if (beta == 0.0) {
215                 arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
216                         row_offset);
217             } else {
218                 arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
219                         row_offset);
220             }
221         }
222     } else {
223         if (row_req == 0) {
224             if (beta == 0.0) {
225                 arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
226                         row_offset);
227             } else {
228                 arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
229                         row_offset);
230             }
231         } else {
232             if (beta == 0.0) {
233                 arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
234                         row_offset);
235             } else {
236                 arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
237                         row_offset);
238             }
239         }
240     }
241 }
242
243 static inline void *align(void *ptr, size_t alignment)
244 {
245     return (void *) utils::rnd_up((uintptr_t) ptr, alignment);
246 }
247
248 static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k,
249         const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co,
250         const blas_t *arg)
251 {
252     dim_t   lda   = arg->lda;
253     dim_t   ldb   = arg->ldb;
254     dim_t   ldc   = arg->ldc;
255     int8_t  ao    = arg->ao;
256     int8_t  bo    = arg->bo;
257     float   alpha = *arg->alpha;
258     float   beta  = *arg->beta;
259
260     if (m <= 0 || n <= 0) {
261         return 0;
262     }
263
264     // Padding along K dimension.
265     dim_t k_padd = 0;
266     if (k <= arg->bk_traditional) {
267         k_padd = utils::rnd_up(k, arg->uk);
268         k_padd = nstl::max(128LL, k_padd);
269     } else if (k < 2 * arg->bk) {
270         k_padd = utils::rnd_up(k / 2, arg->uk);
271     } else {
272         k_padd = arg->bk;
273     }
274
275     // Padding along M dimension.
276     dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
277             arg->um);
278
279     // Padding along N dimension.
280     dim_t n_padd = 0;
281     if (k < arg->blocking_small_k) {
282         n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
283                     arg->bn_small_k), arg->un);
284     } else {
285         n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
286                 arg->un);
287     }
288
289     // Padding for temporary buffer for C
290     dim_t ldc_buf = ld_padd(m_padd);
291
292     dim_t strideAm = (arg->transa == 0)? 1 : lda;
293     dim_t strideAn = (arg->transa != 0)? 1 : lda;
294     dim_t strideBm = (arg->transb == 0)? 1 : ldb;
295     dim_t strideBn = (arg->transb != 0)? 1 : ldb;
296
297     size_t a_buf_nelems = m_padd * k_padd;
298     size_t b_buf_nelems = k_padd * n_padd;
299     size_t a_row_sum_nelems = m_padd;
300     size_t b_col_sum_nelems = n_padd;
301
302     size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
303         + b_buf_nelems * sizeof(*b) + PAGE_4K
304         + a_row_sum_nelems * sizeof(*c) + PAGE_4K
305         + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
306
307     bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
308     if (need_c_buffer) {
309         size_t c_buf_nelems = ldc_buf * n_padd;
310         mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
311     }
312
313     char *mem = (char *) malloc(mem_size, 128);
314
315     if (!mem) {
316         return -1;
317     }
318
319     int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
320     uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K);
321     int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
322     int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems,
323             PAGE_4K);
324
325     int32_t *bufferC = NULL;
326     if (need_c_buffer) {
327         bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
328     }
329
330     float beta_saved = beta;
331
332     int a_block_copied = 0;
333     dim_t sizeM = 0;
334     for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
335         sizeM = m - Bm;
336         if (sizeM > m_padd)
337             sizeM = m_padd;
338
339         dim_t sizeK = 0;
340         for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
341             sizeK = k - Bk;
342             if (sizeK > k_padd)
343                 sizeK = k_padd;
344
345             // Scale C blocks by beta only for the first time
346             if (Bk == 0)
347                 beta = beta_saved;
348             else
349                 beta = 1.0f;
350
351             // Apply C offset when to the last k-block of the partial sum.
352             int offsetc = NO_OFFSET;
353             if (Bk + sizeK == k)
354                 offsetc = arg->offsetc;
355
356             dim_t sizeN = 0;
357             for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
358                 sizeN = n - Bn;
359                 if (sizeN > n_padd)
360                     sizeN = n_padd;
361
362                 const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn;
363                 arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL,
364                         NULL, b_col_sum);
365
366                 dim_t sizeUM = 0;
367                 for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
368                     sizeUM = sizeM - Um;
369                     if (sizeUM > arg->um)
370                         sizeUM = arg->um;
371
372                     /*
373                      * Use the whole A buffer only if we have multiple B blocks
374                      * for k-dimension, otherwise we are wasting cache to store
375                      * B and C blocks.
376                      */
377                     dim_t Um_forA = 0;
378                     if (sizeN < n)
379                         Um_forA = Um;
380
381                     const int8_t *a_block = a + (Bm + Um) * strideAm
382                         + Bk * strideAn;
383                     if (!a_block_copied) {
384                         arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL,
385                                 bufferA + Um_forA * sizeK, NULL, NULL,
386                                 a_row_sum + Um_forA);
387                     }
388
389                     int32_t *c_block = c + (Bm + Um) + Bn * ldc;
390                     dim_t co_stride = 0;
391                     if (offsetc == FIX_OFFSET) {
392                         co_stride = 0;
393                     } else if (offsetc == ROW_OFFSET) {
394                         co_stride = Bn;
395                     } else if (offsetc == COL_OFFSET) {
396                         co_stride = Bm + Um;
397                     }
398                     if (need_c_buffer) {
399                         igemm_inner_kernel(sizeUM, sizeN, sizeK,
400                                 bufferA + Um_forA * sizeK, bufferB, 0.0f,
401                                 bufferC + Um, ldc_buf, a_row_sum + Um_forA,
402                                 b_col_sum, NULL, NO_OFFSET, arg);
403
404                         // Finish the block adding the necessary alpha, beta
405                         // and offsets.
406                         add_results(sizeUM, sizeN, sizeK, alpha, beta,
407                                 bufferC + Um, ldc_buf, c_block, ldc,
408                                 a_row_sum + Um_forA, b_col_sum, ao, bo,
409                                 co + co_stride, offsetc);
410                     } else {
411                         igemm_inner_kernel(sizeUM, sizeN, sizeK,
412                                 bufferA + Um_forA * sizeK, bufferB, beta,
413                                 c_block, ldc, a_row_sum + Um_forA, b_col_sum,
414                                 co + co_stride, offsetc, arg);
415                     }
416                 }
417                 a_block_copied = 1;
418             }
419             a_block_copied = 0;
420         }
421     }
422
423     free(mem);
424
425     return 0;
426 }
427
428 static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n,
429         const dim_t k, const int8_t *bufferA, const uint8_t *b,
430         const float beta, int32_t *c, const int offsetc, const int32_t *co,
431         const int32_t *a_row_sum, const blas_t *arg)
432 {
433     dim_t   ldb   = arg->ldb;
434     dim_t   ldc   = arg->ldc;
435     int8_t  ao    = arg->ao;
436     int8_t  bo    = arg->bo;
437     float   alpha = *arg->alpha;
438
439     if (m <= 0 || n <= 0) {
440         return 0;
441     }
442
443     // Padding along N dimension.
444     dim_t n_padd = 0;
445     if (k < arg->blocking_small_k) {
446         n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
447                     arg->bn_small_k), arg->un);
448     } else {
449         n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
450                 arg->un);
451     }
452
453     // Padding for temporary buffer for C
454     dim_t ldc_buf = ld_padd(m);
455
456     dim_t strideBn = (arg->transb != 0)? 1 : ldb;
457
458     size_t b_buf_nelems = k * n_padd;
459     size_t b_col_sum_nelems = n_padd;
460
461     size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K
462         + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
463
464     bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
465     if (need_c_buffer) {
466         size_t c_buf_nelems = ldc_buf * n_padd;
467         mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
468     }
469
470     char *mem = (char *) malloc(mem_size, 128);
471
472     if (!mem) {
473         return -1;
474     }
475
476     uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K);
477     int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
478
479     int32_t *bufferC = NULL;
480     if (need_c_buffer) {
481         bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
482     }
483
484     dim_t sizeN = 0;
485     for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
486         sizeN = n - Bn;
487         if (sizeN > n_padd)
488             sizeN = n_padd;
489
490         // Implement the kernel here.
491         const uint8_t *b_block = b + Bn * strideBn;
492         arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL,
493                 b_col_sum);
494
495             dim_t co_stride = 0;
496             if (offsetc == FIX_OFFSET) {
497                 co_stride = 0;
498             } else if (offsetc == ROW_OFFSET) {
499                 co_stride = Bn;
500             } else if (offsetc == COL_OFFSET) {
501                 co_stride = 0;
502             }
503         int32_t *c_block = c + Bn * ldc;
504         if (need_c_buffer) {
505             igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC,
506                     ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg);
507
508             // Finish the block adding the necessary alpha, beta and offsets.
509             add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block,
510                     ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride,
511                     offsetc);
512         } else {
513             igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block,
514                     ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
515         }
516     }
517
518     free(mem);
519
520     return 0;
521
522 }
523
524 #define N2D_MAX_AVX512 384
525 #define M2D_MIN_AVX512 384
526 #define VECLEN         16
527 #define NCONS          1
528 static inline void set_thread_opts_avx512(int *p_nthrs,
529         blas_thread_t *thread_info, const blas_t *arg)
530 {
531     int nthrs = *p_nthrs;
532     dim_t m = arg->m;
533     dim_t n = arg->n;
534
535     thread_info->nthrs_m = 0;
536     thread_info->nthrs_n = 0;
537     thread_info->copy_type = COPY_NONE; // By default don't do parallel copy.
538
539     int condition_2D_bsrc = -1;
540     if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) {
541         condition_2D_bsrc = 1;
542     } else {
543         condition_2D_bsrc = 0;
544     }
545
546     int condition_1D_copya = 0;
547     if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) {
548         condition_2D_bsrc  = 0;
549         condition_1D_copya = 1;
550     }
551
552     // If offset is non-zero, we need to keep 1D_copya to reduce update overhead
553     if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0
554             || arg->offsetc != FIX_OFFSET) {
555         condition_2D_bsrc  = 0;
556         condition_1D_copya = 1;
557     }
558
559     if (condition_2D_bsrc == 1) {
560         int nthrs_m = 1;
561         int nthrs_n = nthrs;
562
563         while ((nthrs_n % 2 == 0) &&
564                 (n / nthrs > N2D_MAX_AVX512 ||
565                  n / nthrs_n <= N2D_MAX_AVX512 / 2) &&
566                 (m / nthrs_m >= 2 * M2D_MIN_AVX512) &&
567                 (nthrs_m < 4)) {
568             nthrs_m *= 2;
569             nthrs_n /= 2;
570         }
571
572         thread_info->nthrs_m = nthrs_m;
573         thread_info->nthrs_n = nthrs_n;
574         thread_info->partition = PARTITION_2D;
575
576         // Reset the total number of threads that will be used.
577         *p_nthrs = nthrs_m * nthrs_n;
578
579     } else if (condition_1D_copya && mkldnn_thr_syncable()) {
580         // Use parallel copy A algorithm
581         thread_info->copy_type = COPY_A;
582         thread_info->partition = PARTITION_1D_COL;
583     } else {
584         if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) {
585             thread_info->partition = PARTITION_1D_ROW;
586         } else {
587             thread_info->partition = PARTITION_1D_COL;
588         }
589     }
590 }
591 #undef N2D_MAX_AVX512
592 #undef M2D_MIN_AVX512
593 #undef VECLEN
594 #undef NCONS
595
596 static inline void partition_1d(const int ithr, const int nthrs, const dim_t n,
597         dim_t *t_offset, dim_t *t_block)
598 {
599     dim_t band = n / nthrs;
600
601     dim_t tail = n - (nthrs - 1) * band;
602     if (tail > (band + 1))
603         band++;
604     tail = n - (nthrs - 1) * band;
605
606     if (ithr < (nthrs - 1))
607         *t_block = band;
608     else
609         *t_block = tail;
610
611     *t_offset = ithr * band;
612
613     if (*t_offset >= n) {
614         *t_block = 0;
615         *t_offset = 0;
616     } else if ((*t_offset + *t_block) > n) {
617         *t_block = n - *t_offset;
618     }
619 }
620
621 static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i,
622         const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m,
623         const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp,
624         dim_t *p_n_band)
625 {
626     dim_t m_disp = 0, n_disp = 0;
627     dim_t m_band = 0, n_band = 0;
628
629     int mdiv = nthrs_m;
630     int ndiv = nthrs_n;
631
632     dim_t m_bandt = m / mdiv; /* size per thread */
633     dim_t n_bandt = n / ndiv; /* size per thread */
634     int firstmgroup = mdiv - 1;
635     int firstngroup = ndiv - 1;
636     dim_t firstmval = m_bandt;
637     dim_t firstnval = n_bandt;
638
639     int mthr_used = mdiv;
640     if (m - (mdiv - 1) * m_bandt > m_bandt + 1) {
641         if (m - (mdiv - 1) * m_bandt > mdiv)
642             ++m_bandt;
643
644         firstmval = m_bandt + 1;
645         mthr_used = (int) (m / firstmval);
646
647         if (mthr_used * firstmval < m)
648             ++mthr_used;
649
650         firstmgroup = mthr_used - 1;
651     }
652
653     int nthr_used = ndiv;
654     if (n - (ndiv - 1) * n_bandt > n_bandt + 1) {
655         firstnval = n_bandt + 1;
656         nthr_used = (int) (n / firstnval);
657
658         if (nthr_used * firstnval < n)
659             ++nthr_used;
660
661         firstngroup = nthr_used - 1;
662     }
663
664     *nthrs = mthr_used * nthr_used;
665
666     if (ithr < *nthrs) {
667         if (ithr_i < firstmgroup) {
668             m_band = firstmval;
669             m_disp = ithr_i * firstmval;
670         } else if (ithr_i <= mthr_used - 2) {
671             m_band = m_bandt;
672             m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt;
673         } else {
674             m_disp = firstmgroup * firstmval
675                 + (mthr_used - 1 - firstmgroup) * m_bandt;
676             m_band = nstl::max(0LL, m - m_disp);
677         }
678
679         if (ithr_j < firstngroup) {
680             n_band = firstnval;
681             n_disp = ithr_j * firstnval;
682         } else if (ithr_j <= nthr_used - 2) {
683             n_band = n_bandt;
684             n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt;
685         } else {
686             n_disp = firstngroup * firstnval
687                 + (nthr_used - 1 - firstngroup) * n_bandt;
688             n_band = nstl::max(0LL, n - n_disp);
689         }
690         m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL);
691         n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL);
692     }
693
694     if (ithr < *nthrs) {
695         *p_m_disp = m_disp;
696         *p_n_disp = n_disp;
697         *p_m_band = m_band;
698         *p_n_band = n_band;
699     } else {
700         *p_m_disp = 0;
701         *p_n_disp = 0;
702         *p_m_band = 0;
703         *p_n_band = 0;
704     }
705
706     return;
707 }
708
709 static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m,
710         dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c,
711         const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg)
712 {
713     dim_t strideAm = (arg->transa == 0)? 1 : arg->lda;
714     dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb;
715     int offsetc = arg->offsetc;
716
717     switch (thread_info->partition) {
718     case PARTITION_1D_ROW:
719         {
720             dim_t offset = 0;
721             dim_t block = 0;
722             partition_1d(ithr, *nthrs, arg->m, &offset, &block);
723
724             *m = block;
725             *n = arg->n;
726             *k = arg->k;
727
728             // Set matrix A.
729             *a = arg->a + offset * strideAm;
730
731             // Set matrix B.
732             *b = arg->b;
733
734             // Set matrix C.
735             *c = arg->c + offset;
736
737             // Set offset vector for C matrix
738             dim_t co_stride = 0;
739             if (offsetc == FIX_OFFSET) {
740                 co_stride = 0;
741             } else if (offsetc == ROW_OFFSET) {
742                 co_stride = 0;
743             } else if (offsetc == COL_OFFSET) {
744                 co_stride = offset;
745             }
746             *co = arg->co + co_stride;
747             break;
748         }
749
750     case PARTITION_1D_COL:
751         {
752             dim_t offset = 0;
753             dim_t block = 0;
754             partition_1d(ithr, *nthrs, arg->n, &offset, &block);
755
756             *m = arg->m;
757             *n = block;
758             *k = arg->k;
759
760             // Set matrix A.
761             *a = arg->a;
762
763             // Set matrix B.
764             *b = arg->b + offset * strideBn;
765
766             // Set matrix C.
767             *c = arg->c + offset * arg->ldc;
768
769             // Set offset vector for C matrix
770             dim_t co_stride = 0;
771             if (offsetc == FIX_OFFSET) {
772                 co_stride = 0;
773             } else if (offsetc == ROW_OFFSET) {
774                 co_stride = offset;
775             } else if (offsetc == COL_OFFSET) {
776                 co_stride = 0;
777             }
778             *co = arg->co + co_stride;
779             break;
780         }
781
782     case PARTITION_2D_COL_MAJOR:
783         {
784             int nthrs_m = thread_info->nthrs_m;
785             int nthrs_n = thread_info->nthrs_n;
786             int ithr_i = ithr % nthrs_m;
787             int ithr_j = ithr / nthrs_m;
788
789             dim_t m_disp = 0;
790             dim_t m_band = 0;
791             dim_t n_disp = 0;
792             dim_t n_band = 0;
793
794             partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n,
795                     arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band);
796
797             *m = m_band;
798             *n = n_band;
799             *k = arg->k;
800
801             // Set matrix A.
802             *a = arg->a + m_disp * strideAm;
803
804             // Set matrix B.
805             *b = arg->b + n_disp * strideBn;
806
807             // Set matrix C.
808             *c = arg->c + m_disp + n_disp * arg->ldc;
809
810             // Set offset vector for C matrix
811             dim_t co_stride = 0;
812             if (offsetc == FIX_OFFSET) {
813                 co_stride = 0;
814             } else if (offsetc == ROW_OFFSET) {
815                 co_stride = n_disp;
816             } else if (offsetc == COL_OFFSET) {
817                 co_stride = m_disp;
818             }
819             *co = arg->co + co_stride;
820             break;
821         }
822     }
823 }
824
825 #define MULTIPLIER 10
826 static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m,
827         const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b,
828         int32_t *c, const int32_t *co, const blas_t *arg,
829         char **p_shared_mem)
830 {
831     const dim_t lda = arg->lda;
832     const dim_t ldb = arg->ldb;
833     const dim_t strideAm = (arg->transa == 0)? 1 : lda;
834     const dim_t strideAn = (arg->transa != 0)? 1 : lda;
835     const dim_t strideBm = (arg->transb == 0)? 1 : ldb;
836
837     // Padding along M dimension.
838     dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
839             arg->um);
840
841     // Padding along K dimension.
842     dim_t k_padd = 0;
843     if (k <= arg->bk_traditional) {
844         k_padd = utils::rnd_up(k, arg->uk);
845         k_padd = nstl::max(128LL, k_padd);
846     } else if (k < 2 * arg->bk) {
847         k_padd = utils::rnd_up(k / 2, arg->uk);
848     } else {
849         k_padd = arg->bk;
850     }
851
852     m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs;
853     if (m_padd > m) {
854         m_padd = utils::rnd_up(m, arg->um);
855     }
856
857     size_t a_buf_nelems = m_padd * k_padd;
858
859     // Allocate shared memory for A and its row sum buffers in master thread.
860     if (ithr == 0) { // If thread master
861         size_t a_row_sum_nelems = m_padd;
862
863         size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K)
864             + a_row_sum_nelems * sizeof(*c) + PAGE_4K;
865
866         *p_shared_mem = (char *) malloc(mem_size, 128);
867
868     }
869     mkldnn_thr_barrier();
870
871     char *mem = *p_shared_mem;
872     int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
873     int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K);
874
875     if (!mem) {
876         return -1;
877     }
878
879     int result = 0; // Return status
880
881     dim_t sizeK = 0;
882     for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
883         sizeK = k - Bk;
884         if (sizeK > k_padd)
885             sizeK = k_padd;
886
887         // Scale C blocks by beta only for the first term of partial sum.
888         float beta = 1.0f;
889         if (Bk == 0)
890             beta = *(arg->beta);
891
892         // Apply C offset for the last k-block of the partial sum.
893         int offsetc = NO_OFFSET;
894         if (Bk + sizeK == k)
895             offsetc = arg->offsetc;
896
897         dim_t sizeM = 0;
898         for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
899             sizeM = m - Bm;
900             if (sizeM > m_padd)
901                 sizeM = m_padd;
902
903             if (ithr < nthrs) {
904                 dim_t band = (sizeM + nthrs - 1) / nthrs;
905                 band = utils::rnd_up(band, arg->um);
906
907                 dim_t offset = band * ithr;
908
909                 // If offset is too large don't use that thread for copying.
910                 if (offset >= sizeM) {
911                     offset = 0;
912                     band = 0;
913                 }
914
915                 // Handle the tail of the copy.
916                 if (offset + band > sizeM) {
917                     band = sizeM - offset;
918                 }
919
920                 if (band > 0) {
921                     const int8_t *a_block = a + (Bm + offset) * strideAm
922                         + Bk * strideAn;
923                     arg->copyA(&sizeK, &band, a_block, &lda, NULL,
924                             bufferA + offset * sizeK, NULL, NULL,
925                             a_row_sum + offset);
926                 }
927             }
928             mkldnn_thr_barrier(); // Wait for finishing parallel copy.
929
930             const uint8_t *b_block = b + Bk * strideBm;
931             int32_t *c_block = c + Bm;
932             dim_t co_stride = 0;
933             if (offsetc == FIX_OFFSET) {
934                 co_stride = 0;
935             } else if (offsetc == ROW_OFFSET) {
936                 co_stride = 0;
937             } else if (offsetc == COL_OFFSET) {
938                 co_stride = Bm;
939             }
940
941             result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK,
942                     bufferA, b_block, beta, c_block, offsetc, co + co_stride,
943                     a_row_sum, arg);
944
945             mkldnn_thr_barrier(); // Wait for kernel computations to finish.
946         }
947     }
948
949     // Free memory allocated in master thread
950     if (ithr == 0) {
951         free(mem);
952     }
953
954     return result;
955 }
956 #undef MULTIPLIER
957
958 static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k,
959         double fp_per_cycle, int *nthrs)
960 {
961     double omp_overhead_small_core = 3.0e+3;
962     double omp_intercept_big_core = 4.0e+3;
963     double omp_slope_big_core = 5.0e+2;
964
965     double gemm_cycles = 8.0 * m * n * k / fp_per_cycle;
966
967     int i = *nthrs;
968
969     // Use a different model for omp overheads if nthrs is <= 4
970     if (*nthrs <= 4 && omp_overhead_small_core > 0) {
971         double omp_cycles = omp_overhead_small_core;
972         if (gemm_cycles < omp_cycles) {
973             *nthrs = 1;
974             return;
975         } else {
976             while (i > 1) {
977                 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
978                 --i;
979             }
980         }
981     } else {
982         if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
983             *nthrs = 1;
984             return;
985         }
986
987         // adaptive decrement to march faster·
988         while (i > 1) {
989             double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
990             if (omp_cycles * i < gemm_cycles * (i - 1))
991                 break;
992
993             if (i < 10)
994                 i -= 2;
995             else if (i < 30)
996                 i -= 4;
997             else
998                 i -= 8;
999         }
1000     }
1001
1002     if (i < 1)
1003         i = 1;
1004
1005     *nthrs = i;
1006 }
1007
1008 #define CACHE_LINE_SIZE 64
1009 static int gemm_threading_driver(blas_t *arg)
1010 {
1011     if ((arg->m <= 0) || (arg->n <= 0))
1012         return mkldnn_success;
1013
1014     if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) {
1015         return mkldnn_success;
1016     }
1017
1018     int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
1019     get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr);
1020
1021     if (nthr == 1) {
1022         return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b,
1023                 arg->c, arg->co, arg);
1024     }
1025
1026     int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE,
1027             PAGE_4K);
1028
1029     if (!results) {
1030         return -1;
1031     }
1032
1033     for (int i = 0; i < nthr; i++) {
1034         results[i * CACHE_LINE_SIZE] = 0; // Initialize to success
1035     }
1036
1037     char *shared_mem = NULL;
1038
1039     parallel(nthr, [&](const int ithr, const int nthr) {
1040         int nthrs = nthr;
1041         if (nthrs == 1) {
1042             results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a,
1043                 arg->b, arg->c, arg->co, arg);
1044         } else {
1045             blas_thread_t thread_info;
1046             set_thread_opts_avx512(&nthrs, &thread_info, arg);
1047
1048             const int8_t *a = NULL;
1049             const uint8_t *b = NULL;
1050             int32_t *c = NULL;
1051             const int32_t *co = NULL;
1052             dim_t m = -1;
1053             dim_t n = -1;
1054             dim_t k = -1;
1055             decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co,
1056                 &thread_info, arg);
1057
1058             if (ithr < nthrs) {
1059                 switch (thread_info.copy_type) {
1060                 case COPY_A:
1061                     results[ithr * CACHE_LINE_SIZE] =
1062                         parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg,
1063                                 &shared_mem);
1064                     break;
1065
1066                 default:
1067                 case COPY_NONE:
1068                     results[ithr * CACHE_LINE_SIZE] =
1069                         gemm_kernel_driver(m, n, k, a, b, c, co, arg);
1070                     break;
1071                 }
1072             }
1073         }
1074     });
1075
1076     int result = 0;  // Initialize to success
1077     for (int i = 0; i < nthr; i++) {
1078         if (results[i] != 0) {
1079             result = results[i * CACHE_LINE_SIZE];
1080             break;
1081         }
1082     }
1083
1084     free(results);
1085
1086     return result;
1087 }
1088 #undef CACHE_LINE_SIZE
1089
1090 static jit_avx512_core_u8_copy_an_kern *copy_an;
1091 static jit_avx512_core_u8_copy_at_kern *copy_at;
1092 static jit_avx512_core_u8_copy_bn_kern *copy_bn;
1093 static jit_avx512_core_u8_copy_bt_kern *copy_bt;
1094 static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an;
1095 static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at;
1096 static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn;
1097 static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt;
1098 static jit_avx512_core_gemm_s8u8s32_kern *kernel;
1099 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b;
1100 static jit_avx512_core_gemm_s8u8s32_kern *kernel_r;
1101 static jit_avx512_core_gemm_s8u8s32_kern *kernel_c;
1102 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0;
1103 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b;
1104 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r;
1105 static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c;
1106 static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel;
1107 static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel;
1108
1109 static void jit_init(blas_t *arg)
1110 {
1111     static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a,
1112             const dim_t *lda, const int8_t *alpha, int8_t *b,
1113             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1114
1115     static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t  *a,
1116             const dim_t *lda, const int8_t  *alpha, int8_t  *b,
1117             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1118
1119     static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
1120             const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1121             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1122
1123     static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
1124             const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1125             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1126
1127     static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t  *a,
1128             const dim_t *lda, const int8_t  *alpha, int8_t  *b,
1129             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1130
1131     static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t  *a,
1132             const dim_t *lda, const int8_t  *alpha, int8_t  *b,
1133             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1134
1135     static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
1136             const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1137             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1138
1139     static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
1140             const dim_t *lda, const uint8_t *alpha, uint8_t *b,
1141             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
1142
1143     static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k,
1144             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1145             const dim_t ldc, const int32_t *col_offset,
1146             const int32_t *row_offset);
1147
1148     static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k,
1149             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1150             const dim_t ldc, const int32_t *col_offset,
1151             const int32_t *row_offset);
1152
1153     static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k,
1154             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1155             const dim_t ldc, const int32_t *col_offset,
1156             const int32_t *row_offset);
1157
1158     static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k,
1159             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1160             const dim_t ldc, const int32_t *col_offset,
1161             const int32_t *row_offset);
1162
1163     static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
1164             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1165             const dim_t ldc, const int32_t *col_offset,
1166             const int32_t *row_offset);
1167
1168     static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
1169             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1170             const dim_t ldc, const int32_t *col_offset,
1171             const int32_t *row_offset);
1172
1173     static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
1174             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1175             const dim_t ldc, const int32_t *col_offset,
1176             const int32_t *row_offset);
1177
1178     static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
1179             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
1180             const dim_t ldc, const int32_t *col_offset,
1181             const int32_t *row_offset);
1182
1183     static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float,
1184                                      const int8_t*, const dim_t, const uint8_t*,
1185                                      const float, int32_t*);
1186
1187     static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float,
1188                                      const uint8_t*, const dim_t, const int8_t*,
1189                                      const float, int32_t*);
1190
1191     if (mayiuse(avx512_core_vnni)) {
1192             arg->um = AVX512_UNROLL_M;
1193             arg->un = AVX512_UNROLL_N;
1194             arg->uk = AVX512_UNROLL_K;
1195             arg->bm = AVX512_BM;
1196             arg->bn = AVX512_BN;
1197             arg->bk = AVX512_BK_VNNI;
1198
1199             arg->bk_traditional   = AVX512_BK_TRADITIONAL;
1200             arg->bn_small_k       = AVX512_BN_SMALL_K;
1201             arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
1202     } else {
1203             arg->um = AVX512_UNROLL_M;
1204             arg->un = AVX512_UNROLL_N;
1205             arg->uk = AVX512_UNROLL_K;
1206             arg->bm = AVX512_BM;
1207             arg->bn = AVX512_BN;
1208             arg->bk = AVX512_BK;
1209
1210             arg->bk_traditional   = AVX512_BK_TRADITIONAL;
1211             arg->bn_small_k       = AVX512_BN_SMALL_K;
1212             arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
1213     }
1214
1215     static std::once_flag initialized;
1216     std::call_once(initialized, []{
1217
1218         copy_an = new jit_avx512_core_u8_copy_an_kern();
1219         copy_at = new jit_avx512_core_u8_copy_at_kern();
1220         copy_bn = new jit_avx512_core_u8_copy_bn_kern();
1221         copy_bt = new jit_avx512_core_u8_copy_bt_kern();
1222
1223         copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern();
1224         copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern();
1225         copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern();
1226         copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern();
1227
1228         kernel      = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false);
1229         kernel_b    = new jit_avx512_core_gemm_s8u8s32_kern(false, true,  true);
1230         kernel_r    = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true);
1231         kernel_c    = new jit_avx512_core_gemm_s8u8s32_kern(false, true,  false);
1232         kernel_b0   = new jit_avx512_core_gemm_s8u8s32_kern(true,  false, false);
1233         kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true,  true,  true);
1234         kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true,  false, true);
1235         kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true,  true,  false);
1236
1237         gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
1238         gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
1239
1240
1241         copyAn = copy_an->getCode<int (*)(const dim_t *, const dim_t *,
1242                 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1243                 const dim_t *, const dim_t *, int32_t *)>();
1244
1245         copyAt = copy_at->getCode<int (*)(const dim_t *, const dim_t *,
1246                 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1247                 const dim_t *, const dim_t *, int32_t *)>();
1248
1249         copyBn = copy_bn->getCode<int (*)(const dim_t *, const dim_t *,
1250                 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1251                 const dim_t *, const dim_t *, int32_t *)>();
1252
1253         copyBt = copy_bt->getCode<int (*)(const dim_t *, const dim_t *,
1254                 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1255                 const dim_t *, const dim_t *, int32_t *)>();
1256
1257         copySumAn = copy_sum_an->getCode<int (*)(const dim_t *, const dim_t *,
1258                 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1259                 const dim_t *, const dim_t *, int32_t *)>();
1260
1261         copySumAt = copy_sum_at->getCode<int (*)(const dim_t *, const dim_t *,
1262                 const int8_t *, const dim_t *, const int8_t *, int8_t *,
1263                 const dim_t *, const dim_t *, int32_t *)>();
1264
1265         copySumBn = copy_sum_bn->getCode<int (*)(const dim_t *, const dim_t *,
1266                 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1267                 const dim_t *, const dim_t *, int32_t *)>();
1268
1269         copySumBt = copy_sum_bt->getCode<int (*)(const dim_t *, const dim_t *,
1270                 const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
1271                 const dim_t *, const dim_t *, int32_t *)>();
1272
1273         kern = kernel->getCode<int (*)(const dim_t *, const dim_t *,
1274                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1275                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1276
1277         kern_b = kernel_b->getCode<int (*)(const dim_t *, const dim_t *,
1278                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1279                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1280
1281         kern_r = kernel_r->getCode<int (*)(const dim_t *, const dim_t *,
1282                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1283                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1284
1285         kern_c = kernel_c->getCode<int (*)(const dim_t *, const dim_t *,
1286                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1287                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1288
1289         kern_b0 = kernel_b0->getCode<int (*)(const dim_t *, const dim_t *,
1290                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1291                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1292
1293         kern_b0_b = kernel_b0_b->getCode<int (*)(const dim_t *, const dim_t *,
1294                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1295                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1296
1297         kern_b0_r = kernel_b0_r->getCode<int (*)(const dim_t *, const dim_t *,
1298                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1299                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1300
1301         kern_b0_c = kernel_b0_c->getCode<int (*)(const dim_t *, const dim_t *,
1302                 const dim_t *, const float *, const int8_t *, const uint8_t *,
1303                 int32_t *, const dim_t, const int32_t *, const int32_t *)>();
1304
1305         gemv_s8u8s32_kern =
1306             gemv_s8u8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>
1307             (mayiuse(avx512_core_vnni));
1308         gemv_u8s8s32_kern =
1309             gemv_u8s8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>
1310             (mayiuse(avx512_core_vnni));
1311     });
1312
1313     if (arg->bo == 0) { // No need to compute A row sum if bo is zero
1314         if (arg->transa == 0) {
1315             arg->copyA = copyAn;
1316         } else {
1317             arg->copyA = copyAt;
1318         }
1319     } else {
1320         if (arg->transa == 0) {
1321             arg->copyA = copySumAn;
1322         } else {
1323             arg->copyA = copySumAt;
1324         }
1325     }
1326
1327     if (arg->ao == 0) { // No need to compute B column sum if ao is zero
1328         if (arg->transb == 0) {
1329             arg->copyB = copyBn;
1330         } else {
1331             arg->copyB = copyBt;
1332         }
1333     } else {
1334         if (arg->transb == 0) {
1335             arg->copyB = copySumBn;
1336         } else {
1337             arg->copyB = copySumBt;
1338         }
1339     }
1340
1341     arg->kernel      = kern;
1342     arg->kernel_b    = kern_b;
1343     arg->kernel_r    = kern_r;
1344     arg->kernel_c    = kern_c;
1345     arg->kernel_b0   = kern_b0;
1346     arg->kernel_b0_b = kern_b0_b;
1347     arg->kernel_b0_r = kern_b0_r;
1348     arg->kernel_b0_c = kern_b0_c;
1349     arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern;
1350     arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern;
1351 }
1352
1353 mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
1354         const char *transA, const char *transB, const char *offsetC,
1355         const int *m, const int *n, const int *k,
1356         const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
1357         const uint8_t *b, const int *ldb, const int8_t *ob,
1358         const float *beta, int32_t *c, const int *ldc, const int32_t *oc)
1359 {
1360     char transa  = *transA;
1361     char transb  = *transB;
1362     char offsetc = *offsetC;
1363
1364     blas_t args;
1365
1366     // Initialize blas structure
1367     args.m         = *m;
1368     args.n         = *n;
1369     args.k         = *k;
1370     args.alpha     = alpha;
1371     args.a         = a;
1372     args.lda       = *lda;
1373     args.b         = b;
1374     args.ldb       = *ldb;
1375     args.beta      = beta;
1376     args.c         = c;
1377     args.ldc       = *ldc;
1378     args.transa    = (transa == 'N' || transa == 'n') ? 0 : 1;
1379     args.transb    = (transb == 'N' || transb == 'n') ? 0 : 1;
1380     args.um        = 0;
1381     args.un        = 0;
1382     args.bm        = 0;
1383     args.bn        = 0;
1384     args.bk        = 0;
1385     args.copyA     = NULL;
1386     args.copyB     = NULL;
1387     args.kernel    = NULL;
1388     args.kernel_b0 = NULL;
1389     args.ao        = *oa;
1390     args.bo        = *ob;
1391     args.co        = oc;
1392
1393     if (offsetc == 'F' || offsetc == 'f') {
1394         args.offsetc = FIX_OFFSET;
1395     } else if (offsetc == 'R' || offsetc == 'r') {
1396         args.offsetc = ROW_OFFSET;
1397     } else { // offsetc == 'C' || offsetc == 'c'
1398         args.offsetc = COL_OFFSET;
1399     }
1400
1401     jit_init(&args);
1402     int result = gemm_threading_driver(&args);
1403
1404     return (result < 0) ? mkldnn_out_of_memory : mkldnn_success;
1405 }
1406
1407 }
1408 }
1409 }