2 /*******************************************************************************
3 * Copyright 2019 Intel Corporation
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
18 #ifndef JIT_UNI_DW_CONVOLUTION_UTILS_HPP
19 #define JIT_UNI_DW_CONVOLUTION_UTILS_HPP
22 #include "type_helpers.hpp"
25 #include "c_types_map.hpp"
26 #include "memory_tracking.hpp"
28 #include "jit_generator.hpp"
29 #include "jit_primitive_conf.hpp"
30 #include "jit_uni_eltwise.hpp"
32 #include "jit_avx512_core_bf16_dw_conv_kernel.hpp"
33 #include "jit_uni_dw_conv_kernel_f32.hpp"
39 using namespace mkldnn::impl::memory_format;
40 using namespace mkldnn::impl::memory_tracking::names;
41 using namespace mkldnn::impl::utils;
43 template <cpu_isa_t isa, data_type_t kernel_dt>
44 struct jit_uni_dw_conv_fwd_kernel {
46 jit_uni_dw_conv_fwd_kernel(jit_conv_conf_t ajcp, const primitive_attr_t &attr)
47 : jit_ker(nullptr), ker_(nullptr) {
48 ker_ = new jit_kernel_t(ajcp, attr);
49 jit_ker = ker_->jit_ker;
51 ~jit_uni_dw_conv_fwd_kernel() { delete ker_; }
53 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr, bool is_bf16);
55 static status_t init_conf(jit_conv_conf_t &jcp,
56 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
57 const memory_desc_wrapper &weights_d,
58 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
60 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
61 const jit_conv_conf_t &jcp);
63 void (*jit_ker)(jit_conv_call_s *);
66 using jit_kernel_t = typename utils::conditional<isa == avx512_core
67 && kernel_dt == data_type::bf16,
68 jit_avx512_dw_conv_fwd_kernel_bf16,
69 jit_uni_dw_conv_fwd_kernel_f32<isa>>::type;
73 template <cpu_isa_t isa, data_type_t kernel_dt>
74 bool jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::post_ops_ok(
75 jit_conv_conf_t &jcp, const primitive_attr_t &attr, bool is_bf16) {
76 const auto &p = attr.post_ops_;
79 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
80 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
83 case 0: return true; // no post_ops
84 case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
85 case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
86 default: return false;
89 auto all_post_ops_supported = [&]() {
92 for (int i = 0; i < p.len_; i++) {
93 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
97 auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind) != -1; };
98 auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind); };
99 auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
101 return all_post_ops_supported() &&
102 count(primitive_kind::sum) <= 1 &&
103 IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0);
109 template <cpu_isa_t isa, data_type_t kernel_dt>
110 status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
111 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
112 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
113 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
115 jcp.dst_dt = cd.dst_desc.data_type;
116 const bool is_bf16 = src_d.data_type() == data_type::bf16;
118 if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
119 return status::unimplemented;
121 const int simd_w = one_of(isa, avx512_common, avx512_core) ? 16 : 8;
123 jcp.prop_kind = cd.prop_kind;
125 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
126 if (!with_groups) return status::unimplemented;
128 jcp.ngroups = weights_d.dims()[0];
129 jcp.mb = src_d.dims()[0];
131 jcp.oc = dst_d.dims()[1];
132 jcp.oc_without_padding = jcp.oc;
133 jcp.ic = src_d.dims()[1];
135 jcp.ih = src_d.dims()[2];
136 jcp.iw = src_d.dims()[3];
137 jcp.oh = dst_d.dims()[2];
138 jcp.ow = dst_d.dims()[3];
140 jcp.kh = weights_d.dims()[3];
141 jcp.kw = weights_d.dims()[4];
143 jcp.t_pad = cd.padding[0][0];
144 jcp.l_pad = cd.padding[0][1];
145 jcp.b_pad = cd.padding[1][0];
146 jcp.r_pad = cd.padding[1][1];
148 jcp.stride_h = cd.strides[0];
149 jcp.stride_w = cd.strides[1];
151 jcp.dilate_h = cd.dilates[0];
152 jcp.dilate_w = cd.dilates[1];
154 jcp.src_fmt = src_d.format();
155 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
157 if (!post_ops_ok(jcp, attr, is_bf16))
158 return status::unimplemented;
160 const auto &p = attr.post_ops_;
161 jcp.with_sum = p.find(primitive_kind::sum) != -1;
162 const int eltwise_ind = p.find(primitive_kind::eltwise);
163 jcp.with_eltwise = eltwise_ind != -1;
164 if (jcp.with_eltwise)
165 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
167 bool ok_to_pad_channels = true
168 && jcp.oc == jcp.ngroups
169 && jcp.ic == jcp.ngroups
170 && one_of(isa, avx512_common, avx512_core, avx2);
171 if (ok_to_pad_channels) {
172 jcp.oc = rnd_up(jcp.oc, simd_w);
173 jcp.ic = rnd_up(jcp.oc, simd_w);
174 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
178 = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c;
180 = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g;
183 && jcp.oc == jcp.ngroups
184 && jcp.ic == jcp.ngroups
185 && jcp.ngroups % simd_w == 0
186 && src_d.format() == desired_act_fmt
187 && weights_d.format() == desired_wei_fmt
188 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
189 && dst_d.format() == desired_act_fmt
190 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
191 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
192 && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
193 if (!args_ok) return status::unimplemented;
195 jcp.is_cpx = (mayiuse(avx512_core_bf16)) ? true : false;
197 jcp.typesize_out = jcp.dst_dt == data_type::bf16 ? sizeof(mkldnn_bfloat16_t)
199 jcp.typesize_in = src_d.data_type() == data_type::bf16
200 ? sizeof(mkldnn_bfloat16_t)
203 jcp.ur_w = is_bf16 ? (jcp.is_cpx ? 6 : 4)
204 : isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
206 jcp.ch_block = simd_w;
207 jcp.nb_ch = jcp.oc / jcp.ch_block;
209 = one_of(isa, avx512_common, avx512_core) ? 4 : isa == avx2 ? 3 : 2;
210 if (jcp.nb_ch < jcp.nb_ch_blocking)
211 jcp.nb_ch_blocking = jcp.nb_ch;
213 return status::success;
216 template <cpu_isa_t isa, data_type_t kernel_dt>
217 void jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_scratchpad(
218 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
219 if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
220 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
223 template struct jit_uni_dw_conv_fwd_kernel<avx512_core, data_type::bf16>;
224 template struct jit_uni_dw_conv_fwd_kernel<avx512_common, data_type::f32>;
225 template struct jit_uni_dw_conv_fwd_kernel<avx2, data_type::f32>;
226 template struct jit_uni_dw_conv_fwd_kernel<sse42, data_type::f32>;
228 template <cpu_isa_t isa, data_type_t kernel_dt>
229 struct jit_uni_dw_conv_bwd_data_kernel {
231 jit_uni_dw_conv_bwd_data_kernel(jit_conv_conf_t ajcp)
232 : jit_ker(nullptr), ker_(nullptr) {
233 ker_ = new jit_kernel_t(ajcp);
234 jit_ker = ker_->jit_ker;
236 ~jit_uni_dw_conv_bwd_data_kernel(){
240 static status_t init_conf(jit_conv_conf_t &jcp,
241 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
242 const memory_desc_wrapper &weights_d,
243 const memory_desc_wrapper &diff_dst_d);
245 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
246 const jit_conv_conf_t &jcp);
248 void (*jit_ker)(jit_conv_call_s *);
251 using jit_kernel_t = typename utils::conditional<isa == avx512_core
252 && kernel_dt == data_type::bf16,
253 jit_avx512_dw_conv_bwd_data_kernel_bf16,
254 jit_uni_dw_conv_bwd_data_kernel_f32<isa>>::type;
258 template <cpu_isa_t isa, data_type_t kernel_dt>
259 status_t jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_conf(
260 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
261 const memory_desc_wrapper &diff_src_d,
262 const memory_desc_wrapper &weights_d,
263 const memory_desc_wrapper &diff_dst_d) {
265 jcp.dsrc_dt = cd.diff_src_desc.data_type;
266 const bool is_bf16 = diff_dst_d.data_type() == data_type::bf16;
268 if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
269 return status::unimplemented;
271 const int simd_w = one_of(isa, avx512_common, avx512_core) ? 16 : 8;
273 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
274 if (!with_groups) return status::unimplemented;
276 jcp.ngroups = weights_d.dims()[0];
277 jcp.mb = diff_src_d.dims()[0];
279 jcp.oc = diff_dst_d.dims()[1];
280 jcp.oc_without_padding = jcp.oc;
281 jcp.ic = diff_src_d.dims()[1];
283 jcp.ih = diff_src_d.dims()[2];
284 jcp.iw = diff_src_d.dims()[3];
285 jcp.oh = diff_dst_d.dims()[2];
286 jcp.ow = diff_dst_d.dims()[3];
288 jcp.kh = weights_d.dims()[3];
289 jcp.kw = weights_d.dims()[4];
291 jcp.t_pad = cd.padding[0][0];
292 jcp.l_pad = cd.padding[0][1];
293 jcp.b_pad = cd.padding[1][0];
294 jcp.r_pad = cd.padding[1][1];
296 jcp.stride_h = cd.strides[0];
297 jcp.stride_w = cd.strides[1];
299 jcp.dilate_h = cd.dilates[0];
300 jcp.dilate_w = cd.dilates[1];
302 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
303 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
305 jcp.src_fmt = diff_src_d.format();
307 bool ok_to_pad_channels = true
308 && jcp.oc == jcp.ngroups
309 && jcp.ic == jcp.ngroups
310 && one_of(isa, avx512_common, avx512_core, avx2, sse42);
311 if (ok_to_pad_channels) {
312 jcp.oc = rnd_up(jcp.oc, simd_w);
313 jcp.ic = rnd_up(jcp.oc, simd_w);
314 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
318 = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c;
320 = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g;
323 && jcp.oc == jcp.ngroups
324 && jcp.ic == jcp.ngroups
325 && jcp.ngroups % simd_w == 0
328 && diff_src_d.format() == desired_act_fmt
329 && weights_d.format() == desired_wei_fmt
330 && diff_dst_d.format() == desired_act_fmt
331 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
332 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
333 && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
334 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
335 && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
336 if (!args_ok) return status::unimplemented;
338 jcp.is_cpx = (mayiuse(avx512_core_bf16)) ? true : false;
340 jcp.typesize_out = diff_src_d.data_type() == data_type::bf16
341 ? sizeof(mkldnn_bfloat16_t)
343 jcp.typesize_in = diff_dst_d.data_type() == data_type::bf16
344 ? sizeof(mkldnn_bfloat16_t)
347 jcp.ur_w = is_bf16 ? (jcp.is_cpx ? 6 : 4)
348 : isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
350 jcp.ch_block = simd_w;
351 jcp.nb_ch = jcp.ic / jcp.ch_block;
353 = one_of(isa, avx512_common, avx512_core) ? 4 : isa == avx2 ? 3 : 2;
354 if (jcp.nb_ch < jcp.nb_ch_blocking)
355 jcp.nb_ch_blocking = jcp.nb_ch;
357 return status::success;
360 template <cpu_isa_t isa, data_type_t kernel_dt>
361 void jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_scratchpad(
362 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
367 template struct jit_uni_dw_conv_bwd_data_kernel<avx512_core, data_type::bf16>;
368 template struct jit_uni_dw_conv_bwd_data_kernel<avx512_common, data_type::f32>;
369 template struct jit_uni_dw_conv_bwd_data_kernel<avx2, data_type::f32>;
370 template struct jit_uni_dw_conv_bwd_data_kernel<sse42, data_type::f32>;
372 template <cpu_isa_t isa, data_type_t kernel_dt>
373 struct jit_uni_dw_conv_bwd_weights_kernel {
375 jit_uni_dw_conv_bwd_weights_kernel(jit_conv_conf_t ajcp)
376 : jit_ker(nullptr), ker_(nullptr) {
377 ker_ = new jit_kernel_t(ajcp);
378 jit_ker = ker_->jit_ker;
381 ~jit_uni_dw_conv_bwd_weights_kernel() { delete ker_; }
383 static status_t init_conf(jit_conv_conf_t &jcp,
384 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
385 const memory_desc_wrapper &diff_weights_d,
386 const memory_desc_wrapper &diff_dst_d, int nthreads);
388 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
389 const jit_conv_conf_t &jcp);
391 static void balance(jit_conv_conf_t &jcp, int nthreads);
393 void (*jit_ker)(jit_dw_conv_call_s *);
396 using jit_kernel_t = typename utils::conditional<isa == avx512_core
397 && kernel_dt == data_type::bf16,
398 jit_avx512_dw_conv_bwd_weights_kernel_bf16,
399 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>>::type;
403 template <cpu_isa_t isa, data_type_t kernel_dt>
404 status_t jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_conf(
405 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
406 const memory_desc_wrapper &src_d,
407 const memory_desc_wrapper &diff_weights_d,
408 const memory_desc_wrapper &diff_dst_d, int nthreads) {
410 jcp.dwei_dt = cd.diff_weights_desc.data_type;
411 const bool is_bf16 = src_d.data_type() == data_type::bf16;
413 if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
414 return status::unimplemented;
416 jcp.ngroups = diff_weights_d.dims()[0];
417 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
418 jcp.ic = src_d.dims()[1] / jcp.ngroups;
420 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
422 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
424 if (!jcp.is_depthwise)
425 return status::unimplemented;
427 jcp.ch_block = one_of(isa, avx512_common, avx512_core) ? 16 : 8;
429 jcp.mb = src_d.dims()[0];
431 jcp.ih = src_d.dims()[2];
432 jcp.iw = src_d.dims()[3];
433 jcp.oh = diff_dst_d.dims()[2];
434 jcp.ow = diff_dst_d.dims()[3];
436 jcp.kh = diff_weights_d.dims()[3];
437 jcp.kw = diff_weights_d.dims()[4];
439 jcp.stride_h = cd.strides[0];
440 jcp.stride_w = cd.strides[1];
442 jcp.t_pad = cd.padding[0][0];
443 jcp.b_pad = cd.padding[1][0];
445 jcp.l_pad = cd.padding[0][1];
446 jcp.r_pad = cd.padding[1][1];
448 jcp.dilate_h = cd.dilates[0];
449 jcp.dilate_w = cd.dilates[1];
451 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
452 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
454 jcp.src_fmt = src_d.format();
456 jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
459 = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c;
461 = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g;
463 bool args_ok = true && src_d.format() == desired_act_fmt
464 && diff_weights_d.format() == desired_wei_fmt
465 && diff_dst_d.format() == desired_act_fmt
466 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
467 && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
468 && jcp.dilate_w == 0 && jcp.kw <= 3
469 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
470 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
472 return status::unimplemented;
474 if (!IMPLICATION(is_bf16, desired_act_fmt == mkldnn_nChw16c
475 && desired_wei_fmt == mkldnn_Goihw16g))
476 return status::unimplemented;
478 jcp.nb_ch = jcp.ngroups / jcp.ch_block;
480 /* kernel applicability check wrt boundaries
481 * the conditions are quite general across the kernels we have,
482 * but ideally the check should belong to a specific kernel... */
483 const int max_hpad = (jcp.kh - 1 + 1) / 2;
484 const int max_wpad = (jcp.kw - 1 + 1) / 2;
485 const bool boundaries_ok = true && jcp.t_pad <= max_hpad
486 && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
487 && jcp.r_pad <= max_wpad;
489 return status::unimplemented;
491 jcp.is_cpx = (mayiuse(avx512_core_bf16)) ? true : false;
493 /* BF16: accumulation of output happens in f32, down-conversion to bf16
494 * happens during the reduction phase. */
495 jcp.typesize_out = sizeof(float);
496 jcp.typesize_in = src_d.data_type() == data_type::bf16
497 ? sizeof(mkldnn_bfloat16_t)
500 balance(jcp, nthreads);
502 return status::success;
505 template <cpu_isa_t isa, data_type_t kernel_dt>
506 void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_scratchpad(
507 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
508 /* Notes: if splitting thread work on 'mb', then a reduction has to take
509 * place. Hence, book a per-thread, local weights-buffer for the
511 if (jcp.nthr_mb > 1) {
512 const size_t mb = jcp.dwei_dt == data_type::bf16 ? jcp.nthr_mb
514 const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
515 scratchpad.book(key_conv_wei_reduction, sizeof(float) * wei_size * mb);
518 scratchpad.book(key_conv_bia_reduction,
519 sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
520 } else if (jcp.nthr_mb == 1 && jcp.dwei_dt == data_type::bf16) {
521 const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
522 scratchpad.book(key_conv_wei_reduction, sizeof(float) * wei_size);
526 template <cpu_isa_t isa, data_type_t kernel_dt>
527 void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::balance(
528 jit_conv_conf_t &jcp, int nthreads) {
530 jcp.nthr_g = jcp.nthr_mb = 1;
532 /* Basic-Heuristics for parallel strategy:
533 * 1) Tries to parallel on the number of Groups (g) where tasks are
534 * independent. Otherwise,
535 * 2) Tries to split the work across g and MiniBatch (mb).
536 * Parallelizing on mb requires computing a reduction for weights.
538 * NOTE: because of 'task partitioning' scheme, there will be unbalanced
539 * per-thread load when the number of threads is high (e.g. > 16).
541 jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
542 jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
544 jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
547 template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_core,
549 template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_common,
551 template struct jit_uni_dw_conv_bwd_weights_kernel<avx2, data_type::f32>;
552 template struct jit_uni_dw_conv_bwd_weights_kernel<sse42, data_type::f32>;
556 #endif /* JIT_UNI_DW_CONVOLUTION_UTILS_HPP */