1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_SSE42_CONV_KERNEL_F32_HPP
18 #define JIT_SSE42_CONV_KERNEL_F32_HPP
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "jit_uni_eltwise.hpp"
29 struct jit_sse42_conv_fwd_kernel_f32: public jit_generator {
30 jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
31 const primitive_attr_t &attr): jcp(ajcp), attr_(attr)
34 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
37 static bool post_ops_ok(jit_conv_conf_t &jcp,
38 const primitive_attr_t &attr);
40 static status_t init_conf(jit_conv_conf_t &jcp,
41 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
42 const memory_desc_wrapper &weights_d,
43 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
44 bool with_relu = false, float relu_negative_slope = 0.);
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32)
48 const primitive_attr_t &attr_;
49 void (*jit_ker)(jit_conv_call_s *);
52 using reg64_t = const Xbyak::Reg64;
53 reg64_t reg_input = rax;
54 reg64_t aux_reg_input = r8;
55 reg64_t reg_kernel = rdx;
56 reg64_t aux_reg_kernel = r9;
57 reg64_t reg_output = rsi;
58 reg64_t reg_bias = rbx;
61 reg64_t oi_iter = r11;
62 reg64_t ki_iter = r12;
63 reg64_t reg_kh = abi_not_param1;
64 reg64_t simd_iter = r15;
65 reg64_t reg_oc_blocks = r14;
66 reg64_t imm_addr64 = reg_oc_blocks;
67 Xbyak::Reg32 reg_ci_flag = r13d;
70 jit_uni_eltwise_vector_f32<sse42> eltwise_generator;
72 inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
74 inline void oh_step_nopad(int ur_w, int pad_l, int pad_r,
75 char pad_label, int oc_blocks, char oc_blocks_label);
76 inline void width_blk_step(int ur_w, int pad_l, int pad_r,
77 char pad_label, int oc_blocks, char oc_blocks_label);
78 inline void solve_common(int oc_blocks, char oc_blocks_label);