1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_UNI_BIN_CONV_KERNEL_HPP
18 #define JIT_UNI_BIN_CONV_KERNEL_HPP
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "cpu_memory.hpp"
24 #include "jit_uni_eltwise.hpp"
25 #include "jit_uni_depthwise.hpp"
31 template <cpu_isa_t isa>
32 struct jit_uni_bin_conv_fwd_kernel: public jit_generator {
33 jit_uni_bin_conv_fwd_kernel(jit_bin_conv_conf_t ajcp, jit_conv_conf_t ajcp_dw_conv,
34 const primitive_attr_t &attr): jcp(ajcp), jcp_dw_conv(ajcp_dw_conv), attr_(attr)
37 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
40 ~jit_uni_bin_conv_fwd_kernel() {
41 for (auto inj : eltwise_injectors)
43 eltwise_injectors.clear();
45 for (auto inj : depthwise_injectors)
47 depthwise_injectors.clear();
50 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_bin_conv_fwd_kernel)
52 static bool post_ops_ok(jit_bin_conv_conf_t &jcp, const primitive_attr_t &attr);
53 static status_t init_conf(jit_bin_conv_conf_t &jcp,
54 const binary_convolution_desc_t &cd, const memory_desc_wrapper &src_d,
55 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
56 static void init_scratchpad(
57 memory_tracking::registrar_t &scratchpad, const jit_bin_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw_conv);
59 jit_bin_conv_conf_t jcp;
60 jit_conv_conf_t jcp_dw_conv;
61 const primitive_attr_t &attr_;
62 void (*jit_ker)(jit_conv_call_s *);
65 using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
66 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
67 using Ymm = const Xbyak::Ymm;
68 using reg64_t = const Xbyak::Reg64;
69 using reg32_t = const Xbyak::Reg32;
70 using reg8_t = const Xbyak::Reg8;
72 reg64_t reg_input = r13;
73 reg64_t reg_output = rbp;
74 reg64_t reg_input_base = rax;
75 reg64_t aux_reg_input = r8;
76 reg64_t reg_kernel_base = rdx;
77 reg64_t aux_reg_kernel = r9;
78 reg64_t reg_output_base = rsi;
79 reg64_t aux1_reg_input = reg_input_base;
80 reg64_t aux1_reg_kernel = reg_output_base;
83 reg64_t oi_iter = r11;
84 reg64_t reg_kh = abi_not_param1;
85 reg64_t reg_overflow = reg_kh;
86 reg64_t reg_oc_work = r14;
87 reg64_t reg_table = r15;
88 reg64_t reg_icb_iter = reg_oc_work;
90 reg32_t reg_tmp_32 = r12d;
91 reg64_t reg_tmp_64 = r12;
92 reg8_t reg_tmp_8 = r12b;
94 reg64_t reg_d_weights = aux_reg_input;
95 reg64_t reg_d_bias = aux_reg_kernel;
96 reg64_t reg_oc_off = kj;
97 reg64_t reg_tmp2_64 = reg_oc_off;
98 reg32_t reg_tmp2_32 = reg_oc_off.cvt32();
100 reg64_t reg_b_weights = aux_reg_input;
101 reg64_t reg_b_mask = aux_reg_kernel;
103 reg64_t reg_shift = aux_reg_input;
105 Vmm vmm_scale = Vmm(14);
106 Vmm vmm_shift = Vmm(15);
107 Vmm vmm_sum = Vmm(10);
108 Vmm vmm_lookup = Vmm(12);
109 Vmm vmm_mask = Vmm(13);
110 Vmm vmm_one_u8 = Vmm(14);
111 Vmm vmm_one_s16 = Vmm(15);
112 Ymm ymm_tmp = Ymm(10);
113 Vmm vmm_tmp = Vmm(10);
114 Vmm vmm_tmp1 = Vmm(11);
115 Vmm vmm_src = Vmm(0);
116 Vmm vmm_tmp2 = Vmm(9);
117 Vmm vmm_thr = Vmm(10);
119 Xbyak::Label l_table;
121 nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
122 nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
124 inline void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op, bool scalar_load);
125 inline void store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store);
126 inline void apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, int ic_blocks, bool last_icb, bool h_padded);
127 inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded);
128 inline void kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step);
129 inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step);
130 inline void solve_common(int oc_blocks, int oc_step);
131 inline void prepare_table();