Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "cpu_memory.hpp"
21
22 #include "jit_sse42_conv_kernel_f32.hpp"
23
24 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::prop_kind;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 using namespace Xbyak;
36
37 void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
38         int pad_l, int pad_r, int oc_blocks)
39 {
40     int iw = jcp.iw;
41     int ih = jcp.ih;
42     int kw = jcp.kw;
43     int kh = jcp.kh;
44     int nb_ic = jcp.nb_ic;
45     int stride_w = jcp.stride_w;
46     int dilate_w = jcp.dilate_w + 1;
47     int ic_blk = jcp.ic_block;
48     int oc_blk = jcp.oc_block;
49
50     for (int ki = 0; ki < kw; ki++) {
51         int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
52         int jj_end = ur_w
53         - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w));
54         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
55             for (int jj = jj_start; jj < jj_end; jj++) {
56                 int inp_off;
57                 if (one_of(jcp.src_fmt, ncw, nchw))
58                     inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l);
59                 else
60                     inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2;
61
62                 movss(Xmm(oc_blocks * ur_w + jj + 1),
63                       ptr[aux_reg_input + sizeof(float) * inp_off]);
64                 shufps(Xmm(oc_blocks * ur_w + jj + 1),
65                        Xmm(oc_blocks * ur_w + jj + 1), 0x0);
66             }
67
68             for (int ii = 0; ii < oc_blocks; ii++) {
69                 int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk
70                               + ki * ic_blk * oc_blk + ifm2 * oc_blk;
71
72                 for (int jj = jj_start; jj < jj_end; jj++)
73                 {
74                     movups(xmm0,
75                       ptr[aux_reg_kernel + sizeof(float) * ker_off]);
76                     mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
77                     addps(Xmm(ur_w * ii + jj + 1), xmm0);
78                 }
79             }
80         }
81     }
82 }
83
84 void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
85         int pad_l, int pad_r, int oc_blocks)
86 {
87     Label kw_loop;
88
89     int iw = jcp.iw;
90     int ih = jcp.ih;
91     int kw = jcp.kw;
92     int kh = jcp.kh;
93     int nb_ic = jcp.nb_ic;
94     int stride_w = jcp.stride_w;
95     int dilate_w = jcp.dilate_w + 1;
96     int ic_blk = jcp.ic_block;
97     int oc_blk = jcp.oc_block;
98
99     xor_(ki_iter, ki_iter);
100     L(kw_loop);
101     {
102         int jj_start = 0;
103         int jj_end = ur_w;
104         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
105             for (int jj = jj_start; jj < jj_end; jj++) {
106                 int inp_off;
107                 if (one_of(jcp.src_fmt, ncw, nchw))
108                     inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l);
109                 else
110                     inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2;
111
112                 movss(Xmm(oc_blocks * ur_w + jj + 1),
113                       ptr[aux_reg_input + sizeof(float) * inp_off]);
114                 shufps(Xmm(oc_blocks * ur_w + jj + 1),
115                        Xmm(oc_blocks * ur_w + jj + 1), 0x0);
116             }
117             for (int ii = 0; ii < oc_blocks; ii++) {
118                 int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk
119                                         + ifm2 * oc_blk;
120                 for (int jj = jj_start; jj < jj_end; jj++) {
121                     movups(xmm0,
122                       ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
123                     mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
124                     addps(Xmm(ur_w * ii + jj + 1), xmm0);
125                 }
126             }
127         }
128         add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
129         add(aux_reg_input, sizeof(float) * (one_of(jcp.src_fmt, ncw, nchw) ?
130             dilate_w : ic_blk * dilate_w));
131
132         inc(ki_iter);
133         cmp(ki_iter, kw);
134         jl(kw_loop, T_NEAR);
135     }
136 }
137
138 void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
139         int pad_l, int pad_r, int oc_blocks)
140 {
141     int iw = jcp.iw;
142     int kw = jcp.kw;
143     int ow = jcp.ow;
144     int oh = jcp.oh;
145     int dilate_h = jcp.dilate_h + 1;
146     int dilate_w = jcp.dilate_w + 1;
147     int ic_blk = jcp.ic_block;
148     int oc_blk = jcp.oc_block;
149     const int inp_mult = one_of(jcp.src_fmt, ncw, nchw)
150         ? dilate_h : ic_blk * dilate_h;
151     const int inp_off = one_of(jcp.src_fmt, ncw, nchw)
152         ? dilate_w : ic_blk * dilate_w;
153
154     xor_(simd_iter, simd_iter);
155
156     mov(aux_reg_input, reg_input);
157     mov(aux_reg_kernel, reg_kernel);
158
159     Label init_simd_iter_loop;
160     Label init_done;
161     Label init_first;
162
163     L(init_simd_iter_loop);
164
165     if (!jcp.with_sum) {
166         test(reg_ci_flag, FLAG_IC_FIRST);
167         jne(init_first, T_NEAR);
168     }
169
170     for (int ii = 0; ii < oc_blocks; ii++)
171         for (int jj = 0; jj < ur_w; jj++) {
172             int o_off;
173             if (jcp.with_dw_conv)
174                 o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
175             else
176                 o_off = (ii * oh * ow + jj) * oc_blk;
177
178             movups(Xmm(ur_w * ii + jj + 1), xword[reg_output
179                 + sizeof(float) * o_off]);
180         }
181
182     if (jcp.with_sum && jcp.with_bias) {
183         test(reg_ci_flag, FLAG_IC_FIRST);
184         je(init_done, T_NEAR);
185
186         for (int ii = 0; ii < oc_blocks; ii++)
187             for (int jj = 0; jj < ur_w; jj++)
188                 addps(Xmm(ur_w * ii + jj + 1),
189                     xword[reg_bias + sizeof(float) * ii * oc_blk]);
190     }
191
192     jmp(init_done);
193
194     L(init_first);
195     if (this->jcp.with_bias) {
196         for (int ii = 0; ii < oc_blocks; ii++)
197             for (int jj = 0; jj < ur_w; jj++)
198                 movups(Xmm(ur_w * ii + jj + 1),
199                        xword[reg_bias + sizeof(float) * ii * oc_blk]);
200     } else {
201         for (int ii = 0; ii < oc_blocks; ii++)
202             for (int jj = 0; jj < ur_w; jj++)
203                 pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1));
204     }
205
206     L(init_done);
207
208     Label skip_kh_loop;
209     mov(kj, reg_kh);
210     if ((jcp.dilate_h >= jcp.ih)
211             || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
212         cmp(kj, 0);
213         je(skip_kh_loop, T_NEAR);
214     }
215     Label kh_loop;
216     L(kh_loop);
217     {
218         if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
219             oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks);
220             sub(aux_reg_input, sizeof(float) * kw * inp_off);
221             add(aux_reg_input, sizeof(float) * iw * inp_mult);
222         } else {
223             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
224             add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
225             add(aux_reg_input, sizeof(float) * iw * inp_mult);
226         }
227
228         dec(kj);
229         cmp(kj, 0);
230         jg(kh_loop, T_NEAR);
231     }
232
233     L(skip_kh_loop);
234
235     Label done;
236     Label regular_store;
237
238     test(reg_ci_flag, FLAG_IC_LAST);
239     je(regular_store, T_NEAR);
240
241     int eltwise_inj_idx = 0;
242     int depthwise_inj_idx = 0;
243     const auto &p = attr_.post_ops_;
244
245     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
246     for (int i = 0; i < end_idx; i++) {
247         auto& post_op = p.entry_[i];
248         if (post_op.is_eltwise()) {
249             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(1, oc_blocks * ur_w + 1);
250             eltwise_inj_idx++;
251         } else if (post_op.is_depthwise()) {
252             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
253             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
254
255             add(reg_d_weights, reg_oc_off);
256             add(reg_d_bias, reg_oc_off);
257
258             for (int ii = 0; ii < oc_blocks; ii++) {
259                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
260                         ur_w * ii + 1, ur_w * ii + ur_w + 1, reg_d_weights, reg_d_bias);
261
262                 add(reg_d_weights, oc_blk * sizeof(float));
263                 add(reg_d_bias, oc_blk * sizeof(float));
264             }
265
266             depthwise_inj_idx++;
267         }
268     }
269
270     L(regular_store);
271
272     for (int ii = 0; ii < oc_blocks; ii++) {
273         for (int jj = 0; jj < ur_w; jj++) {
274             int o_off;
275             if (jcp.with_dw_conv)
276                 o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
277             else
278                 o_off = (ii * oh * ow + jj) * oc_blk;
279
280             Xmm reg_out = Xmm(ur_w * ii + jj + 1);
281             movups(xword[reg_output + sizeof(float) * o_off], reg_out);
282         }
283     }
284
285     mov(aux_reg_kernel, reg_kernel);
286     mov(aux_reg_input, reg_input);
287     add(aux_reg_kernel, sizeof(float) * 4);
288     add(reg_output, sizeof(float) * 4);
289     add(reg_bias,   sizeof(float) * 4);
290     add(reg_oc_off, sizeof(float) * 4);
291
292     inc(simd_iter);
293     cmp(simd_iter, 2);
294     jl(init_simd_iter_loop, T_NEAR);
295
296     sub(reg_output, sizeof(float) * 8);
297     sub(reg_bias,   sizeof(float) * 8);
298     sub(reg_oc_off, sizeof(float) * 8);
299 }
300
301 inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks)
302 {
303     int ur_w = jcp.ur_w;
304     int ur_w_tail = jcp.ur_w_tail;
305     int n_oi = jcp.ow / ur_w;
306     int iw = jcp.iw;
307     int kw = jcp.kw;
308     int ic_blk = jcp.ic_block;
309     int oc_blk = jcp.oc_block;
310     int dilate_w = jcp.dilate_w + 1;
311     int str_w = jcp.stride_w;
312     const int inp_mult = one_of(jcp.src_fmt, ncw, nchw) ? 1 : ic_blk;
313
314     int l_pad = jcp.l_pad;
315     int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
316         - (iw + l_pad - 1));
317     int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
318         - (iw + l_pad - 1);
319     if (r_pad1 > 0) n_oi--;
320
321     if (l_pad > 0) {
322         n_oi--;
323         if (n_oi < 0 && r_pad1 > 0)
324             width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad"
325         else
326             width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad"
327         add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
328         add(reg_output, sizeof(float) * ur_w * oc_blk);
329     }
330
331     Label ow_loop;
332     xor_(oi_iter, oi_iter);
333
334     if (n_oi > 0) {
335         L(ow_loop);
336
337         width_blk_step(ur_w, 0, 0, oc_blocks); // "middle"
338         add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
339         add(reg_output, sizeof(float) * ur_w * oc_blk);
340
341         inc(oi_iter);
342         cmp(oi_iter, n_oi);
343         jl(ow_loop, T_NEAR);
344     }
345
346     if (r_pad1 > 0 && n_oi >=0) {
347         width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad"
348         add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
349         add(reg_output, sizeof(float) * ur_w * oc_blk);
350     }
351
352     if (ur_w_tail != 0)
353         width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail"
354 }
355
356 void jit_sse42_conv_fwd_kernel_f32::generate()
357 {
358     const auto &p = attr_.post_ops_;
359     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
360     for (int i = 0; i < end_idx; i++) {
361         auto &post_op = p.entry_[i];
362         if (post_op.is_eltwise()) {
363             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
364                     this,
365                     post_op.eltwise.alg,
366                     post_op.eltwise.alpha,
367                     post_op.eltwise.beta
368             ));
369         } else if (post_op.is_depthwise()) {
370             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<sse42>(
371                     this,
372                     post_op.depthwise.alg
373             ));
374         }
375     }
376
377     this->preamble();
378
379     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
380     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
381     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
382     if (jcp.with_bias)
383         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
384     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
385     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
386     mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
387     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
388
389     int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
390     Label tail, exit;
391
392     cmp(reg_oc_blocks, jcp.nb_oc_blocking);
393     jne(nb_oc_tail ? tail : exit, T_NEAR);
394
395     solve_common(jcp.nb_oc_blocking);
396     jmp(exit, T_NEAR);
397
398     if (nb_oc_tail) {
399         L(tail);
400         cmp(reg_oc_blocks, nb_oc_tail);
401         jne(exit, T_NEAR);
402         solve_common(nb_oc_tail);
403     }
404
405     L(exit);
406
407     this->postamble();
408
409     for (auto& inj : eltwise_injectors)
410         inj->prepare_table();
411 }
412
413 bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
414         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
415     const auto &p = attr.post_ops_;
416
417     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
418     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
419     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
420     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
421     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
422
423     switch (p.len_) {
424         case 0: return true;
425         case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
426         case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
427                        (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
428                        (is_simple(0) && is_simple(1));
429         case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
430                        (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
431                        (is_sum(0) && is_simple(1) && is_simple(2));
432         case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
433         default: return false;
434     }
435
436     return false;
437 }
438
439 status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
440         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
441         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
442         const primitive_attr_t &attr)
443 {
444     if (!mayiuse(sse42)) return status::unimplemented;
445
446     jcp.prop_kind = cd.prop_kind;
447
448     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
449     const int ndims = src_d.ndims();
450     jcp.ndims = ndims;
451
452     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
453     jcp.mb = src_d.dims()[0];
454
455     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
456     jcp.oc_without_padding = jcp.oc;
457     jcp.ic = src_d.dims()[1] / jcp.ngroups;
458
459     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
460     jcp.iw = src_d.dims()[ndims - 1];
461     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
462     jcp.ow = dst_d.dims()[ndims - 1];
463
464     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
465     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
466
467     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
468     jcp.l_pad = cd.padding[0][ndims - 3];
469
470     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
471     jcp.stride_w = cd.strides[ndims - 3];
472
473     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0];
474     jcp.dilate_w = cd.dilates[ndims - 3];
475     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
476             - (jcp.ih + jcp.t_pad - 1);
477
478     jcp.src_fmt = src_d.format();
479     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
480
481     if (!post_ops_ok(jcp, attr))
482         return status::unimplemented;
483
484     const auto &p = attr.post_ops_;
485
486     int dw_conv_ind = p.find(primitive_kind::convolution);
487     jcp.with_dw_conv = dw_conv_ind != -1;
488     if (jcp.with_dw_conv) {
489         jcp.dw_conv_oh = jcp.oh;
490         jcp.dw_conv_ow = jcp.ow;
491         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
492         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
493     }
494
495     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
496
497     jcp.src_dt = cd.src_desc.data_type;
498     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
499     jcp.dst_dt = cd.dst_desc.data_type;
500
501     const bool flat = jcp.ic == 3 || jcp.ic == 1;
502     const bool mimo = !flat;
503
504     bool args_ok = true
505         && IMPLICATION(flat, one_of(src_d.format(), ncw, nwc, nchw, nhwc)
506                 && one_of(weights_d.format(), Owi8o, gOwi8o, Ohwi8o, gOhwi8o))
507         && IMPLICATION(mimo, one_of(src_d.format(), nCw8c, nChw8c)
508                 && one_of(weights_d.format(), OIw8i8o, gOIw8i8o, OIhw8i8o,
509                     gOIhw8i8o))
510         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
511         && one_of(dst_d.format(), nCw8c, nChw8c);
512     if (!args_ok) return status::unimplemented;
513
514     bool ok_to_pad_channels = true
515                               && jcp.ngroups == 1;
516
517     const int simd_w = 8; // 2 SSE vectors processing at once
518     if (ok_to_pad_channels) {
519         jcp.oc = rnd_up(jcp.oc, simd_w);
520         if (mimo)
521             jcp.ic = rnd_up(jcp.ic, simd_w);
522     }
523
524     jcp.ur_h = 1; /* no code-unrolling by h so far */
525     jcp.ur_w = 3;
526     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
527     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
528
529     jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
530
531     args_ok = true
532         && jcp.oc % simd_w == 0
533         && jcp.l_pad <= jcp.ur_w
534         && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
535                 || (jcp.stride_w == 1 && jcp.stride_h == 1))
536         && IMPLICATION(mimo, jcp.ic % simd_w == 0);
537     if (!args_ok) return status::unimplemented;
538
539     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
540         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
541
542     // kernel needs 1 temporary YMM register
543     const int num_avail_regs = 15;
544     if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
545         /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
546         jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
547                 nstl::min(jcp.ow, num_avail_regs / 2));
548         jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
549         jcp.ur_w_tail = jcp.ow % jcp.ur_w;
550         /* check again ... */
551         r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
552             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
553         if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
554             return status::unimplemented;
555     }
556     assert(jcp.nb_oc_blocking > 0);
557     assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
558
559     jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
560     jcp.nb_ic = jcp.ic / jcp.ic_block;
561
562     jcp.oc_block = simd_w;
563     jcp.nb_oc = jcp.oc / jcp.oc_block;
564
565     if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
566         jcp.nb_ic_blocking = 12;
567         jcp.nb_ic_blocking_max = 16;
568     } else {
569         jcp.nb_ic_blocking = 1;
570         jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
571     }
572
573     return status::success;
574 }
575
576 void jit_sse42_conv_fwd_kernel_f32::init_scratchpad(
577         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
578     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
579         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
580
581     if (jcp.with_dw_conv) {
582         const int nthreads = mkldnn_get_max_threads();
583         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
584         scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
585
586         if (jcp.oc != jcp.oc_without_padding)
587             scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
588     }
589 }
590
591 }
592 }
593 }