1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_generator.hpp"
24 class jit_avx512_core_gemv_s8u8s32_kern : jit_generator {
26 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern);
28 // assumes untoll_{m,n} are a power of 2
29 static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m
30 const int mask_um = 0xFFFFFFF0;
31 static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n
32 const int mask_un = 0xFFFFFFC0;
33 const int size_vec_reg = 64; // bytes
35 void aligned_label(Xbyak::Label &label, int alignment = 16) {
40 void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int);
41 void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64,
42 Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask);
43 void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm);
44 void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask);
47 jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {};
49 // m, n, alpha, a, lda, x, beta, y
50 typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float,
51 const int8_t*, const dim_t, const uint8_t*,
52 const float, int32_t*);
53 typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float,
54 const uint8_t*, const dim_t, const int8_t*,
55 const float, int32_t*);
58 T generate(int use_vnni);