1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
18 #define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
23 #include "cpu_memory.hpp"
25 #include "jit_generator.hpp"
26 #include "jit_primitive_conf.hpp"
27 #include "jit_uni_eltwise.hpp"
28 #include "jit_uni_depthwise.hpp"
34 template<typename Vmm>
35 struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
36 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
38 enum { STATE_FIRST_DST_LOAD = 0x1U };
40 _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
41 const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
44 jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
47 ~_jit_avx512_core_x8s8s32x_fwd_kernel() {
48 for (auto inj : eltwise_injectors)
50 eltwise_injectors.clear();
52 for (auto inj : depthwise_injectors)
54 depthwise_injectors.clear();
58 const primitive_attr_t &attr_;
59 void (*jit_ker_)(jit_conv_call_s *);
62 nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
63 nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
66 typesize = sizeof(float),
67 ker_reg_base_idx = 28,
68 ker_dw_reg_base_idx = 30,
77 const Xbyak::Reg64 reg_ptr_scales = rax;
78 const Xbyak::Reg64 reg_inp = r8;
79 const Xbyak::Reg64 reg_ker = r9;
80 const Xbyak::Reg64 reg_out = r10;
81 const Xbyak::Reg64 aux_reg_inp = r11;
82 const Xbyak::Reg64 reg_ptr_sum_scale = r11;
83 const Xbyak::Reg64 aux_reg_ker = r12;
84 const Xbyak::Reg64 reg_compensation = r14;
86 const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
87 const Xbyak::Reg64 reg_oi = rbx;
88 const Xbyak::Reg64 reg_bias = rdx;
89 const Xbyak::Reg64 reg_oc_blocks = rsi;
90 const Xbyak::Reg64 reg_owb = aux_reg_ker;
91 const Xbyak::Reg64 reg_scratch = reg_compensation;
92 const Xbyak::Reg64 reg_kj = reg_ptr_scales;
93 const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
94 const Xbyak::Reg64 reg_icb = reg_bias;
96 const Xbyak::Reg64 reg_d_weights = r15;
97 const Xbyak::Reg64 reg_d_bias = r13;
99 const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
100 const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
102 const Vmm vmm_wei = Vmm(31);
103 /* used during bias section of store_output */
104 const Vmm vmm_comp = Vmm(30); // only for signed input
105 const Vmm vmm_bias = Vmm(31);
106 /* used during post_op sum section of store_output */
107 const Vmm vmm_prev_dst = Vmm(31);
108 /* used during write-out section of store_output */
109 const Vmm vmm_zero = Vmm(31);
111 /* used in compute_ker (but set during prepare_output) */
112 const Vmm vmm_shift = vmm_comp; // only for signed input
113 /* used in compute_ker (but only for pre-VNNI machines) */
114 const Vmm vmm_tmp = Vmm(28); // not used for depthwise
115 const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise.
117 /* registers use only for depthwise
118 groups are always blocked by 16(padded if needed),
119 hence use only Zmm registers */
120 const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
123 Xbyak::Zmm zmm_shifted_zero;
124 Xbyak::Zmm zmm_permute;
126 Vmm vmm_out(int i_ur, int i_oc) {
127 int idx = i_ur + i_oc * jcp.ur_w;
128 assert(idx < (jcp.is_depthwise
129 ? ker_dw_reg_base_idx : ker_reg_base_idx));
132 Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
133 int idx = i_ur + i_oc * jcp.ur_w;
134 assert(idx < (jcp.is_depthwise
135 ? ker_dw_reg_base_idx : ker_reg_base_idx));
136 return Xbyak::Zmm(idx);
138 Vmm vmm_inp(int i_ic, int nb_x_blocking) {
139 int idx = i_ic + nb_x_blocking * jcp.ur_w;
143 Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
144 int idx = i_ic + nb_x_blocking * jcp.ur_w;
146 return Xbyak::Zmm(idx);
148 Vmm vmm_bias_alpha() {
149 int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
150 return Vmm(nb_c_block * jcp.ur_w);
152 Xbyak::Xmm xmm_bias_alpha() {
153 int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
154 return Xbyak::Xmm(nb_c_block * jcp.ur_w);
156 int get_ow_start(int ki, int pad_l) {
158 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
160 int get_ow_end(int ur_w, int ki, int pad_r) {
161 return ur_w - nstl::max(0, utils::div_up(pad_r
163 * (jcp.dilate_w + 1),
167 void prepare_output(int ur_w);
168 void store_output(int ur_w, bool last_oc_block_flag);
170 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded);
171 void compute_ker(int ur_w, int pad_l, int pad_r,
172 ic_block_t last_ic_block_flag, bool h_padded = false);
173 void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
175 int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
177 void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
179 const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
182 struct jit_avx512_core_x8s8s32x_fwd_kernel {
184 jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
185 const primitive_attr_t &attr) :
187 zmm_kernel_(nullptr),
188 ymm_kernel_(nullptr),
189 xmm_kernel_(nullptr) {
190 int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
194 new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
196 jit_ker = zmm_kernel_->jit_ker_;
200 new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
202 jit_ker = ymm_kernel_->jit_ker_;
206 new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
208 jit_ker = xmm_kernel_->jit_ker_;
211 assert(!"invalid channel blocking");
215 ~jit_avx512_core_x8s8s32x_fwd_kernel() {
221 static bool post_ops_ok(jit_conv_conf_t &jcp,
222 const primitive_attr_t &attr);
224 static status_t init_conf(jit_conv_conf_t &jcp,
225 const convolution_desc_t &cd,
226 cpu_memory_t::pd_t &src_pd,
227 cpu_memory_t::pd_t &weights_pd,
228 cpu_memory_t::pd_t &dst_pd,
229 cpu_memory_t::pd_t &bias_pd,
230 const primitive_attr_t &attr,
232 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
233 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
235 void (*jit_ker)(jit_conv_call_s *);
236 _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
237 _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_;
238 _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;