Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / common.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef COMMON_H
18 #define COMMON_H
19
20 #define GEMM_CODE_SIZE          (4096L * 32)
21
22 #define AVX512_UNROLL_M                   48
23 #define AVX512_UNROLL_N                    8
24 #define AVX512_UNROLL_K                    1
25 #define AVX512_BM                       9984
26 #define AVX512_BN                        384
27 #define AVX512_BK                        768
28 #define AVX512_BK_VNNI                  1536
29 #define AVX512_BK_TRADITIONAL            384
30 #define AVX512_BLOCKING_SMALL_K           48
31 #define AVX512_BN_SMALL_K                 24
32
33
34 #define PAGESIZE 4096
35
36 #define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE
37 #define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size
38
39 #include "jit_generator.hpp"
40
41 namespace mkldnn {
42 namespace impl {
43 namespace cpu {
44
45 enum {
46     PARTITION_1D_ROW,
47     PARTITION_1D_COL,
48     PARTITION_2D_COL_MAJOR,
49     PARTITION_2D = PARTITION_2D_COL_MAJOR,
50 };
51
52 enum {
53     COPY_NONE,
54     COPY_A,
55 };
56
57 enum {
58     NO_OFFSET,
59     FIX_OFFSET,
60     COL_OFFSET,
61     ROW_OFFSET,
62 };
63
64 // Alias for any dimension related variable.
65 typedef long long int dim_t;
66
67 typedef struct {
68     // Interface arguments.
69     int transa, transb, offsetc;
70     dim_t m, n, k;
71     dim_t lda, ldb, ldc;
72     const int8_t *a;
73     const uint8_t *b;
74     int32_t *c;
75     const float *alpha, *beta;
76
77     int8_t ao, bo;
78     const int32_t *co;
79
80     // Kernel parameters.
81     dim_t um, un, uk, bm, bn, bk;
82     dim_t bn_small_k, bk_traditional, blocking_small_k;
83
84     int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a,
85             const dim_t *lda, const int8_t *alpha, int8_t *b,
86             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
87
88     int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a,
89             const dim_t *lda, const uint8_t *alpha, uint8_t *b,
90             const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
91
92     int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k,
93             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
94             const dim_t ldc, const int32_t *col_offset,
95             const int32_t *row_offset);
96
97     int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k,
98             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
99             const dim_t ldc, const int32_t *col_offset,
100             const int32_t *row_offset);
101
102     int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k,
103             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
104             const dim_t ldc, const int32_t *col_offset,
105             const int32_t *row_offset);
106
107     int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k,
108             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
109             const dim_t ldc, const int32_t *col_offset,
110             const int32_t *row_offset);
111
112     int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
113             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
114             const dim_t ldc, const int32_t *col_offset,
115             const int32_t *row_offset);
116
117     int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
118             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
119             const dim_t ldc, const int32_t *col_offset,
120             const int32_t *row_offset);
121
122     int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
123             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
124             const dim_t ldc, const int32_t *col_offset,
125             const int32_t *row_offset);
126
127     int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
128             const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
129             const dim_t ldc, const int32_t *col_offset,
130             const int32_t *row_offset);
131
132     // Gemv kernels
133     void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float,
134                                 const int8_t*, const dim_t, const uint8_t*,
135                                 const float, int32_t*);
136
137     void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float,
138                                 const uint8_t*, const dim_t, const int8_t*,
139                                 const float, int32_t*);
140
141     // Gemv parameters
142     int swap;
143
144 } blas_t;
145
146
147 class jit_avx512_core_u8_copy_an_kern : public jit_generator {
148     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern);
149
150     public:
151         jit_avx512_core_u8_copy_an_kern();
152 };
153
154 class jit_avx512_core_u8_copy_at_kern : public jit_generator {
155     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern);
156
157     public:
158         jit_avx512_core_u8_copy_at_kern();
159 };
160
161 class jit_avx512_core_u8_copy_bn_kern : public jit_generator {
162     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern);
163
164     public:
165         jit_avx512_core_u8_copy_bn_kern();
166 };
167
168 class jit_avx512_core_u8_copy_bt_kern : public jit_generator {
169     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern);
170
171     public:
172         jit_avx512_core_u8_copy_bt_kern();
173 };
174
175 class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator {
176     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern);
177
178     public:
179         jit_avx512_core_u8_copy_sum_an_kern();
180 };
181
182 class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator {
183     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern);
184
185     public:
186         jit_avx512_core_u8_copy_sum_at_kern();
187 };
188
189 class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator {
190     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern);
191
192     public:
193         jit_avx512_core_u8_copy_sum_bn_kern();
194 };
195
196 class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator {
197     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern);
198
199     public:
200         jit_avx512_core_u8_copy_sum_bt_kern();
201 };
202
203 }
204 }
205 }
206 #endif