1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
18 #define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
20 #include "c_types_map.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
30 //alpha determines the output tile_size
31 constexpr int alpha = 6;
32 constexpr int tile_size = 4;
33 //simd length used for vectorization
34 constexpr int simd_w = 16;
36 struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator {
37 _jit_avx512_common_conv_winograd_data_kernel_f32(
38 jit_conv_winograd_conf_t ajcp)
41 //******************* First iter kernel ********************//
42 this->gemm_loop_generate(true);
43 gemm_loop_ker_first_iter
44 = (decltype(gemm_loop_ker_first_iter)) this->getCode();
46 //************** Subsequent iterations kernel **************//
47 if (jcp.dimK_nb_block > 1) {
49 const Xbyak::uint8 *addr = getCurr();
50 this->gemm_loop_generate(false);
51 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
55 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32)
57 static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
58 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
59 const memory_desc_wrapper &weights_d,
60 const memory_desc_wrapper &dst_d);
62 static status_t init_conf_kernel(
63 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
65 jit_conv_winograd_conf_t jcp;
66 void (*gemm_loop_ker)(float *, const float *, const float *);
67 void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
70 using reg64_t = const Xbyak::Reg64;
71 enum { typesize = sizeof(float) };
73 void gemm_loop_generate(bool is_beta_zero);
75 /* registers used for GEMM */
76 reg64_t reg_dstC = abi_param1;
77 reg64_t reg_srcA = abi_param2;
78 reg64_t reg_srcB = abi_param3;
80 reg64_t reg_dimM_block_loop_cnt = r10;
81 reg64_t reg_dimK_block_loop_cnt = r11;
84 struct jit_avx512_common_conv_winograd_fwd_kernel_f32
85 : _jit_avx512_common_conv_winograd_data_kernel_f32 {
86 using _jit_avx512_common_conv_winograd_data_kernel_f32::
87 _jit_avx512_common_conv_winograd_data_kernel_f32;
89 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
91 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
92 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
93 const memory_desc_wrapper &weights_d,
94 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
97 struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32
98 : public _jit_avx512_common_conv_winograd_data_kernel_f32 {
99 using _jit_avx512_common_conv_winograd_data_kernel_f32::
100 _jit_avx512_common_conv_winograd_data_kernel_f32;
102 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
103 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
104 const memory_desc_wrapper &weights_d,
105 const memory_desc_wrapper &diff_dst_d);
108 struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32
109 : public jit_generator {
110 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32)
112 jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
113 jit_conv_winograd_conf_t ajcp)
117 //******************* First iter kernel ********************//
120 const Xbyak::uint8 *addr = getCurr();
121 this->gemm_loop_generate(true);
122 gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr;
125 if (jcp.tile_block > 1) {
127 const Xbyak::uint8 *addr = getCurr();
128 this->gemm_loop_generate(false);
129 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
132 if (jcp.ver == ver_4fma) {
134 const Xbyak::uint8 *addr = getCurr();
135 this->transpose_ker_generate();
136 transpose_4fma_ker = (decltype(transpose_4fma_ker))addr;
140 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
141 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
142 const memory_desc_wrapper &diff_dst_d,
143 const memory_desc_wrapper &diff_weights_d);
145 jit_conv_winograd_conf_t jcp;
146 void (*gemm_loop_ker)(float *, const float *, const float *);
147 void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
148 void (*transpose_4fma_ker)(float *, float *);
151 using reg64_t = const Xbyak::Reg64;
152 enum { typesize = sizeof(float) };
154 void gemm_loop_generate(bool is_first_tile);
155 void transpose_ker_generate();
157 reg64_t reg_origB = abi_param2;
158 reg64_t reg_transB = abi_param1;
160 reg64_t reg_dstC = abi_param1;
161 reg64_t reg_srcA_const = abi_param2;
162 reg64_t reg_srcB = abi_param3;
164 reg64_t reg_sp = rsp;
165 reg64_t reg_srcA = r9;
166 reg64_t reg_nb_ic = r10;
167 reg64_t reg_loop_cpt = r11;
168 reg64_t reg_transB_idx = r13;
170 /* Registers used by new kernel */
171 reg64_t reg_dimM_block_loop_cnt = r10;
172 reg64_t reg_dimK_block_loop_cnt = r12;
173 reg64_t reg_dimN_block_loop_cnt = r11;