Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "cpu_memory.hpp"
21
22 #include "jit_sse42_conv_kernel_f32.hpp"
23
24 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::prop_kind;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
33
34 using namespace Xbyak;
35
36 void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
37         int pad_l, int pad_r, int oc_blocks)
38 {
39     int iw = jcp.iw;
40     int ih = jcp.ih;
41     int kw = jcp.kw;
42     int kh = jcp.kh;
43     int nb_ic = jcp.nb_ic;
44     int stride_w = jcp.stride_w;
45     int dilate_w = jcp.dilate_w + 1;
46     int ic_blk = jcp.ic_block;
47     int oc_blk = jcp.oc_block;
48
49     for (int ki = 0; ki < kw; ki++) {
50         int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
51         int jj_end = ur_w
52         - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w));
53         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
54             for (int jj = jj_start; jj < jj_end; jj++) {
55                 int inp_off;
56                 if (jcp.src_fmt == nchw)
57                     inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l);
58                 else
59                     inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2;
60
61                 movss(Xmm(oc_blocks * ur_w + jj + 1),
62                       ptr[aux_reg_input + sizeof(float) * inp_off]);
63                 shufps(Xmm(oc_blocks * ur_w + jj + 1),
64                        Xmm(oc_blocks * ur_w + jj + 1), 0x0);
65             }
66
67             for (int ii = 0; ii < oc_blocks; ii++) {
68                 int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk
69                               + ki * ic_blk * oc_blk + ifm2 * oc_blk;
70
71                 for (int jj = jj_start; jj < jj_end; jj++)
72                 {
73                     movups(xmm0,
74                       ptr[aux_reg_kernel + sizeof(float) * ker_off]);
75                     mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
76                     addps(Xmm(ur_w * ii + jj + 1), xmm0);
77                 }
78             }
79         }
80     }
81 }
82
83 void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
84         int pad_l, int pad_r, char pad_tag,
85         int oc_blocks, char oc_blocks_tag)
86 {
87     jit_tagged_label kw_label("kw", pad_tag, oc_blocks_tag);
88
89     int iw = jcp.iw;
90     int ih = jcp.ih;
91     int kw = jcp.kw;
92     int kh = jcp.kh;
93     int nb_ic = jcp.nb_ic;
94     int stride_w = jcp.stride_w;
95     int dilate_w = jcp.dilate_w + 1;
96     int ic_blk = jcp.ic_block;
97     int oc_blk = jcp.oc_block;
98
99     xor_(ki_iter, ki_iter);
100     L(kw_label);
101     {
102         int jj_start = 0;
103         int jj_end = ur_w;
104         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
105             for (int jj = jj_start; jj < jj_end; jj++) {
106                 int inp_off;
107                 if (jcp.src_fmt == nchw)
108                     inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l);
109                 else
110                     inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2;
111
112                 movss(Xmm(oc_blocks * ur_w + jj + 1),
113                       ptr[aux_reg_input + sizeof(float) * inp_off]);
114                 shufps(Xmm(oc_blocks * ur_w + jj + 1),
115                        Xmm(oc_blocks * ur_w + jj + 1), 0x0);
116             }
117             for (int ii = 0; ii < oc_blocks; ii++) {
118                 int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk
119                                         + ifm2 * oc_blk;
120                 for (int jj = jj_start; jj < jj_end; jj++) {
121                     movups(xmm0,
122                       ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
123                     mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
124                     addps(Xmm(ur_w * ii + jj + 1), xmm0);
125                 }
126             }
127         }
128         add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
129         add(aux_reg_input, sizeof(float) * (jcp.src_fmt == nchw ?
130             dilate_w : ic_blk * dilate_w));
131
132         inc(ki_iter);
133         cmp(ki_iter, kw);
134         jl(kw_label, T_NEAR);
135     }
136 }
137
138 void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
139         int pad_l, int pad_r, char pad_tag,
140         int oc_blocks, char oc_blocks_tag)
141 {
142     int iw = jcp.iw;
143     int kw = jcp.kw;
144     int ow = jcp.ow;
145     int oh = jcp.oh;
146     int dilate_h = jcp.dilate_h + 1;
147     int dilate_w = jcp.dilate_w + 1;
148     int ic_blk = jcp.ic_block;
149     int oc_blk = jcp.oc_block;
150     const int inp_mult = jcp.src_fmt == nchw ? dilate_h : ic_blk * dilate_h;
151     const int inp_off = jcp.src_fmt == nchw ? dilate_w : ic_blk * dilate_w;
152
153     xor_(simd_iter, simd_iter);
154
155     mov(aux_reg_input, reg_input);
156     mov(aux_reg_kernel, reg_kernel);
157
158     jit_tagged_label init_simd_iter_label("simd_iter", pad_tag, oc_blocks_tag);
159     jit_tagged_label init_done_label("init", pad_tag, oc_blocks_tag);
160     jit_tagged_label init_first_label("first", pad_tag, oc_blocks_tag);
161
162     L(init_simd_iter_label);
163
164     if (!jcp.with_sum) {
165         test(reg_ci_flag, FLAG_IC_FIRST);
166         jne(init_first_label, T_NEAR);
167     }
168
169     for (int ii = 0; ii < oc_blocks; ii++)
170         for (int jj = 0; jj < ur_w; jj++) {
171             int o_off;
172             if (jcp.with_dw_conv)
173                 o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
174             else
175                 o_off = (ii * oh * ow + jj) * oc_blk;
176
177             movups(Xmm(ur_w * ii + jj + 1), xword[reg_output
178                 + sizeof(float) * o_off]);
179         }
180
181     if (jcp.with_sum && jcp.with_bias) {
182         test(reg_ci_flag, FLAG_IC_FIRST);
183         je(init_done_label, T_NEAR);
184
185         for (int ii = 0; ii < oc_blocks; ii++)
186             for (int jj = 0; jj < ur_w; jj++)
187                 addps(Xmm(ur_w * ii + jj + 1),
188                     xword[reg_bias + sizeof(float) * ii * oc_blk]);
189     }
190
191     jmp(init_done_label);
192
193     L(init_first_label);
194     if (this->jcp.with_bias) {
195         for (int ii = 0; ii < oc_blocks; ii++)
196             for (int jj = 0; jj < ur_w; jj++)
197                 movups(Xmm(ur_w * ii + jj + 1),
198                        xword[reg_bias + sizeof(float) * ii * oc_blk]);
199     } else {
200         for (int ii = 0; ii < oc_blocks; ii++)
201             for (int jj = 0; jj < ur_w; jj++)
202                 pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1));
203     }
204
205     L(init_done_label);
206
207     Label skip_kh_loop;
208     mov(kj, reg_kh);
209     if (jcp.kh <= jcp.t_pad) {
210         cmp(kj, 0);
211         je(skip_kh_loop, T_NEAR);
212     }
213     jit_tagged_label kh_label("kh", pad_tag, oc_blocks_tag);
214     L(kh_label);
215     {
216         if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
217             oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
218                           oc_blocks_tag);
219             sub(aux_reg_input, sizeof(float) * kw * inp_off);
220             add(aux_reg_input, sizeof(float) * iw * inp_mult);
221         } else {
222             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
223             add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
224             add(aux_reg_input, sizeof(float) * iw * inp_mult);
225         }
226
227         dec(kj);
228         cmp(kj, 0);
229         jg(kh_label, T_NEAR);
230     }
231
232     L(skip_kh_loop);
233
234     jit_tagged_label done_label("done", pad_tag, oc_blocks_tag);
235     jit_tagged_label regular_store_label("store", pad_tag, oc_blocks_tag);
236
237     if (jcp.with_eltwise) {
238         assert(oc_blocks * ur_w < 15);
239         test(reg_ci_flag, FLAG_IC_LAST);
240         je(regular_store_label, T_NEAR);
241
242         inject(eltwise_generator.prepareConstants(jcp.eltwise_alpha, jcp.eltwise_beta));
243
244         // TODO (dmitrygo): need to find appropriate way to share labels.
245         mov(imm_addr64, l_table);
246         for (int ii = 0; ii < oc_blocks; ii++) {
247             for (int jj = 0; jj < ur_w; jj++) {
248                 Xmm reg_out = Xmm(ur_w * ii + jj + 1);
249
250                 inject(eltwise_generator.computeVector(reg_out, reg_out));
251             }
252         }
253
254         L(regular_store_label);
255     }
256
257     for (int ii = 0; ii < oc_blocks; ii++) {
258         for (int jj = 0; jj < ur_w; jj++) {
259             int o_off;
260             if (jcp.with_dw_conv)
261                 o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
262             else
263                 o_off = (ii * oh * ow + jj) * oc_blk;
264
265             Xmm reg_out = Xmm(ur_w * ii + jj + 1);
266             movups(xword[reg_output + sizeof(float) * o_off], reg_out);
267         }
268     }
269
270     L(done_label);
271
272     mov(aux_reg_kernel, reg_kernel);
273     mov(aux_reg_input, reg_input);
274     add(aux_reg_kernel, sizeof(float) * 4);
275     add(reg_output, sizeof(float) * 4);
276     add(reg_bias,   sizeof(float) * 4);
277
278     inc(simd_iter);
279     cmp(simd_iter, 2);
280     jl(init_simd_iter_label, T_NEAR);
281
282     sub(reg_output, sizeof(float) * 8);
283     sub(reg_bias,   sizeof(float) * 8);
284 }
285
286 inline void jit_sse42_conv_fwd_kernel_f32::solve_common(
287         int oc_blocks, char oc_blocks_tag)
288 {
289     int ur_w = jcp.ur_w;
290     int ur_w_tail = jcp.ur_w_tail;
291     int n_oi = jcp.ow / ur_w;
292     int iw = jcp.iw;
293     int kw = jcp.kw;
294     int ic_blk = jcp.ic_block;
295     int oc_blk = jcp.oc_block;
296     int dilate_w = jcp.dilate_w + 1;
297     int str_w = jcp.stride_w;
298     const int inp_mult = jcp.src_fmt == nchw ? 1 : ic_blk;
299
300     int l_pad = jcp.l_pad;
301     int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
302         - (iw + l_pad - 1));
303     int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
304         - (iw + l_pad - 1);
305     if (r_pad1 > 0) n_oi--;
306
307     if (l_pad > 0) {
308         n_oi--;
309         if (n_oi < 0 && r_pad1 > 0)
310             width_blk_step(ur_w, l_pad, r_pad1,
311                            'l', oc_blocks, oc_blocks_tag); // "lrpad"
312         else
313             width_blk_step(ur_w, l_pad, 0,
314                            'l', oc_blocks, oc_blocks_tag); // "lpad"
315         add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
316         add(reg_output, sizeof(float) * ur_w * oc_blk);
317     }
318
319     jit_tagged_label ow_loop_label("ow", oc_blocks_tag);
320     xor_(oi_iter, oi_iter);
321
322     if (n_oi > 0) {
323         L(ow_loop_label);
324
325         width_blk_step(ur_w, 0, 0,
326                        'm', oc_blocks, oc_blocks_tag); // "middle"
327         add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
328         add(reg_output, sizeof(float) * ur_w * oc_blk);
329
330         inc(oi_iter);
331         cmp(oi_iter, n_oi);
332         jl(ow_loop_label, T_NEAR);
333     }
334
335     if (r_pad1 > 0 && n_oi >=0) {
336         width_blk_step(ur_w, 0, r_pad1,
337                        'r', oc_blocks, oc_blocks_tag); // "rpad"
338         add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
339         add(reg_output, sizeof(float) * ur_w * oc_blk);
340     }
341
342     if (ur_w_tail != 0)
343         width_blk_step(ur_w_tail, 0, r_pad,
344                        't', oc_blocks, oc_blocks_tag); // "tail"
345 }
346
347 void jit_sse42_conv_fwd_kernel_f32::generate()
348 {
349     if (jcp.with_eltwise) {
350         nstl::vector<int> shared_vecs;
351         shared_vecs.push_back(0);
352         shared_vecs.push_back(13);
353         shared_vecs.push_back(14);
354         shared_vecs.push_back(15);
355
356         nstl::vector<Reg64> shared_regs;
357         shared_regs.push_back(imm_addr64);
358
359         eltwise_generator.init(jcp.eltwise_alg, shared_vecs, shared_regs);
360     }
361
362     this->preamble();
363
364     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
365     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
366     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
367     if (jcp.with_bias)
368         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
369     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
370     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
371     mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
372
373     int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
374     const char *tail_label = ".tail";
375     const char *exit_label = ".exit";
376
377     cmp(reg_oc_blocks, jcp.nb_oc_blocking);
378     jne(nb_oc_tail ? tail_label : exit_label, T_NEAR);
379
380     solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
381     jmp(exit_label, T_NEAR);
382
383     if (nb_oc_tail) {
384         L(tail_label);
385         cmp(reg_oc_blocks, nb_oc_tail);
386         jne(exit_label, T_NEAR);
387         solve_common(nb_oc_tail, '0' + nb_oc_tail);
388     }
389
390     L(exit_label);
391
392     this->postamble();
393
394     if (jcp.with_eltwise) {
395         // TODO (dmitrygo): need to find appropriate way to share labels.
396         align(64);
397         L(l_table);
398         inject(eltwise_generator.prepareTable());
399         eltwise_generator.release();
400     }
401 }
402
403 bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
404         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
405     const auto &p = attr.post_ops_;
406
407     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
408     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
409     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
410
411     switch (p.len_) {
412     case 0: return true; // no post_ops
413     case 1:
414         return true // sum OR eltwise OR dw_conv
415                 && !jcp.with_eltwise && (is_eltwise(0) || is_sum(0) || is_dw_conv(0));
416     case 2:
417         return true // sum->eltwise or dw_conv->eltwise or eltwise->dw_conv or dw_conv->sum
418                 && !jcp.with_eltwise && ((is_sum(0) && is_eltwise(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
419                                          (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)));
420     case 3:
421         return true // eltwise->dw_conv->eltwise or dw_conv->sum->eltwise
422                 && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
423                                          (is_dw_conv(0) && is_sum(1) && is_eltwise(2)));
424     case 4: return true // eltwise->dw_conv->sum->eltwise
425             && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
426     default: return false;
427     }
428
429     return false;
430 }
431
432 status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
433         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
434         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
435         const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
436 {
437     if (!mayiuse(sse42)) return status::unimplemented;
438
439     jcp.prop_kind = cd.prop_kind;
440
441     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
442
443     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
444     jcp.mb = src_d.dims()[0];
445
446     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
447     jcp.ic = src_d.dims()[1] / jcp.ngroups;
448
449     jcp.ih = src_d.dims()[2];
450     jcp.iw = src_d.dims()[3];
451     jcp.oh = dst_d.dims()[2];
452     jcp.ow = dst_d.dims()[3];
453
454     jcp.kh = weights_d.dims()[with_groups + 2];
455     jcp.kw = weights_d.dims()[with_groups + 3];
456
457     jcp.t_pad = cd.padding[0][0];
458     jcp.l_pad = cd.padding[0][1];
459
460     jcp.stride_h = cd.strides[0];
461     jcp.stride_w = cd.strides[1];
462
463     jcp.dilate_h = cd.dilates[0];
464     jcp.dilate_w = cd.dilates[1];
465
466     jcp.src_fmt = src_d.format();
467     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
468     jcp.with_eltwise = with_relu;
469     jcp.eltwise_alg = mkldnn_eltwise_relu;
470     jcp.eltwise_alpha = relu_negative_slope;
471
472     if (!post_ops_ok(jcp, attr))
473         return status::unimplemented;
474
475     const auto &p = attr.post_ops_;
476     jcp.with_dw_conv = false;
477     int dw_conv_ind = p.find(primitive_kind::convolution);
478     if (dw_conv_ind != -1) {
479         jcp.with_dw_conv = true;
480         jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
481         jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
482         jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
483         jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
484         jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
485         jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
486         jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
487         jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
488     }
489
490     if (!jcp.with_eltwise) {
491         int eltwise_ind = p.find(primitive_kind::eltwise, 0, dw_conv_ind);
492         if (eltwise_ind != -1) {
493             jcp.with_eltwise  = true;
494             jcp.eltwise_alg   = p.entry_[eltwise_ind].eltwise.alg;
495             jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
496             jcp.eltwise_beta  = p.entry_[eltwise_ind].eltwise.beta;
497             jcp.eltwise_scale = p.entry_[eltwise_ind].eltwise.scale;
498         }
499     }
500
501     if (jcp.with_dw_conv) {
502         int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
503         if (dw_conv_eltwise_ind != -1) {
504             jcp.dw_conv_with_eltwise = true;
505             jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
506             jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
507             jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
508         }
509     }
510
511     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
512     if (jcp.with_dw_conv) {
513         jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
514     }
515
516     if (jcp.with_dw_conv) {
517         jcp.oh = jcp.dw_conv_in_h;
518         jcp.ow = jcp.dw_conv_in_w;
519     }
520
521     const bool flat = jcp.ic == 3 || jcp.ic == 1;
522     const bool mimo = !flat;
523
524     bool args_ok = true
525         && implication(flat, one_of(src_d.format(), nchw, nhwc)
526                 && one_of(weights_d.format(), Ohwi8o, gOhwi8o))
527         && implication(mimo, src_d.format() == nChw8c
528                 && one_of(weights_d.format(), OIhw8i8o, gOIhw8i8o))
529         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
530         && dst_d.format() == nChw8c;
531     if (!args_ok) return status::unimplemented;
532
533     const int simd_w = 8; // 2 SSE vectors processing at once
534
535     jcp.ur_h = 1; /* no code-unrolling by h so far */
536     jcp.ur_w = 3;
537     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
538     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
539
540     jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
541
542     args_ok = true
543         && jcp.oc % simd_w == 0
544         && jcp.l_pad <= jcp.ur_w
545         && implication(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
546                 || (jcp.stride_w == 1 && jcp.stride_h == 1))
547         && implication(mimo, jcp.ic % simd_w == 0);
548     if (!args_ok) return status::unimplemented;
549
550     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
551         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
552
553     if (r_pad_no_tail > jcp.ur_w) {
554         /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
555         jcp.ur_w = r_pad_no_tail + 1;
556         jcp.nb_oc_blocking = ((16 - 1)-jcp.ur_w)/jcp.ur_w;
557         jcp.ur_w_tail = jcp.ow % jcp.ur_w;
558         /* check again ... */
559         r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
560             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
561         if ((r_pad_no_tail > jcp.ur_w) || (jcp.ow < jcp.ur_w))
562             return status::unimplemented;
563     }
564     if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
565
566     jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
567     jcp.nb_ic = jcp.ic / jcp.ic_block;
568
569     jcp.oc_block = simd_w;
570     jcp.nb_oc = jcp.oc / jcp.oc_block;
571
572     if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
573         jcp.nb_ic_blocking = 12;
574         jcp.nb_ic_blocking_max = 16;
575     } else {
576         jcp.nb_ic_blocking = 1;
577         jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
578     }
579
580     if (jcp.with_eltwise) {
581         int nvecs_elt = jit_uni_eltwise_vector_f32<sse42>::sharedVecsCount(jcp.eltwise_alg);
582         int nvecs_conv = 16 - nvecs_elt;
583         while (jcp.ur_w * jcp.nb_oc_blocking > nvecs_conv) {
584             if (jcp.nb_oc_blocking <= 1) {
585                 break;
586             }
587
588             jcp.nb_oc_blocking -= 1;
589         }
590
591         if (jcp.ur_w * jcp.nb_oc_blocking > nvecs_conv)
592             return status::unimplemented;
593     }
594
595     return status::success;
596 }
597
598 }
599 }
600 }