1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #define GEMM_CODE_SIZE (4096L * 32)
22 #define AVX512_UNROLL_M 48
23 #define AVX512_UNROLL_N 8
24 #define AVX512_UNROLL_K 1
25 #define AVX512_BM 9984
28 #define AVX512_BK_VNNI 1536
29 #define AVX512_BK_TRADITIONAL 384
30 #define AVX512_BLOCKING_SMALL_K 48
31 #define AVX512_BN_SMALL_K 24
36 #define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE
37 #define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size
39 #include "jit_generator.hpp"
48 PARTITION_2D_COL_MAJOR,
49 PARTITION_2D = PARTITION_2D_COL_MAJOR,
64 // Alias for any dimension related variable.
65 typedef long long int dim_t;
68 // Interface arguments.
69 int transa, transb, offsetc;
75 const float *alpha, *beta;
81 dim_t um, un, uk, bm, bn, bk;
82 dim_t bn_small_k, bk_traditional, blocking_small_k;
84 int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a,
85 const dim_t *lda, const int8_t *alpha, int8_t *b,
86 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
88 int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a,
89 const dim_t *lda, const uint8_t *alpha, uint8_t *b,
90 const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
92 int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k,
93 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
94 const dim_t ldc, const int32_t *col_offset,
95 const int32_t *row_offset);
97 int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k,
98 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
99 const dim_t ldc, const int32_t *col_offset,
100 const int32_t *row_offset);
102 int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k,
103 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
104 const dim_t ldc, const int32_t *col_offset,
105 const int32_t *row_offset);
107 int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k,
108 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
109 const dim_t ldc, const int32_t *col_offset,
110 const int32_t *row_offset);
112 int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
113 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
114 const dim_t ldc, const int32_t *col_offset,
115 const int32_t *row_offset);
117 int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
118 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
119 const dim_t ldc, const int32_t *col_offset,
120 const int32_t *row_offset);
122 int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
123 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
124 const dim_t ldc, const int32_t *col_offset,
125 const int32_t *row_offset);
127 int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
128 const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
129 const dim_t ldc, const int32_t *col_offset,
130 const int32_t *row_offset);
133 void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float,
134 const int8_t*, const dim_t, const uint8_t*,
135 const float, int32_t*);
137 void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float,
138 const uint8_t*, const dim_t, const int8_t*,
139 const float, int32_t*);
147 class jit_avx512_core_u8_copy_an_kern : public jit_generator {
148 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern);
151 jit_avx512_core_u8_copy_an_kern();
154 class jit_avx512_core_u8_copy_at_kern : public jit_generator {
155 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern);
158 jit_avx512_core_u8_copy_at_kern();
161 class jit_avx512_core_u8_copy_bn_kern : public jit_generator {
162 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern);
165 jit_avx512_core_u8_copy_bn_kern();
168 class jit_avx512_core_u8_copy_bt_kern : public jit_generator {
169 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern);
172 jit_avx512_core_u8_copy_bt_kern();
175 class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator {
176 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern);
179 jit_avx512_core_u8_copy_sum_an_kern();
182 class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator {
183 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern);
186 jit_avx512_core_u8_copy_sum_at_kern();
189 class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator {
190 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern);
193 jit_avx512_core_u8_copy_sum_bn_kern();
196 class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator {
197 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern);
200 jit_avx512_core_u8_copy_sum_bt_kern();