1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_avx512_common_convolution.hpp"
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
35 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
37 #define PIPELINE(field) \
39 p.field = p.field ## _prf; \
40 p.field ## _prf = field; \
43 inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
44 const void *src, const void *dst, const void *filt, const void *bias,
45 int channel, int kh_padding, int oc_off)
58 // The special case for the driver with ow-parallelization (FWD)
59 // TODO: implement it for BWD_D and BWD_W too
60 inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p,
61 const void *src, const void *dst, const void *filt, const void *bias,
62 int channel, int kh_padding, int owb, int oc_off)
77 inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
78 const void *src, const void *dst, const void *filt, const void *bias,
79 int channel, int kh_padding, int kd_padding, int oc_off)
93 // The special case for the driver with ow-parallelization (FWD)
94 // TODO: implement it for BWD_D and BWD_W too
95 inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker,
96 jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
97 const void *bias, int channel, int kh_padding, int kd_padding, int owb, int oc_off)
104 PIPELINE(kh_padding);
105 PIPELINE(kd_padding);
113 void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
114 const void *src, const void *dst, const void *filt, const void *bias,
115 int channel, int d_index, int d_worksize,
116 int kd_padding /* kd_work_size */, size_t kd_offset) {
122 PIPELINE(kd_padding);
123 PIPELINE(d_worksize);
130 #define wht_blk_off(d, g, ...) \
131 (pd()->with_groups() \
132 ? (d).blk_off((g), __VA_ARGS__) \
133 : (d).blk_off(__VA_ARGS__))
135 template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
136 void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
137 prepare_padded_bias(const dst_data_t *&bias) const {
138 if (!pd()->wants_padded_bias()) return;
140 auto padded_bias = scratchpad().template get<dst_data_t>(
141 key_conv_padded_bias);
142 utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
143 utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
144 (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
148 template <data_type_t src_type, data_type_t wei_type,
149 data_type_t dst_type>
150 void jit_avx512_common_convolution_fwd_t
151 <src_type, wei_type, dst_type>::execute_forward_1d() const
153 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
154 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
155 auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
156 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
158 prepare_padded_bias(bias);
160 const memory_desc_wrapper src_d(pd()->src_pd());
161 const memory_desc_wrapper dst_d(pd()->dst_pd());
162 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
164 const auto &jcp = pd()->jcp_;
165 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
167 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
168 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow;
171 if (jcp.aligned_threads)
172 nthr = jcp.aligned_threads;
174 nthr = mkldnn_get_max_threads();
176 parallel(nthr, [&](const int ithr, const int nthr) {
177 int start{0}, end{0}, start_copy;
178 balance211(work_amount, nthr, ithr, start, end);
181 auto par_conv = jit_conv_call_s();
182 size_t src_c_stride = src_d.blk_off(0, 1);
183 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
185 for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
187 int n{0}, g{0}, occ{0}, owb{0};
189 if (jcp.loop_order == loop_cwgn) {
191 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
192 g, jcp.ngroups, n, jcp.mb, dummy, 1);
193 } else if (jcp.loop_order == loop_gncw) {
195 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ,
196 oc_chunks, owb, jcp.nb_ow, dummy, 1);
198 assert(!"unsupported loop order");
201 while (start < end) {
202 int ocb = occ * jcp.nb_oc_blocking;
203 int g_ocb = g * jcp.nb_oc + ocb;
204 int g_oc = g_ocb * jcp.oc_block;
205 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
207 int ow_s = owb * jcp.ow_block;
208 int iw_s = ow_s * jcp.stride_w;
209 auto bias_w = bias ? bias + g_oc : nullptr;
210 auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s);
211 auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s);
212 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
214 int oc_off = g_oc * sizeof(dst_data_t);
216 for (int icb = icb_l2;
217 icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
218 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
219 src_w, dst_w, wht_w, bias_w, icb, 1, owb, oc_off);
221 src_w += src_c_stride;
222 wht_w += wht_ic_stride;
224 if (jcp.loop_order == loop_cwgn) {
226 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
227 g, jcp.ngroups, n, jcp.mb, dummy, 1);
228 } else if (jcp.loop_order == loop_gncw) {
230 nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb,
231 occ, oc_chunks, owb, jcp.nb_ow, dummy, 1);
233 assert(!"unsupported loop order");
237 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
238 src, dst, weights, bias, 0, 0, 0, 0);
242 template <data_type_t src_type, data_type_t wei_type,
243 data_type_t dst_type>
244 void jit_avx512_common_convolution_fwd_t
245 <src_type, wei_type, dst_type>::execute_forward_2d() const
247 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
248 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
249 auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
250 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
252 prepare_padded_bias(bias);
254 const memory_desc_wrapper src_d(pd()->src_pd());
255 const memory_desc_wrapper dst_d(pd()->dst_pd());
256 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
258 const auto &jcp = pd()->jcp_;
259 const int MB = pd()->MB();
260 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
262 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
263 int work_amount = MB * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow;
266 if (jcp.aligned_threads)
267 nthr = jcp.aligned_threads;
269 nthr = mkldnn_get_max_threads();
271 parallel(nthr, [&](const int ithr, const int nthr) {
272 int start{0}, end{0}, start_copy;
273 balance211(work_amount, nthr, ithr, start, end);
276 auto par_conv = jit_conv_call_s();
277 size_t src_h_stride = src_d.blk_off(0, 0, 1) - src_d.off_l(0);
278 size_t src_c_stride = src_d.blk_off(0, 1) - src_d.off_l(0);
279 size_t dst_h_stride = dst_d.blk_off(0, 0, 1) - dst_d.off_l(0);
280 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
281 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
283 for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
285 int n{0}, g{0}, occ{0}, oh_s{0}, owb{0};
287 if (jcp.loop_order == loop_cwgn)
288 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
289 g, jcp.ngroups, n, MB, oh_s, jcp.oh);
290 else if (jcp.loop_order == loop_gncw)
291 nd_iterator_init(start, g, jcp.ngroups, n, MB,
292 occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
294 assert(!"unsupported loop order");
296 while (start < end) {
297 int ocb = occ * jcp.nb_oc_blocking;
298 int g_ocb = g * jcp.nb_oc + ocb;
299 int g_oc = g_ocb * jcp.oc_block;
300 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
302 int work_rem = end - start;
304 int ow_s = owb * jcp.ow_block;
305 int iw_s = ow_s * jcp.stride_w;
306 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
307 auto bias_w = bias ? bias + g_oc : nullptr;
309 for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) {
310 int ih_b = -jcp.t_pad + oh_b * jcp.stride_h;
312 auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s);
314 = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s);
316 = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
318 for (int icb = icb_l2;
319 icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
323 for (int oj = oh_b, ij = ih_b;
324 oj < min(oh_e, oh_b + jcp.h_blocking);
325 ++oj, ij += jcp.stride_h) {
326 int dilate_h = jcp.dilate_h + 1;
327 int i_t_overflow = div_up(max(0, -ij), dilate_h);
328 int i_b_overflow = div_up(max(0, ij - jcp.ih
329 + (jcp.kh - 1) * dilate_h + 1), dilate_h);
330 int kh_padding = nstl::max(
331 0, jcp.kh - i_t_overflow - i_b_overflow);
334 + i_t_overflow * dilate_h * src_h_stride;
335 auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
337 int oc_off = g_oc * sizeof(dst_data_t);
339 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker,
340 par_conv, aux_src, dst_c, aux_wht, bias_w, icb,
341 kh_padding, owb, oc_off);
343 src_c += src_h_stride * jcp.stride_h;
344 dst_c += dst_h_stride;
346 src_w += src_c_stride;
347 wht_w += wht_ic_stride;
351 if (jcp.loop_order == loop_cwgn)
352 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
353 g, jcp.ngroups, n, MB, oh_s, jcp.oh);
354 else if (jcp.loop_order == loop_gncw)
355 nd_iterator_jump(start, end, g, jcp.ngroups, n, MB, occ,
356 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
358 assert(!"unsupported loop order");
362 jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
363 src, dst, weights, bias, 0, 0, 0, 0);
367 template <data_type_t src_type, data_type_t wei_type,
368 data_type_t dst_type>
369 void jit_avx512_common_convolution_fwd_t
370 <src_type, wei_type, dst_type>::execute_forward_3d() const
372 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
373 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
374 auto bias = reinterpret_cast<const dst_data_t *>(this->input_memory(2));
375 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
377 prepare_padded_bias(bias);
379 const memory_desc_wrapper src_d(pd()->src_pd());
380 const memory_desc_wrapper dst_d(pd()->dst_pd());
381 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
382 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
384 const auto &jcp = pd()->jcp_;
385 const int MB = pd()->MB();
386 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
388 parallel(0, [&](const int ithr, const int nthr) {
389 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
390 int start{0}, end{0}, start_copy;
391 int work_amount = MB * jcp.ngroups * oc_chunks * jcp.od * jcp.oh
393 balance211(work_amount, nthr, ithr, start, end);
396 auto par_conv = jit_conv_call_s();
397 size_t src_d_stride = src_d.blk_off(0, 0, 1);
398 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
399 size_t src_c_stride = src_d.blk_off(0, 1);
400 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
401 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
402 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
403 size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
405 for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
407 int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0};
409 if (jcp.loop_order == loop_cwgn)
410 nd_iterator_init(start,
411 occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, MB,
412 od_s, jcp.od, oh_s, jcp.oh);
413 else if (jcp.loop_order == loop_gncw)
414 nd_iterator_init(start,
415 g, jcp.ngroups, n, MB, occ, oc_chunks, owb, jcp.nb_ow,
416 od_s, jcp.od, oh_s, jcp.oh);
418 assert(!"unsupported loop order");
420 while (start < end) {
421 int ocb = occ * jcp.nb_oc_blocking;
422 int g_ocb = g * jcp.nb_oc + ocb;
423 int g_oc = g_ocb * jcp.oc_block;
424 int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
426 int work_rem = end - start;
427 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
428 int ow_s = owb * jcp.ow_block;
429 int iw_s = ow_s * jcp.stride_w;
430 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
432 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
434 int dilate_d = jcp.dilate_d + 1;
435 int d_t_overflow = div_up(max(0, -id_s), dilate_d);
436 int d_b_overflow = div_up(
437 max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
439 int kd_padding = nstl::max(0,
440 jcp.kd - d_t_overflow - d_b_overflow);
442 auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0;
443 auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s);
444 auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s,
445 iw_s) + d_t_overflow * dilate_d * src_d_stride;
446 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2)
447 + d_t_overflow * wht_d_stride;
449 for (int icb = icb_l2;
450 icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
453 for (int oj = oh_s, ij = ih_s;
454 oj < oh_e; ++oj, ij += jcp.stride_h)
456 int dilate_h = jcp.dilate_h + 1;
457 int i_t_overflow = div_up(max(0, -ij), dilate_h);
458 int i_b_overflow = div_up(
459 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h
462 int kh_padding = nstl::max(0,
463 jcp.kh - i_t_overflow - i_b_overflow);
465 int oc_off = g_oc * sizeof(dst_data_t);
467 jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker,
469 src_c + i_t_overflow * dilate_h * src_h_stride,
470 dst_c, wht_w + i_t_overflow * wht_h_stride,
471 bias_w, icb, kh_padding, kd_padding, owb, oc_off);
473 src_c += src_h_stride * jcp.stride_h;
474 dst_c += dst_h_stride;
476 src_w += src_c_stride;
477 wht_w += wht_ic_stride;
480 if (jcp.loop_order == loop_cwgn)
481 nd_iterator_jump(start, end,
482 occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, MB,
483 od_s, jcp.od, oh_s, jcp.oh);
484 else if (jcp.loop_order == loop_gncw)
485 nd_iterator_jump(start, end,
486 g, jcp.ngroups, n, MB, occ, oc_chunks, owb, jcp.nb_ow,
487 od_s, jcp.od, oh_s, jcp.oh);
489 assert(!"unsupported loop order");
492 jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
493 src, dst, weights, bias, 0, 0, 0, 0);
497 template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
498 template struct jit_avx512_common_convolution_fwd_t<data_type::s16,
499 data_type::s16, data_type::s32>;
501 template <data_type_t diff_dst_type, data_type_t wei_type,
502 data_type_t diff_src_type>
503 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
504 diff_src_type>::execute_backward_data_1d() const {
505 auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
506 (this->input_memory(0));
507 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
508 auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
510 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
511 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
512 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
514 const auto &jcp = kernel_->jcp;
516 parallel(0, [&](const int ithr, const int nthr) {
517 int start{0}, end{0}, start_copy;
518 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
519 int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
520 balance211(work_amount, nthr, ithr, start, end);
523 auto par_conv = jit_conv_call_s();
524 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
525 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
527 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
529 int n{0}, g{0}, icc{0};
530 if (jcp.loop_order == loop_cgn) {
532 nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n,
534 } else if (jcp.loop_order == loop_gnc) {
536 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc,
537 ic_chunks, dummy, 1);
539 assert(!"unsupported loop order");
542 while (start < end) {
543 int icb = icc * jcp.nb_ic_blocking;
544 int g_icb = g * jcp.nb_ic + icb;
545 int g_ocb = g * jcp.nb_oc;
547 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
548 auto diff_dst_w = diff_dst
549 + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
550 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
552 for (int ocb = ocb_l2;
553 ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
554 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
555 diff_src_w, diff_dst_w, wht_w, 0, ocb, 1, 0);
556 diff_dst_w += diff_dst_c_stride;
557 wht_w += wht_oc_stride;
560 if (jcp.loop_order == loop_cgn) {
562 nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups,
563 n, jcp.mb, dummy, 1);
564 } else if (jcp.loop_order == loop_gnc) {
566 nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc,
567 ic_chunks, dummy, 1);
569 assert(!"unsupported loop order");
574 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
575 diff_src, diff_dst, weights, 0, 0, 1, 0);
579 template <data_type_t diff_dst_type, data_type_t wei_type,
580 data_type_t diff_src_type>
581 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
582 diff_src_type>::execute_backward_data_2d() const {
583 auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
584 (this->input_memory(0));
585 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
586 auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
588 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
589 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
590 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
592 const auto &jcp = kernel_->jcp;
593 const int MB = pd()->MB();
595 parallel(0, [&](const int ithr, const int nthr) {
596 int start{0}, end{0}, start_copy;
597 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
598 int work_amount = jcp.ngroups * MB * ic_chunks * jcp.ih;
599 balance211(work_amount, nthr, ithr, start, end);
602 auto par_conv = jit_conv_call_s();
603 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
604 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
605 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
606 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
607 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
609 bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
611 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
613 int n{0}, g{0}, icc{0}, ih_s{0};
614 if (jcp.loop_order == loop_cgn)
615 nd_iterator_init(start,
616 icc, ic_chunks, g, jcp.ngroups, n, MB, ih_s, jcp.ih);
617 else if (jcp.loop_order == loop_gnc)
618 nd_iterator_init(start,
619 g, jcp.ngroups, n, MB, icc, ic_chunks, ih_s, jcp.ih);
621 assert(!"unsupported loop order");
623 while (start < end) {
624 int icb = icc * jcp.nb_ic_blocking;
625 int g_icb = g * jcp.nb_ic + icb;
626 int g_ocb = g * jcp.nb_oc;
628 int work_rem = end - start;
629 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
631 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
632 auto diff_dst_w = diff_dst
633 + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
634 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
636 for (int ocb = ocb_l2;
637 ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
638 for (int ij = ih_s; ij < ih_e; ++ij) {
640 if (is_fast_path) { // dilate == 0 && stride == 1
641 int i_t_overflow = max(0, jcp.kh - 1 - ij
643 int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
645 k_len = jcp.kh - i_t_overflow - i_b_overflow;
647 oj = ij + jcp.t_pad - i_b_overflow;
648 } else if (jcp.dilate_h != 0) { // stride == 1
649 int dilate_h = jcp.dilate_h + 1;
650 // Note: use div_up to account for "holes" in filter
652 = div_up(max(0, (jcp.kh - 1) * dilate_h
653 - ij - jcp.t_pad), dilate_h);
655 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
656 - jcp.ih + ij - jcp.b_pad), dilate_h);
657 k_len = jcp.kh - i_t_overflow - i_b_overflow;
659 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
660 } else { // dilate == 0
661 int i_t_overflow = max(0, (jcp.kh - 1 - ij
662 - jcp.t_pad) / jcp.stride_h);
663 int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
664 - jcp.b_pad) / jcp.stride_h);
665 int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
666 + jcp.b_pad - ij) % jcp.stride_h);
667 int overflow_kh_lo = (ij + jcp.t_pad)
670 k_len = (overflow_kh_hi - overflow_kh_lo)
671 / jcp.stride_h + 1 - i_t_overflow
673 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
674 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
678 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
679 diff_src_w + ij * diff_src_h_stride,
680 diff_dst_w + oj * diff_dst_h_stride,
681 wht_w + k_lo * wht_h_stride,
684 diff_dst_w += diff_dst_c_stride;
685 wht_w += wht_oc_stride;
688 if (jcp.loop_order == loop_cgn)
689 nd_iterator_jump(start, end,
690 icc, ic_chunks, g, jcp.ngroups, n, MB, ih_s, jcp.ih);
691 else if (jcp.loop_order == loop_gnc)
692 nd_iterator_jump(start, end,
693 g, jcp.ngroups, n, MB, icc, ic_chunks, ih_s, jcp.ih);
695 assert(!"unsupported loop order");
699 jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
700 diff_src, diff_dst, weights, 0, 0, 1, 0);
704 template <data_type_t diff_dst_type, data_type_t wei_type,
705 data_type_t diff_src_type>
706 void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
707 diff_src_type>::execute_backward_data_3d() const {
708 auto diff_dst = reinterpret_cast<const diff_dst_data_t *>
709 (this->input_memory(0));
710 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
711 auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
713 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
714 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
715 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
717 const auto &jcp = kernel_->jcp;
718 const int MB = pd()->MB();
720 parallel(0, [&](const int ithr, const int nthr) {
721 int start{0}, end{0}, start_copy;
722 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
723 int work_amount = jcp.ngroups * MB * ic_chunks * jcp.id * jcp.ih;
724 balance211(work_amount, nthr, ithr, start, end);
727 auto par_conv = jit_conv_call_s();
728 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
729 size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
730 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
731 size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
732 size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
733 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
734 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
735 size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
737 bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
738 bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
740 for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
742 int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0};
743 if (jcp.loop_order == loop_cgn)
744 nd_iterator_init(start,
745 icc, ic_chunks, g, jcp.ngroups, n, MB, id_s, jcp.id,
747 else if (jcp.loop_order == loop_gnc)
748 nd_iterator_init(start,
749 g, jcp.ngroups, n, MB, icc, ic_chunks, id_s, jcp.id,
752 assert(!"unsupported loop order");
754 while (start < end) {
755 int icb = icc * jcp.nb_ic_blocking;
756 int g_icb = g * jcp.nb_ic + icb;
757 int g_ocb = g * jcp.nb_oc;
759 int work_rem = end - start;
760 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
761 int d_len = 0, d_lo = 0, d_oj = 0;
762 if (is_fast_path_d) { // dilate == 0 && stride == 1
763 int d_t_overflow = max(0, jcp.kd - 1 - id_s
765 int d_b_overflow = max(0, jcp.kd - jcp.id + id_s
767 d_len = jcp.kd - d_t_overflow - d_b_overflow;
769 d_oj = id_s + jcp.f_pad - d_b_overflow;
770 } else if (jcp.dilate_d != 0) { // stride == 1
771 int dilate_d = jcp.dilate_d + 1;
772 // Note: use div_up to account for "holes" in filter
773 int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d
774 - id_s - jcp.f_pad), dilate_d);
775 int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1
776 - jcp.id + id_s - jcp.back_pad), dilate_d);
777 d_len = jcp.kd - d_t_overflow - d_b_overflow;
779 d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d;
780 } else { // dilate == 0
781 int d_t_overflow = max(0, (jcp.kd - 1 - id_s
782 - jcp.f_pad) / jcp.stride_d);
783 int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s
784 - jcp.back_pad) / jcp.stride_d);
785 int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1
786 + jcp.back_pad - id_s) % jcp.stride_d);
787 int overflow_kd_lo = (id_s + jcp.f_pad)
790 d_len = (overflow_kd_hi - overflow_kd_lo)
791 / jcp.stride_d + 1 - d_t_overflow
793 d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
794 d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d;
798 auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb)
799 + id_s * diff_src_d_stride;
800 auto diff_dst_w = diff_dst
801 + diff_dst_d.blk_off(n, g_ocb + ocb_l2)
802 + d_oj * diff_dst_d_stride;
803 auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb)
804 + d_lo * wht_d_stride;
806 for (int ocb = ocb_l2;
807 ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
808 for (int ij = ih_s; ij < ih_e; ++ij) {
810 if (is_fast_path_h) { // dilate == 0 && stride == 1
811 int i_t_overflow = max(0, jcp.kh - 1 - ij
813 int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
815 k_len = jcp.kh - i_t_overflow - i_b_overflow;
817 oj = ij + jcp.t_pad - i_b_overflow;
818 } else if (jcp.dilate_h != 0) { // stride == 1
819 int dilate_h = jcp.dilate_h + 1;
820 // Note: use div_up to account for "holes" in filter
822 = div_up(max(0, (jcp.kh - 1) * dilate_h
823 - ij - jcp.t_pad), dilate_h);
825 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
826 - jcp.ih + ij - jcp.b_pad), dilate_h);
827 k_len = jcp.kh - i_t_overflow - i_b_overflow;
829 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
830 } else { // dilate == 0
831 int i_t_overflow = max(0, (jcp.kh - 1 - ij
832 - jcp.t_pad) / jcp.stride_h);
833 int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
834 - jcp.b_pad) / jcp.stride_h);
835 int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
836 + jcp.b_pad - ij) % jcp.stride_h);
837 int overflow_kh_lo = (ij + jcp.t_pad)
840 k_len = (overflow_kh_hi - overflow_kh_lo)
841 / jcp.stride_h + 1 - i_t_overflow
843 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
844 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
848 jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
849 diff_src_w + ij * diff_src_h_stride,
850 diff_dst_w + oj * diff_dst_h_stride,
851 wht_w + k_lo * wht_h_stride,
852 0, ocb, k_len, d_len, 0);
854 diff_dst_w += diff_dst_c_stride;
855 wht_w += wht_oc_stride;
858 if (jcp.loop_order == loop_cgn)
859 nd_iterator_jump(start, end,
860 icc, ic_chunks, g, jcp.ngroups, n, MB, id_s, jcp.id,
862 else if (jcp.loop_order == loop_gnc)
863 nd_iterator_jump(start, end,
864 g, jcp.ngroups, n, MB, icc, ic_chunks, id_s, jcp.id,
867 assert(!"unsupported loop order");
871 jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
872 diff_src, diff_dst, weights, 0, 0, 1, 1, 0);
876 template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
877 template struct jit_avx512_common_convolution_bwd_data_t<data_type::s16,
878 data_type::s16, data_type::s32>;
880 template <data_type_t src_type, data_type_t diff_dst_type,
881 data_type_t diff_weights_type>
882 jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
884 jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd,
885 const input_vector &inputs, const output_vector &outputs)
886 : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr)
887 , trans_kernel_(nullptr), trans_dst_kernel_(nullptr), acc_ker_(nullptr)
888 , reducer_bias_(nullptr)
890 const auto &j = pd()->jcp_;
893 nthr_mb_ = j.nthr_mb;
895 nthr_oc_b_ = j.nthr_oc_b;
896 nthr_ic_b_ = j.nthr_ic_b;
898 kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
900 if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
901 trans_kernel_ = create_trans_src(&j);
902 if (utils::one_of(j.ver, ver_4vnni, ver_vnni))
903 trans_dst_kernel_ = create_trans_dst(&j);
907 acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
910 new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
913 template <data_type_t src_type, data_type_t diff_dst_type,
914 data_type_t diff_weights_type>
915 struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
916 diff_weights_type>::thread_info_t {
917 const src_data_t *src;
918 const diff_dst_data_t *diff_dst;
919 const diff_weights_data_t *diff_weights;
920 diff_weights_data_t *diff_bias;
922 const memory_tracking::grantor_t scratchpad;
925 simple_barrier::ctx_t *tr_src_bctx;
927 diff_dst_data_t *tr_diff_dst;
928 simple_barrier::ctx_t *tr_diff_dst_bctx;
930 diff_weights_data_t *wei_bia_reduction;
931 simple_barrier::ctx_t *wei_bia_reduction_bctx;
934 int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
938 int img_start = 0, img_end = 0, img_work;
939 int g_start = 0, g_end = 0, g_work;
940 int oc_b_start = 0, oc_b_end = 0, oc_b_work;
941 int ic_b_start = 0, ic_b_end = 0, ic_b_work;
943 thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
944 int ithr): scratchpad(self->scratchpad()), ithr(ithr) {
945 src = reinterpret_cast<const src_data_t *>(self->input_memory(0));
946 diff_dst = reinterpret_cast<const diff_dst_data_t *>(
947 self->input_memory(1));
948 diff_weights = reinterpret_cast<diff_weights_data_t *>(self->memory(0));
949 diff_bias = self->pd()->wants_padded_bias()
950 ? scratchpad.template get<diff_weights_data_t>(
951 key_conv_padded_bias)
952 : reinterpret_cast<diff_weights_data_t *>(self->memory(1));
954 tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
955 tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
956 key_conv_tr_src_bctx);
958 tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
959 key_conv_tr_diff_dst);
960 tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
961 key_conv_tr_diff_dst_bctx);
963 wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
964 key_conv_wei_bia_reduction);
965 wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
966 key_conv_wei_bia_reduction_bctx);
968 ithr_ic_b = ithr % self->nthr_ic_b_;
969 ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
970 ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
971 ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
973 ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
976 ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
979 const auto &jcp = self->kernel_->jcp;
981 /* reduction dimension */
982 balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end);
983 img_work = img_end - img_start;
985 /* independent dimensions */
986 balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
987 g_work = g_end - g_start;
989 balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start,
991 oc_b_work = oc_b_end - oc_b_start;
993 balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start,
995 ic_b_work = ic_b_end - ic_b_start;
999 template <data_type_t src_type, data_type_t diff_dst_type,
1000 data_type_t diff_weights_type>
1001 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1002 diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
1003 const memory_desc_wrapper src_d(pd()->src_pd(0));
1004 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
1005 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1007 const auto &jcp = kernel_->jcp;
1008 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
1010 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1011 ? (diff_weights_data_t*)ti->diff_weights
1012 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1013 diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1014 ? (diff_weights_data_t*)ti->diff_bias
1015 : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1016 + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
1018 // TODO: use memory descriptor with the same fmt as src (or use a macro :))
1019 auto tr_src_off = [&](int ithr_mb, int ic, int ij) {
1020 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1021 const size_t tr_chn_size = tr_row_size * jcp.ih;
1022 const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups;
1024 return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size;
1027 auto uker_trans = [&](int img) {
1028 const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih;
1030 int start{0}, end{0};
1031 balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
1032 const int my_work = end - start;
1034 int g{0}, ic_b{0}, j{0};
1035 nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih);
1037 ic_b += ti->ic_b_start;
1039 const int _ic = g * jcp.nb_ic + ic_b;
1040 src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
1041 src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
1043 assert(jcp.ic_block == 16);
1044 const int src_stride = jcp.iw * jcp.ic_block;
1045 const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
1047 const int pf_depth = 2;
1048 struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth];
1050 for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
1051 pf_circ_buf[iwork % pf_depth] = {src1, tr_src1};
1053 if (iwork >= pf_depth - 1) {
1054 int old_idx = (iwork - pf_depth + 1) % pf_depth;
1055 auto ctx = jit_trans_src_t::ctx_t();
1056 ctx.src = pf_circ_buf[old_idx].src;
1057 ctx.tr_src = pf_circ_buf[old_idx].tr_src;
1059 ctx.tr_src_prf = tr_src1;
1060 (*trans_kernel_)(&ctx);
1063 tr_src1 += tr_src_stride;
1066 // reference transposition
1067 const int l_pad = jcp.l_pad;
1068 const int iwlp = l_pad + jcp.iw;
1069 const int tr_iw = jcp.tr_iw;
1071 for (size_t iwork = start; iwork < end; iwork++) {
1074 for (int i = 0; i < l_pad; i++)
1075 for (int j = 0; j < jcp.ic_block; j++)
1076 tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
1080 for (int i = l_pad; i < iwlp; i++)
1081 for (int j = 0; j < jcp.ic_block; j++)
1082 tr_src1[j * jcp.tr_iw + i]
1083 = (src_data_t)src1[(i - l_pad) * 16 + j];
1087 for (int i = iwlp; i < tr_iw; i++)
1088 for (int j = 0; j < jcp.ic_block; j++)
1089 tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
1092 tr_src1 += tr_src_stride;
1097 auto tr_diff_dst_off = [&](int ithr_mb, int oc, int oj) {
1098 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1099 const size_t tr_chn_size = tr_row_size * jcp.oh;
1100 const size_t tr_img_size = tr_chn_size * jcp.nb_oc * jcp.ngroups;
1101 return ti->ithr_mb * tr_img_size + oc * tr_chn_size + oj * tr_row_size;
1104 auto diff_dst_trans = [&](int img) {
1105 const size_t work_amount = ti->g_work * ti->oc_b_work * jcp.oh;
1107 size_t start{0}, end{0};
1108 balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, start, end);
1109 const int my_work = end - start;
1111 int g{0}, oc_b{0}, j{0};
1112 nd_iterator_init(start, g, ti->g_work, oc_b, ti->oc_b_work, j, jcp.oh);
1114 oc_b += ti->oc_b_start;
1115 const int oc = g * jcp.nb_oc + oc_b;
1116 const diff_dst_data_t *diff_dst1
1117 = &ti->diff_dst[diff_dst_d.blk_off(img, oc, j)];
1118 diff_dst_data_t *tr_diff_dst1
1119 = &ti->tr_diff_dst[tr_diff_dst_off(img, oc, j)];
1122 assert(jcp.ic_block == 16);
1123 const int diff_dst_stride = jcp.ow * jcp.oc_block;
1124 const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block;
1126 const int pf_depth = 2;
1127 struct { diff_dst_data_t *diff_dst, *tr_diff_dst; }
1128 pf_circ_buf[pf_depth];
1130 for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
1131 pf_circ_buf[iwork % pf_depth]
1132 = {(diff_dst_data_t*)diff_dst1, tr_diff_dst1};
1134 if (iwork >= pf_depth - 1) {
1135 int old_idx = (iwork - pf_depth + 1) % pf_depth;
1136 auto ctx = jit_trans_dst_t::ctx_t();
1137 ctx.src = pf_circ_buf[old_idx].diff_dst;
1138 ctx.tr_src = pf_circ_buf[old_idx].tr_diff_dst;
1139 ctx.src_prf = diff_dst1;
1140 ctx.tr_src_prf = tr_diff_dst1;
1141 (*trans_dst_kernel_)(&ctx);
1143 diff_dst1 += diff_dst_stride;
1144 tr_diff_dst1 += tr_diff_dst_stride;
1147 // reference transposition
1148 int r_pad = jcp.ow % 2;
1149 for(size_t work = start; work < end; ++work) {
1151 for (int j = 0; j < jcp.oc_block; ++j) {
1153 for (int i = 0; i < jcp.ow / 2; i++) {
1154 tr_diff_dst1[i*jcp.oc_block*2 + j*2] =
1155 diff_dst1[2*i*jcp.oc_block + j];
1156 tr_diff_dst1[i*jcp.oc_block*2 + j*2 + 1] =
1157 diff_dst1[(2*i+1)*jcp.oc_block + j];
1160 const int last_w = jcp.ow / 2;
1161 tr_diff_dst1[last_w * jcp.oc_block * 2 + j * 2] =
1162 diff_dst1[last_w * jcp.oc_block * 2 + j];
1163 tr_diff_dst1[last_w * jcp.oc_block * 2 + j * 2 + 1] =
1169 diff_dst1 += diff_dst_stride;
1170 tr_diff_dst1 += tr_diff_dst_stride;
1175 if (jcp.is_1stconv && jcp.ver == ver_4fma) {
1176 /* prepare contexts */
1177 auto tr_ctx = jit_trans_src_t::ctx_t();
1178 tr_ctx.tr_src = ti->tr_src
1179 + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
1181 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
1182 tr_ctx.nthr_oc_b = nthr_oc_b_;
1183 int ih_start{0}, ih_end{0};
1184 balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
1185 tr_ctx.tr_src_ih_start = ih_start;
1186 tr_ctx.tr_src_ih_end = ih_end;
1187 tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
1189 auto p = jit_conv_call_s();
1190 p.src = tr_ctx.tr_src;
1192 /* zero diff_bias if applicable */
1193 if (jcp.with_bias && ti->ithr_ic_b == 0) {
1194 assert(jcp.oc_block == 16);
1195 for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1196 diff_weights_data_t *db = &diff_bia[oc_b * 16];
1197 for (int o = 0; o < 16; ++o)
1202 for (int img = ti->img_start; img < ti->img_end; ++img) {
1203 p.flags = (img == ti->img_start) * FLAG_MB_FIRST;
1205 for (int g = ti->g_start; g < ti->g_end; ++g) {
1206 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1207 const int _ic = g * jcp.nb_ic + ic_b;
1208 tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)];
1210 (*trans_kernel_)(&tr_ctx);
1213 p.flags |= FLAG_IC_FIRST;
1215 p.flags &= ~FLAG_IC_FIRST;
1217 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1218 const int _oc = g * jcp.nb_oc + oc_b;
1219 p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
1222 wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1223 p.filt = diff_wei + off;
1224 p.bias = diff_bia + _oc * jcp.oc_block;
1226 kernel_->jit_ker(&p);
1232 for (int img = ti->img_start; img < ti->img_end; ++img) {
1233 auto p = jit_conv_call_s();
1235 if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
1236 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1237 using simple_barrier::barrier;
1239 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1242 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1245 if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
1246 /* tr_diff_dst[nb_oc][OW][oh][16c][2ow]
1247 * <- diff_dst[nb_oc][oh][ow][16c] */
1249 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1250 diff_dst_trans(img);
1252 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1255 for (int g = ti->g_start; g < ti->g_end; ++g) {
1256 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1257 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1258 const int _oc = g * jcp.nb_oc + oc_b;
1259 const int _ic = g * jcp.nb_ic + ic_b;
1261 jit_conv_ker_pipeline(kernel_->jit_ker, p,
1262 (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
1263 ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
1264 : &ti->src[src_d.blk_off(img, _ic)]),
1265 utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
1266 ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
1267 : &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
1268 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1269 0, (img == ti->img_start), 0, 0);
1275 const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
1276 const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
1277 jit_conv_ker_pipeline(kernel_->jit_ker, p,
1278 (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
1279 ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
1280 : &ti->src[src_d.blk_off(img + 1, _ic)]),
1281 utils::one_of(jcp.ver, ver_4vnni, ver_vnni)
1282 ? &ti->tr_diff_dst[tr_diff_dst_off(ti->ithr_mb, _oc, 0)]
1283 : &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
1284 diff_wei + wht_blk_off(
1285 diff_weights_d, ti->g_start,
1286 ti->oc_b_start, ti->ic_b_start),
1292 template <data_type_t src_type, data_type_t diff_dst_type,
1293 data_type_t diff_weights_type>
1294 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1295 diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
1297 const memory_desc_wrapper src_d(pd()->src_pd(0));
1298 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
1299 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1301 const auto &jcp = kernel_->jcp;
1303 = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd;
1305 diff_weights_data_t *diff_wei = ti->ithr_mb == 0
1306 ? (diff_weights_data_t*)ti->diff_weights
1307 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1308 diff_weights_data_t *diff_bia = ti->ithr_mb == 0
1309 ? (diff_weights_data_t*)ti->diff_bias
1310 : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
1311 + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
1313 const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1314 const int input_step = jcp.stride_d * jcp.ih * jcp.iw * inp_mult;
1315 const int output_step = jcp.ow * jcp.oh * jcp.oc_block;
1316 int img{0}, od_s{0};
1317 int img_start = ti->img_start, img_end = ti->img_end;
1318 nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
1319 const int img_first = img;
1321 while (img_start < img_end) {
1322 auto p = jit_conv_call_s();
1324 int work_rem = img_end - img_start;
1325 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
1326 const int id_s = od_s * jcp.stride_d;
1327 const int ik_overlap = nstl::max(0, id_s - jcp.f_pad);
1328 const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s);
1329 const int kd_back_pad = nstl::max(0, id_s + 1 + jcp.back_pad - jcp.od);
1330 int kd_pad_off = kd_front_pad * jcp.kh * jcp.kw * jcp.ic_block
1331 * jcp.oc_block * jcp.typesize_out;
1333 for (int g = ti->g_start; g < ti->g_end; ++g) {
1334 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
1335 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1336 const int _oc = g * jcp.nb_oc + oc_b;
1337 const int _ic = g * jcp.nb_ic + ic_b;
1339 auto src = &ti->src[src_d.blk_off(img, _ic)
1340 + ik_overlap * input_step];
1341 auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)
1342 + od_s * output_step];
1344 jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst,
1345 diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
1346 diff_bia + _oc * 16, (img == img_first), od_s, od_e,
1347 jcp.kd - nstl::max(kd_front_pad, kd_back_pad), kd_pad_off);
1349 if (ic_b == 0) p.flags = 0;
1355 const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
1356 const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
1357 jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p,
1358 &ti->src[src_d.blk_off(img + 1, _ic)],
1359 &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
1360 diff_wei + wht_blk_off(diff_weights_d, ti->g_start,
1361 ti->oc_b_start, ti->ic_b_start),
1362 diff_bia, 0, 0, 0, 0, 0);
1363 nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
1367 template <data_type_t src_type, data_type_t diff_dst_type,
1368 data_type_t diff_weights_type>
1369 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1370 diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
1371 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1373 const auto &jcp = kernel_->jcp;
1374 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
1375 const int bia_size = jcp.ngroups * jcp.oc;
1376 const diff_weights_data_t *diff_bias_ws
1377 = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
1379 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1380 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1382 const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
1383 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1385 int start{0}, end{0};
1386 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1387 if (start == end) return;
1389 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1391 int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
1392 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1393 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1395 const int g = ti->g_start + sub_g_start;
1396 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1397 const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh;
1398 const int kh = sub_ic_b_kh_start % jcp.kh;
1401 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1402 * jcp.kw * jcp.ic_block * jcp.oc_block;
1405 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh);
1407 diff_weights_data_t *d
1408 = (diff_weights_data_t *)ti->diff_weights + off;
1409 diff_weights_data_t *s
1410 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1412 acc_ker_->accumulate(d, s, acc_size);
1414 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1415 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1418 if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) {
1420 acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias,
1421 diff_bias_ws, bia_size);
1422 diff_bias_ws += bia_size;
1427 template <data_type_t src_type, data_type_t diff_dst_type,
1428 data_type_t diff_weights_type>
1429 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1430 diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
1431 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
1433 const auto &jcp = kernel_->jcp;
1434 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
1437 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1438 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1440 const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
1441 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1443 int start{0}, end{0};
1444 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1445 if (start == end) return;
1447 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1449 int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
1450 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1451 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1453 const int g = ti->g_start + sub_g_start;
1454 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1455 const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd;
1456 const int kd = sub_ic_b_kh_start % jcp.kd;
1459 = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
1460 * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh;
1463 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd);
1464 diff_weights_data_t *d
1465 = (diff_weights_data_t *)ti->diff_weights + off;
1466 diff_weights_data_t *s
1467 = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
1468 acc_ker_->accumulate(d, s, acc_size);
1470 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1471 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1476 template <data_type_t src_type, data_type_t diff_dst_type,
1477 data_type_t diff_weights_type>
1478 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1479 diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
1480 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
1482 auto rb = this->reducer_bias_;
1483 assert(nthr_ == rb->balancer().nthr_);
1485 const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
1486 ti->scratchpad, prefix_reducer_bia);
1488 const auto &jcp = kernel_->jcp;
1490 if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
1492 const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
1493 const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
1495 if (b_njobs == 0) return;
1497 /* reduction dimension */
1498 int img_start{0}, img_end{0};
1499 balance211(jcp.mb, rb->balancer().nthr_per_group_,
1500 rb->balancer().id_in_group(ti->ithr), img_start, img_end);
1503 int g_start{0}, ocb_start{0};
1504 nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
1505 for (int img = img_start; img < img_end; ++img) {
1506 int g = g_start, ocb = ocb_start;
1507 for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
1508 const size_t _oc = g * jcp.nb_oc + ocb;
1510 const diff_dst_data_t *d_dst
1511 = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
1512 diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
1513 ti->diff_bias, reducer_bia_scratchpad)
1514 + b_job_loc * rb->balancer().job_size_;
1516 if (img == img_start)
1517 for (int o = 0; o < 16; ++o)
1519 for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) {
1521 for (int o = 0; o < 16; ++o)
1522 d_bias[o] += d_dst[o];
1526 nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
1530 rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
1533 template <data_type_t src_type, data_type_t diff_dst_type,
1534 data_type_t diff_weights_type>
1535 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1536 diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
1538 const auto &jcp = kernel_->jcp;
1540 const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh
1542 const int bia_size = jcp.ngroups * jcp.oc;
1543 const diff_weights_data_t *diff_bias_ws
1544 = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
1546 if (nthr_mb_ > 1) mkldnn_thr_barrier();
1550 for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
1551 acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size);
1552 diff_bias_ws += bia_size;
1557 template <data_type_t src_type, data_type_t diff_dst_type,
1558 data_type_t diff_weights_type>
1559 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1560 diff_weights_type>::prepare_scratchpad_data() const
1562 const auto &j = pd()->jcp_;
1563 auto scratchpad = this->scratchpad();
1565 if (utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
1566 if (!j.is_1stconv) {
1567 // XXX: See the comment about tr_iw and guarding elements in
1568 // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
1569 const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
1570 const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
1572 auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1573 /* to avoid NaNs in computations we zero tail num_guard_elems for
1574 * each possible thread group */
1576 for (int ithr = 1; ithr <= max_nthr; ++ithr) {
1577 src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
1578 for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
1583 if (j.nthr_oc_b > 1) {
1584 const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
1585 auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1586 key_conv_tr_src_bctx);
1587 for (int i = 0; i < tr_src_bctx_size; ++i)
1588 simple_barrier::ctx_init(&tr_src_bctx[i]);
1591 if (utils::one_of(j.ver, ver_4vnni, ver_vnni) && j.nthr_ic_b > 1) {
1592 const int tr_diff_dst_bctx_size = j.nthr / j.nthr_ic_b;
1593 auto tr_diff_dst_bctx =
1594 scratchpad.template get<simple_barrier::ctx_t>(
1595 key_conv_tr_diff_dst_bctx);
1596 for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
1597 simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
1602 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1603 key_conv_wei_bia_reduction_bctx));
1606 const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
1607 prefix_reducer_bia);
1608 auto rb = this->reducer_bias_;
1609 rb->init(reducer_bia_scratchpad);
1612 template <data_type_t src_type, data_type_t diff_dst_type,
1613 data_type_t diff_weights_type>
1614 void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
1615 diff_weights_type>::execute_backward_weights() const {
1616 prepare_scratchpad_data();
1618 parallel(nthr_, [&](const int ithr, const int nthr) {
1619 assert(nthr_ == nthr);
1621 thread_info_t thread_info(this, ithr);
1623 if (utils::one_of(pd()->ndims(), 3, 4)) {
1624 compute_diff_weights(&thread_info);
1625 if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
1626 if (pd()->with_bias()) compute_diff_bias(&thread_info);
1627 } else if (pd()->ndims() == 5) {
1628 compute_diff_weights_3d(&thread_info);
1629 if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
1630 if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
1636 /* TODO: put that into compute_diff_bias() */
1637 if (pd()->wants_padded_bias()) {
1638 auto diff_bias = scratchpad().template get<const diff_weights_data_t>(
1639 key_conv_padded_bias);
1641 = reinterpret_cast<diff_weights_data_t *>(this->memory(1));
1642 for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
1643 diff_bias_in[oc] = diff_bias[oc];
1647 template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;
1648 template struct jit_avx512_common_convolution_bwd_weights_t<data_type::s16,
1649 data_type::s16, data_type::s32>;
1655 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s