1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
18 #define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "jit_uni_eltwise.hpp"
29 struct jit_avx2_1x1_conv_kernel_f32: public jit_generator {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32)
32 jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
33 const primitive_attr_t &attr): jcp(ajcp), attr_(attr)
36 jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
39 static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
40 const primitive_attr_t &attr);
42 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
43 const convolution_desc_t &cd,
44 const memory_desc_wrapper &src_d,
45 const memory_desc_wrapper &weights_d,
46 const memory_desc_wrapper &dst_d,
47 const primitive_attr_t &attr,
48 bool with_relu, float relu_negative_slope);
50 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
51 const convolution_desc_t &cd,
52 const memory_desc_wrapper &src_d,
53 const memory_desc_wrapper &weights_d,
54 const memory_desc_wrapper &dst_d,
55 const primitive_attr_t &attr)
57 return init_conf(jcp, cd, src_d, weights_d, dst_d, attr, false, 0.0);
60 jit_1x1_conv_conf_t jcp;
61 const primitive_attr_t &attr_;
62 void (*jit_ker)(jit_1x1_conv_call_s *);
65 using reg64_t = const Xbyak::Reg64;
66 using ymm_t = const Xbyak::Ymm;
68 reg64_t reg_bcast_data = rax;
69 reg64_t reg_load_data = rsi;
70 reg64_t reg_output_data = rbx;
71 reg64_t aux_reg_bcast_data = rdx;
72 reg64_t aux1_reg_bcast_data = abi_not_param1;
73 reg64_t aux_reg_load_data = abi_param1;
74 reg64_t aux_reg_output_data = rbp;
75 reg64_t reg_load_loop_work = r9;
76 reg64_t reg_bcast_loop_work = r10;
77 reg64_t reg_reduce_loop_work = r11;
78 reg64_t load_loop_iter = r13;
79 reg64_t bcast_loop_iter = r14;
80 reg64_t reduce_loop_iter = r15;
81 reg64_t imm_addr64 = reduce_loop_iter;
82 reg64_t reg_reduce_pos_flag = r8;
83 reg64_t reg_output_stride = r12;
84 reg64_t reg_bias_data = r12;
85 reg64_t reg_diff_bias_data = bcast_loop_iter;
87 int reg_diff_bias_data_stack_offt = 0;
88 int stack_space_needed = 8;
90 ymm_t vreg_bcast = ymm_t(15);
93 jit_uni_eltwise_vector_f32<avx2> eltwise_generator;
95 void bcast_loop(int load_loop_blk, char load_loop_tag);
96 void reduce_loop(int load_loop_blk, int ur, char load_loop_tag,
98 void diff_bias_loop(int load_loop_blk, char load_loop_tag);