updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_conv_kernel_utils.hpp
1
2 /*******************************************************************************
3 * Copyright 2019 Intel Corporation
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #ifndef JIT_UNI_DW_CONVOLUTION_UTILS_HPP
19 #define JIT_UNI_DW_CONVOLUTION_UTILS_HPP
20
21 #include "nstl.hpp"
22 #include "type_helpers.hpp"
23 #include "utils.hpp"
24
25 #include "c_types_map.hpp"
26 #include "memory_tracking.hpp"
27
28 #include "jit_generator.hpp"
29 #include "jit_primitive_conf.hpp"
30 #include "jit_uni_eltwise.hpp"
31
32 #include "jit_avx512_core_bf16_dw_conv_kernel.hpp"
33 #include "jit_uni_dw_conv_kernel_f32.hpp"
34
35 namespace mkldnn {
36 namespace impl {
37 namespace cpu {
38
39 using namespace mkldnn::impl::memory_format;
40 using namespace mkldnn::impl::memory_tracking::names;
41 using namespace mkldnn::impl::utils;
42
43 template <cpu_isa_t isa, data_type_t kernel_dt>
44 struct jit_uni_dw_conv_fwd_kernel {
45
46     jit_uni_dw_conv_fwd_kernel(jit_conv_conf_t ajcp, const primitive_attr_t &attr)
47         : jit_ker(nullptr), ker_(nullptr) {
48         ker_ = new jit_kernel_t(ajcp, attr);
49         jit_ker = ker_->jit_ker;
50     }
51     ~jit_uni_dw_conv_fwd_kernel() { delete ker_; }
52
53     static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr, bool is_bf16);
54
55     static status_t init_conf(jit_conv_conf_t &jcp,
56             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
57             const memory_desc_wrapper &weights_d,
58             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
59
60     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
61             const jit_conv_conf_t &jcp);
62
63     void (*jit_ker)(jit_conv_call_s *);
64
65 private:
66     using jit_kernel_t = typename utils::conditional<isa == avx512_core
67                     && kernel_dt == data_type::bf16,
68             jit_avx512_dw_conv_fwd_kernel_bf16,
69             jit_uni_dw_conv_fwd_kernel_f32<isa>>::type;
70     jit_kernel_t *ker_;
71 };
72
73 template <cpu_isa_t isa, data_type_t kernel_dt>
74 bool jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::post_ops_ok(
75         jit_conv_conf_t &jcp, const primitive_attr_t &attr, bool is_bf16) {
76     const auto &p = attr.post_ops_;
77
78     if (is_bf16) {
79         auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
80         auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
81
82         switch (p.len_) {
83             case 0: return true; // no post_ops
84             case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
85             case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
86             default: return false;
87         }
88     } else {
89         auto all_post_ops_supported = [&]() {
90             bool ok = true;
91
92             for (int i = 0; i < p.len_; i++) {
93                 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
94             }
95             return ok;
96         };
97         auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind) != -1; };
98         auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind); };
99         auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
100
101         return all_post_ops_supported() &&
102                count(primitive_kind::sum) <= 1 &&
103                IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0);
104     }
105
106     return false;
107 }
108
109 template <cpu_isa_t isa, data_type_t kernel_dt>
110 status_t jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_conf(
111         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
112         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
113         const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
114
115     jcp.dst_dt = cd.dst_desc.data_type;
116     const bool is_bf16 = src_d.data_type() == data_type::bf16;
117
118     if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
119         return status::unimplemented;
120
121     const int simd_w = one_of(isa, avx512_common, avx512_core) ? 16 : 8;
122
123     jcp.prop_kind = cd.prop_kind;
124
125     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
126     if (!with_groups) return status::unimplemented;
127
128     jcp.ngroups = weights_d.dims()[0];
129     jcp.mb = src_d.dims()[0];
130
131     jcp.oc = dst_d.dims()[1];
132     jcp.oc_without_padding = jcp.oc;
133     jcp.ic = src_d.dims()[1];
134
135     jcp.ih = src_d.dims()[2];
136     jcp.iw = src_d.dims()[3];
137     jcp.oh = dst_d.dims()[2];
138     jcp.ow = dst_d.dims()[3];
139
140     jcp.kh = weights_d.dims()[3];
141     jcp.kw = weights_d.dims()[4];
142
143     jcp.t_pad = cd.padding[0][0];
144     jcp.l_pad = cd.padding[0][1];
145     jcp.b_pad = cd.padding[1][0];
146     jcp.r_pad = cd.padding[1][1];
147
148     jcp.stride_h = cd.strides[0];
149     jcp.stride_w = cd.strides[1];
150
151     jcp.dilate_h = cd.dilates[0];
152     jcp.dilate_w = cd.dilates[1];
153
154     jcp.src_fmt = src_d.format();
155     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
156
157     if (!post_ops_ok(jcp, attr, is_bf16))
158         return status::unimplemented;
159
160     const auto &p = attr.post_ops_;
161     jcp.with_sum = p.find(primitive_kind::sum) != -1;
162     const int eltwise_ind = p.find(primitive_kind::eltwise);
163     jcp.with_eltwise = eltwise_ind != -1;
164     if (jcp.with_eltwise)
165         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
166
167     bool ok_to_pad_channels = true
168         && jcp.oc == jcp.ngroups
169         && jcp.ic == jcp.ngroups
170         && one_of(isa, avx512_common, avx512_core, avx2);
171     if (ok_to_pad_channels) {
172         jcp.oc = rnd_up(jcp.oc, simd_w);
173         jcp.ic = rnd_up(jcp.oc, simd_w);
174         jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
175     }
176
177     auto desired_act_fmt
178             = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c;
179     auto desired_wei_fmt
180             = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g;
181
182     bool args_ok = true
183         && jcp.oc == jcp.ngroups
184         && jcp.ic == jcp.ngroups
185         && jcp.ngroups % simd_w == 0
186         && src_d.format() == desired_act_fmt
187         && weights_d.format() == desired_wei_fmt
188         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
189         && dst_d.format() == desired_act_fmt
190         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
191         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
192         && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
193     if (!args_ok) return status::unimplemented;
194
195     jcp.is_cpx = (mayiuse(avx512_core_bf16)) ? true : false;
196
197     jcp.typesize_out = jcp.dst_dt == data_type::bf16 ? sizeof(mkldnn_bfloat16_t)
198                                                      : sizeof(float);
199     jcp.typesize_in = src_d.data_type() == data_type::bf16
200             ? sizeof(mkldnn_bfloat16_t)
201             : sizeof(float);
202
203     jcp.ur_w = is_bf16 ? (jcp.is_cpx ? 6 : 4)
204                        : isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
205
206     jcp.ch_block = simd_w;
207     jcp.nb_ch = jcp.oc / jcp.ch_block;
208     jcp.nb_ch_blocking
209             = one_of(isa, avx512_common, avx512_core) ? 4 : isa == avx2 ? 3 : 2;
210     if (jcp.nb_ch < jcp.nb_ch_blocking)
211         jcp.nb_ch_blocking = jcp.nb_ch;
212
213     return status::success;
214 }
215
216 template <cpu_isa_t isa, data_type_t kernel_dt>
217 void jit_uni_dw_conv_fwd_kernel<isa, kernel_dt>::init_scratchpad(
218         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
219     if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
220         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
221 }
222
223 template struct jit_uni_dw_conv_fwd_kernel<avx512_core, data_type::bf16>;
224 template struct jit_uni_dw_conv_fwd_kernel<avx512_common, data_type::f32>;
225 template struct jit_uni_dw_conv_fwd_kernel<avx2, data_type::f32>;
226 template struct jit_uni_dw_conv_fwd_kernel<sse42, data_type::f32>;
227
228 template <cpu_isa_t isa, data_type_t kernel_dt>
229 struct jit_uni_dw_conv_bwd_data_kernel {
230
231     jit_uni_dw_conv_bwd_data_kernel(jit_conv_conf_t ajcp)
232         : jit_ker(nullptr), ker_(nullptr) {
233         ker_ = new jit_kernel_t(ajcp);
234         jit_ker = ker_->jit_ker;
235     }
236     ~jit_uni_dw_conv_bwd_data_kernel(){
237         delete ker_;
238     }
239
240     static status_t init_conf(jit_conv_conf_t &jcp,
241             const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
242             const memory_desc_wrapper &weights_d,
243             const memory_desc_wrapper &diff_dst_d);
244
245     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
246             const jit_conv_conf_t &jcp);
247
248     void (*jit_ker)(jit_conv_call_s *);
249
250 private:
251     using jit_kernel_t = typename utils::conditional<isa == avx512_core
252                     && kernel_dt == data_type::bf16,
253             jit_avx512_dw_conv_bwd_data_kernel_bf16,
254             jit_uni_dw_conv_bwd_data_kernel_f32<isa>>::type;
255     jit_kernel_t *ker_;
256 };
257
258 template <cpu_isa_t isa, data_type_t kernel_dt>
259 status_t jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_conf(
260         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
261         const memory_desc_wrapper &diff_src_d,
262         const memory_desc_wrapper &weights_d,
263         const memory_desc_wrapper &diff_dst_d) {
264
265     jcp.dsrc_dt = cd.diff_src_desc.data_type;
266     const bool is_bf16 = diff_dst_d.data_type() == data_type::bf16;
267
268     if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
269         return status::unimplemented;
270
271     const int simd_w = one_of(isa, avx512_common, avx512_core) ? 16 : 8;
272
273     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
274     if (!with_groups) return status::unimplemented;
275
276     jcp.ngroups = weights_d.dims()[0];
277     jcp.mb = diff_src_d.dims()[0];
278
279     jcp.oc = diff_dst_d.dims()[1];
280     jcp.oc_without_padding = jcp.oc;
281     jcp.ic = diff_src_d.dims()[1];
282
283     jcp.ih = diff_src_d.dims()[2];
284     jcp.iw = diff_src_d.dims()[3];
285     jcp.oh = diff_dst_d.dims()[2];
286     jcp.ow = diff_dst_d.dims()[3];
287
288     jcp.kh = weights_d.dims()[3];
289     jcp.kw = weights_d.dims()[4];
290
291     jcp.t_pad = cd.padding[0][0];
292     jcp.l_pad = cd.padding[0][1];
293     jcp.b_pad = cd.padding[1][0];
294     jcp.r_pad = cd.padding[1][1];
295
296     jcp.stride_h = cd.strides[0];
297     jcp.stride_w = cd.strides[1];
298
299     jcp.dilate_h = cd.dilates[0];
300     jcp.dilate_w = cd.dilates[1];
301
302     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
303     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
304
305     jcp.src_fmt = diff_src_d.format();
306
307     bool ok_to_pad_channels = true
308         && jcp.oc == jcp.ngroups
309         && jcp.ic == jcp.ngroups
310         && one_of(isa, avx512_common, avx512_core, avx2, sse42);
311     if (ok_to_pad_channels) {
312         jcp.oc = rnd_up(jcp.oc, simd_w);
313         jcp.ic = rnd_up(jcp.oc, simd_w);
314         jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
315     }
316
317     auto desired_act_fmt
318             = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c;
319     auto desired_wei_fmt
320             = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g;
321
322     bool args_ok = true
323         && jcp.oc == jcp.ngroups
324         && jcp.ic == jcp.ngroups
325         && jcp.ngroups % simd_w == 0
326         && jcp.dilate_h == 0
327         && jcp.dilate_w == 0
328         && diff_src_d.format() == desired_act_fmt
329         && weights_d.format() == desired_wei_fmt
330         && diff_dst_d.format() == desired_act_fmt
331         && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
332         && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
333         && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
334         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
335         && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
336     if (!args_ok) return status::unimplemented;
337
338     jcp.is_cpx = (mayiuse(avx512_core_bf16)) ? true : false;
339
340     jcp.typesize_out = diff_src_d.data_type() == data_type::bf16
341             ? sizeof(mkldnn_bfloat16_t)
342             : sizeof(float);
343     jcp.typesize_in = diff_dst_d.data_type() == data_type::bf16
344             ? sizeof(mkldnn_bfloat16_t)
345             : sizeof(float);
346
347     jcp.ur_w = is_bf16 ? (jcp.is_cpx ? 6 : 4)
348                        : isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
349
350     jcp.ch_block = simd_w;
351     jcp.nb_ch = jcp.ic / jcp.ch_block;
352     jcp.nb_ch_blocking
353             = one_of(isa, avx512_common, avx512_core) ? 4 : isa == avx2 ? 3 : 2;
354     if (jcp.nb_ch < jcp.nb_ch_blocking)
355         jcp.nb_ch_blocking = jcp.nb_ch;
356
357     return status::success;
358 }
359
360 template <cpu_isa_t isa, data_type_t kernel_dt>
361 void jit_uni_dw_conv_bwd_data_kernel<isa, kernel_dt>::init_scratchpad(
362         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
363     UNUSED(scratchpad);
364     UNUSED(jcp);
365 }
366
367 template struct jit_uni_dw_conv_bwd_data_kernel<avx512_core, data_type::bf16>;
368 template struct jit_uni_dw_conv_bwd_data_kernel<avx512_common, data_type::f32>;
369 template struct jit_uni_dw_conv_bwd_data_kernel<avx2, data_type::f32>;
370 template struct jit_uni_dw_conv_bwd_data_kernel<sse42, data_type::f32>;
371
372 template <cpu_isa_t isa, data_type_t kernel_dt>
373 struct jit_uni_dw_conv_bwd_weights_kernel {
374
375     jit_uni_dw_conv_bwd_weights_kernel(jit_conv_conf_t ajcp)
376         : jit_ker(nullptr), ker_(nullptr) {
377         ker_ = new jit_kernel_t(ajcp);
378         jit_ker = ker_->jit_ker;
379     }
380
381     ~jit_uni_dw_conv_bwd_weights_kernel() { delete ker_; }
382
383     static status_t init_conf(jit_conv_conf_t &jcp,
384             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
385             const memory_desc_wrapper &diff_weights_d,
386             const memory_desc_wrapper &diff_dst_d, int nthreads);
387
388     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
389             const jit_conv_conf_t &jcp);
390
391     static void balance(jit_conv_conf_t &jcp, int nthreads);
392
393     void (*jit_ker)(jit_dw_conv_call_s *);
394
395 private:
396     using jit_kernel_t = typename utils::conditional<isa == avx512_core
397                     && kernel_dt == data_type::bf16,
398             jit_avx512_dw_conv_bwd_weights_kernel_bf16,
399             jit_uni_dw_conv_bwd_weights_kernel_f32<isa>>::type;
400     jit_kernel_t *ker_;
401 };
402
403 template <cpu_isa_t isa, data_type_t kernel_dt>
404 status_t jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_conf(
405         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
406         const memory_desc_wrapper &src_d,
407         const memory_desc_wrapper &diff_weights_d,
408         const memory_desc_wrapper &diff_dst_d, int nthreads) {
409
410     jcp.dwei_dt = cd.diff_weights_desc.data_type;
411     const bool is_bf16 = src_d.data_type() == data_type::bf16;
412
413     if (!mayiuse(isa) || (is_bf16 && !mayiuse(avx512_core)))
414         return status::unimplemented;
415
416     jcp.ngroups = diff_weights_d.dims()[0];
417     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
418     jcp.ic = src_d.dims()[1] / jcp.ngroups;
419
420     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
421
422     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
423
424     if (!jcp.is_depthwise)
425         return status::unimplemented;
426
427     jcp.ch_block = one_of(isa, avx512_common, avx512_core) ? 16 : 8;
428
429     jcp.mb = src_d.dims()[0];
430
431     jcp.ih = src_d.dims()[2];
432     jcp.iw = src_d.dims()[3];
433     jcp.oh = diff_dst_d.dims()[2];
434     jcp.ow = diff_dst_d.dims()[3];
435
436     jcp.kh = diff_weights_d.dims()[3];
437     jcp.kw = diff_weights_d.dims()[4];
438
439     jcp.stride_h = cd.strides[0];
440     jcp.stride_w = cd.strides[1];
441
442     jcp.t_pad = cd.padding[0][0];
443     jcp.b_pad = cd.padding[1][0];
444
445     jcp.l_pad = cd.padding[0][1];
446     jcp.r_pad = cd.padding[1][1];
447
448     jcp.dilate_h = cd.dilates[0];
449     jcp.dilate_w = cd.dilates[1];
450
451     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
452     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
453
454     jcp.src_fmt = src_d.format();
455
456     jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
457
458     auto desired_act_fmt
459             = one_of(isa, avx512_common, avx512_core) ? nChw16c : nChw8c;
460     auto desired_wei_fmt
461             = one_of(isa, avx512_common, avx512_core) ? Goihw16g : Goihw8g;
462
463     bool args_ok = true && src_d.format() == desired_act_fmt
464             && diff_weights_d.format() == desired_wei_fmt
465             && diff_dst_d.format() == desired_act_fmt
466             && one_of(cd.bias_desc.format, memory_format::undef, any, x)
467             && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
468             && jcp.dilate_w == 0 && jcp.kw <= 3
469             && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
470             && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
471     if (!args_ok)
472         return status::unimplemented;
473
474     if (!IMPLICATION(is_bf16, desired_act_fmt == mkldnn_nChw16c
475                         && desired_wei_fmt == mkldnn_Goihw16g))
476         return status::unimplemented;
477
478     jcp.nb_ch = jcp.ngroups / jcp.ch_block;
479
480     /* kernel applicability check wrt boundaries
481      * the conditions are quite general across the kernels we have,
482      * but ideally the check should belong to a specific kernel... */
483     const int max_hpad = (jcp.kh - 1 + 1) / 2;
484     const int max_wpad = (jcp.kw - 1 + 1) / 2;
485     const bool boundaries_ok = true && jcp.t_pad <= max_hpad
486             && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
487             && jcp.r_pad <= max_wpad;
488     if (!boundaries_ok)
489         return status::unimplemented;
490
491     jcp.is_cpx = (mayiuse(avx512_core_bf16)) ? true : false;
492
493     /* BF16: accumulation of output happens in f32, down-conversion to bf16
494      * happens during the reduction phase. */
495     jcp.typesize_out = sizeof(float);
496     jcp.typesize_in = src_d.data_type() == data_type::bf16
497             ? sizeof(mkldnn_bfloat16_t)
498             : sizeof(float);
499
500     balance(jcp, nthreads);
501
502     return status::success;
503 }
504
505 template <cpu_isa_t isa, data_type_t kernel_dt>
506 void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::init_scratchpad(
507         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
508     /* Notes: if splitting thread work on 'mb', then a reduction has to take
509      * place. Hence, book a per-thread, local weights-buffer for the
510      * reduction */
511     if (jcp.nthr_mb > 1) {
512         const size_t mb = jcp.dwei_dt == data_type::bf16 ? jcp.nthr_mb
513                                                            : jcp.nthr_mb - 1;
514         const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
515         scratchpad.book(key_conv_wei_reduction, sizeof(float) * wei_size * mb);
516
517         if (jcp.with_bias)
518             scratchpad.book(key_conv_bia_reduction,
519                     sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
520     } else if (jcp.nthr_mb == 1 && jcp.dwei_dt == data_type::bf16) {
521         const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
522         scratchpad.book(key_conv_wei_reduction, sizeof(float) * wei_size);
523     }
524 }
525
526 template <cpu_isa_t isa, data_type_t kernel_dt>
527 void jit_uni_dw_conv_bwd_weights_kernel<isa, kernel_dt>::balance(
528         jit_conv_conf_t &jcp, int nthreads) {
529     jcp.nthr = nthreads;
530     jcp.nthr_g = jcp.nthr_mb = 1;
531
532     /* Basic-Heuristics for parallel strategy:
533      * 1) Tries to parallel on the number of Groups (g) where tasks are
534      * independent. Otherwise,
535      * 2) Tries to split the work across g and MiniBatch (mb).
536      * Parallelizing on mb requires computing a reduction for weights.
537      *
538      * NOTE: because of 'task partitioning' scheme, there will be unbalanced
539      * per-thread load when the number of threads is high (e.g. > 16).
540      */
541     jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
542     jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
543
544     jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
545 }
546
547 template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_core,
548         data_type::bf16>;
549 template struct jit_uni_dw_conv_bwd_weights_kernel<avx512_common,
550         data_type::f32>;
551 template struct jit_uni_dw_conv_bwd_weights_kernel<avx2, data_type::f32>;
552 template struct jit_uni_dw_conv_bwd_weights_kernel<sse42, data_type::f32>;
553 }
554 }
555 }
556 #endif /* JIT_UNI_DW_CONVOLUTION_UTILS_HPP */