1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
18 #define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25 #include "jit_uni_eltwise.hpp"
26 #include "jit_uni_depthwise.hpp"
32 struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
33 jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
34 const primitive_attr_t &attr)
35 : jcp(ajcp), attr_(attr)
38 jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
41 ~jit_avx512_common_1x1_conv_kernel() {
42 for (auto inj : eltwise_injectors)
44 eltwise_injectors.clear();
46 for (auto inj : depthwise_injectors)
48 depthwise_injectors.clear();
51 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel)
53 static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
54 const primitive_attr_t &attr);
56 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
57 const convolution_desc_t &cd,
58 const memory_desc_wrapper &src_d,
59 const memory_desc_wrapper &weights_d,
60 const memory_desc_wrapper &dst_d,
61 const primitive_attr_t &attr,
62 int nthreads, bool reduce_src);
64 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
65 const jit_1x1_conv_conf_t &jcp);
67 jit_1x1_conv_conf_t jcp;
68 const primitive_attr_t &attr_;
69 void (*jit_ker)(jit_1x1_conv_call_s *);
72 using reg64_t = const Xbyak::Reg64;
73 using zmm_t = const Xbyak::Zmm;
75 reg64_t reg_bcast_data = r8;
76 reg64_t reg_load_data = r10;
77 reg64_t reg_output_data = r9;
78 reg64_t aux_reg_bcast_data = r14;
79 reg64_t aux1_reg_bcast_data = rbx;
80 reg64_t aux_reg_load_data = r15;
81 reg64_t imm_addr64 = aux_reg_load_data;
82 reg64_t aux_reg_output_data = abi_not_param1;
83 reg64_t reg_load_loop_work = rsi;
84 reg64_t reg_reduce_loop_work = r11;
85 reg64_t bcast_loop_iter = rdx;
86 reg64_t reduce_loop_iter = r13;
87 reg64_t reg_reduce_pos_flag = rax;
88 reg64_t reg_output_stride = r13;
89 reg64_t reg_bias_data = r12;
90 reg64_t reg_relu_ns = r13;
91 reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
93 Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
95 reg64_t reg_oc_off = abi_param1;
96 reg64_t reg_d_weights = imm_addr64;
97 reg64_t reg_d_bias = r13;
99 nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
100 nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
102 int bcast_loop_work_offt = 0;
103 int stack_space_needed = 16;
105 void bcast_loop(int load_loop_blk);
106 void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
109 static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);