Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 * Copyright 2018 YANDEX LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include "c_types_map.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_avx2_conv_kernel_f32.hpp"
25
26 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36
37 using namespace Xbyak;
38
39 void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
40         int pad_l, int pad_r, int oc_blocks)
41 {
42     int iw = jcp.iw;
43     int ih = jcp.ih;
44     int id = jcp.id;
45     int kw = jcp.kw;
46     int kh = jcp.kh;
47     int kd = jcp.kd;
48     int nb_ic = jcp.nb_ic;
49     int stride_w = jcp.stride_w;
50     int dilate_w = jcp.dilate_w + 1;
51     int ic_blk = jcp.ic_block;
52     int oc_blk = jcp.oc_block;
53
54     for (int ki = 0; ki < kw; ki++) {
55         int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
56         int jj_end = ur_w
57             - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
58         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
59             for (int jj = jj_start; jj < jj_end; jj++) {
60                 size_t inp_off;
61                 if (one_of(jcp.src_fmt, ncw, nchw, ncdhw))
62                     inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw
63                         + (ki*dilate_w + jj*stride_w - pad_l));
64                 else
65                     inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w
66                                 - pad_l)*ic_blk + ifm2);
67                 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
68                         make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
69             }
70
71             for (int ii = 0; ii < oc_blocks; ii++) {
72                 int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk
73                         + ki * ic_blk * oc_blk + ifm2 * oc_blk;
74                 vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]);
75                 for (int jj = jj_start; jj < jj_end; jj++)
76                     if (mayiuse(avx2))
77                         vfmadd231ps(Ymm(ur_w * ii + jj),
78                                 Ymm(oc_blocks * ur_w + jj), ymm15);
79                     else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
80                         vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
81                         vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
82                     }
83             }
84         }
85     }
86 }
87
88 void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
89         int pad_l, int pad_r, char pad_tag,
90         int oc_blocks, char oc_blocks_tag)
91 {
92     Label kw_loop;
93
94     int iw = jcp.iw;
95     int ih = jcp.ih;
96     int id = jcp.id;
97     int kw = jcp.kw;
98     int kh = jcp.kh;
99     int kd = jcp.kd;
100     int nb_ic = jcp.nb_ic;
101     int stride_w = jcp.stride_w;
102     int dilate_w = jcp.dilate_w + 1;
103     int ic_blk = jcp.ic_block;
104     int oc_blk = jcp.oc_block;
105
106     xor_(ki_iter, ki_iter);
107     L(kw_loop);
108     {
109         int jj_start = 0;
110         int jj_end = ur_w;
111         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
112             for (int jj = jj_start; jj < jj_end; jj++) {
113                 size_t inp_off;
114                 if (one_of(jcp.src_fmt, ncw, nchw, ncdhw))
115                     inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw
116                             + (jj * stride_w - pad_l));
117                 else
118                     inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk
119                             + ifm2);
120                 vbroadcastss(Ymm(oc_blocks * ur_w + jj),
121                     make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
122             }
123             for (int ii = 0; ii < oc_blocks; ii++) {
124                 int aux_kernel_offset =
125                     ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk;
126                 vmovups(ymm15, ptr[aux_reg_kernel
127                         + sizeof(float) * aux_kernel_offset]);
128                 for (int jj = jj_start; jj < jj_end; jj++)
129                     if (mayiuse(avx2))
130                         vfmadd231ps(Ymm(ur_w * ii + jj),
131                                 Ymm(oc_blocks * ur_w + jj), ymm15);
132                     else { // Intel AVX support
133                         vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
134                         vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
135                     }
136             }
137         }
138         add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
139         add(aux_reg_input, sizeof(float) * (one_of(jcp.src_fmt, ncw, nchw, ncdhw)
140                 ? dilate_w : ic_blk * dilate_w));
141
142         inc(ki_iter);
143         cmp(ki_iter, kw);
144         jl(kw_loop, T_NEAR);
145     }
146 }
147
148 void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
149         int pad_l, int pad_r, char pad_tag,
150         int oc_blocks, char oc_blocks_tag)
151 {
152     int iw = jcp.iw;
153     int kw = jcp.kw;
154     int ow = jcp.ow;
155     int oh = jcp.oh;
156     int od = jcp.od;
157     int dilate_h = jcp.dilate_h + 1;
158     int dilate_w = jcp.dilate_w + 1;
159     int ic_blk = jcp.ic_block;
160     int oc_blk = jcp.oc_block;
161     const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
162         ? 1 : ic_blk;
163     const int inp_off = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
164         ? dilate_w : ic_blk * dilate_w;
165
166     Label init_done, init_first;
167
168     if (!jcp.with_sum) {
169         test(reg_ci_flag, FLAG_IC_FIRST);
170         jne(init_first, T_NEAR);
171     }
172
173     for (int ii = 0; ii < oc_blocks; ii++) {
174         for (int jj = 0; jj < ur_w; jj++) {
175             size_t offt;
176             if (jcp.with_dw_conv)
177                 offt = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
178             else
179                 offt = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
180             vmovups(Ymm(ur_w * ii + jj),
181                     make_safe_addr(reg_output, offt, reg_long_offt));
182         }
183     }
184
185     if (jcp.with_sum && jcp.with_bias) {
186         test(reg_ci_flag, FLAG_IC_FIRST);
187         je(init_done, T_NEAR);
188
189         for (int ii = 0; ii < oc_blocks; ii++)
190             for (int jj = 0; jj < ur_w; jj++)
191                 vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
192                     yword[reg_bias + sizeof(float) * ii * oc_blk]);
193     }
194
195     jmp(init_done);
196
197     L(init_first);
198     if (this->jcp.with_bias) {
199         for (int ii = 0; ii < oc_blocks; ii++)
200             for (int jj = 0; jj < ur_w; jj++)
201                 vmovups(Ymm(ur_w * ii + jj),
202                         yword[reg_bias + sizeof(float) * ii * oc_blk]);
203     } else {
204         for (int ii = 0; ii < oc_blocks; ii++)
205             for (int jj = 0; jj < ur_w; jj++)
206                 uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj));
207     }
208
209     L(init_done);
210
211     if (one_of(jcp.ndims, 3, 4)) {
212         mov(aux_reg_input, reg_input);
213         mov(aux_reg_kernel, reg_kernel);
214     }
215
216     Label skip_kh_loop, skip_kd_loop, kd_loop;
217     if (jcp.ndims == 5) {
218         push(reg_output);
219         push(oi_iter);
220
221         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
222         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
223         mov(aux_reg_inp_d, reg_input);
224
225         if ((jcp.dilate_d >= jcp.id)
226                 || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
227             cmp(reg_ki, 0);
228             je(skip_kd_loop, T_NEAR);
229         }
230         L(kd_loop);
231         mov(kj, ptr[param1 + GET_OFF(kh_padding)]);
232     } else {
233         mov(kj, reg_kh);
234     }
235
236     if (jcp.ndims == 5) {
237         mov(aux_reg_input, aux_reg_inp_d);
238         mov(aux_reg_kernel, aux_reg_ker_d);
239     }
240
241     if ((jcp.dilate_h >= jcp.ih)
242             || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
243         cmp(kj, 0);
244         je(skip_kh_loop, T_NEAR);
245     }
246     Label kh_loop;
247     L(kh_loop);
248     {
249         if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
250             oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
251                     oc_blocks_tag);
252             sub(aux_reg_input, sizeof(float) * kw * inp_off);
253             add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
254         } else {
255             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
256             add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
257             add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
258         }
259
260         dec(kj);
261         cmp(kj, 0);
262         jg(kh_loop, T_NEAR);
263     }
264
265     L(skip_kh_loop);
266
267     if (jcp.ndims == 5) {
268         add(aux_reg_inp_d,
269             sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult);
270         add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block
271             * jcp.ic_block);
272
273         dec(reg_ki);
274         cmp(reg_ki, 0);
275         jg(kd_loop, T_NEAR);
276         L(skip_kd_loop);
277
278         pop(oi_iter);
279         pop(reg_output);
280     }
281
282     Label regular_store;
283
284     test(reg_ci_flag, FLAG_IC_LAST);
285     je(regular_store, T_NEAR);
286
287     int eltwise_inj_idx = 0;
288     int depthwise_inj_idx = 0;
289     const auto &p = attr_.post_ops_;
290
291     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
292     for (int i = 0; i < end_idx; i++) {
293         auto& post_op = p.entry_[i];
294         if (post_op.is_eltwise()) {
295             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, oc_blocks * ur_w);
296             eltwise_inj_idx++;
297         } else if (post_op.is_depthwise()) {
298             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
299             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
300
301             add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
302             add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
303
304             for (int ii = 0; ii < oc_blocks; ii++) {
305                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
306                         ur_w * ii, ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
307
308                 add(reg_d_weights, jcp.oc_block * sizeof(float));
309                 add(reg_d_bias, jcp.oc_block * sizeof(float));
310             }
311
312             depthwise_inj_idx++;
313         }
314     }
315
316     L(regular_store);
317
318     for (int ii = 0; ii < oc_blocks; ii++) {
319         for (int jj = 0; jj < ur_w; jj++) {
320             size_t o_off;
321             if (jcp.with_dw_conv)
322                 o_off = sizeof(float) * ((size_t)ii * od * jcp_dw.kh * ow + jj) * oc_blk;
323             else
324                 o_off = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
325             Ymm reg_out = Ymm(ur_w * ii + jj);
326             vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
327         }
328     }
329 }
330
331 inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
332         int oc_blocks, char oc_blocks_tag)
333 {
334     int ur_w = jcp.ur_w;
335     int ur_w_tail = jcp.ur_w_tail;
336     int n_oi = jcp.ow / ur_w;
337     int iw = jcp.iw;
338     int kw = jcp.kw;
339     int ic_blk = jcp.ic_block;
340     int oc_blk = jcp.oc_block;
341     int dilate_w = jcp.dilate_w + 1;
342     int str_w = jcp.stride_w;
343     const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : ic_blk;
344
345     int l_pad = jcp.l_pad;
346     int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
347             - (iw + l_pad - 1));
348     int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
349             - (iw + l_pad - 1);
350     if (r_pad1 > 0) n_oi--;
351
352     if (l_pad > 0) {
353         n_oi--;
354         if (n_oi < 0 && r_pad1 > 0)
355             width_blk_step(ur_w, l_pad, r_pad1,
356                     'l', oc_blocks, oc_blocks_tag); // "lrpad"
357         else
358             width_blk_step(ur_w, l_pad, 0,
359                     'l', oc_blocks, oc_blocks_tag); // "lpad"
360         add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
361         add(reg_output, sizeof(float) * ur_w * oc_blk);
362     }
363
364     Label ow_loop;
365     xor_(oi_iter, oi_iter);
366
367     if (n_oi > 0) {
368         L(ow_loop);
369
370         width_blk_step(ur_w, 0, 0,
371                 'm', oc_blocks, oc_blocks_tag); // "middle"
372         add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
373         add(reg_output, sizeof(float) * ur_w * oc_blk);
374
375         inc(oi_iter);
376         cmp(oi_iter, n_oi);
377         jl(ow_loop, T_NEAR);
378     }
379
380     if (r_pad1 > 0 && n_oi >=0) {
381         width_blk_step(ur_w, 0, r_pad1,
382                 'r', oc_blocks, oc_blocks_tag); // "rpad"
383         add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
384         add(reg_output, sizeof(float) * ur_w * oc_blk);
385     }
386
387     if (ur_w_tail != 0)
388         width_blk_step(ur_w_tail, 0, r_pad,
389                 't', oc_blocks, oc_blocks_tag); // "tail"
390 }
391
392 void jit_avx2_conv_fwd_kernel_f32::generate()
393 {
394     const auto &p = attr_.post_ops_;
395     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
396     for (int i = 0; i < end_idx; i++) {
397         auto &post_op = p.entry_[i];
398         if (post_op.is_eltwise()) {
399             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
400                     this,
401                     post_op.eltwise.alg,
402                     post_op.eltwise.alpha,
403                     post_op.eltwise.beta
404             ));
405         } else if (post_op.is_depthwise()) {
406             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx2>(
407                     this,
408                     post_op.depthwise.alg
409             ));
410         }
411     }
412
413     this->preamble();
414
415     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
416     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
417     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
418     if (jcp.with_bias)
419         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
420     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
421     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
422     mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
423
424     int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
425     Label tail, exit;
426
427     if (jcp.nb_oc > jcp.nb_oc_blocking) {
428         cmp(reg_oc_blocks, jcp.nb_oc_blocking);
429         jne(nb_oc_tail ? tail : exit, T_NEAR);
430
431         solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
432         jmp(exit, T_NEAR);
433
434         if (nb_oc_tail) {
435             L(tail);
436             cmp(reg_oc_blocks, nb_oc_tail);
437             jne(exit, T_NEAR);
438             solve_common(nb_oc_tail, '0' + nb_oc_tail);
439         }
440
441         L(exit);
442     } else if (jcp.nb_oc == jcp.nb_oc_blocking) {
443         solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
444     } else {
445         solve_common(nb_oc_tail, '0' + nb_oc_tail);
446     }
447
448     this->postamble();
449
450     for (auto& inj : eltwise_injectors)
451         inj->prepare_table();
452 }
453
454 bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
455         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
456     const auto &p = attr.post_ops_;
457
458     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
459     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
460     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
461     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
462     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
463
464     switch (p.len_) {
465         case 0: return true;
466         case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
467         case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
468                        (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
469                        (is_simple(0) && is_simple(1));
470         case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
471                        (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
472                        (is_sum(0) && is_simple(1) && is_simple(2));
473         case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
474         default: return false;
475     }
476
477     return false;
478 }
479
480 status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
481         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
482         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
483         const primitive_attr_t &attr)
484 {
485     if (!mayiuse(avx)) return status::unimplemented;
486
487     jcp.prop_kind = cd.prop_kind;
488
489     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
490     int ndims = src_d.ndims();
491     jcp.ndims = ndims;
492
493     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
494     jcp.mb = src_d.dims()[0];
495
496     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
497     jcp.oc_without_padding = jcp.oc;
498     jcp.ic = src_d.dims()[1] / jcp.ngroups;
499
500     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
501     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
502     jcp.iw = src_d.dims()[ndims-1];
503     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
504     jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2];
505     jcp.ow = dst_d.dims()[ndims-1];
506     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
507     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
508     jcp.kw = weights_d.dims()[with_groups + ndims-1];
509
510     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
511     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
512     jcp.l_pad = cd.padding[0][ndims-3];
513     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
514     jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4];
515     jcp.stride_w = cd.strides[ndims-3];
516
517     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
518     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
519     jcp.dilate_w = cd.dilates[ndims-3];
520
521     jcp.src_fmt = src_d.format();
522     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
523
524     if (!post_ops_ok(jcp, attr))
525         return status::unimplemented;
526
527     const auto &p = attr.post_ops_;
528
529     int dw_conv_ind = p.find(primitive_kind::convolution);
530     jcp.with_dw_conv = dw_conv_ind != -1;
531     if (jcp.with_dw_conv) {
532         jcp.dw_conv_oh = jcp.oh;
533         jcp.dw_conv_ow = jcp.ow;
534         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
535         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
536     }
537
538     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
539                 - (jcp.ih + jcp.t_pad - 1);
540
541     if (jcp.with_dw_conv && !mayiuse(avx2))
542         return status::unimplemented;
543
544     if (jcp.with_dw_conv && jcp.ndims == 5)
545         return status::unimplemented;
546
547     if (!mayiuse(avx2)) {
548         for (int i = 0; i < p.len_; i++) {
549             auto &post_op = p.entry_[i];
550             if (post_op.is_eltwise()) {
551                 if (post_op.eltwise.alg != alg_kind::eltwise_relu)
552                     return status::unimplemented;
553             } else if (post_op.is_depthwise()) {
554                 return status::unimplemented;
555             }
556         }
557     }
558
559     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
560
561     jcp.src_dt = cd.src_desc.data_type;
562     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
563     jcp.dst_dt = cd.dst_desc.data_type;
564
565     const int simd_w = 8;
566     const bool flat = jcp.ic < simd_w;
567     const bool mimo = !flat;
568
569
570     /* Grouped channel offset to support 'non-blocked data' format for
571      * convolution sizes with '(input_channel / ngroups) < simd' */
572     jcp.nonblk_group_off
573             = (one_of(src_d.format(), ncw, nchw, ncdhw) && jcp.ngroups > 1) ?
574             jcp.ic :
575             1;
576
577     bool ok_to_pad_channels = true
578         && jcp.ngroups == 1;
579
580     if (ok_to_pad_channels) {
581         jcp.oc = rnd_up(jcp.oc, simd_w);
582         if (mimo)
583             jcp.ic = rnd_up(jcp.ic, simd_w);
584     }
585
586     bool args_ok = true
587         && IMPLICATION(flat, one_of(src_d.format(), ncw, nwc, nchw, nhwc,
588             ncdhw, ndhwc)
589             && one_of(weights_d.format(), Owi8o, gOwi8o, Ohwi8o, gOhwi8o,
590                 Odhwi8o, gOdhwi8o))
591         && IMPLICATION(mimo, one_of(src_d.format(), nCw8c, nChw8c, nCdhw8c)
592             && one_of(weights_d.format(), OIw8i8o, gOIw8i8o, OIhw8i8o,
593                 gOIhw8i8o, OIdhw8i8o, gOIdhw8i8o))
594         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
595         && one_of(dst_d.format(), nCw8c, nChw8c, nCdhw8c);
596     if (!args_ok) return status::unimplemented;
597
598     jcp.ur_h = 1; /* no code-unrolling by h so far */
599     jcp.ur_w = 3;
600
601     jcp.oc_block = simd_w;
602     jcp.nb_oc = jcp.oc / jcp.oc_block;
603
604     jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
605
606     // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively
607     // Thus, we can only assign 14 or 15 YMMs for data storage
608     const int num_avail_regs = mayiuse(avx2) ? 15 : 14;
609     if (!mayiuse(avx2)) {
610         if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
611             // current register assignment requires more YMMs than available
612             // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad
613             if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1)
614                 jcp.ur_w -= 1;
615             else
616                 for (int b = 3; b > 1; b--)
617                     if (jcp.nb_oc % b == 0) {
618                         jcp.nb_oc_blocking = b;
619                         break;
620                     }
621         }
622     }
623
624     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
625     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
626
627     args_ok = true
628         && jcp.oc % simd_w == 0
629         && jcp.l_pad <= jcp.ur_w
630         && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
631                 || (jcp.stride_w == 1 && jcp.stride_h == 1))
632         && IMPLICATION(mimo, jcp.ic % simd_w == 0);
633     if (!args_ok) return status::unimplemented;
634
635     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
636         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
637
638     if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
639         /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
640         jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
641                 nstl::min(jcp.ow, num_avail_regs / 2));
642         jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
643         jcp.ur_w_tail = jcp.ow % jcp.ur_w;
644         /* check again ... */
645         r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
646             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
647         if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
648             return status::unimplemented;
649     }
650     assert(jcp.nb_oc_blocking > 0);
651     assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
652
653     jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
654     jcp.nb_ic = jcp.ic / jcp.ic_block;
655
656     if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
657         jcp.nb_ic_blocking = 12;
658         jcp.nb_ic_blocking_max = 16;
659     } else {
660         jcp.nb_ic_blocking = 1;
661         jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
662     }
663
664     return status::success;
665 }
666
667 void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
668         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
669     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
670         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
671
672     if (jcp.with_dw_conv) {
673         const int nthreads = mkldnn_get_max_threads();
674         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
675         scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
676
677         if (jcp.oc != jcp.oc_without_padding)
678             scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
679     }
680 }
681
682 void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
683         int r_overflow)
684 {
685     int kw = jcp.kw;
686     int kh = jcp.kh;
687     int kd = jcp.kd;
688     int iw = jcp.iw;
689     int ih = jcp.ih;
690     int id = jcp.id;
691     int ow = jcp.ow;
692
693     int ic_block = jcp.ic_block;
694     int oc_block = jcp.oc_block;
695     int nb_ic_block = jcp.nb_ic_blocking;
696     int stride_w = jcp.stride_w;
697     int stride_h = jcp.stride_h;
698
699     Label kd_loop, skip_kd_loop;
700     Label oc_loop, skip_oc_loop;
701
702     for (int ii = 0; ii < nb_ic_block; ii++)
703         for (int jj = 0; jj < ur_w; jj++) {
704             uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
705                       Ymm(ur_w * ii + jj));
706         }
707
708     if (one_of(jcp.ndims, 3, 4)) {
709         cmp(reg_channel_work, 0);
710         jle(skip_oc_loop, T_NEAR);
711         xor_(reg_channel, reg_channel);
712
713         mov(aux_reg_ddst_oc_loop, reg_ddst);
714         mov(aux_reg_kernel_oc_loop, reg_kernel);
715
716         L(oc_loop);
717         mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
718         mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
719     }
720
721     if (jcp.ndims == 5) {
722         assert(jcp.nb_oc_blocking == 1);
723         push(oi_iter);
724
725         mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
726         mov(aux_reg_dst_d, reg_ddst);
727         mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]);
728
729         L(kd_loop);
730         mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]);
731     } else {
732         mov(kj, reg_kh);
733     }
734
735     if (jcp.ndims == 5) {
736         mov(aux_reg_ddst, aux_reg_dst_d);
737         mov(aux_reg_kernel, aux_reg_ker_d);
738     }
739
740     Label kh_loop, skip_kh_loop;
741     cmp(kj, 0);
742     jle(skip_kh_loop, T_NEAR);
743     L(kh_loop); {
744         for (int ki = 0; ki < kw; ki++) {
745             int jj_start = get_iw_start(ki, l_overflow); // 0;
746             int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
747             for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
748
749                 for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
750                     int aux_output_offset
751                       = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
752                     vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
753                             ptr[aux_reg_ddst
754                             + sizeof(float) * aux_output_offset]);
755                 }
756
757                 for (int ii = 0; ii  < nb_ic_block; ii++) {
758                     int aux_kernel_offset
759                         = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
760                         + ki * jcp.ic_block * jcp.oc_block
761                         + ofm2 * jcp.ic_block;
762                     vmovups(ymm15,
763                             ptr[aux_reg_kernel
764                             + sizeof(float) * aux_kernel_offset]);
765                     for (int jj = jj_start; jj  < jj_end; jj += stride_w)
766                         vfmadd231ps(Ymm(ur_w * ii + jj),
767                                 Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
768                 }
769             }
770         }
771         add(aux_reg_kernel, sizeof(float) * stride_h * kw  * oc_block
772                                           * ic_block);
773         sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
774
775         dec(kj);
776         cmp(kj, 0);
777         jg(kh_loop, T_NEAR);
778     }
779     L(skip_kh_loop);
780
781     if (jcp.ndims == 5) {
782         sub(aux_reg_dst_d,
783                 sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
784         add(aux_reg_ker_d,
785                 sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block);
786
787         dec(reg_ki);
788         cmp(reg_ki, 0);
789         jg(kd_loop, T_NEAR);
790         L(skip_kd_loop);
791
792         pop(oi_iter);
793     }
794
795     if (one_of(jcp.ndims, 3, 4)) {
796         int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
797                           * jcp.oc_block;
798         int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
799                           * jcp.ic * jcp.oc_block;
800
801         add(aux_reg_ddst_oc_loop, ddst_oc_shift);
802         add(aux_reg_kernel_oc_loop, kernel_oc_shift);
803
804         inc(reg_channel);
805         cmp(reg_channel, reg_channel_work);
806         jl(oc_loop, T_NEAR);
807
808         L(skip_oc_loop);
809         mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
810     }
811
812     Label no_update_label;
813     cmp(reg_channel, 0);
814     je(no_update_label, T_NEAR);
815     for (int ii = 0; ii < nb_ic_block; ii++) {
816         for (int jj = 0; jj < ur_w; jj++) {
817             size_t offt =
818                 sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
819             vmovups(Ymm(15),
820                     make_safe_addr(reg_dsrc, offt, reg_long_offt));
821             vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
822                     Ymm(15));
823
824         }
825     }
826     L(no_update_label);
827
828     for (int ii = 0; ii < nb_ic_block; ii++)
829         for (int jj = 0; jj < ur_w; jj++) {
830             size_t offt =
831                 sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
832             vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt),
833                     Ymm(ur_w * ii + jj));
834         }
835 }
836
837 void jit_avx2_conv_bwd_data_kernel_f32::generate() {
838     preamble();
839
840     mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
841     mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
842     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
843     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
844     mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
845     mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
846
847     int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
848     int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
849
850     int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
851     int r_overflow = nstl::max(0, (jcp.kw - 1
852                     - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
853     int r_overflow1 = nstl::max(0, (jcp.kw - 1
854                     - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
855
856     int n_oi = jcp.iw / jcp.ur_w;
857     if (r_overflow1 > 0)
858         n_oi--;
859
860     if (jcp.ur_w == jcp.iw) {
861         compute_loop(jcp.ur_w, l_overflow, r_overflow);
862     } else if (n_oi == 0) {
863         compute_loop(jcp.ur_w, l_overflow, r_overflow1);
864         add(reg_dsrc, dsrc_shift);
865         add(reg_ddst, ddst_shift);
866         if (jcp.ur_w_tail != 0)
867             compute_loop(jcp.ur_w_tail, 0, r_overflow);
868     } else {
869         xor_(oi_iter, oi_iter);
870         if (l_overflow > 0) {
871             compute_loop(jcp.ur_w, l_overflow, 0);
872             add(reg_dsrc, dsrc_shift);
873             add(reg_ddst, ddst_shift);
874             inc(oi_iter);
875         }
876
877         if ((l_overflow <= 0 && n_oi > 0) || (l_overflow >  0 && n_oi > 1)) {
878             Label ow_loop;
879             L(ow_loop); {
880                 compute_loop(jcp.ur_w, 0, 0);
881                 add(reg_dsrc, dsrc_shift);
882                 add(reg_ddst, ddst_shift);
883                 inc(oi_iter);
884                 cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
885             }
886         }
887
888         if (r_overflow1 > 0 ) {
889             compute_loop(jcp.ur_w, 0, r_overflow1);
890             add(reg_dsrc, dsrc_shift);
891             add(reg_ddst, ddst_shift);
892         }
893
894         if (jcp.ur_w_tail != 0)
895             compute_loop(jcp.ur_w_tail, 0, r_overflow);
896     }
897
898     this->postamble();
899 }
900
901 status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
902         const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
903         const memory_desc_wrapper &weights_d,
904         const memory_desc_wrapper &diff_dst_d)
905 {
906     if (!mayiuse(avx2)) return status::unimplemented;
907
908     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
909
910     int ndims = diff_src_d.ndims();
911     jcp.ndims = ndims;
912
913     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
914     jcp.mb = diff_src_d.dims()[0];
915
916     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
917     jcp.oc_without_padding = jcp.oc;
918     jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
919
920     jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
921     jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
922     jcp.iw = diff_src_d.dims()[ndims-1];
923     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
924     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
925     jcp.ow = diff_dst_d.dims()[ndims-1];
926
927     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
928     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
929     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
930
931     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
932     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
933     jcp.l_pad = cd.padding[0][ndims-3];
934
935     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
936     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
937     jcp.stride_w = cd.strides[ndims-3];
938
939     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
940     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
941     jcp.dilate_w = cd.dilates[ndims-3];
942
943     const int simd_w = 8;
944
945     /* derivatives */
946     jcp.idp = jcp.id + 2 * jcp.f_pad;
947     jcp.ihp = jcp.ih + 2 * jcp.t_pad;
948     jcp.iwp = jcp.iw + 2 * jcp.l_pad;
949     jcp.ohp = jcp.oh; /* do we really need */
950     jcp.owp = jcp.ow; /* padded output ??? */
951
952     bool ok_to_pad_channels = true
953         && jcp.ngroups == 1;
954
955     /* gemm-based convolution performs better in these cases */
956     if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
957         return status::unimplemented;
958
959     if (ok_to_pad_channels) {
960         jcp.oc = rnd_up(jcp.oc, simd_w);
961         jcp.ic = rnd_up(jcp.ic, simd_w);
962     }
963
964     jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w;
965     jcp.nb_ic = jcp.ic / jcp.ic_block;
966
967     jcp.oc_block = simd_w;
968     if (jcp.oc % jcp.oc_block) return status::unimplemented;
969     jcp.nb_oc = jcp.oc / jcp.oc_block;
970
971     jcp.ur_h = 1; /* no code-unrolling by h so far */
972     jcp.nb_ic_blocking = 1;
973     jcp.nb_oc_blocking = 1;
974     jcp.ur_w = 1;
975
976     if(one_of(ndims, 3, 4) && jcp.ow < 40)
977         jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
978
979     jcp.src_fmt = diff_src_d.format();
980
981     bool args_ok = true
982         && one_of(diff_src_d.format(), nCw8c, nChw8c, nCdhw8c)
983         && one_of(weights_d.format(), gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
984                 gOIdhw8o8i, OIdhw8o8i)
985         && one_of(diff_dst_d.format(), nCw8c, nChw8c, nCdhw8c)
986         && jcp.stride_w == jcp.stride_h
987         && jcp.stride_d == 1
988         && jcp.dilate_d == 0
989         && jcp.dilate_h == 0
990         && jcp.dilate_w == 0
991         && jcp.ic % simd_w == 0
992         && jcp.oc % simd_w == 0
993         && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1
994         && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
995         && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
996     if (!args_ok) return status::unimplemented;
997     jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
998     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
999     int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
1000
1001     const int max_regs = 15; /* Maximun number of registers available for
1002                                 result accumulation and delta dst data.
1003                                 One additional register is reserved for weights
1004                                 data. */
1005
1006     /* Find the best blocking with maximum number of fma instructions
1007        per ur_w * nb_ic_blocking compute loops. Number of required registers
1008        is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1009        ur_w must be divisible by stride_w */
1010     if (jcp.stride_w + 1 > max_regs)  /* Minimal possible registers
1011                                          distribution exceeds max_regs */
1012         return status::unimplemented;
1013
1014     int best_nfmas = 0;
1015     for (int b = 1; b <= 4; b++)
1016     {
1017         if (jcp.nb_ic % b != 0)
1018             continue;
1019
1020         for (int u = jcp.stride_w;
1021              u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
1022              u += jcp.stride_w)
1023         {
1024             int ur_w = nstl::min(u, jcp.iw);
1025             /* maximum 1 step with l_overflow so far */
1026             if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1027                 continue;
1028             int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
1029             if (nfmas > best_nfmas
1030                || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
1031                 jcp.ur_w = ur_w;
1032                 jcp.nb_ic_blocking = b;
1033                 best_nfmas = nfmas;
1034             }
1035         }
1036     }
1037     if (best_nfmas == 0) /* can't find appropriate blocking */
1038         return status::unimplemented;
1039
1040     jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1041
1042     int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
1043                     - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
1044     /* maximum 1 ur_w block with r_overflow so far */
1045     if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
1046         return status::unimplemented;
1047
1048     if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1049         return status::unimplemented;
1050
1051     return status::success;
1052 }
1053
1054 void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
1055         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1056     UNUSED(scratchpad);
1057     UNUSED(jcp);
1058 }
1059
1060 void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
1061     this->preamble();
1062
1063     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
1064     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
1065     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
1066     compute_oh_loop_common();
1067     this->postamble();
1068 }
1069
1070 status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
1071         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
1072         const memory_desc_wrapper &diff_weights_d,
1073         const memory_desc_wrapper &diff_dst_d) {
1074     if (!mayiuse(avx2)) return status::unimplemented;
1075
1076     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1077     int ndims = src_d.ndims();
1078     jcp.ndims = ndims;
1079
1080     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1081     jcp.mb = src_d.dims()[0];
1082
1083     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1084     jcp.oc_without_padding = jcp.oc;
1085     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1086
1087     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1088     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1089     jcp.iw = src_d.dims()[ndims-1];
1090     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1091     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
1092     jcp.ow = diff_dst_d.dims()[ndims-1];
1093
1094     jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
1095     jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
1096     jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
1097
1098     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1099     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1100     jcp.l_pad = cd.padding[0][ndims-3];
1101
1102     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1103     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1104     jcp.stride_w = cd.strides[ndims-3];
1105
1106     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1107     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1108     jcp.dilate_w = cd.dilates[ndims-3];
1109
1110     jcp.src_fmt = src_d.format();
1111     jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
1112
1113     const bool flat = jcp.ic == 3;
1114     const bool mimo = !flat;
1115
1116     const int simd_w = 8;
1117
1118     int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id
1119         - jcp.f_pad);
1120     if (ndims == 5)
1121         if (jcp.f_pad != 0 || back_pad != 0)
1122             return status::unimplemented;
1123
1124     bool ok_to_pad_channels = true
1125         && jcp.ngroups == 1;
1126
1127     if (ok_to_pad_channels) {
1128         jcp.oc = rnd_up(jcp.oc, simd_w);
1129         if (mimo)
1130             jcp.ic = rnd_up(jcp.ic, simd_w);
1131     }
1132
1133     bool args_ok = true
1134         && IMPLICATION(flat, one_of(src_d.format(), ncw, nwc, nchw, nhwc, ncdhw,
1135                 ndhwc)
1136                 && one_of(diff_weights_d.format(), Owi8o, gOwi8o, Ohwi8o,
1137                     gOhwi8o, Odhwi8o, gOdhwi8o))
1138         && IMPLICATION(mimo, one_of(src_d.format(), nCw8c, nChw8c, nCdhw8c)
1139                 && one_of(diff_weights_d.format(), OIw8i8o, gOIw8i8o, OIhw8i8o,
1140                     gOIhw8i8o, OIdhw8i8o, gOIdhw8i8o))
1141         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
1142         && one_of(diff_dst_d.format(), nCw8c, nChw8c, nCdhw8c)
1143         && IMPLICATION(mimo, jcp.ic % simd_w == 0)
1144         && jcp.oc % simd_w == 0
1145         && jcp.kw < 14
1146         && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */
1147         && jcp.kh <= jcp.ih /* [bwd_w:r2] */
1148         && jcp.kd <= jcp.f_pad + jcp.id
1149         && jcp.kd <= jcp.id
1150         && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */
1151         && jcp.dilate_d == 0
1152         && jcp.dilate_h == 0
1153         && jcp.dilate_w == 0;
1154     if (!args_ok) return status::unimplemented;
1155
1156     jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
1157     jcp.nb_ic = jcp.ic / jcp.ic_block;
1158
1159     jcp.oc_block = simd_w;
1160     jcp.nb_oc = jcp.oc / jcp.oc_block;
1161     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1162
1163     return status::success;
1164 }
1165
1166 void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
1167         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1168     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1169         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
1170 }
1171
1172 inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
1173 {
1174     Label kd_comeback_loop;
1175     mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0
1176     L(kd_comeback_loop); {
1177         const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1178             ? 1 : jcp.ic_block;
1179         sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult);
1180         sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block
1181                 * jcp.oc_block);
1182         dec(kj);
1183         cmp(kj, 0);
1184         jg(kd_comeback_loop, T_NEAR);
1185     }
1186 }
1187
1188 inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
1189 {
1190     mov(kj, reg_kh);
1191     Label kh_comeback_loop;
1192     L(kh_comeback_loop); {
1193         const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1194             ? 1 : jcp.ic_block;
1195         sub(reg_input, sizeof(float) * jcp.iw * inp_mult);
1196         sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block);
1197         dec(kj);
1198         cmp(kj, 0);
1199         jg(kh_comeback_loop, T_NEAR);
1200     }
1201 }
1202
1203 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1204         int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
1205         int kernel_offset, int output_offset)
1206 {
1207     const int kw = jcp.kw;
1208     const int ic_block = jcp.ic_block;
1209     const int oc_block = jcp.oc_block;
1210     for (int i_kw = 0; i_kw < kw; i_kw++)
1211         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1212             size_t off
1213                 = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
1214                 + kernel_offset;
1215             vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]);
1216         }
1217
1218     for (int i_ur = 0; i_ur < ur_w; i_ur++) {
1219         vmovups(Ymm(kw * ic_block_step + 0),
1220                 yword[reg_output
1221                 + sizeof(float) * i_ur * oc_block + output_offset]);
1222
1223         for (int i_kw = 0; i_kw < kw; i_kw++) {
1224             int i_iw = i_ur * jcp.stride_w + i_kw;
1225             if (i_iw - pad_l < 0
1226                     || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r)
1227                 continue;
1228             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1229                 size_t i_off = (size_t)input_offset + sizeof(float)*(
1230                     one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1231                         ? (i_iw - pad_l) + i_ic
1232                         * ((size_t)jcp.id * jcp.ih * jcp.iw)
1233                         : (i_iw - pad_l) * ic_block + i_ic);
1234                 vbroadcastss(Ymm(kw * ic_block_step + 1),
1235                         make_safe_addr(reg_input, i_off, reg_long_offt));
1236                 vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic),
1237                         Ymm(kw * ic_block_step + 0),
1238                         Ymm(kw * ic_block_step + 1));
1239             }
1240         }
1241     }
1242
1243     for (int i_kw = 0; i_kw < kw; i_kw++)
1244         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1245             size_t off
1246                 = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
1247                 + kernel_offset;
1248             vmovups(yword[reg_kernel + off],
1249                     Ymm(i_kw * ic_block_step + i_ic));
1250         }
1251 }
1252
1253 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp()
1254 {
1255     int ic_block_step;
1256     if (one_of(jcp.src_fmt, ncw, nchw, ncdhw)) {
1257         ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block;
1258     } else {
1259         ic_block_step = jcp.kw > 7 ? 1
1260         : jcp.kw > 3 ? 2
1261         : jcp.kw > 1 ? 4 : 8;
1262     }
1263
1264     const int max_ur_w = jcp.ow > 56 ? 14 : 28;
1265
1266     if (jcp.ow <= max_ur_w)
1267         compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
1268     else
1269         compute_oh_step_common(ic_block_step, max_ur_w);
1270
1271     if (jcp.ndims == 5) {
1272         od_step_comeback_pointers();
1273         mov(reg_input, aux_reg_input);
1274         mov(reg_kernel, aux_reg_kernel);
1275     } else {
1276         oh_step_comeback_pointers();
1277     }
1278 }
1279
1280 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow(
1281         int ic_block_step, int max_ur_w)
1282 {
1283     UNUSED(max_ur_w);
1284
1285     const int ic_block = jcp.ic_block;
1286     const int oc_block = jcp.oc_block;
1287     int inp_mul = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
1288     Label kd_loop;
1289
1290     const int r_pad
1291         = nstl::max(0,
1292                 (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1293
1294     if (jcp.ndims == 5) {
1295         mov(aux_reg_input, reg_input);
1296         mov(aux_reg_kernel, reg_kernel);
1297         mov(ki, jcp.kd);
1298         L(kd_loop);
1299         mov(reg_input, aux_reg_input);
1300         mov(reg_kernel, aux_reg_kernel);
1301     }
1302
1303     mov(kj, reg_kh);
1304     Label kh_loop;
1305     L(kh_loop); {
1306         xor_(b_ic, b_ic);
1307         Label ic_block_loop;
1308         L(ic_block_loop); {
1309             compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0,
1310                     0, 0);
1311             size_t inp_icblk_stride = sizeof(float) * ic_block_step
1312                 * (one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1313                 ? jcp.id*jcp.ih*jcp.iw : 1);
1314             safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1315             add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
1316             add(b_ic, ic_block_step);
1317             cmp(b_ic, ic_block);
1318             jl(ic_block_loop, T_NEAR);
1319         }
1320         if(one_of(jcp.src_fmt, ncw, nchw, ncdhw)) {
1321             size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
1322             safe_sub(reg_input, offt, reg_long_offt);
1323             add(reg_input, sizeof(float) * jcp.iw);
1324         } else {
1325             add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
1326         }
1327         add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
1328         dec(kj);
1329         cmp(kj, 0);
1330         jg(kh_loop, T_NEAR);
1331     }
1332
1333     if (jcp.ndims == 5) {
1334         add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
1335         add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
1336             * oc_block);
1337         dec(ki);
1338         cmp(ki, 0);
1339         jg(kd_loop, T_NEAR);
1340     }
1341
1342 }
1343
1344 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common(
1345         int ic_block_step, int max_ur_w)
1346 {
1347     const int ic_block = jcp.ic_block;
1348     const int oc_block = jcp.oc_block;
1349     const int stride_w = jcp.stride_w;
1350     int inp_mul = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
1351     Label kd_loop;
1352
1353     const int r_pad
1354         = nstl::max(0,
1355                 (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1356
1357     int ur_w = nstl::min(jcp.ow, max_ur_w);
1358     int ur_w_trips = jcp.ow / ur_w;
1359     int ur_w_tail = jcp.ow % ur_w;
1360     if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) {
1361         if (ur_w_trips > 1) {
1362             ur_w_tail += ur_w;
1363             ur_w_trips--;
1364         } else {
1365             ur_w_tail += (ur_w - ur_w / 2);
1366             ur_w = ur_w / 2;
1367         }
1368     }
1369     const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw) ? 1 : ic_block;
1370
1371     int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult;
1372     int output_comeback = ur_w_trips * ur_w * oc_block;
1373
1374     if (jcp.ndims == 5) {
1375         mov(aux_reg_input, reg_input);
1376         mov(aux_reg_kernel, reg_kernel);
1377         mov(ki, jcp.kd);
1378         L(kd_loop);
1379         mov(reg_input, aux_reg_input);
1380         mov(reg_kernel, aux_reg_kernel);
1381     }
1382
1383     mov(kj, reg_kh);
1384     Label kh_loop;
1385     L(kh_loop); {
1386         xor_(b_ic, b_ic);
1387         Label ic_block_loop;
1388         L(ic_block_loop); {
1389             if (jcp.l_pad != 0) {
1390                 ur_w_trips--;
1391                 compute_ic_block_step(ur_w,
1392                         jcp.l_pad, 0, ic_block_step, 0, 0, 0);
1393                 add(reg_input, sizeof(float)
1394                         * (ur_w * stride_w - jcp.l_pad) * inp_mult);
1395                 add(reg_output, sizeof(float) * ur_w * oc_block);
1396             }
1397
1398             if (ur_w_trips > 0) {
1399                 xor_(reg_ur_w_trips, reg_ur_w_trips);
1400                 Label ow_block_loop;
1401                 L(ow_block_loop); {
1402                     compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
1403                     add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult);
1404                     add(reg_output, sizeof(float) * ur_w * oc_block);
1405
1406                     inc(reg_ur_w_trips);
1407                     cmp(reg_ur_w_trips, ur_w_trips);
1408                     jl(ow_block_loop, T_NEAR);
1409                 }
1410             }
1411
1412             if (ur_w_tail > 0)
1413                 compute_ic_block_step(ur_w_tail,
1414                         0, r_pad, ic_block_step, 0, 0, 0);
1415
1416             sub(reg_input, sizeof(float) * input_comeback);
1417             sub(reg_output, sizeof(float) * output_comeback);
1418
1419             size_t inp_icblk_stride = sizeof(float) * ic_block_step
1420                 * (one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1421                 ? jcp.id*jcp.ih*jcp.iw : 1);
1422             safe_add(reg_input, inp_icblk_stride, reg_long_offt);
1423             add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
1424
1425             add(b_ic, ic_block_step);
1426             cmp(b_ic, jcp.ic_block);
1427             jl(ic_block_loop, T_NEAR);
1428         }
1429         if (one_of(jcp.src_fmt, ncw, nchw, ncdhw)) {
1430             size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
1431             safe_sub(reg_input, offt, reg_long_offt);
1432             add(reg_input, sizeof(float) * jcp.iw);
1433         } else {
1434             add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
1435         }
1436         add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
1437         dec(kj);
1438         cmp(kj, 0);
1439         jg(kh_loop, T_NEAR);
1440     }
1441
1442     if (jcp.ndims == 5) {
1443         add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
1444         add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
1445             * oc_block);
1446         dec(ki);
1447         cmp(ki, 0);
1448         jg(kd_loop, T_NEAR);
1449     }
1450
1451 }
1452
1453 inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common()
1454 {
1455     const int icoc_block = jcp.ic_block * jcp.oc_block;
1456     const int t_pad = jcp.t_pad;
1457     const int stride_h = jcp.stride_h;
1458     const int inp_mult = one_of(jcp.src_fmt, ncw, nchw, ncdhw)
1459         ? 1 : jcp.ic_block;
1460     int b_pad
1461         = nstl::max(0, (jcp.oh - 1) * stride_h + jcp.kh - jcp.ih - t_pad);
1462
1463     Label oh_tpad_loop, oh_loop, oh_loop_end;
1464
1465     mov(reg_kh, jcp.kh);
1466     xor_(reg_ih_count, reg_ih_count);
1467     xor_(reg_oj, reg_oj);
1468     if (t_pad > 0) {
1469         assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */
1470         mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
1471         add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block);
1472
1473         L(oh_tpad_loop); {
1474             compute_oh_step_disp();
1475             add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1476             sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block);
1477
1478             inc(reg_oj);
1479             add(reg_ih_count, stride_h);
1480             add(reg_kh, stride_h);
1481
1482             /* the overlap between input and kernel may not reach kernel size.
1483              * so far we do not support that (until we put constant here) */
1484             const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */
1485             cmp(reg_kh, final_inp_ker_overlap);
1486             jl(oh_tpad_loop, T_NEAR);
1487         }
1488
1489         if (t_pad % stride_h != 0) {
1490             int inp_corr = stride_h - t_pad % stride_h;
1491             add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block);
1492             add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult);
1493         }
1494     }
1495     cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1496     jge(oh_loop_end, T_NEAR);
1497     cmp(reg_oj, jcp.oh);
1498     jge(oh_loop, T_NEAR);
1499
1500     mov(reg_kh, jcp.kh);
1501     L(oh_loop); {
1502         compute_oh_step_disp();
1503         add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
1504         add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1505
1506         inc(reg_oj);
1507         add(reg_ih_count, stride_h);
1508
1509         cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
1510         jge(oh_loop_end, T_NEAR);
1511
1512         cmp(reg_oj, jcp.oh);
1513         jl(oh_loop, T_NEAR);
1514     }
1515     L(oh_loop_end);
1516     if (b_pad > 0) {
1517         Label oh_bpad_loop, oh_bpad_loop_end;
1518         cmp(reg_oj, jcp.oh);
1519         jge(oh_bpad_loop_end, T_NEAR);
1520
1521         mov(reg_kh, jcp.ih + t_pad);
1522         sub(reg_kh, reg_ih_count);
1523         L(oh_bpad_loop); {
1524             compute_oh_step_disp();
1525             add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
1526             add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
1527
1528             sub(reg_kh, stride_h);
1529             cmp(reg_kh, 0);
1530             jle(oh_bpad_loop_end, T_NEAR);
1531
1532             inc(reg_oj);
1533             cmp(reg_oj, jcp.oh);
1534             jl(oh_bpad_loop, T_NEAR);
1535         }
1536         L(oh_bpad_loop_end);
1537     }
1538 }
1539
1540 }
1541 }
1542 }
1543
1544 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s