Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_uni_dw_conv_kernel_f32.hpp"
24
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
34
35 using namespace Xbyak;
36
37 template <cpu_isa_t isa>
38 void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
39     int repeats = isa == sse42 ? 2 : 1;
40     for (int i = 0; i < repeats; i++) {
41         for (int ch = 0; ch < ur_ch_blocks; ch++) {
42             for (int ow = 0; ow < ur_w; ow++) {
43                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
44
45                 int b_off = ch*jcp.ch_block + i*4;
46                 if (this->jcp.with_bias)
47                     uni_vmovups(vmm_acc,
48                         vmmword[reg_bias + b_off*sizeof(float)]);
49                 else
50                     uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
51
52                 int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block
53                     + ow*jcp.ch_block + i*4;
54                 if (this->jcp.with_sum)
55                     uni_vaddps(vmm_acc, vmm_acc,
56                         vmmword[reg_output + o_off*sizeof(float)]);
57             }
58         }
59     }
60 }
61
62 template <cpu_isa_t isa>
63 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter(
64         int ur_ch_blocks, int ur_w) {
65     int ch_blk = jcp.ch_block;
66     int dilate_h = jcp.dilate_h + 1;
67     int dilate_w = jcp.dilate_w + 1;
68     int stride_w = jcp.stride_w;
69
70     Label iter_exit_label;
71
72     cmp(reg_kh, 0);
73     je(iter_exit_label, T_NEAR);
74     cmp(reg_kw, 0);
75     je(iter_exit_label, T_NEAR);
76
77     mov(iter_kh, reg_kh);
78     Label kh_label;
79     L(kh_label); {
80         mov(iter_kw, reg_kw);
81         mov(aux1_reg_input, aux_reg_input);
82         mov(aux1_reg_kernel, aux_reg_kernel);
83
84         Label kw_label;
85         L(kw_label); {
86             int repeats = isa == sse42 ? 2 : 1;
87             for (int i = 0; i < repeats; i++) {
88                 for (int ch = 0; ch < ur_ch_blocks; ch++) {
89                     int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4;
90                     Vmm vmm_ker = get_ker_reg(0);
91                     uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
92                         + ker_off*sizeof(float)]);
93
94                     for (int ow = 0; ow < ur_w; ow++) {
95                         int inp_off = ch*jcp.ih*jcp.iw*ch_blk
96                             + ow*stride_w*ch_blk + i*4;
97                         Vmm vmm_src = get_src_reg(0);
98                         uni_vmovups(vmm_src, ptr[aux1_reg_input
99                             + inp_off*sizeof(float)]);
100
101                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
102                             + ch*ur_w + ow);
103                         uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
104                     }
105                 }
106             }
107             add(aux1_reg_kernel, ch_blk*sizeof(float));
108             add(aux1_reg_input, ch_blk*dilate_w*sizeof(float));
109
110             dec(iter_kw);
111             cmp(iter_kw, 0);
112             jg(kw_label, T_NEAR);
113         }
114         add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
115         add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
116
117         dec(iter_kh);
118         cmp(iter_kh, 0);
119         jg(kh_label, T_NEAR);
120     }
121
122     L(iter_exit_label);
123 }
124
125 template <cpu_isa_t isa>
126 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
127         int ur_ch_blocks, int ur_w) {
128     int ch_blk = jcp.ch_block;
129     int dilate_h = jcp.dilate_h + 1;
130     int dilate_w = jcp.dilate_w + 1;
131     int stride_w = jcp.stride_w;
132
133     Label iter_exit_label;
134
135     cmp(reg_kh, 0);
136     je(iter_exit_label, T_NEAR);
137
138     mov(iter_kh, reg_kh);
139     Label kh_label;
140     L(kh_label); {
141         int repeats = isa == sse42 ? 2 : 1;
142         for (int i = 0; i < repeats; i++) {
143             for (int ch = 0; ch < ur_ch_blocks; ch++) {
144                 for (int kw = 0; kw < jcp.kw; kw++) {
145                     int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4;
146
147                     Vmm vmm_ker = get_ker_reg(0);
148                     uni_vmovups(vmm_ker, ptr[aux_reg_kernel
149                         + ker_off*sizeof(float)]);
150
151                     for (int ow = 0; ow < ur_w; ow++) {
152                         int inp_off = ch*jcp.ih*jcp.iw*ch_blk
153                             + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4;
154
155                         Vmm vmm_src = get_src_reg(0);
156                         uni_vmovups(vmm_src, ptr[aux_reg_input
157                             + inp_off*sizeof(float)]);
158
159                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
160                             + ch*ur_w + ow);
161                         uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
162                     }
163                 }
164             }
165         }
166
167         add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
168         add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
169
170         dec(iter_kh);
171         cmp(iter_kh, 0);
172         jg(kh_label, T_NEAR);
173     }
174
175     L(iter_exit_label);
176 }
177
178 template <cpu_isa_t isa>
179 void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation(int ur_ch_blocks, int ur_w) {
180     if (this->jcp.with_eltwise) {
181         inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta));
182
183         // TODO (dmitrygo): need to find appropriate way to share labels.
184         mov(imm_addr64, l_table);
185         int repeats = isa == sse42 ? 2 : 1;
186         for (int i = 0; i < repeats; i++) {
187             for (int ch = 0; ch < ur_ch_blocks; ch++) {
188                 for (int ow = 0; ow < ur_w; ow++) {
189                     Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
190
191                     inject(eltwise_generator.computeVector(vmm_dst, vmm_dst));
192                 }
193             }
194         }
195     }
196 }
197
198 template <cpu_isa_t isa>
199 void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
200         int ur_ch_blocks, int ur_w) {
201     int ch_blk = jcp.ch_block;
202
203     int repeats = isa == sse42 ? 2 : 1;
204     for (int i = 0; i < repeats; i++) {
205         for (int ch = 0; ch < ur_ch_blocks; ch++) {
206             for (int ow = 0; ow < ur_w; ow++) {
207                 int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4;
208                 Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
209
210                 uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
211             }
212         }
213     }
214 }
215
216 template <cpu_isa_t isa>
217 void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
218     Label unrolled_w_label;
219     Label tail_w_label;
220     Label exit_label;
221
222     L(unrolled_w_label); {
223         int ur_w = jcp.ur_w;
224
225         cmp(reg_ur_w, ur_w);
226         jl(tail_w_label, T_NEAR);
227
228         mov(aux_reg_input, reg_input);
229         mov(aux_reg_kernel, reg_kernel);
230
231         load_src(ur_ch_blocks, ur_w);
232         apply_filter_unrolled(ur_ch_blocks, ur_w);
233         apply_activation(ur_ch_blocks, ur_w);
234         store_dst(ur_ch_blocks, ur_w);
235
236         add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
237         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
238
239         sub(reg_ur_w, ur_w);
240         jmp(unrolled_w_label);
241     }
242
243     L(tail_w_label); {
244         int ur_w = 1;
245
246         cmp(reg_ur_w, ur_w);
247         jl(exit_label, T_NEAR);
248
249         mov(aux_reg_input, reg_input);
250         mov(aux_reg_kernel, reg_kernel);
251
252         load_src(ur_ch_blocks, ur_w);
253         apply_filter(ur_ch_blocks, ur_w);
254         apply_activation(ur_ch_blocks, ur_w);
255         store_dst(ur_ch_blocks, ur_w);
256
257         add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
258         add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
259
260         sub(reg_ur_w, ur_w);
261         jmp(tail_w_label);
262     }
263
264     L(exit_label);
265 }
266
267 template <cpu_isa_t isa>
268 void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate()
269 {
270     nstl::vector<int> shared_vecs;
271     shared_vecs.push_back(0);
272     shared_vecs.push_back(1);
273     shared_vecs.push_back(2);
274     shared_vecs.push_back(3);
275     if (isa == avx512_common)
276         shared_vecs.push_back(31);
277
278     nstl::vector<Reg64> shared_regs;
279     shared_regs.push_back(imm_addr64);
280
281     eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs);
282
283     this->preamble();
284
285     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
286     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
287     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
288     if (jcp.with_bias)
289         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
290     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
291     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
292     mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
293     mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
294
295     Label ch_blocks_tail_label;
296     Label exit_label;
297
298     int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
299
300     cmp(reg_ch_blocks, jcp.nb_ch_blocking);
301     jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
302
303     loop_body(jcp.nb_ch_blocking); // channel main loop
304
305     if (ch_blocks_tail) {
306         L(ch_blocks_tail_label);
307
308         cmp(reg_ch_blocks, ch_blocks_tail);
309         jne(exit_label, T_NEAR);
310
311         loop_body(ch_blocks_tail); // channel tail loop
312     }
313
314     L(exit_label);
315
316     this->postamble();
317
318     // TODO (dmitrygo): need to find appropriate way to share labels.
319     align(64);
320     L(l_table);
321     inject(eltwise_generator.prepareTable());
322     eltwise_generator.release();
323 }
324
325 template <cpu_isa_t isa>
326 bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
327         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
328     const auto &p = attr.post_ops_;
329
330     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
331     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
332
333     switch (p.len_) {
334     case 0: return true; // no post_ops
335     case 1: return !jcp.with_eltwise && (is_eltwise(0) || is_sum(0)); // sum OR relu
336     case 2: return !jcp.with_eltwise && (is_sum(0) && is_eltwise(1)); // sum->relu
337     default: return false;
338     }
339
340     return false;
341 }
342
343 template <cpu_isa_t isa>
344 status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
345         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
346         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
347         const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
348 {
349     if (!mayiuse(isa)) return status::unimplemented;
350
351     const int simd_w = isa == avx512_common ? 16 : 8;
352
353     jcp.prop_kind = cd.prop_kind;
354
355     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
356     if (!with_groups) return status::unimplemented;
357
358     jcp.ngroups = weights_d.dims()[0];
359     jcp.mb = src_d.dims()[0];
360
361     jcp.oc = dst_d.dims()[1];
362     jcp.oc_without_padding = jcp.oc;
363     jcp.ic = src_d.dims()[1];
364
365     jcp.ih = src_d.dims()[2];
366     jcp.iw = src_d.dims()[3];
367     jcp.oh = dst_d.dims()[2];
368     jcp.ow = dst_d.dims()[3];
369
370     jcp.kh = weights_d.dims()[3];
371     jcp.kw = weights_d.dims()[4];
372
373     jcp.t_pad = cd.padding[0][0];
374     jcp.l_pad = cd.padding[0][1];
375     jcp.b_pad = cd.padding[1][0];
376     jcp.r_pad = cd.padding[1][1];
377
378     jcp.stride_h = cd.strides[0];
379     jcp.stride_w = cd.strides[1];
380
381     jcp.dilate_h = cd.dilates[0];
382     jcp.dilate_w = cd.dilates[1];
383
384     jcp.src_fmt = src_d.format();
385     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
386     jcp.with_eltwise = with_relu;
387     jcp.eltwise_alg = mkldnn_eltwise_relu;
388     jcp.eltwise_alpha = relu_negative_slope;
389
390     if (!post_ops_ok(jcp, attr))
391         return status::unimplemented;
392
393     const auto &p = attr.post_ops_;
394     jcp.with_sum = p.find(primitive_kind::sum) != -1;
395     if (!jcp.with_eltwise) {
396         int eltwise_ind = p.find(primitive_kind::eltwise);
397         if (eltwise_ind != -1) {
398             jcp.with_eltwise  = true;
399             jcp.eltwise_alg   = p.entry_[eltwise_ind].eltwise.alg;
400             jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
401             jcp.eltwise_beta  = p.entry_[eltwise_ind].eltwise.beta;
402             jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale;
403         }
404     }
405
406     bool ok_to_pad_channels = true
407         && jcp.oc == jcp.ngroups
408         && jcp.ic == jcp.ngroups
409         && isa == avx512_common;
410     if (ok_to_pad_channels) {
411         jcp.oc = rnd_up(jcp.oc, simd_w);
412         jcp.ic = rnd_up(jcp.oc, simd_w);
413         jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
414     }
415
416     auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
417     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
418
419     bool args_ok = true
420         && jcp.oc == jcp.ngroups
421         && jcp.ic == jcp.ngroups
422         && jcp.ngroups % simd_w == 0
423         && src_d.format() == desired_act_fmt
424         && weights_d.format() == desired_wei_fmt
425         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
426         && dst_d.format() == desired_act_fmt
427         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
428         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
429         && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
430     if (!args_ok) return status::unimplemented;
431
432     jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
433
434     jcp.ch_block = simd_w;
435     jcp.nb_ch = jcp.oc / jcp.ch_block;
436     jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
437     if (jcp.nb_ch < jcp.nb_ch_blocking)
438         jcp.nb_ch_blocking = jcp.nb_ch;
439
440     if (jcp.with_eltwise) {
441         int nvecs_elt = jit_uni_eltwise_vector_f32<isa>::sharedVecsCount(jcp.eltwise_alg);
442         int nvecs_conv = isa == avx512_common ? 32 - nvecs_elt : 16 - nvecs_elt;
443         int isa_mult = isa == sse42 ? 2 : 1;
444         while (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv) {
445             if (jcp.nb_ch_blocking <= 1) {
446                 break;
447             }
448
449             jcp.nb_ch_blocking -= 1;
450         }
451
452         if (isa_mult * jcp.ur_w * jcp.nb_ch_blocking > nvecs_conv)
453             return status::unimplemented;
454     }
455
456     return status::success;
457 }
458
459 template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
460 template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
461 template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
462
463 template <cpu_isa_t isa>
464 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
465         int ur_ch_blocks, int ur_str_w) {
466     int repeats = isa == sse42 ? 2 : 1;
467     for (int i = 0; i < repeats; i++) {
468         for (int ch = 0; ch < ur_ch_blocks; ch++) {
469             for (int w = 0; w < ur_str_w; w++) {
470                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
471                     + ch*ur_str_w + w);
472                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
473             }
474         }
475     }
476 }
477
478 template <cpu_isa_t isa>
479 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
480         int ur_ch_blocks, int ur_str_w) {
481     int kw = jcp.kw;
482     int kh = jcp.kh;
483     int ow = jcp.ow;
484     int oh = jcp.oh;
485
486     int ch_blk = jcp.ch_block;
487     int stride_h = jcp.stride_h;
488     int stride_w = jcp.stride_w;
489
490     Label iter_exit_label;
491
492     cmp(reg_kh, 0);
493     je(iter_exit_label, T_NEAR);
494
495     cmp(reg_kw, 0);
496     je(iter_exit_label, T_NEAR);
497
498     mov(iter_kh, reg_kh);
499     Label kh_label;
500     L(kh_label); {
501         mov(aux1_reg_ddst, aux_reg_ddst);
502         mov(aux1_reg_kernel, aux_reg_kernel);
503
504         mov(iter_kw, reg_kw);
505         Label kw_label;
506         L(kw_label); {
507             int repeats = isa == sse42 ? 2 : 1;
508             for (int i = 0; i < repeats; i++) {
509                 for (int ch = 0; ch < ur_ch_blocks; ch++) {
510                     int ker_off = ch*kh*kw*ch_blk + i*4;
511                     Vmm vmm_ker = get_ker_reg(0);
512                     uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
513                         + ker_off*sizeof(float)]);
514
515                     for (int w = 0; w < ur_str_w; w++) {
516                         int ddst_off = (ch*oh*ow + w)*ch_blk + i*4;
517
518                         Vmm vmm_src = get_src_reg(0);
519                         uni_vmovups(vmm_src, ptr[aux1_reg_ddst
520                             + ddst_off*sizeof(float)]);
521
522                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
523                             + ch*ur_str_w + w);
524                         uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
525                     }
526                 }
527             }
528
529             add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float));
530             sub(aux1_reg_ddst, ch_blk*sizeof(float));
531
532             sub(iter_kw, stride_w);
533             cmp(iter_kw, 0);
534             jg(kw_label, T_NEAR);
535         }
536
537         add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float));
538         sub(aux_reg_ddst, ow*ch_blk*sizeof(float));
539
540         sub(iter_kh, stride_h);
541         cmp(iter_kh, 0);
542         jg(kh_label, T_NEAR);
543     }
544
545     L(iter_exit_label);
546 }
547
548 template <cpu_isa_t isa>
549 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
550         int ur_ch_blocks, int ur_str_w) {
551     int ch_blk = jcp.ch_block;
552     int iw = jcp.iw;
553     int ih = jcp.ih;
554     int stride_w = jcp.stride_w;
555
556     int repeats = isa == sse42 ? 2 : 1;
557     for (int i = 0; i < repeats; i++) {
558         for (int ch = 0; ch < ur_ch_blocks; ch++) {
559             for (int w = 0; w < ur_str_w; w++) {
560                 int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4;
561                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
562                     + ch*ur_str_w + w);
563
564                 uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc);
565             }
566         }
567     }
568 }
569
570 template <cpu_isa_t isa>
571 inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body(
572         int ur_ch_blocks) {
573     Label unrolled_w_label;
574     Label tail_w_label;
575     Label exit_label;
576
577     L(unrolled_w_label); {
578         int ur_w = jcp.ur_w;
579
580         cmp(reg_ur_str_w, ur_w);
581         jl(tail_w_label, T_NEAR);
582
583         mov(aux_reg_ddst, reg_ddst);
584         mov(aux_reg_kernel, reg_kernel);
585
586         load_ddst(ur_ch_blocks, ur_w);
587         apply_filter(ur_ch_blocks, ur_w);
588         store_dsrc(ur_ch_blocks, ur_w);
589
590         add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
591         add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
592
593         sub(reg_ur_str_w, ur_w);
594         jmp(unrolled_w_label);
595     }
596
597     L(tail_w_label); {
598         int ur_w = 1;
599
600         cmp(reg_ur_str_w, ur_w);
601         jl(exit_label, T_NEAR);
602
603         mov(aux_reg_ddst, reg_ddst);
604         mov(aux_reg_kernel, reg_kernel);
605
606         load_ddst(ur_ch_blocks, ur_w);
607         apply_filter(ur_ch_blocks, ur_w);
608         store_dsrc(ur_ch_blocks, ur_w);
609
610         add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
611         add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
612
613         sub(reg_ur_str_w, ur_w);
614         jmp(tail_w_label);
615     }
616
617     L(exit_label);
618 }
619
620 template <cpu_isa_t isa>
621 void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
622     preamble();
623
624     mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
625     mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
626     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
627     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
628     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
629     mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
630     mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
631
632     Label ch_blocks_tail_label;
633     Label exit_label;
634
635     int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
636
637     cmp(reg_ch_blocks, jcp.nb_ch_blocking);
638     jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
639
640     loop_body(jcp.nb_ch_blocking); // channel main loop
641
642     if (ch_blocks_tail) {
643         L(ch_blocks_tail_label);
644
645         cmp(reg_ch_blocks, ch_blocks_tail);
646         jne(exit_label, T_NEAR);
647
648         loop_body(ch_blocks_tail); // channel tail loop
649     }
650
651     L(exit_label);
652
653     this->postamble();
654 }
655
656 template <cpu_isa_t isa>
657 status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
658         jit_conv_conf_t &jcp, const convolution_desc_t &cd,
659         const memory_desc_wrapper &diff_src_d,
660         const memory_desc_wrapper &weights_d,
661         const memory_desc_wrapper &diff_dst_d) {
662     if (!mayiuse(isa)) return status::unimplemented;
663
664     const int simd_w = isa == avx512_common ? 16 : 8;
665
666     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
667     if (!with_groups) return status::unimplemented;
668
669     jcp.ngroups = weights_d.dims()[0];
670     jcp.mb = diff_src_d.dims()[0];
671
672     jcp.oc = diff_dst_d.dims()[1];
673     jcp.oc_without_padding = jcp.oc;
674     jcp.ic = diff_src_d.dims()[1];
675
676     jcp.ih = diff_src_d.dims()[2];
677     jcp.iw = diff_src_d.dims()[3];
678     jcp.oh = diff_dst_d.dims()[2];
679     jcp.ow = diff_dst_d.dims()[3];
680
681     jcp.kh = weights_d.dims()[3];
682     jcp.kw = weights_d.dims()[4];
683
684     jcp.t_pad = cd.padding[0][0];
685     jcp.l_pad = cd.padding[0][1];
686     jcp.b_pad = cd.padding[1][0];
687     jcp.r_pad = cd.padding[1][1];
688
689     jcp.stride_h = cd.strides[0];
690     jcp.stride_w = cd.strides[1];
691
692     jcp.dilate_h = cd.dilates[0];
693     jcp.dilate_w = cd.dilates[1];
694
695     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
696     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
697
698     jcp.src_fmt = diff_src_d.format();
699
700     bool ok_to_pad_channels = true
701         && jcp.oc == jcp.ngroups
702         && jcp.ic == jcp.ngroups
703         && isa == avx512_common;
704     if (ok_to_pad_channels) {
705         jcp.oc = rnd_up(jcp.oc, simd_w);
706         jcp.ic = rnd_up(jcp.oc, simd_w);
707         jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
708     }
709
710     auto desired_act_fmt = isa == avx512_common ? nChw16c : nChw8c;
711     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
712
713     bool args_ok = true
714         && jcp.oc == jcp.ngroups
715         && jcp.ic == jcp.ngroups
716         && jcp.ngroups % simd_w == 0
717         && jcp.dilate_h == 0
718         && jcp.dilate_w == 0
719         && diff_src_d.format() == desired_act_fmt
720         && weights_d.format() == desired_wei_fmt
721         && diff_dst_d.format() == desired_act_fmt
722         && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
723         && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
724         && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
725         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
726         && jcp.ngroups <= weights_d.blocking_desc().padding_dims[0];
727     if (!args_ok) return status::unimplemented;
728
729     jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
730
731     jcp.ch_block = simd_w;
732     jcp.nb_ch = jcp.ic / jcp.ch_block;
733     jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
734     if (jcp.nb_ch < jcp.nb_ch_blocking)
735         jcp.nb_ch_blocking = jcp.nb_ch;
736
737     return status::success;
738 }
739
740 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
741 template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
742 template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
743
744 }
745 }
746 }