1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_uni_dw_conv_kernel_f32.hpp"
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::memory_tracking::names;
34 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
38 template <cpu_isa_t isa>
39 void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
40 int repeats = isa == sse42 ? 2 : 1;
41 for (int i = 0; i < repeats; i++) {
42 for (int ch = 0; ch < ur_ch_blocks; ch++) {
43 for (int ow = 0; ow < ur_w; ow++) {
44 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
46 int b_off = ch*jcp.ch_block + i*4;
47 if (this->jcp.with_bias)
49 vmmword[reg_bias + b_off*sizeof(float)]);
51 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
53 int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block
54 + ow*jcp.ch_block + i*4;
55 if (this->jcp.with_sum)
56 uni_vaddps(vmm_acc, vmm_acc,
57 vmmword[reg_output + o_off*sizeof(float)]);
63 template <cpu_isa_t isa>
64 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter(
65 int ur_ch_blocks, int ur_w) {
66 int ch_blk = jcp.ch_block;
67 int dilate_h = jcp.dilate_h + 1;
68 int dilate_w = jcp.dilate_w + 1;
69 int stride_w = jcp.stride_w;
71 Label iter_exit_label;
74 je(iter_exit_label, T_NEAR);
76 je(iter_exit_label, T_NEAR);
82 mov(aux1_reg_input, aux_reg_input);
83 mov(aux1_reg_kernel, aux_reg_kernel);
87 int repeats = isa == sse42 ? 2 : 1;
88 for (int i = 0; i < repeats; i++) {
89 for (int ch = 0; ch < ur_ch_blocks; ch++) {
90 int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4;
91 Vmm vmm_ker = get_ker_reg(0);
92 uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
93 + ker_off*sizeof(float)]);
95 for (int ow = 0; ow < ur_w; ow++) {
96 int inp_off = ch*jcp.ih*jcp.iw*ch_blk
97 + ow*stride_w*ch_blk + i*4;
98 Vmm vmm_src = get_src_reg(0);
99 uni_vmovups(vmm_src, ptr[aux1_reg_input
100 + inp_off*sizeof(float)]);
102 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
104 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
108 add(aux1_reg_kernel, ch_blk*sizeof(float));
109 add(aux1_reg_input, ch_blk*dilate_w*sizeof(float));
113 jg(kw_label, T_NEAR);
115 add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
116 add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
120 jg(kh_label, T_NEAR);
126 template <cpu_isa_t isa>
127 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
128 int ur_ch_blocks, int ur_w) {
129 int ch_blk = jcp.ch_block;
130 int dilate_h = jcp.dilate_h + 1;
131 int dilate_w = jcp.dilate_w + 1;
132 int stride_w = jcp.stride_w;
134 Label iter_exit_label;
137 je(iter_exit_label, T_NEAR);
139 mov(iter_kh, reg_kh);
142 int repeats = isa == sse42 ? 2 : 1;
143 for (int i = 0; i < repeats; i++) {
144 for (int ch = 0; ch < ur_ch_blocks; ch++) {
145 for (int kw = 0; kw < jcp.kw; kw++) {
146 int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4;
148 Vmm vmm_ker = get_ker_reg(0);
149 uni_vmovups(vmm_ker, ptr[aux_reg_kernel
150 + ker_off*sizeof(float)]);
152 for (int ow = 0; ow < ur_w; ow++) {
153 int inp_off = ch*jcp.ih*jcp.iw*ch_blk
154 + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4;
156 Vmm vmm_src = get_src_reg(0);
157 uni_vmovups(vmm_src, ptr[aux_reg_input
158 + inp_off*sizeof(float)]);
160 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
162 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
168 add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
169 add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
173 jg(kh_label, T_NEAR);
179 template <cpu_isa_t isa>
180 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_ch_blocks, int ur_w) {
181 int repeats = isa == sse42 ? 2 : 1;
183 int eltwise_inj_idx = 0;
184 int depthwise_inj_idx = 0;
185 const auto &p = attr_.post_ops_;
187 for (int i = 0; i < p.len_; i++) {
188 auto& post_op = p.entry_[i];
189 if (post_op.is_eltwise()) {
190 int start_idx = get_acc_reg(0).getIdx();
191 int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx();
193 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx);
195 } else if (post_op.is_depthwise()) {
196 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
197 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
199 add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
200 add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
202 for (int ch = 0; ch < ur_ch_blocks; ch++) {
203 for (int k = 0; k < repeats; k++) {
204 int start_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch).getIdx();
205 int end_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch + ur_w).getIdx();
207 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
208 start_idx, end_idx, reg_d_weights, reg_d_bias);
210 add(reg_d_weights, jcp.ch_block / repeats * sizeof(float));
211 add(reg_d_bias, jcp.ch_block / repeats * sizeof(float));
220 template <cpu_isa_t isa>
221 void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
222 int ur_ch_blocks, int ur_w) {
223 int ch_blk = jcp.ch_block;
225 int repeats = isa == sse42 ? 2 : 1;
226 for (int i = 0; i < repeats; i++) {
227 for (int ch = 0; ch < ur_ch_blocks; ch++) {
228 for (int ow = 0; ow < ur_w; ow++) {
229 int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4;
230 Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
232 uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
238 template <cpu_isa_t isa>
239 void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
240 Label unrolled_w_label;
244 L(unrolled_w_label); {
248 jl(tail_w_label, T_NEAR);
250 mov(aux_reg_input, reg_input);
251 mov(aux_reg_kernel, reg_kernel);
253 load_src(ur_ch_blocks, ur_w);
254 apply_filter_unrolled(ur_ch_blocks, ur_w);
255 apply_postprocess(ur_ch_blocks, ur_w);
256 store_dst(ur_ch_blocks, ur_w);
258 add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
259 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
262 jmp(unrolled_w_label);
269 jl(exit_label, T_NEAR);
271 mov(aux_reg_input, reg_input);
272 mov(aux_reg_kernel, reg_kernel);
274 load_src(ur_ch_blocks, ur_w);
275 apply_filter(ur_ch_blocks, ur_w);
276 apply_postprocess(ur_ch_blocks, ur_w);
277 store_dst(ur_ch_blocks, ur_w);
279 add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
280 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
289 template <cpu_isa_t isa>
290 void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
291 const auto &p = attr_.post_ops_;
292 for (int i = 0; i < p.len_; i++) {
293 auto &post_op = p.entry_[i];
294 if (post_op.is_eltwise()) {
295 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
298 post_op.eltwise.alpha,
301 } else if (post_op.is_depthwise()) {
302 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
304 post_op.depthwise.alg
311 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
312 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
313 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
315 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
316 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
317 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
318 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
319 mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
321 Label ch_blocks_tail_label;
324 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
326 cmp(reg_ch_blocks, jcp.nb_ch_blocking);
327 jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
329 loop_body(jcp.nb_ch_blocking); // channel main loop
331 if (ch_blocks_tail) {
332 L(ch_blocks_tail_label);
334 cmp(reg_ch_blocks, ch_blocks_tail);
335 jne(exit_label, T_NEAR);
337 loop_body(ch_blocks_tail); // channel tail loop
344 for (auto& inj : eltwise_injectors)
345 inj->prepare_table();
348 template <cpu_isa_t isa>
349 bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
350 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
351 const auto &p = attr.post_ops_;
353 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
354 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
355 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
356 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
360 case 1: return is_simple(0) || is_sum(0);
361 case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
362 case 3: return is_sum(0) && is_simple(1) && is_simple(2);
363 default: return false;
369 template <cpu_isa_t isa>
370 status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
371 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
372 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
373 const primitive_attr_t &attr)
375 if (!mayiuse(isa)) return status::unimplemented;
377 const int simd_w = isa == avx512_common ? 16 : 8;
379 jcp.prop_kind = cd.prop_kind;
381 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
382 if (!with_groups) return status::unimplemented;
384 jcp.ngroups = weights_d.dims()[0];
385 jcp.mb = src_d.dims()[0];
387 jcp.oc = dst_d.dims()[1];
388 jcp.oc_without_padding = jcp.oc;
389 jcp.ic = src_d.dims()[1];
391 jcp.ih = src_d.dims()[2];
392 jcp.iw = src_d.dims()[3];
393 jcp.oh = dst_d.dims()[2];
394 jcp.ow = dst_d.dims()[3];
396 jcp.kh = weights_d.dims()[3];
397 jcp.kw = weights_d.dims()[4];
399 jcp.t_pad = cd.padding[0][0];
400 jcp.l_pad = cd.padding[0][1];
401 jcp.b_pad = cd.padding[1][0];
402 jcp.r_pad = cd.padding[1][1];
404 jcp.stride_h = cd.strides[0];
405 jcp.stride_w = cd.strides[1];
407 jcp.dilate_h = cd.dilates[0];
408 jcp.dilate_w = cd.dilates[1];
410 jcp.src_fmt = src_d.format();
411 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
413 if (!post_ops_ok(jcp, attr))
414 return status::unimplemented;
416 const auto &p = attr.post_ops_;
417 jcp.with_sum = p.find(primitive_kind::sum) != -1;
419 bool ok_to_pad_channels = true
420 && jcp.oc == jcp.ngroups
421 && jcp.ic == jcp.ngroups
422 && one_of(isa, avx512_common, avx2, sse42);
423 if (ok_to_pad_channels) {
424 jcp.oc = rnd_up(jcp.oc, simd_w);
425 jcp.ic = rnd_up(jcp.oc, simd_w);
426 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
429 auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
430 auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
433 && jcp.oc == jcp.ngroups
434 && jcp.ic == jcp.ngroups
435 && jcp.ngroups % simd_w == 0
436 && src_d.format() == desired_act_fmt
437 && weights_d.format() == desired_wei_fmt
438 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
439 && dst_d.format() == desired_act_fmt
440 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
441 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
442 && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
443 if (!args_ok) return status::unimplemented;
445 jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
447 jcp.ch_block = simd_w;
448 jcp.nb_ch = jcp.oc / jcp.ch_block;
449 jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
450 if (jcp.nb_ch < jcp.nb_ch_blocking)
451 jcp.nb_ch_blocking = jcp.nb_ch;
453 return status::success;
456 template <cpu_isa_t isa>
457 void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(
458 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
459 if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
460 scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
463 template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
464 template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
465 template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
467 template <cpu_isa_t isa>
468 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
469 int ur_ch_blocks, int ur_str_w) {
470 int repeats = isa == sse42 ? 2 : 1;
471 for (int i = 0; i < repeats; i++) {
472 for (int ch = 0; ch < ur_ch_blocks; ch++) {
473 for (int w = 0; w < ur_str_w; w++) {
474 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
476 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
482 template <cpu_isa_t isa>
483 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
484 int ur_ch_blocks, int ur_str_w) {
490 int ch_blk = jcp.ch_block;
491 int stride_h = jcp.stride_h;
492 int stride_w = jcp.stride_w;
494 Label iter_exit_label;
497 je(iter_exit_label, T_NEAR);
500 je(iter_exit_label, T_NEAR);
502 mov(iter_kh, reg_kh);
505 mov(aux1_reg_ddst, aux_reg_ddst);
506 mov(aux1_reg_kernel, aux_reg_kernel);
508 mov(iter_kw, reg_kw);
511 int repeats = isa == sse42 ? 2 : 1;
512 for (int i = 0; i < repeats; i++) {
513 for (int ch = 0; ch < ur_ch_blocks; ch++) {
514 int ker_off = ch*kh*kw*ch_blk + i*4;
515 Vmm vmm_ker = get_ker_reg(0);
516 uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
517 + ker_off*sizeof(float)]);
519 for (int w = 0; w < ur_str_w; w++) {
520 int ddst_off = (ch*oh*ow + w)*ch_blk + i*4;
522 Vmm vmm_src = get_src_reg(0);
523 uni_vmovups(vmm_src, ptr[aux1_reg_ddst
524 + ddst_off*sizeof(float)]);
526 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
528 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
533 add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float));
534 sub(aux1_reg_ddst, ch_blk*sizeof(float));
536 sub(iter_kw, stride_w);
538 jg(kw_label, T_NEAR);
541 add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float));
542 sub(aux_reg_ddst, ow*ch_blk*sizeof(float));
544 sub(iter_kh, stride_h);
546 jg(kh_label, T_NEAR);
552 template <cpu_isa_t isa>
553 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
554 int ur_ch_blocks, int ur_str_w) {
555 int ch_blk = jcp.ch_block;
558 int stride_w = jcp.stride_w;
560 int repeats = isa == sse42 ? 2 : 1;
561 for (int i = 0; i < repeats; i++) {
562 for (int ch = 0; ch < ur_ch_blocks; ch++) {
563 for (int w = 0; w < ur_str_w; w++) {
564 int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4;
565 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
568 uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc);
574 template <cpu_isa_t isa>
575 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body(
577 Label unrolled_w_label;
581 L(unrolled_w_label); {
584 cmp(reg_ur_str_w, ur_w);
585 jl(tail_w_label, T_NEAR);
587 mov(aux_reg_ddst, reg_ddst);
588 mov(aux_reg_kernel, reg_kernel);
590 load_ddst(ur_ch_blocks, ur_w);
591 apply_filter(ur_ch_blocks, ur_w);
592 store_dsrc(ur_ch_blocks, ur_w);
594 add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
595 add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
597 sub(reg_ur_str_w, ur_w);
598 jmp(unrolled_w_label);
604 cmp(reg_ur_str_w, ur_w);
605 jl(exit_label, T_NEAR);
607 mov(aux_reg_ddst, reg_ddst);
608 mov(aux_reg_kernel, reg_kernel);
610 load_ddst(ur_ch_blocks, ur_w);
611 apply_filter(ur_ch_blocks, ur_w);
612 store_dsrc(ur_ch_blocks, ur_w);
614 add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
615 add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
617 sub(reg_ur_str_w, ur_w);
624 template <cpu_isa_t isa>
625 void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
628 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
629 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
630 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
631 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
632 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
633 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
634 mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
636 Label ch_blocks_tail_label;
639 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
641 cmp(reg_ch_blocks, jcp.nb_ch_blocking);
642 jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
644 loop_body(jcp.nb_ch_blocking); // channel main loop
646 if (ch_blocks_tail) {
647 L(ch_blocks_tail_label);
649 cmp(reg_ch_blocks, ch_blocks_tail);
650 jne(exit_label, T_NEAR);
652 loop_body(ch_blocks_tail); // channel tail loop
660 template <cpu_isa_t isa>
661 status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
662 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
663 const memory_desc_wrapper &diff_src_d,
664 const memory_desc_wrapper &weights_d,
665 const memory_desc_wrapper &diff_dst_d) {
666 if (!mayiuse(isa)) return status::unimplemented;
668 const int simd_w = isa == avx512_common ? 16 : 8;
670 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
671 if (!with_groups) return status::unimplemented;
673 jcp.ngroups = weights_d.dims()[0];
674 jcp.mb = diff_src_d.dims()[0];
676 jcp.oc = diff_dst_d.dims()[1];
677 jcp.oc_without_padding = jcp.oc;
678 jcp.ic = diff_src_d.dims()[1];
680 jcp.ih = diff_src_d.dims()[2];
681 jcp.iw = diff_src_d.dims()[3];
682 jcp.oh = diff_dst_d.dims()[2];
683 jcp.ow = diff_dst_d.dims()[3];
685 jcp.kh = weights_d.dims()[3];
686 jcp.kw = weights_d.dims()[4];
688 jcp.t_pad = cd.padding[0][0];
689 jcp.l_pad = cd.padding[0][1];
690 jcp.b_pad = cd.padding[1][0];
691 jcp.r_pad = cd.padding[1][1];
693 jcp.stride_h = cd.strides[0];
694 jcp.stride_w = cd.strides[1];
696 jcp.dilate_h = cd.dilates[0];
697 jcp.dilate_w = cd.dilates[1];
699 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
700 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
702 jcp.src_fmt = diff_src_d.format();
704 bool ok_to_pad_channels = true
705 && jcp.oc == jcp.ngroups
706 && jcp.ic == jcp.ngroups
707 && one_of(isa, avx512_common, avx2);
708 if (ok_to_pad_channels) {
709 jcp.oc = rnd_up(jcp.oc, simd_w);
710 jcp.ic = rnd_up(jcp.oc, simd_w);
711 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
714 auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
715 auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
718 && jcp.oc == jcp.ngroups
719 && jcp.ic == jcp.ngroups
720 && jcp.ngroups % simd_w == 0
723 && diff_src_d.format() == desired_act_fmt
724 && weights_d.format() == desired_wei_fmt
725 && diff_dst_d.format() == desired_act_fmt
726 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
727 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
728 && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
729 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
730 && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
731 if (!args_ok) return status::unimplemented;
733 jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
735 jcp.ch_block = simd_w;
736 jcp.nb_ch = jcp.ic / jcp.ch_block;
737 jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
738 if (jcp.nb_ch < jcp.nb_ch_blocking)
739 jcp.nb_ch_blocking = jcp.nb_ch;
741 return status::success;
744 template <cpu_isa_t isa>
745 void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
746 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
751 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
752 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
753 template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
755 template <cpu_isa_t isa>
756 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() {
757 for (int r = 0; r < reg_repeats; ++r) {
758 for (int i = 0; i < jcp.kw; ++i) {
759 Vmm vmm_acc = get_acc_reg(r * jcp.kw + i);
760 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
765 template <cpu_isa_t isa>
766 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter() {
767 for (int r = 0; r < reg_repeats; ++r) {
768 const int reg_set = r * jcp.kw;
769 for (int i = 0; i < jcp.kw; ++i) {
770 int off_filter = (reg_set + i) * simd_w;
771 Vmm vmm_acc = get_acc_reg(reg_set + i);
773 vmmword[reg_tmp_filter + off_filter * sizeof(float)]);
778 template <cpu_isa_t isa>
779 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() {
780 for (int r = 0; r < reg_repeats; ++r) {
781 Vmm vmm_bias = get_bias_reg(r);
782 uni_vpxor(vmm_bias, vmm_bias, vmm_bias);
786 template <cpu_isa_t isa>
787 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias() {
788 for (int r = 0; r < reg_repeats; ++r) {
789 Vmm vmm_bias = get_bias_reg(r);
791 vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]);
795 template <cpu_isa_t isa>
796 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
797 int unroll_w, int l_pad, int pad_offset, int ow_block) {
799 const int iw_block = ow_block * jcp.stride_w;
800 const int right_border = jcp.iw - iw_block;
802 const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
804 /* preamble count for number of cascaded LOAD + FMA operation */
805 const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
807 /* LOAD initial input registers, then cascade LOADs and FMAs*/
808 for (int r = 0; r < reg_repeats; ++r) {
809 for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
810 int off_output = (i_ur * reg_repeats + r) * simd_w;
811 Vmm vmm_output = get_output_reg(r);
812 uni_vmovups(vmm_output,
813 ptr[reg_tmp_output + off_output * sizeof(float)]);
815 for (int c = 0; c < input_overlap; ++c) {
817 = ((c - pad_offset) * reg_repeats + r) * simd_w;
819 = get_input_reg((c % jcp.kw) * reg_repeats + r);
820 uni_vmovups(vmm_input,
821 ptr[reg_tmp_input + off_input * sizeof(float)]);
824 for (int c = 0; c < cascade_input; ++c) {
825 int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
827 = ((overlap + c - pad_offset) * reg_repeats + r)
829 Vmm vmm_input = get_input_reg(
830 ((overlap + c) % jcp.kw) * reg_repeats + r);
831 uni_vmovups(vmm_input,
832 ptr[reg_tmp_input + off_input * sizeof(float)]);
836 for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
837 int io_overlap = i_kw + (i_ur * jcp.stride_w);
839 /* Don't apply FMAs that fall into the padded region */
840 if (io_overlap - l_pad < 0
841 || io_overlap - jcp.l_pad >= right_border)
844 Vmm vmm_input = get_input_reg(
845 ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
846 Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
847 Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input;
849 uni_vmovups(vmm_aux, vmm_input);
850 uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
856 template <cpu_isa_t isa>
858 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll(
859 const int unroll_w) {
860 for (int r = 0; r < reg_repeats; ++r) {
861 for (int i = 0; i < unroll_w; ++i) {
862 Vmm vmm_bias = get_bias_reg(r);
863 int off_output = (i * reg_repeats + r) * simd_w;
865 /* Need to support unaligned address loads for SSE42*/
866 Vmm vmm_output = get_output_reg(1 + r);
867 uni_vmovups(vmm_output,
868 ptr[reg_tmp_output + off_output * sizeof(float)]);
869 uni_vaddps(vmm_bias, vmm_bias, vmm_output);
871 uni_vaddps(vmm_bias, vmm_bias,
872 vmmword[reg_tmp_output + off_output * sizeof(float)]);
878 template <cpu_isa_t isa>
879 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter() {
880 for (int r = 0; r < reg_repeats; ++r) {
881 const int reg_set = r * jcp.kw;
882 for (int i = 0; i < jcp.kw; ++i) {
883 int off_filter = (i + reg_set) * simd_w;
884 Vmm vmm_acc = get_acc_reg(i + reg_set);
885 uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)],
891 template <cpu_isa_t isa>
892 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias() {
893 for (int r = 0; r < reg_repeats; ++r) {
894 Vmm vmm_bias = get_bias_reg(r);
896 vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias);
900 template <cpu_isa_t isa>
901 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop(
902 const int block_size) {
906 const int unroll_w = nstl::min(block_size, jcp.ow);
907 const int unroll_w_trips = jcp.ow / unroll_w;
908 const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0;
910 const int ch_offset = jcp.ch_block;
912 mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
914 ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
916 mov(reg_tmp_output, reg_output_baddr);
920 mov(iter_ow_blk, unroll_w_trips);
924 compute_bias_step_unroll(unroll_w);
925 add(reg_tmp_output, unroll_w * ch_offset * sizeof(float));
929 jg(ow_blk_label, T_NEAR);
933 compute_bias_step_unroll(tail_w);
934 add(reg_tmp_output, tail_w * ch_offset * sizeof(float));
938 cmp(reg_oh, reg_oh_worksize);
939 jl(oh_label, T_NEAR);
943 template <cpu_isa_t isa>
944 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() {
946 const int ch_offset = jcp.ch_block;
948 Label kh_loop_label, skip_zeroing_label;
951 ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
952 and_(reg_exec_flags, FLAG_ZERO_FILTER);
953 test(reg_exec_flags, reg_exec_flags);
954 je(skip_zeroing_label);
958 mov(reg_tmp_filter, reg_filter_baddr);
964 add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
970 /* Comeback pointers */
971 sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float));
973 L(skip_zeroing_label);
976 template <cpu_isa_t isa>
977 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step(
978 int unroll_w, int l_pad, int pad_offset, int ow_block) {
980 const int ch_offset = jcp.ch_block;
982 Label kh_loop_label, skip_loop_label;
984 cmp(reg_kh_count, 0);
985 je(skip_loop_label, T_NEAR);
987 mov(reg_kh, reg_kh_count);
991 compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block);
994 add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
995 add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
1001 /* Comeback pointers */
1002 Label kh_comeback_label;
1003 mov(reg_kh, reg_kh_count);
1004 L(kh_comeback_label);
1006 sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
1007 sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
1010 jg(kh_comeback_label, T_NEAR);
1016 template <cpu_isa_t isa>
1017 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
1018 int unroll_w, int l_pad, int pad_offset, int ow_block) {
1020 const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ?
1021 jcp.ih / jcp.stride_h - 1 :
1022 jcp.oh - jcp.b_pad - 1;
1023 const int ch_offset = jcp.ch_block;
1024 const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
1025 const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
1027 Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label,
1030 mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
1031 mov(reg_oh_worksize,
1032 ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
1034 ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]);
1036 mov(reg_tmp_output, reg_output_baddr);
1037 mov(reg_tmp_input, reg_input_baddr);
1038 mov(reg_tmp_filter, reg_filter_baddr);
1043 compute_h_step(unroll_w, l_pad, pad_offset, ow_block);
1045 add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float));
1047 /* If within the top_pad region */
1048 if (jcp.t_pad > 0) {
1049 /* Skip t_pad area if no longer in initial h_block */
1050 cmp(reg_oh, jcp.t_pad);
1051 jg(skip_tpad_label, T_NEAR);
1053 cmp(reg_kh_count, jcp.kh);
1054 jge(skip_tpad_label, T_NEAR);
1056 add(reg_kh_count, t_overlap_off);
1058 t_overlap_off * jcp.kw * ch_offset * sizeof(float));
1060 /* kernel has moved beyond padding (adjust for stride effects) */
1061 if (jcp.t_pad % jcp.stride_h != 0) {
1062 int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
1064 inp_corr * jcp.iw * ch_offset * sizeof(float));
1066 jmp(tpad_loop_label, T_NEAR);
1071 cmp(reg_oh, io_overlap);
1072 jl(skip_bpad_label, T_NEAR);
1073 sub(reg_kh_count, b_overlap_off);
1076 add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
1080 cmp(reg_oh, jcp.ih / jcp.stride_h);
1081 jge(end_h_loop_label, T_NEAR);
1085 cmp(reg_oh, reg_oh_worksize);
1086 jl(h_loop_label, T_NEAR);
1088 L(end_h_loop_label);
1091 template <cpu_isa_t isa>
1093 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
1095 const int ch_offset = jcp.ch_block;
1098 int l_pad = jcp.l_pad;
1100 /* Calculate effective padding */
1101 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
1102 + (jcp.kw - 1) * (jcp.dilate_w + 1)
1103 - (jcp.iw + jcp.l_pad - 1));
1105 /* Is this strictly defined by:
1107 * -address size (?) */
1108 const int max_unroll_w = 30;
1109 const int block_size = 15;
1111 int unroll_w_tail = 0;
1113 int unroll_w_trips = 0;
1115 if (jcp.ow > max_unroll_w) {
1116 unroll_w = nstl::min(block_size, jcp.ow);
1117 unroll_w_trips = ow / unroll_w;
1118 /* calculate tail */
1119 unroll_w_tail = ow % unroll_w;
1120 /* Perform some rebalancing if tail too small*/
1121 if ((unroll_w_tail == 0 && r_pad != 0)
1122 || (r_pad > 0 && r_pad >= unroll_w_tail)) {
1123 if (unroll_w_trips > 1) {
1124 unroll_w_tail += unroll_w;
1127 /* Idealy, this case shouldn't happen */
1128 unroll_w_tail += (unroll_w - unroll_w / 2);
1129 unroll_w = unroll_w / 2;
1134 unroll_w_trips = nstl::max(1, ow / unroll_w);
1136 if (jcp.with_bias) {
1137 Label skip_load_bias;
1139 ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
1144 ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
1145 and_(reg_exec_flags, FLAG_ZERO_BIAS);
1146 test(reg_exec_flags, reg_exec_flags);
1147 jne(skip_load_bias);
1152 compute_bias_loop(block_size);
1157 /* Pass filter address, then offset for h_padding. */
1158 compute_zero_filter();
1160 ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]);
1161 add(reg_filter_baddr, reg_kh_offset);
1163 /* compute left padded block */
1165 compute_h_loop(unroll_w, l_pad, 0, 0);
1166 add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
1167 add(reg_input_baddr,
1168 unroll_w * jcp.stride_w * ch_offset * sizeof(float));
1174 /* compute middle block */
1177 /* Insert loop for 'ow' block when middle block needs to execute more
1179 bool do_ow_blk_loop = unroll_w_trips > 1;
1180 if (do_ow_blk_loop) {
1181 mov(iter_ow_blk, unroll_w_trips);
1184 if (unroll_w_trips > 0) {
1185 compute_h_loop(unroll_w, l_pad, pad_offset, 0);
1186 add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
1187 add(reg_input_baddr,
1188 unroll_w * jcp.stride_w * ch_offset * sizeof(float));
1190 if (do_ow_blk_loop) {
1192 cmp(iter_ow_blk, 0);
1193 jg(ow_blk_label, T_NEAR);
1196 /* compute right padded block */
1197 if (unroll_w_tail) {
1198 compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
1202 template <cpu_isa_t isa>
1203 void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
1206 mov(reg_input_baddr,
1207 ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]);
1208 mov(reg_output_baddr,
1209 ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]);
1210 mov(reg_filter_baddr,
1211 ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]);
1213 compute_ow_block_unroll();
1218 template <cpu_isa_t isa>
1219 status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
1220 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1221 const memory_desc_wrapper &src_d,
1222 const memory_desc_wrapper &diff_weights_d,
1223 const memory_desc_wrapper &diff_dst_d, int nthreads) {
1225 return status::unimplemented;
1227 jcp.ngroups = diff_weights_d.dims()[0];
1228 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1229 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1231 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1233 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
1235 if (!jcp.is_depthwise)
1236 return status::unimplemented;
1238 jcp.ch_block = isa == avx512_common ? 16 : 8;
1240 jcp.mb = src_d.dims()[0];
1242 jcp.ih = src_d.dims()[2];
1243 jcp.iw = src_d.dims()[3];
1244 jcp.oh = diff_dst_d.dims()[2];
1245 jcp.ow = diff_dst_d.dims()[3];
1247 jcp.kh = diff_weights_d.dims()[3];
1248 jcp.kw = diff_weights_d.dims()[4];
1250 jcp.stride_h = cd.strides[0];
1251 jcp.stride_w = cd.strides[1];
1253 jcp.t_pad = cd.padding[0][0];
1254 jcp.b_pad = cd.padding[1][0];
1256 jcp.l_pad = cd.padding[0][1];
1257 jcp.r_pad = cd.padding[1][1];
1259 jcp.dilate_h = cd.dilates[0];
1260 jcp.dilate_w = cd.dilates[1];
1262 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1263 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1265 jcp.src_fmt = src_d.format();
1267 jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
1269 auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
1270 auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
1272 bool args_ok = true && src_d.format() == desired_act_fmt
1273 && diff_weights_d.format() == desired_wei_fmt
1274 && diff_dst_d.format() == desired_act_fmt
1275 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
1276 && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
1277 && jcp.dilate_w == 0 && jcp.kw <= 3
1278 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
1279 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
1281 return status::unimplemented;
1283 jcp.nb_ch = jcp.ngroups / jcp.ch_block;
1285 /* kernel applicability check wrt boundaries
1286 * the conditions are quite general across the kernels we have,
1287 * but ideally the check should belong to a specific kernel... */
1288 const int max_hpad = (jcp.kh - 1 + 1) / 2;
1289 const int max_wpad = (jcp.kw - 1 + 1) / 2;
1290 const bool boundaries_ok = true && jcp.t_pad <= max_hpad
1291 && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
1292 && jcp.r_pad <= max_wpad;
1294 return status::unimplemented;
1296 balance(jcp, nthreads);
1298 return status::success;
1301 template <cpu_isa_t isa>
1302 void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
1303 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1304 /* Notes: if splitting thread work on 'mb', then a reduction has to take
1305 * place. Hence, book a per-thread, local weights-buffer for the
1307 if (jcp.nthr_mb > 1) {
1308 const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
1309 scratchpad.book(key_conv_wei_reduction,
1310 sizeof(float) * wei_size * (jcp.nthr_mb - 1));
1313 scratchpad.book(key_conv_bia_reduction,
1314 sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
1318 template <cpu_isa_t isa>
1319 void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp,
1321 jcp.nthr = nthreads;
1322 jcp.nthr_g = jcp.nthr_mb = 1;
1324 /* Basic-Heuristics for parallel strategy:
1325 * 1) Tries to parallel on the number of Groups (g) where tasks are
1326 * independent. Otherwise,
1327 * 2) Tries to split the work across g and MiniBatch (mb).
1328 * Parallelizing on mb requires computing a reduction for weights.
1330 * NOTE: because of 'task partitioning' scheme, there will be unbalanced
1331 * per-thread load when the number of threads is high (e.g. > 16).
1333 jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
1334 jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
1336 jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
1339 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;
1340 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>;
1341 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse42>;