Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm / s8x8s32 / jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
1 /*******************************************************************************
2  * Copyright 2019 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include "jit_generator.hpp"
18 #include "common.hpp"
19
20 namespace mkldnn {
21 namespace impl {
22 namespace cpu {
23
24 class jit_avx512_core_gemv_s8u8s32_kern : jit_generator {
25
26     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern);
27
28     // assumes untoll_{m,n} are a power of 2
29     static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m
30     const int mask_um = 0xFFFFFFF0;
31     static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n
32     const int mask_un = 0xFFFFFFC0;
33     const int size_vec_reg = 64; // bytes
34
35     void aligned_label(Xbyak::Label &label, int alignment = 16) {
36         align(alignment);
37         L(label);
38     }
39
40     void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int);
41     void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64,
42                      Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask);
43     void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm);
44     void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask);
45
46 public:
47     jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {};
48
49     // m, n, alpha, a, lda, x, beta, y
50     typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float,
51                                           const int8_t*, const dim_t, const uint8_t*,
52                                           const float, int32_t*);
53     typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float,
54                                           const uint8_t*, const dim_t, const int8_t*,
55                                           const float, int32_t*);
56
57     template <typename T>
58     T generate(int use_vnni);
59
60 };
61
62 }
63 }
64 }