Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_uni_dw_conv_kernel_f32.hpp"
24
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::memory_tracking::names;
34 using namespace mkldnn::impl::utils;
35
36 using namespace Xbyak;
37
38 template <cpu_isa_t isa>
39 void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
40     int repeats = isa == sse42 ? 2 : 1;
41     for (int i = 0; i < repeats; i++) {
42         for (int ch = 0; ch < ur_ch_blocks; ch++) {
43             for (int ow = 0; ow < ur_w; ow++) {
44                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
45
46                 int b_off = ch*jcp.ch_block + i*4;
47                 if (this->jcp.with_bias)
48                     uni_vmovups(vmm_acc,
49                         vmmword[reg_bias + b_off*sizeof(float)]);
50                 else
51                     uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
52
53                 int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block
54                     + ow*jcp.ch_block + i*4;
55                 if (this->jcp.with_sum)
56                     uni_vaddps(vmm_acc, vmm_acc,
57                         vmmword[reg_output + o_off*sizeof(float)]);
58             }
59         }
60     }
61 }
62
63 template <cpu_isa_t isa>
64 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter(
65         int ur_ch_blocks, int ur_w) {
66     int ch_blk = jcp.ch_block;
67     int dilate_h = jcp.dilate_h + 1;
68     int dilate_w = jcp.dilate_w + 1;
69     int stride_w = jcp.stride_w;
70
71     Label iter_exit_label;
72
73     cmp(reg_kh, 0);
74     je(iter_exit_label, T_NEAR);
75     cmp(reg_kw, 0);
76     je(iter_exit_label, T_NEAR);
77
78     mov(iter_kh, reg_kh);
79     Label kh_label;
80     L(kh_label); {
81         mov(iter_kw, reg_kw);
82         mov(aux1_reg_input, aux_reg_input);
83         mov(aux1_reg_kernel, aux_reg_kernel);
84
85         Label kw_label;
86         L(kw_label); {
87             int repeats = isa == sse42 ? 2 : 1;
88             for (int i = 0; i < repeats; i++) {
89                 for (int ch = 0; ch < ur_ch_blocks; ch++) {
90                     int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4;
91                     Vmm vmm_ker = get_ker_reg(0);
92                     uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
93                         + ker_off*sizeof(float)]);
94
95                     for (int ow = 0; ow < ur_w; ow++) {
96                         int inp_off = ch*jcp.ih*jcp.iw*ch_blk
97                             + ow*stride_w*ch_blk + i*4;
98                         Vmm vmm_src = get_src_reg(0);
99                         uni_vmovups(vmm_src, ptr[aux1_reg_input
100                             + inp_off*sizeof(float)]);
101
102                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
103                             + ch*ur_w + ow);
104                         uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
105                     }
106                 }
107             }
108             add(aux1_reg_kernel, ch_blk*sizeof(float));
109             add(aux1_reg_input, ch_blk*dilate_w*sizeof(float));
110
111             dec(iter_kw);
112             cmp(iter_kw, 0);
113             jg(kw_label, T_NEAR);
114         }
115         add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
116         add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
117
118         dec(iter_kh);
119         cmp(iter_kh, 0);
120         jg(kh_label, T_NEAR);
121     }
122
123     L(iter_exit_label);
124 }
125
126 template <cpu_isa_t isa>
127 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
128         int ur_ch_blocks, int ur_w) {
129     int ch_blk = jcp.ch_block;
130     int dilate_h = jcp.dilate_h + 1;
131     int dilate_w = jcp.dilate_w + 1;
132     int stride_w = jcp.stride_w;
133
134     Label iter_exit_label;
135
136     cmp(reg_kh, 0);
137     je(iter_exit_label, T_NEAR);
138
139     mov(iter_kh, reg_kh);
140     Label kh_label;
141     L(kh_label); {
142         int repeats = isa == sse42 ? 2 : 1;
143         for (int i = 0; i < repeats; i++) {
144             for (int ch = 0; ch < ur_ch_blocks; ch++) {
145                 for (int kw = 0; kw < jcp.kw; kw++) {
146                     int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4;
147
148                     Vmm vmm_ker = get_ker_reg(0);
149                     uni_vmovups(vmm_ker, ptr[aux_reg_kernel
150                         + ker_off*sizeof(float)]);
151
152                     for (int ow = 0; ow < ur_w; ow++) {
153                         int inp_off = ch*jcp.ih*jcp.iw*ch_blk
154                             + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4;
155
156                         Vmm vmm_src = get_src_reg(0);
157                         uni_vmovups(vmm_src, ptr[aux_reg_input
158                             + inp_off*sizeof(float)]);
159
160                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
161                             + ch*ur_w + ow);
162                         uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
163                     }
164                 }
165             }
166         }
167
168         add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
169         add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
170
171         dec(iter_kh);
172         cmp(iter_kh, 0);
173         jg(kh_label, T_NEAR);
174     }
175
176     L(iter_exit_label);
177 }
178
179 template <cpu_isa_t isa>
180 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_ch_blocks, int ur_w) {
181     int repeats = isa == sse42 ? 2 : 1;
182
183     int eltwise_inj_idx = 0;
184     int depthwise_inj_idx = 0;
185     const auto &p = attr_.post_ops_;
186
187     for (int i = 0; i < p.len_; i++) {
188         auto& post_op = p.entry_[i];
189         if (post_op.is_eltwise()) {
190             int start_idx = get_acc_reg(0).getIdx();
191             int end_idx = get_acc_reg(repeats * ur_w * ur_ch_blocks).getIdx();
192
193             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, end_idx);
194             eltwise_inj_idx++;
195         } else if (post_op.is_depthwise()) {
196             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
197             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
198
199             add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
200             add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
201
202             for (int ch = 0; ch < ur_ch_blocks; ch++) {
203                 for (int k = 0; k < repeats; k++) {
204                     int start_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch).getIdx();
205                     int end_idx = get_acc_reg(k*ur_ch_blocks*ur_w + ur_w * ch + ur_w).getIdx();
206
207                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
208                             start_idx, end_idx, reg_d_weights, reg_d_bias);
209
210                     add(reg_d_weights, jcp.ch_block / repeats * sizeof(float));
211                     add(reg_d_bias, jcp.ch_block / repeats * sizeof(float));
212                 }
213             }
214
215             depthwise_inj_idx++;
216         }
217     }
218 }
219
220 template <cpu_isa_t isa>
221 void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
222         int ur_ch_blocks, int ur_w) {
223     int ch_blk = jcp.ch_block;
224
225     int repeats = isa == sse42 ? 2 : 1;
226     for (int i = 0; i < repeats; i++) {
227         for (int ch = 0; ch < ur_ch_blocks; ch++) {
228             for (int ow = 0; ow < ur_w; ow++) {
229                 int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4;
230                 Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
231
232                 uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
233             }
234         }
235     }
236 }
237
238 template <cpu_isa_t isa>
239 void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
240     Label unrolled_w_label;
241     Label tail_w_label;
242     Label exit_label;
243
244     L(unrolled_w_label); {
245         int ur_w = jcp.ur_w;
246
247         cmp(reg_ur_w, ur_w);
248         jl(tail_w_label, T_NEAR);
249
250         mov(aux_reg_input, reg_input);
251         mov(aux_reg_kernel, reg_kernel);
252
253         load_src(ur_ch_blocks, ur_w);
254         apply_filter_unrolled(ur_ch_blocks, ur_w);
255         apply_postprocess(ur_ch_blocks, ur_w);
256         store_dst(ur_ch_blocks, ur_w);
257
258         add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
259         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
260
261         sub(reg_ur_w, ur_w);
262         jmp(unrolled_w_label);
263     }
264
265     L(tail_w_label); {
266         int ur_w = 1;
267
268         cmp(reg_ur_w, ur_w);
269         jl(exit_label, T_NEAR);
270
271         mov(aux_reg_input, reg_input);
272         mov(aux_reg_kernel, reg_kernel);
273
274         load_src(ur_ch_blocks, ur_w);
275         apply_filter(ur_ch_blocks, ur_w);
276         apply_postprocess(ur_ch_blocks, ur_w);
277         store_dst(ur_ch_blocks, ur_w);
278
279         add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
280         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
281
282         sub(reg_ur_w, ur_w);
283         jmp(tail_w_label);
284     }
285
286     L(exit_label);
287 }
288
289 template <cpu_isa_t isa>
290 void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
291     const auto &p = attr_.post_ops_;
292     for (int i = 0; i < p.len_; i++) {
293         auto &post_op = p.entry_[i];
294         if (post_op.is_eltwise()) {
295             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
296                     this,
297                     post_op.eltwise.alg,
298                     post_op.eltwise.alpha,
299                     post_op.eltwise.beta
300             ));
301         } else if (post_op.is_depthwise()) {
302             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
303                     this,
304                     post_op.depthwise.alg
305             ));
306         }
307     }
308
309     this->preamble();
310
311     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
312     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
313     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
314     if (jcp.with_bias)
315         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
316     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
317     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
318     mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
319     mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
320
321     Label ch_blocks_tail_label;
322     Label exit_label;
323
324     int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
325
326     cmp(reg_ch_blocks, jcp.nb_ch_blocking);
327     jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
328
329     loop_body(jcp.nb_ch_blocking); // channel main loop
330
331     if (ch_blocks_tail) {
332         L(ch_blocks_tail_label);
333
334         cmp(reg_ch_blocks, ch_blocks_tail);
335         jne(exit_label, T_NEAR);
336
337         loop_body(ch_blocks_tail); // channel tail loop
338     }
339
340     L(exit_label);
341
342     this->postamble();
343
344     for (auto& inj : eltwise_injectors)
345         inj->prepare_table();
346 }
347
348 template <cpu_isa_t isa>
349 bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
350         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
351     const auto &p = attr.post_ops_;
352
353     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
354     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
355     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
356     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
357
358     switch (p.len_) {
359     case 0: return true;
360     case 1: return is_simple(0) || is_sum(0);
361     case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
362     case 3: return is_sum(0) && is_simple(1) && is_simple(2);
363     default: return false;
364     }
365
366     return false;
367 }
368
369 template <cpu_isa_t isa>
370 status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
371         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
372         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
373         const primitive_attr_t &attr)
374 {
375     if (!mayiuse(isa)) return status::unimplemented;
376
377     const int simd_w = isa == avx512_common ? 16 : 8;
378
379     jcp.prop_kind = cd.prop_kind;
380
381     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
382     if (!with_groups) return status::unimplemented;
383
384     jcp.ngroups = weights_d.dims()[0];
385     jcp.mb = src_d.dims()[0];
386
387     jcp.oc = dst_d.dims()[1];
388     jcp.oc_without_padding = jcp.oc;
389     jcp.ic = src_d.dims()[1];
390
391     jcp.ih = src_d.dims()[2];
392     jcp.iw = src_d.dims()[3];
393     jcp.oh = dst_d.dims()[2];
394     jcp.ow = dst_d.dims()[3];
395
396     jcp.kh = weights_d.dims()[3];
397     jcp.kw = weights_d.dims()[4];
398
399     jcp.t_pad = cd.padding[0][0];
400     jcp.l_pad = cd.padding[0][1];
401     jcp.b_pad = cd.padding[1][0];
402     jcp.r_pad = cd.padding[1][1];
403
404     jcp.stride_h = cd.strides[0];
405     jcp.stride_w = cd.strides[1];
406
407     jcp.dilate_h = cd.dilates[0];
408     jcp.dilate_w = cd.dilates[1];
409
410     jcp.src_fmt = src_d.format();
411     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
412
413     if (!post_ops_ok(jcp, attr))
414         return status::unimplemented;
415
416     const auto &p = attr.post_ops_;
417     jcp.with_sum = p.find(primitive_kind::sum) != -1;
418
419     bool ok_to_pad_channels = true
420         && jcp.oc == jcp.ngroups
421         && jcp.ic == jcp.ngroups
422         && one_of(isa, avx512_common, avx2, sse42);
423     if (ok_to_pad_channels) {
424         jcp.oc = rnd_up(jcp.oc, simd_w);
425         jcp.ic = rnd_up(jcp.oc, simd_w);
426         jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
427     }
428
429     auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
430     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
431
432     bool args_ok = true
433         && jcp.oc == jcp.ngroups
434         && jcp.ic == jcp.ngroups
435         && jcp.ngroups % simd_w == 0
436         && src_d.format() == desired_act_fmt
437         && weights_d.format() == desired_wei_fmt
438         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
439         && dst_d.format() == desired_act_fmt
440         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
441         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
442         && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
443     if (!args_ok) return status::unimplemented;
444
445     jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
446
447     jcp.ch_block = simd_w;
448     jcp.nb_ch = jcp.oc / jcp.ch_block;
449     jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
450     if (jcp.nb_ch < jcp.nb_ch_blocking)
451         jcp.nb_ch_blocking = jcp.nb_ch;
452
453     return status::success;
454 }
455
456 template <cpu_isa_t isa>
457 void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(
458         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
459     if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
460         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
461 }
462
463 template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
464 template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
465 template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
466
467 template <cpu_isa_t isa>
468 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
469         int ur_ch_blocks, int ur_str_w) {
470     int repeats = isa == sse42 ? 2 : 1;
471     for (int i = 0; i < repeats; i++) {
472         for (int ch = 0; ch < ur_ch_blocks; ch++) {
473             for (int w = 0; w < ur_str_w; w++) {
474                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
475                     + ch*ur_str_w + w);
476                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
477             }
478         }
479     }
480 }
481
482 template <cpu_isa_t isa>
483 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
484         int ur_ch_blocks, int ur_str_w) {
485     int kw = jcp.kw;
486     int kh = jcp.kh;
487     int ow = jcp.ow;
488     int oh = jcp.oh;
489
490     int ch_blk = jcp.ch_block;
491     int stride_h = jcp.stride_h;
492     int stride_w = jcp.stride_w;
493
494     Label iter_exit_label;
495
496     cmp(reg_kh, 0);
497     je(iter_exit_label, T_NEAR);
498
499     cmp(reg_kw, 0);
500     je(iter_exit_label, T_NEAR);
501
502     mov(iter_kh, reg_kh);
503     Label kh_label;
504     L(kh_label); {
505         mov(aux1_reg_ddst, aux_reg_ddst);
506         mov(aux1_reg_kernel, aux_reg_kernel);
507
508         mov(iter_kw, reg_kw);
509         Label kw_label;
510         L(kw_label); {
511             int repeats = isa == sse42 ? 2 : 1;
512             for (int i = 0; i < repeats; i++) {
513                 for (int ch = 0; ch < ur_ch_blocks; ch++) {
514                     int ker_off = ch*kh*kw*ch_blk + i*4;
515                     Vmm vmm_ker = get_ker_reg(0);
516                     uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
517                         + ker_off*sizeof(float)]);
518
519                     for (int w = 0; w < ur_str_w; w++) {
520                         int ddst_off = (ch*oh*ow + w)*ch_blk + i*4;
521
522                         Vmm vmm_src = get_src_reg(0);
523                         uni_vmovups(vmm_src, ptr[aux1_reg_ddst
524                             + ddst_off*sizeof(float)]);
525
526                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
527                             + ch*ur_str_w + w);
528                         uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
529                     }
530                 }
531             }
532
533             add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float));
534             sub(aux1_reg_ddst, ch_blk*sizeof(float));
535
536             sub(iter_kw, stride_w);
537             cmp(iter_kw, 0);
538             jg(kw_label, T_NEAR);
539         }
540
541         add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float));
542         sub(aux_reg_ddst, ow*ch_blk*sizeof(float));
543
544         sub(iter_kh, stride_h);
545         cmp(iter_kh, 0);
546         jg(kh_label, T_NEAR);
547     }
548
549     L(iter_exit_label);
550 }
551
552 template <cpu_isa_t isa>
553 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
554         int ur_ch_blocks, int ur_str_w) {
555     int ch_blk = jcp.ch_block;
556     int iw = jcp.iw;
557     int ih = jcp.ih;
558     int stride_w = jcp.stride_w;
559
560     int repeats = isa == sse42 ? 2 : 1;
561     for (int i = 0; i < repeats; i++) {
562         for (int ch = 0; ch < ur_ch_blocks; ch++) {
563             for (int w = 0; w < ur_str_w; w++) {
564                 int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4;
565                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
566                     + ch*ur_str_w + w);
567
568                 uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc);
569             }
570         }
571     }
572 }
573
574 template <cpu_isa_t isa>
575 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body(
576         int ur_ch_blocks) {
577     Label unrolled_w_label;
578     Label tail_w_label;
579     Label exit_label;
580
581     L(unrolled_w_label); {
582         int ur_w = jcp.ur_w;
583
584         cmp(reg_ur_str_w, ur_w);
585         jl(tail_w_label, T_NEAR);
586
587         mov(aux_reg_ddst, reg_ddst);
588         mov(aux_reg_kernel, reg_kernel);
589
590         load_ddst(ur_ch_blocks, ur_w);
591         apply_filter(ur_ch_blocks, ur_w);
592         store_dsrc(ur_ch_blocks, ur_w);
593
594         add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
595         add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
596
597         sub(reg_ur_str_w, ur_w);
598         jmp(unrolled_w_label);
599     }
600
601     L(tail_w_label); {
602         int ur_w = 1;
603
604         cmp(reg_ur_str_w, ur_w);
605         jl(exit_label, T_NEAR);
606
607         mov(aux_reg_ddst, reg_ddst);
608         mov(aux_reg_kernel, reg_kernel);
609
610         load_ddst(ur_ch_blocks, ur_w);
611         apply_filter(ur_ch_blocks, ur_w);
612         store_dsrc(ur_ch_blocks, ur_w);
613
614         add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
615         add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
616
617         sub(reg_ur_str_w, ur_w);
618         jmp(tail_w_label);
619     }
620
621     L(exit_label);
622 }
623
624 template <cpu_isa_t isa>
625 void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
626     preamble();
627
628     mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
629     mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
630     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
631     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
632     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
633     mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
634     mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
635
636     Label ch_blocks_tail_label;
637     Label exit_label;
638
639     int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
640
641     cmp(reg_ch_blocks, jcp.nb_ch_blocking);
642     jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
643
644     loop_body(jcp.nb_ch_blocking); // channel main loop
645
646     if (ch_blocks_tail) {
647         L(ch_blocks_tail_label);
648
649         cmp(reg_ch_blocks, ch_blocks_tail);
650         jne(exit_label, T_NEAR);
651
652         loop_body(ch_blocks_tail); // channel tail loop
653     }
654
655     L(exit_label);
656
657     this->postamble();
658 }
659
660 template <cpu_isa_t isa>
661 status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
662         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
663         const memory_desc_wrapper &diff_src_d,
664         const memory_desc_wrapper &weights_d,
665         const memory_desc_wrapper &diff_dst_d) {
666     if (!mayiuse(isa)) return status::unimplemented;
667
668     const int simd_w = isa == avx512_common ? 16 : 8;
669
670     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
671     if (!with_groups) return status::unimplemented;
672
673     jcp.ngroups = weights_d.dims()[0];
674     jcp.mb = diff_src_d.dims()[0];
675
676     jcp.oc = diff_dst_d.dims()[1];
677     jcp.oc_without_padding = jcp.oc;
678     jcp.ic = diff_src_d.dims()[1];
679
680     jcp.ih = diff_src_d.dims()[2];
681     jcp.iw = diff_src_d.dims()[3];
682     jcp.oh = diff_dst_d.dims()[2];
683     jcp.ow = diff_dst_d.dims()[3];
684
685     jcp.kh = weights_d.dims()[3];
686     jcp.kw = weights_d.dims()[4];
687
688     jcp.t_pad = cd.padding[0][0];
689     jcp.l_pad = cd.padding[0][1];
690     jcp.b_pad = cd.padding[1][0];
691     jcp.r_pad = cd.padding[1][1];
692
693     jcp.stride_h = cd.strides[0];
694     jcp.stride_w = cd.strides[1];
695
696     jcp.dilate_h = cd.dilates[0];
697     jcp.dilate_w = cd.dilates[1];
698
699     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
700     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
701
702     jcp.src_fmt = diff_src_d.format();
703
704     bool ok_to_pad_channels = true
705         && jcp.oc == jcp.ngroups
706         && jcp.ic == jcp.ngroups
707         && one_of(isa, avx512_common, avx2);
708     if (ok_to_pad_channels) {
709         jcp.oc = rnd_up(jcp.oc, simd_w);
710         jcp.ic = rnd_up(jcp.oc, simd_w);
711         jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
712     }
713
714     auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
715     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
716
717     bool args_ok = true
718         && jcp.oc == jcp.ngroups
719         && jcp.ic == jcp.ngroups
720         && jcp.ngroups % simd_w == 0
721         && jcp.dilate_h == 0
722         && jcp.dilate_w == 0
723         && diff_src_d.format() == desired_act_fmt
724         && weights_d.format() == desired_wei_fmt
725         && diff_dst_d.format() == desired_act_fmt
726         && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
727         && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
728         && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
729         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
730         && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
731     if (!args_ok) return status::unimplemented;
732
733     jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
734
735     jcp.ch_block = simd_w;
736     jcp.nb_ch = jcp.ic / jcp.ch_block;
737     jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
738     if (jcp.nb_ch < jcp.nb_ch_blocking)
739         jcp.nb_ch_blocking = jcp.nb_ch;
740
741     return status::success;
742 }
743
744 template <cpu_isa_t isa>
745 void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
746         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
747     UNUSED(scratchpad);
748     UNUSED(jcp);
749 }
750
751 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
752 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
753 template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
754
755 template <cpu_isa_t isa>
756 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() {
757     for (int r = 0; r < reg_repeats; ++r) {
758         for (int i = 0; i < jcp.kw; ++i) {
759             Vmm vmm_acc = get_acc_reg(r * jcp.kw + i);
760             uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
761         }
762     }
763 }
764
765 template <cpu_isa_t isa>
766 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter() {
767     for (int r = 0; r < reg_repeats; ++r) {
768         const int reg_set = r * jcp.kw;
769         for (int i = 0; i < jcp.kw; ++i) {
770             int off_filter = (reg_set + i) * simd_w;
771             Vmm vmm_acc = get_acc_reg(reg_set + i);
772             uni_vmovups(vmm_acc,
773                     vmmword[reg_tmp_filter + off_filter * sizeof(float)]);
774         }
775     }
776 }
777
778 template <cpu_isa_t isa>
779 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() {
780     for (int r = 0; r < reg_repeats; ++r) {
781         Vmm vmm_bias = get_bias_reg(r);
782         uni_vpxor(vmm_bias, vmm_bias, vmm_bias);
783     }
784 }
785
786 template <cpu_isa_t isa>
787 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias() {
788     for (int r = 0; r < reg_repeats; ++r) {
789         Vmm vmm_bias = get_bias_reg(r);
790         uni_vmovups(
791                 vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]);
792     }
793 }
794
795 template <cpu_isa_t isa>
796 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
797         int unroll_w, int l_pad, int pad_offset, int ow_block) {
798
799     const int iw_block = ow_block * jcp.stride_w;
800     const int right_border = jcp.iw - iw_block;
801
802     const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
803
804     /* preamble count for number of cascaded LOAD + FMA operation */
805     const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
806
807     /* LOAD initial input registers, then cascade LOADs and FMAs*/
808     for (int r = 0; r < reg_repeats; ++r) {
809         for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
810             int off_output = (i_ur * reg_repeats + r) * simd_w;
811             Vmm vmm_output = get_output_reg(r);
812             uni_vmovups(vmm_output,
813                     ptr[reg_tmp_output + off_output * sizeof(float)]);
814             if (i_ur == 0) {
815                 for (int c = 0; c < input_overlap; ++c) {
816                     int off_input
817                             = ((c - pad_offset) * reg_repeats + r) * simd_w;
818                     Vmm vmm_input
819                             = get_input_reg((c % jcp.kw) * reg_repeats + r);
820                     uni_vmovups(vmm_input,
821                             ptr[reg_tmp_input + off_input * sizeof(float)]);
822                 }
823             } else {
824                 for (int c = 0; c < cascade_input; ++c) {
825                     int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
826                     int off_input
827                             = ((overlap + c - pad_offset) * reg_repeats + r)
828                             * simd_w;
829                     Vmm vmm_input = get_input_reg(
830                             ((overlap + c) % jcp.kw) * reg_repeats + r);
831                     uni_vmovups(vmm_input,
832                             ptr[reg_tmp_input + off_input * sizeof(float)]);
833                 }
834             }
835
836             for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
837                 int io_overlap = i_kw + (i_ur * jcp.stride_w);
838
839                 /* Don't apply FMAs that fall into the padded region */
840                 if (io_overlap - l_pad < 0
841                         || io_overlap - jcp.l_pad >= right_border)
842                     continue;
843
844                 Vmm vmm_input = get_input_reg(
845                         ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
846                 Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
847                 Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input;
848                 if (isa == sse42)
849                     uni_vmovups(vmm_aux, vmm_input);
850                 uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
851             }
852         }
853     }
854 }
855
856 template <cpu_isa_t isa>
857 inline void
858 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll(
859         const int unroll_w) {
860     for (int r = 0; r < reg_repeats; ++r) {
861         for (int i = 0; i < unroll_w; ++i) {
862             Vmm vmm_bias = get_bias_reg(r);
863             int off_output = (i * reg_repeats + r) * simd_w;
864             if (isa == sse42) {
865                 /* Need to support unaligned address loads for SSE42*/
866                 Vmm vmm_output = get_output_reg(1 + r);
867                 uni_vmovups(vmm_output,
868                         ptr[reg_tmp_output + off_output * sizeof(float)]);
869                 uni_vaddps(vmm_bias, vmm_bias, vmm_output);
870             } else {
871                 uni_vaddps(vmm_bias, vmm_bias,
872                         vmmword[reg_tmp_output + off_output * sizeof(float)]);
873             }
874         }
875     }
876 }
877
878 template <cpu_isa_t isa>
879 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter() {
880     for (int r = 0; r < reg_repeats; ++r) {
881         const int reg_set = r * jcp.kw;
882         for (int i = 0; i < jcp.kw; ++i) {
883             int off_filter = (i + reg_set) * simd_w;
884             Vmm vmm_acc = get_acc_reg(i + reg_set);
885             uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)],
886                     vmm_acc);
887         }
888     }
889 }
890
891 template <cpu_isa_t isa>
892 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias() {
893     for (int r = 0; r < reg_repeats; ++r) {
894         Vmm vmm_bias = get_bias_reg(r);
895         uni_vmovups(
896                 vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias);
897     }
898 }
899
900 template <cpu_isa_t isa>
901 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop(
902         const int block_size) {
903     Label oh_label;
904     Label ow_blk_label;
905
906     const int unroll_w = nstl::min(block_size, jcp.ow);
907     const int unroll_w_trips = jcp.ow / unroll_w;
908     const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0;
909
910     const int ch_offset = jcp.ch_block;
911
912     mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
913     mov(reg_oh_worksize,
914             ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
915
916     mov(reg_tmp_output, reg_output_baddr);
917     L(oh_label);
918     {
919
920         mov(iter_ow_blk, unroll_w_trips);
921         L(ow_blk_label);
922         {
923
924             compute_bias_step_unroll(unroll_w);
925             add(reg_tmp_output, unroll_w * ch_offset * sizeof(float));
926
927             dec(iter_ow_blk);
928             cmp(iter_ow_blk, 0);
929             jg(ow_blk_label, T_NEAR);
930         }
931
932         if (tail_w > 0) {
933             compute_bias_step_unroll(tail_w);
934             add(reg_tmp_output, tail_w * ch_offset * sizeof(float));
935         }
936
937         inc(reg_oh);
938         cmp(reg_oh, reg_oh_worksize);
939         jl(oh_label, T_NEAR);
940     }
941 }
942
943 template <cpu_isa_t isa>
944 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() {
945
946     const int ch_offset = jcp.ch_block;
947
948     Label kh_loop_label, skip_zeroing_label;
949
950     mov(reg_exec_flags,
951             ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
952     and_(reg_exec_flags, FLAG_ZERO_FILTER);
953     test(reg_exec_flags, reg_exec_flags);
954     je(skip_zeroing_label);
955
956     zero_filter();
957
958     mov(reg_tmp_filter, reg_filter_baddr);
959     mov(reg_kh, jcp.kh);
960     L(kh_loop_label);
961     {
962         store_filter();
963
964         add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
965         dec(reg_kh);
966         cmp(reg_kh, 0);
967         jg(kh_loop_label);
968     }
969
970     /* Comeback pointers */
971     sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float));
972
973     L(skip_zeroing_label);
974 }
975
976 template <cpu_isa_t isa>
977 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step(
978         int unroll_w, int l_pad, int pad_offset, int ow_block) {
979
980     const int ch_offset = jcp.ch_block;
981
982     Label kh_loop_label, skip_loop_label;
983
984     cmp(reg_kh_count, 0);
985     je(skip_loop_label, T_NEAR);
986
987     mov(reg_kh, reg_kh_count);
988     L(kh_loop_label);
989     {
990         load_filter();
991         compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block);
992         store_filter();
993
994         add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
995         add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
996         dec(reg_kh);
997         cmp(reg_kh, 0);
998         jg(kh_loop_label);
999     }
1000
1001     /* Comeback pointers */
1002     Label kh_comeback_label;
1003     mov(reg_kh, reg_kh_count);
1004     L(kh_comeback_label);
1005     {
1006         sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
1007         sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
1008         dec(reg_kh);
1009         cmp(reg_kh, 0);
1010         jg(kh_comeback_label, T_NEAR);
1011     }
1012
1013     L(skip_loop_label);
1014 }
1015
1016 template <cpu_isa_t isa>
1017 inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
1018         int unroll_w, int l_pad, int pad_offset, int ow_block) {
1019
1020     const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ?
1021             jcp.ih / jcp.stride_h - 1 :
1022             jcp.oh - jcp.b_pad - 1;
1023     const int ch_offset = jcp.ch_block;
1024     const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
1025     const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
1026
1027     Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label,
1028             end_h_loop_label;
1029
1030     mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
1031     mov(reg_oh_worksize,
1032             ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
1033     mov(reg_kh_count,
1034             ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]);
1035
1036     mov(reg_tmp_output, reg_output_baddr);
1037     mov(reg_tmp_input, reg_input_baddr);
1038     mov(reg_tmp_filter, reg_filter_baddr);
1039
1040     L(h_loop_label);
1041     {
1042
1043         compute_h_step(unroll_w, l_pad, pad_offset, ow_block);
1044
1045         add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float));
1046
1047         /* If within the top_pad region */
1048         if (jcp.t_pad > 0) {
1049             /* Skip t_pad area if no longer in initial h_block */
1050             cmp(reg_oh, jcp.t_pad);
1051             jg(skip_tpad_label, T_NEAR);
1052
1053             cmp(reg_kh_count, jcp.kh);
1054             jge(skip_tpad_label, T_NEAR);
1055
1056             add(reg_kh_count, t_overlap_off);
1057             sub(reg_tmp_filter,
1058                     t_overlap_off * jcp.kw * ch_offset * sizeof(float));
1059
1060             /* kernel has moved beyond padding (adjust for stride effects) */
1061             if (jcp.t_pad % jcp.stride_h != 0) {
1062                 int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
1063                 add(reg_tmp_input,
1064                         inp_corr * jcp.iw * ch_offset * sizeof(float));
1065             }
1066             jmp(tpad_loop_label, T_NEAR);
1067         }
1068
1069         L(skip_tpad_label);
1070
1071         cmp(reg_oh, io_overlap);
1072         jl(skip_bpad_label, T_NEAR);
1073         sub(reg_kh_count, b_overlap_off);
1074
1075         L(skip_bpad_label);
1076         add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
1077
1078         L(tpad_loop_label);
1079
1080         cmp(reg_oh, jcp.ih / jcp.stride_h);
1081         jge(end_h_loop_label, T_NEAR);
1082
1083         inc(reg_oh);
1084
1085         cmp(reg_oh, reg_oh_worksize);
1086         jl(h_loop_label, T_NEAR);
1087     }
1088     L(end_h_loop_label);
1089 }
1090
1091 template <cpu_isa_t isa>
1092 inline void
1093 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
1094
1095     const int ch_offset = jcp.ch_block;
1096     int ow = jcp.ow;
1097     int pad_offset = 0;
1098     int l_pad = jcp.l_pad;
1099
1100     /* Calculate effective padding */
1101     int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
1102                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
1103                     - (jcp.iw + jcp.l_pad - 1));
1104
1105     /* Is this strictly defined by:
1106      * -code-size (?)
1107      * -address size (?) */
1108     const int max_unroll_w = 30;
1109     const int block_size = 15;
1110
1111     int unroll_w_tail = 0;
1112     int unroll_w = 0;
1113     int unroll_w_trips = 0;
1114
1115     if (jcp.ow > max_unroll_w) {
1116         unroll_w = nstl::min(block_size, jcp.ow);
1117         unroll_w_trips = ow / unroll_w;
1118         /* calculate tail */
1119         unroll_w_tail = ow % unroll_w;
1120         /* Perform some rebalancing if tail too small*/
1121         if ((unroll_w_tail == 0 && r_pad != 0)
1122                 || (r_pad > 0 && r_pad >= unroll_w_tail)) {
1123             if (unroll_w_trips > 1) {
1124                 unroll_w_tail += unroll_w;
1125                 unroll_w_trips--;
1126             } else {
1127                 /* Idealy, this case shouldn't happen */
1128                 unroll_w_tail += (unroll_w - unroll_w / 2);
1129                 unroll_w = unroll_w / 2;
1130             }
1131         }
1132     } else {
1133         unroll_w = jcp.ow;
1134         unroll_w_trips = nstl::max(1, ow / unroll_w);
1135     }
1136     if (jcp.with_bias) {
1137         Label skip_load_bias;
1138         mov(reg_bias_baddr,
1139                 ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
1140
1141         zero_bias();
1142
1143         mov(reg_exec_flags,
1144                 ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
1145         and_(reg_exec_flags, FLAG_ZERO_BIAS);
1146         test(reg_exec_flags, reg_exec_flags);
1147         jne(skip_load_bias);
1148
1149         load_bias();
1150
1151         L(skip_load_bias);
1152         compute_bias_loop(block_size);
1153
1154         store_bias();
1155     }
1156
1157     /* Pass filter address, then offset for h_padding. */
1158     compute_zero_filter();
1159     mov(reg_kh_offset,
1160             ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]);
1161     add(reg_filter_baddr, reg_kh_offset);
1162
1163     /* compute left padded block */
1164     if (l_pad) {
1165         compute_h_loop(unroll_w, l_pad, 0, 0);
1166         add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
1167         add(reg_input_baddr,
1168                 unroll_w * jcp.stride_w * ch_offset * sizeof(float));
1169         unroll_w_trips--;
1170         pad_offset = l_pad;
1171         l_pad = 0;
1172     }
1173
1174     /* compute middle block */
1175     Label ow_blk_label;
1176
1177     /* Insert loop for 'ow' block when middle block needs to execute more
1178      * than once */
1179     bool do_ow_blk_loop = unroll_w_trips > 1;
1180     if (do_ow_blk_loop) {
1181         mov(iter_ow_blk, unroll_w_trips);
1182         L(ow_blk_label);
1183     }
1184     if (unroll_w_trips > 0) {
1185         compute_h_loop(unroll_w, l_pad, pad_offset, 0);
1186         add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
1187         add(reg_input_baddr,
1188                 unroll_w * jcp.stride_w * ch_offset * sizeof(float));
1189     }
1190     if (do_ow_blk_loop) {
1191         dec(iter_ow_blk);
1192         cmp(iter_ow_blk, 0);
1193         jg(ow_blk_label, T_NEAR);
1194     }
1195
1196     /* compute right padded block */
1197     if (unroll_w_tail) {
1198         compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
1199     }
1200 }
1201
1202 template <cpu_isa_t isa>
1203 void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
1204     preamble();
1205
1206     mov(reg_input_baddr,
1207             ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]);
1208     mov(reg_output_baddr,
1209             ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]);
1210     mov(reg_filter_baddr,
1211             ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]);
1212
1213     compute_ow_block_unroll();
1214
1215     this->postamble();
1216 }
1217
1218 template <cpu_isa_t isa>
1219 status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
1220         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1221         const memory_desc_wrapper &src_d,
1222         const memory_desc_wrapper &diff_weights_d,
1223         const memory_desc_wrapper &diff_dst_d, int nthreads) {
1224     if (!mayiuse(isa))
1225         return status::unimplemented;
1226
1227     jcp.ngroups = diff_weights_d.dims()[0];
1228     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1229     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1230
1231     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1232
1233     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
1234
1235     if (!jcp.is_depthwise)
1236         return status::unimplemented;
1237
1238     jcp.ch_block = isa == avx512_common ? 16 : 8;
1239
1240     jcp.mb = src_d.dims()[0];
1241
1242     jcp.ih = src_d.dims()[2];
1243     jcp.iw = src_d.dims()[3];
1244     jcp.oh = diff_dst_d.dims()[2];
1245     jcp.ow = diff_dst_d.dims()[3];
1246
1247     jcp.kh = diff_weights_d.dims()[3];
1248     jcp.kw = diff_weights_d.dims()[4];
1249
1250     jcp.stride_h = cd.strides[0];
1251     jcp.stride_w = cd.strides[1];
1252
1253     jcp.t_pad = cd.padding[0][0];
1254     jcp.b_pad = cd.padding[1][0];
1255
1256     jcp.l_pad = cd.padding[0][1];
1257     jcp.r_pad = cd.padding[1][1];
1258
1259     jcp.dilate_h = cd.dilates[0];
1260     jcp.dilate_w = cd.dilates[1];
1261
1262     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1263     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1264
1265     jcp.src_fmt = src_d.format();
1266
1267     jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
1268
1269     auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
1270     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
1271
1272     bool args_ok = true && src_d.format() == desired_act_fmt
1273             && diff_weights_d.format() == desired_wei_fmt
1274             && diff_dst_d.format() == desired_act_fmt
1275             && one_of(cd.bias_desc.format, memory_format::undef, any, x)
1276             && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
1277             && jcp.dilate_w == 0 && jcp.kw <= 3
1278             && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
1279             && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
1280     if (!args_ok)
1281         return status::unimplemented;
1282
1283     jcp.nb_ch = jcp.ngroups / jcp.ch_block;
1284
1285     /* kernel applicability check wrt boundaries
1286      * the conditions are quite general across the kernels we have,
1287      * but ideally the check should belong to a specific kernel... */
1288     const int max_hpad = (jcp.kh - 1 + 1) / 2;
1289     const int max_wpad = (jcp.kw - 1 + 1) / 2;
1290     const bool boundaries_ok = true && jcp.t_pad <= max_hpad
1291             && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
1292             && jcp.r_pad <= max_wpad;
1293     if (!boundaries_ok)
1294         return status::unimplemented;
1295
1296     balance(jcp, nthreads);
1297
1298     return status::success;
1299 }
1300
1301 template <cpu_isa_t isa>
1302 void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
1303         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1304     /* Notes: if splitting thread work on 'mb', then a reduction has to take
1305      * place. Hence, book a per-thread, local weights-buffer for the
1306      * reduction */
1307     if (jcp.nthr_mb > 1) {
1308         const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
1309         scratchpad.book(key_conv_wei_reduction,
1310                 sizeof(float) * wei_size * (jcp.nthr_mb - 1));
1311
1312         if (jcp.with_bias)
1313             scratchpad.book(key_conv_bia_reduction,
1314                     sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
1315     }
1316 }
1317
1318 template <cpu_isa_t isa>
1319 void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp,
1320         int nthreads) {
1321     jcp.nthr = nthreads;
1322     jcp.nthr_g = jcp.nthr_mb = 1;
1323
1324     /* Basic-Heuristics for parallel strategy:
1325      * 1) Tries to parallel on the number of Groups (g) where tasks are
1326      * independent. Otherwise,
1327      * 2) Tries to split the work across g and MiniBatch (mb).
1328      * Parallelizing on mb requires computing a reduction for weights.
1329      *
1330      * NOTE: because of 'task partitioning' scheme, there will be unbalanced
1331      * per-thread load when the number of threads is high (e.g. > 16).
1332      */
1333     jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
1334     jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
1335
1336     jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
1337 }
1338
1339 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;
1340 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>;
1341 template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse42>;
1342
1343 }
1344 }
1345 }