1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_uni_dw_conv_kernel_f32.hpp"
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
35 using namespace Xbyak;
37 template <cpu_isa_t isa>
38 void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
39 int repeats = isa == sse42 ? 2 : 1;
40 for (int i = 0; i < repeats; i++) {
41 for (int ch = 0; ch < ur_ch_blocks; ch++) {
42 for (int ow = 0; ow < ur_w; ow++) {
43 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
45 int b_off = ch*jcp.ch_block + i*4;
46 if (this->jcp.with_bias)
48 vmmword[reg_bias + b_off*sizeof(float)]);
50 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
52 int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block
53 + ow*jcp.ch_block + i*4;
54 if (this->jcp.with_sum)
55 uni_vaddps(vmm_acc, vmm_acc,
56 vmmword[reg_output + o_off*sizeof(float)]);
62 template <cpu_isa_t isa>
63 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter(
64 int ur_ch_blocks, int ur_w) {
65 int ch_blk = jcp.ch_block;
66 int dilate_h = jcp.dilate_h + 1;
67 int dilate_w = jcp.dilate_w + 1;
68 int stride_w = jcp.stride_w;
70 Label iter_exit_label;
73 je(iter_exit_label, T_NEAR);
75 je(iter_exit_label, T_NEAR);
81 mov(aux1_reg_input, aux_reg_input);
82 mov(aux1_reg_kernel, aux_reg_kernel);
86 int repeats = isa == sse42 ? 2 : 1;
87 for (int i = 0; i < repeats; i++) {
88 for (int ch = 0; ch < ur_ch_blocks; ch++) {
89 int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4;
90 Vmm vmm_ker = get_ker_reg(0);
91 uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
92 + ker_off*sizeof(float)]);
94 for (int ow = 0; ow < ur_w; ow++) {
95 int inp_off = ch*jcp.ih*jcp.iw*ch_blk
96 + ow*stride_w*ch_blk + i*4;
97 Vmm vmm_src = get_src_reg(0);
98 uni_vmovups(vmm_src, ptr[aux1_reg_input
99 + inp_off*sizeof(float)]);
101 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
103 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
107 add(aux1_reg_kernel, ch_blk*sizeof(float));
108 add(aux1_reg_input, ch_blk*dilate_w*sizeof(float));
112 jg(kw_label, T_NEAR);
114 add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
115 add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
119 jg(kh_label, T_NEAR);
125 template <cpu_isa_t isa>
126 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
127 int ur_ch_blocks, int ur_w) {
128 int ch_blk = jcp.ch_block;
129 int dilate_h = jcp.dilate_h + 1;
130 int dilate_w = jcp.dilate_w + 1;
131 int stride_w = jcp.stride_w;
133 Label iter_exit_label;
136 je(iter_exit_label, T_NEAR);
138 mov(iter_kh, reg_kh);
141 int repeats = isa == sse42 ? 2 : 1;
142 for (int i = 0; i < repeats; i++) {
143 for (int ch = 0; ch < ur_ch_blocks; ch++) {
144 for (int kw = 0; kw < jcp.kw; kw++) {
145 int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4;
147 Vmm vmm_ker = get_ker_reg(0);
148 uni_vmovups(vmm_ker, ptr[aux_reg_kernel
149 + ker_off*sizeof(float)]);
151 for (int ow = 0; ow < ur_w; ow++) {
152 int inp_off = ch*jcp.ih*jcp.iw*ch_blk
153 + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4;
155 Vmm vmm_src = get_src_reg(0);
156 uni_vmovups(vmm_src, ptr[aux_reg_input
157 + inp_off*sizeof(float)]);
159 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
161 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
167 add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
168 add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
172 jg(kh_label, T_NEAR);
178 template <cpu_isa_t isa>
179 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation(int ur_ch_blocks, int ur_w) {
180 if (this->jcp.with_eltwise) {
181 inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta));
183 // TODO (dmitrygo): need to find appropriate way to share labels.
184 mov(imm_addr64, l_table);
185 int repeats = isa == sse42 ? 2 : 1;
186 for (int i = 0; i < repeats; i++) {
187 for (int ch = 0; ch < ur_ch_blocks; ch++) {
188 for (int ow = 0; ow < ur_w; ow++) {
189 Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
191 inject(eltwise_generator.computeVector(vmm_dst, vmm_dst));
198 template <cpu_isa_t isa>
199 void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
200 int ur_ch_blocks, int ur_w) {
201 int ch_blk = jcp.ch_block;
203 int repeats = isa == sse42 ? 2 : 1;
204 for (int i = 0; i < repeats; i++) {
205 for (int ch = 0; ch < ur_ch_blocks; ch++) {
206 for (int ow = 0; ow < ur_w; ow++) {
207 int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4;
208 Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
210 uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
216 template <cpu_isa_t isa>
217 void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
218 Label unrolled_w_label;
222 L(unrolled_w_label); {
226 jl(tail_w_label, T_NEAR);
228 mov(aux_reg_input, reg_input);
229 mov(aux_reg_kernel, reg_kernel);
231 load_src(ur_ch_blocks, ur_w);
232 apply_filter_unrolled(ur_ch_blocks, ur_w);
233 apply_activation(ur_ch_blocks, ur_w);
234 store_dst(ur_ch_blocks, ur_w);
236 add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
237 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
240 jmp(unrolled_w_label);
247 jl(exit_label, T_NEAR);
249 mov(aux_reg_input, reg_input);
250 mov(aux_reg_kernel, reg_kernel);
252 load_src(ur_ch_blocks, ur_w);
253 apply_filter(ur_ch_blocks, ur_w);
254 apply_activation(ur_ch_blocks, ur_w);
255 store_dst(ur_ch_blocks, ur_w);
257 add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
258 add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
267 template <cpu_isa_t isa>
268 void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate()
270 nstl::vector<int> shared_vecs;
271 shared_vecs.push_back(0);
272 shared_vecs.push_back(1);
273 shared_vecs.push_back(2);
274 shared_vecs.push_back(3);
275 if (isa == avx512_common)
276 shared_vecs.push_back(31);
278 nstl::vector<Reg64> shared_regs;
279 shared_regs.push_back(imm_addr64);
281 eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs);
285 mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
286 mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
287 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
289 mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
290 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
291 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
292 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
293 mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
295 Label ch_blocks_tail_label;
298 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
300 cmp(reg_ch_blocks, jcp.nb_ch_blocking);
301 jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
303 loop_body(jcp.nb_ch_blocking); // channel main loop
305 if (ch_blocks_tail) {
306 L(ch_blocks_tail_label);
308 cmp(reg_ch_blocks, ch_blocks_tail);
309 jne(exit_label, T_NEAR);
311 loop_body(ch_blocks_tail); // channel tail loop
318 // TODO (dmitrygo): need to find appropriate way to share labels.
321 inject(eltwise_generator.prepareTable());
322 eltwise_generator.release();
325 template <cpu_isa_t isa>
326 bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
327 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
328 const auto &p = attr.post_ops_;
330 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
331 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
334 case 0: return true; // no post_ops
335 case 1: return !jcp.with_eltwise && (is_eltwise(0) || is_sum(0)); // sum OR relu
336 case 2: return !jcp.with_eltwise && (is_sum(0) && is_eltwise(1)); // sum->relu
337 default: return false;
343 template <cpu_isa_t isa>
344 status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
345 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
346 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
347 const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
349 if (!mayiuse(isa)) return status::unimplemented;
351 const int simd_w = isa == avx512_common ? 16 : 8;
353 jcp.prop_kind = cd.prop_kind;
355 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
356 if (!with_groups) return status::unimplemented;
358 jcp.ngroups = weights_d.dims()[0];
359 jcp.mb = src_d.dims()[0];
361 jcp.oc = dst_d.dims()[1];
362 jcp.oc_without_padding = jcp.oc;
363 jcp.ic = src_d.dims()[1];
365 jcp.ih = src_d.dims()[2];
366 jcp.iw = src_d.dims()[3];
367 jcp.oh = dst_d.dims()[2];
368 jcp.ow = dst_d.dims()[3];
370 jcp.kh = weights_d.dims()[3];
371 jcp.kw = weights_d.dims()[4];
373 jcp.t_pad = cd.padding[0][0];
374 jcp.l_pad = cd.padding[0][1];
375 jcp.b_pad = cd.padding[1][0];
376 jcp.r_pad = cd.padding[1][1];
378 jcp.stride_h = cd.strides[0];
379 jcp.stride_w = cd.strides[1];
381 jcp.dilate_h = cd.dilates[0];
382 jcp.dilate_w = cd.dilates[1];
384 jcp.src_fmt = src_d.format();
385 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
386 jcp.with_eltwise = with_relu;
387 jcp.eltwise_alg = mkldnn_eltwise_relu;
388 jcp.eltwise_alpha = relu_negative_slope;
390 if (!post_ops_ok(jcp, attr))
391 return status::unimplemented;
393 const auto &p = attr.post_ops_;
394 jcp.with_sum = p.find(primitive_kind::sum) != -1;
395 if (!jcp.with_eltwise) {
396 int eltwise_ind = p.find(primitive_kind::eltwise);
397 if (eltwise_ind != -1) {
398 jcp.with_eltwise = true;
399 jcp.eltwise_alg = p.entry_[eltwise_ind].eltwise.alg;
400 jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
401 jcp.eltwise_beta = p.entry_[eltwise_ind].eltwise.beta;
402 jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale;
406 bool ok_to_pad_channels = true
407 && jcp.oc == jcp.ngroups
408 && jcp.ic == jcp.ngroups
409 && isa == avx512_common;
410 if (ok_to_pad_channels) {
411 jcp.oc = rnd_up(jcp.oc, simd_w);
412 jcp.ic = rnd_up(jcp.oc, simd_w);
413 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
416 auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
417 auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
420 && jcp.oc == jcp.ngroups
421 && jcp.ic == jcp.ngroups
422 && jcp.ngroups % simd_w == 0
423 && src_d.format() == desired_act_fmt
424 && weights_d.format() == desired_wei_fmt
425 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
426 && dst_d.format() == desired_act_fmt
427 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
428 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
429 && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
430 if (!args_ok) return status::unimplemented;
432 jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
434 jcp.ch_block = simd_w;
435 jcp.nb_ch = jcp.oc / jcp.ch_block;
436 jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
437 if (jcp.nb_ch < jcp.nb_ch_blocking)
438 jcp.nb_ch_blocking = jcp.nb_ch;
440 if (jcp.with_eltwise) {
441 int nvecs_elt = jit_uni_eltwise_vector_f32<isa>::sharedVecsCount(jcp.eltwise_alg);
442 int nvecs_conv = isa == avx512_common ? 32 - nvecs_elt : 16 - nvecs_elt;
443 int isa_mult = isa == sse42 ? 2 : 1;
444 while (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv) {
445 if (jcp.nb_ch_blocking <= 1) {
449 jcp.nb_ch_blocking -= 1;
452 if (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv)
453 return status::unimplemented;
456 return status::success;
459 template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
460 template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
461 template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
463 template <cpu_isa_t isa>
464 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
465 int ur_ch_blocks, int ur_str_w) {
466 int repeats = isa == sse42 ? 2 : 1;
467 for (int i = 0; i < repeats; i++) {
468 for (int ch = 0; ch < ur_ch_blocks; ch++) {
469 for (int w = 0; w < ur_str_w; w++) {
470 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
472 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
478 template <cpu_isa_t isa>
479 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
480 int ur_ch_blocks, int ur_str_w) {
486 int ch_blk = jcp.ch_block;
487 int stride_h = jcp.stride_h;
488 int stride_w = jcp.stride_w;
490 Label iter_exit_label;
493 je(iter_exit_label, T_NEAR);
496 je(iter_exit_label, T_NEAR);
498 mov(iter_kh, reg_kh);
501 mov(aux1_reg_ddst, aux_reg_ddst);
502 mov(aux1_reg_kernel, aux_reg_kernel);
504 mov(iter_kw, reg_kw);
507 int repeats = isa == sse42 ? 2 : 1;
508 for (int i = 0; i < repeats; i++) {
509 for (int ch = 0; ch < ur_ch_blocks; ch++) {
510 int ker_off = ch*kh*kw*ch_blk + i*4;
511 Vmm vmm_ker = get_ker_reg(0);
512 uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
513 + ker_off*sizeof(float)]);
515 for (int w = 0; w < ur_str_w; w++) {
516 int ddst_off = (ch*oh*ow + w)*ch_blk + i*4;
518 Vmm vmm_src = get_src_reg(0);
519 uni_vmovups(vmm_src, ptr[aux1_reg_ddst
520 + ddst_off*sizeof(float)]);
522 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
524 uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
529 add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float));
530 sub(aux1_reg_ddst, ch_blk*sizeof(float));
532 sub(iter_kw, stride_w);
534 jg(kw_label, T_NEAR);
537 add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float));
538 sub(aux_reg_ddst, ow*ch_blk*sizeof(float));
540 sub(iter_kh, stride_h);
542 jg(kh_label, T_NEAR);
548 template <cpu_isa_t isa>
549 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
550 int ur_ch_blocks, int ur_str_w) {
551 int ch_blk = jcp.ch_block;
554 int stride_w = jcp.stride_w;
556 int repeats = isa == sse42 ? 2 : 1;
557 for (int i = 0; i < repeats; i++) {
558 for (int ch = 0; ch < ur_ch_blocks; ch++) {
559 for (int w = 0; w < ur_str_w; w++) {
560 int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4;
561 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
564 uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc);
570 template <cpu_isa_t isa>
571 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body(
573 Label unrolled_w_label;
577 L(unrolled_w_label); {
580 cmp(reg_ur_str_w, ur_w);
581 jl(tail_w_label, T_NEAR);
583 mov(aux_reg_ddst, reg_ddst);
584 mov(aux_reg_kernel, reg_kernel);
586 load_ddst(ur_ch_blocks, ur_w);
587 apply_filter(ur_ch_blocks, ur_w);
588 store_dsrc(ur_ch_blocks, ur_w);
590 add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
591 add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
593 sub(reg_ur_str_w, ur_w);
594 jmp(unrolled_w_label);
600 cmp(reg_ur_str_w, ur_w);
601 jl(exit_label, T_NEAR);
603 mov(aux_reg_ddst, reg_ddst);
604 mov(aux_reg_kernel, reg_kernel);
606 load_ddst(ur_ch_blocks, ur_w);
607 apply_filter(ur_ch_blocks, ur_w);
608 store_dsrc(ur_ch_blocks, ur_w);
610 add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
611 add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
613 sub(reg_ur_str_w, ur_w);
620 template <cpu_isa_t isa>
621 void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
624 mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
625 mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
626 mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
627 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
628 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
629 mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
630 mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
632 Label ch_blocks_tail_label;
635 int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
637 cmp(reg_ch_blocks, jcp.nb_ch_blocking);
638 jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
640 loop_body(jcp.nb_ch_blocking); // channel main loop
642 if (ch_blocks_tail) {
643 L(ch_blocks_tail_label);
645 cmp(reg_ch_blocks, ch_blocks_tail);
646 jne(exit_label, T_NEAR);
648 loop_body(ch_blocks_tail); // channel tail loop
656 template <cpu_isa_t isa>
657 status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
658 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
659 const memory_desc_wrapper &diff_src_d,
660 const memory_desc_wrapper &weights_d,
661 const memory_desc_wrapper &diff_dst_d) {
662 if (!mayiuse(isa)) return status::unimplemented;
664 const int simd_w = isa == avx512_common ? 16 : 8;
666 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
667 if (!with_groups) return status::unimplemented;
669 jcp.ngroups = weights_d.dims()[0];
670 jcp.mb = diff_src_d.dims()[0];
672 jcp.oc = diff_dst_d.dims()[1];
673 jcp.oc_without_padding = jcp.oc;
674 jcp.ic = diff_src_d.dims()[1];
676 jcp.ih = diff_src_d.dims()[2];
677 jcp.iw = diff_src_d.dims()[3];
678 jcp.oh = diff_dst_d.dims()[2];
679 jcp.ow = diff_dst_d.dims()[3];
681 jcp.kh = weights_d.dims()[3];
682 jcp.kw = weights_d.dims()[4];
684 jcp.t_pad = cd.padding[0][0];
685 jcp.l_pad = cd.padding[0][1];
686 jcp.b_pad = cd.padding[1][0];
687 jcp.r_pad = cd.padding[1][1];
689 jcp.stride_h = cd.strides[0];
690 jcp.stride_w = cd.strides[1];
692 jcp.dilate_h = cd.dilates[0];
693 jcp.dilate_w = cd.dilates[1];
695 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
696 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
698 jcp.src_fmt = diff_src_d.format();
700 bool ok_to_pad_channels = true
701 && jcp.oc == jcp.ngroups
702 && jcp.ic == jcp.ngroups
703 && isa == avx512_common;
704 if (ok_to_pad_channels) {
705 jcp.oc = rnd_up(jcp.oc, simd_w);
706 jcp.ic = rnd_up(jcp.oc, simd_w);
707 jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
710 auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
711 auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
714 && jcp.oc == jcp.ngroups
715 && jcp.ic == jcp.ngroups
716 && jcp.ngroups % simd_w == 0
719 && diff_src_d.format() == desired_act_fmt
720 && weights_d.format() == desired_wei_fmt
721 && diff_dst_d.format() == desired_act_fmt
722 && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
723 && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
724 && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
725 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
726 && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
727 if (!args_ok) return status::unimplemented;
729 jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
731 jcp.ch_block = simd_w;
732 jcp.nb_ch = jcp.ic / jcp.ch_block;
733 jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
734 if (jcp.nb_ch < jcp.nb_ch_blocking)
735 jcp.nb_ch_blocking = jcp.nb_ch;
737 return status::success;
740 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
741 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
742 template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;