1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "memory_tracking.hpp"
19 #include "mkldnn_thread.hpp"
21 #include "jit_uni_dw_convolution.hpp"
27 using namespace mkldnn::impl::status;
28 using namespace mkldnn::impl::memory_format;
29 using namespace mkldnn::impl::memory_tracking::names;
30 using namespace mkldnn::impl::utils;
32 template <cpu_isa_t isa>
33 void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward() const {
34 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
35 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
36 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
37 auto dst = reinterpret_cast<data_t *>(this->memory());
39 const memory_desc_wrapper src_d(pd()->src_pd());
40 const memory_desc_wrapper dst_d(pd()->dst_pd());
41 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
42 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
44 const auto &jcp = kernel_->jcp;
46 if (pd()->wants_padded_bias()) {
47 auto padded_bias = this->scratchpad().template get<data_t>(
48 key_conv_padded_bias);
49 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
50 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
51 jcp.oc - jcp.oc_without_padding);
55 int dil_h = jcp.dilate_h + 1;
56 int dil_w = jcp.dilate_w + 1;
57 int str_h = jcp.stride_h;
58 int str_w = jcp.stride_w;
60 auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
61 int kh_padding, int ch, int ch_num, int n) {
62 auto par_conv = jit_conv_call_s();
64 const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
65 const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
66 + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
68 const int iw = nstl::max((ow*str_w - jcp.l_pad
69 + div_up(i_l_overflow, dil_w)*dil_w), 0);
70 const int kw = div_up(i_l_overflow, dil_w);
72 const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
73 - div_up(i_r_overflow, dil_w);
75 par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)];
76 par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)];
78 par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
79 if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)];
81 par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
82 par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
84 par_conv.ur_w = (size_t)ur_w_step;
86 par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
87 par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
93 const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
94 parallel_nd(MB, chb_work, jcp.oh,
95 [&](int n, int chb, int oh) {
96 int ch = chb * jcp.nb_ch_blocking;
97 int ch_num = jcp.nb_ch_blocking;
99 const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
100 const int i_b_overflow = nstl::max(jcp.ih,
101 (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
103 const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
104 + div_up(i_t_overflow, dil_h)*dil_h), 0);
105 const int kh = div_up(i_t_overflow, dil_h);
106 const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
107 - div_up(i_b_overflow, dil_h);
111 int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
113 for (; ow < l_border; ow++) {
114 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
115 kh, kh_padding, ch, ch_num, n);
117 kernel_->jit_ker(&par_conv);
121 ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
122 / jcp.stride_w - ow + 1;
124 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
125 kh, kh_padding, ch, ch_num, n);
127 kernel_->jit_ker(&par_conv);
134 for (; ow < jcp.ow; ow++) {
135 jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
136 kh, kh_padding, ch, ch_num, n);
138 kernel_->jit_ker(&par_conv);
142 if (pd()->wants_zero_pad_dst())
143 output_memory_primitive(0)->zero_pad();
146 template struct _jit_uni_dw_convolution_fwd_t<avx512_common>;
147 template struct _jit_uni_dw_convolution_fwd_t<avx2>;
148 template struct _jit_uni_dw_convolution_fwd_t<sse42>;
150 template <cpu_isa_t isa>
151 void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() const {
152 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
153 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
154 auto diff_src = reinterpret_cast<data_t *>(this->memory());
156 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
157 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
158 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
160 const auto &jcp = kernel_->jcp;
162 auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih,
163 int i_t_overflow, int i_b_overflow, int stride_off_h,
164 int ch, int ch_num, int n) {
165 auto par_conv = jit_conv_call_s();
167 const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad));
168 const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw)
171 int ow = iw + jcp.l_pad - i_r_overflow;
172 int stride_off_w = ow % jcp.stride_w;
175 par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)];
176 par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)];
177 par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow
178 + stride_off_h, i_r_overflow + stride_off_w)];
180 par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow
182 par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow
185 par_conv.ur_str_w = ur_str_w;
187 par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
193 const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
194 parallel_nd(MB, chb_work, jcp.ih,
195 [&](int n, int chb, int ih) {
196 int ch = chb * jcp.nb_ch_blocking;
197 int ch_num = jcp.nb_ch_blocking;
199 const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih
201 const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1
202 - (jcp.ih - 1 - ih) - jcp.b_pad));
204 int oh = ih + jcp.t_pad - i_b_overflow;
205 int stride_off_h = oh % jcp.stride_h;
208 for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) {
211 int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw);
213 for (; iw < l_border; iw += jcp.stride_w) {
214 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
215 ih, i_t_overflow, i_b_overflow,
216 stride_off_h, ch, ch_num, n);
218 kernel_->jit_ker(&par_conv);
222 ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw)
223 / jcp.stride_w, jcp.iw);
225 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
226 ih, i_t_overflow, i_b_overflow,
227 stride_off_h, ch, ch_num, n);
229 kernel_->jit_ker(&par_conv);
231 iw += ur_str_w * jcp.stride_w;
236 for (; iw < jcp.iw; iw += jcp.stride_w) {
237 jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
238 ih, i_t_overflow, i_b_overflow,
239 stride_off_h, ch, ch_num, n);
241 kernel_->jit_ker(&par_conv);
247 template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
248 template struct _jit_uni_dw_convolution_bwd_data_t<avx2>;
249 template struct _jit_uni_dw_convolution_bwd_data_t<sse42>;
251 template <cpu_isa_t isa>
252 _jit_uni_dw_convolution_bwd_weights_t<isa>::
253 _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd,
254 const input_vector &inputs, const output_vector &outputs)
255 : cpu_primitive_t(apd, inputs, outputs)
256 , kernel_(nullptr), acc_ker_(nullptr)
258 kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_);
259 if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction())
260 acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
263 template <cpu_isa_t isa>
264 void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights() const {
265 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
266 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
267 auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
268 auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
270 auto diff_wei_reduction_buf =
271 scratchpad().template get<data_t>(key_conv_wei_reduction);
272 auto diff_bia_reduction_buf =
273 scratchpad().template get<data_t>(key_conv_bia_reduction);
275 const auto &jcp = kernel_->jcp;
277 /* Used when executing a parallel reduction */
278 simple_barrier::ctx_t reduction_bctx;
279 simple_barrier::ctx_init(&reduction_bctx);
281 const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
282 const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
284 const int ch_block = jcp.ch_block;
286 auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params,
287 const int batch, const int group, const int oh_start,
288 const int work_size, const unsigned char exec_flag,
289 const size_t kh_padding, const size_t filter_off) {
290 const int tpad_underflow_off = jcp.t_pad - filter_off;
292 conv_params->exec_flags = exec_flag;
293 conv_params->kh_count = jcp.kh - kh_padding;
295 const int oh_s = oh_start;
296 const int oh_e = oh_start + work_size;
297 const int ih_s = oh_s * jcp.stride_h;
299 conv_params->filter_pad_off
300 = filter_off * jcp.kw * ch_block * sizeof(float);
301 conv_params->oh_index = oh_s;
302 conv_params->oh_count = oh_e;
305 = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
309 size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
310 + ih_s - tpad_underflow_off) * jcp.iw;
312 conv_params->output = &diff_dst[diff_dst_off * ch_block];
313 conv_params->input = &src[src_off * ch_block];
316 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
317 assert(nthr == jcp.nthr);
319 auto conv_params = jit_dw_conv_call_s();
320 const int h_block_size = 15;
322 /* assign iteration space to thread */
323 const int ithr_g = ithr % jcp.nthr_g;
324 const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
326 /* split dimensions */
327 int g_start{ 0 }, g_end{ 0 };
328 balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
330 int mb_start{ 0 }, mb_end{ 0 };
331 balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
333 auto diff_wei = ithr_mb == 0
334 ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size;
335 auto diff_bia = ithr_mb == 0
336 ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
338 for (int g = g_start; g < g_end; ++g) {
339 unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
340 unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
342 size_t diff_wei_off = g * jcp.kh * jcp.kw;
343 conv_params.filter = &diff_wei[diff_wei_off * ch_block];
346 conv_params.bias = &diff_bia[g * ch_block];
348 for (int mb = mb_start; mb < mb_end; ++mb) {
350 while (oh < jcp.oh) {
351 const int h_work = nstl::min(h_block_size, jcp.oh - oh);
352 auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
354 = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ?
355 jcp.b_pad - (h_work - 1) :
358 set_kernel_params(&conv_params, mb, g, oh, h_work,
359 zero_filter_flag | zero_bias_flag,
360 kh_t_padding + kh_b_padding, kh_t_padding);
361 kernel_->jit_ker(&conv_params);
363 zero_bias_flag &= ~FLAG_ZERO_BIAS;
364 zero_filter_flag &= ~FLAG_ZERO_FILTER;
370 if (do_parallel_reduction() && jcp.nthr_mb > 1) {
371 size_t reduct_start{ 0 }, reduct_end{ 0 };
372 balance211(wei_size, nthr, ithr, reduct_start, reduct_end);
374 const int acc_size = reduct_end - reduct_start;
375 const size_t reduct_off = reduct_start;
376 auto *acc_data = diff_weights + reduct_off;
378 simple_barrier::barrier(&reduction_bctx, nthr);
380 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
381 auto *src_data = diff_wei_reduction_buf
382 + (thr_mb - 1) * wei_size + reduct_off;
383 acc_ker_->accumulate(acc_data, src_data, acc_size);
388 if (jcp.nthr_mb <= 1) return;
390 /* Apply single-threaded 'mb' reduction */
391 for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
392 size_t mb_accum_offset = (thr_mb - 1) * wei_size;
393 size_t b_accum_offset = (thr_mb - 1) * bias_size;
395 for (int g = 0; g < jcp.nb_ch; ++g) {
396 /* Reduction on Bias */
399 for (int g_block = 0; g_block < ch_block; ++g_block) {
400 size_t bias_offset = g * ch_block + g_block;
401 diff_bias[bias_offset] += diff_bia_reduction_buf[
402 b_accum_offset + bias_offset];
406 if (do_parallel_reduction()) continue;
408 for (int kh = 0; kh < jcp.kh; ++kh)
409 for (int kw = 0; kw < jcp.kw; ++kw)
411 size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
413 for (int g_block = 0; g_block < ch_block; ++g_block) {
414 const size_t off = wei_offset * ch_block + g_block;
416 diff_wei_reduction_buf[mb_accum_offset + off];
423 template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
424 template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>;
425 template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>;