Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution_utils.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "utils.hpp"
23 #include "cpu_isa_traits.hpp"
24
25 #include "gemm_convolution_utils.hpp"
26 #include "jit_generator.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::status;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35 using namespace prop_kind;
36 using namespace data_type;
37
38 namespace jit_gemm_convolution_utils {
39
40 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
41         int od)
42 {
43     const size_t OHW = jcp.oh * jcp.ow;
44     const size_t im_step = jcp.ih * jcp.iw * jcp.id;
45     const size_t col_step = jcp.ks * OHW;
46
47     parallel_nd(jcp.ic, [&](int ic) {
48         const float *__restrict im_loc = im + ic * im_step;
49         float *__restrict col_loc = col + ic * col_step;
50         int id = od * jcp.stride_d - jcp.f_pad;
51         for (int kd = 0; kd < jcp.kd; ++kd) {
52             float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
53             if (id < 0 || id >= jcp.id) {
54                 int ih_ = -jcp.t_pad;
55                 for (int kh = 0; kh < jcp.kh; ++kh) {
56                     int ih = ih_;
57                     for (int oh = 0; oh < jcp.oh; ++oh) {
58                         if (ih < 0 || ih >= jcp.ih) {
59                             ih += jcp.stride_h;
60                             continue;
61                         }
62                         int iw_ = -jcp.l_pad;
63                         for (int kw = 0; kw < jcp.kw; ++kw) {
64                             int iw = iw_;
65                             for (int ow = 0; ow < jcp.ow; ++ow) {
66                                 if (iw < 0 || iw >= jcp.iw) {
67                                     iw += jcp.stride_w;
68                                     continue;
69                                 }
70
71                                 const size_t col_idx = kw * OHW + oh * jcp.ow
72                                     + ow;
73
74                                 col_[col_idx] = 0;
75                                 iw += jcp.stride_w;
76                             }
77                             iw_ += (1 + jcp.dilate_w);
78                         }
79                         ih += jcp.stride_h;
80                     }
81                     ih_ += (1 + jcp.dilate_h);
82                     col_ += jcp.kw * OHW;
83                 }
84             } else {
85                 const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw;
86                 int ih_ = -jcp.t_pad;
87                 for (int kh = 0; kh < jcp.kh; ++kh) {
88                     int ih = ih_;
89                     for (int oh = 0; oh < jcp.oh; ++oh) {
90                         if (ih < 0 || ih >= jcp.ih) {
91                             ih += jcp.stride_h;
92                             continue;
93                         }
94                         int iw_ = -jcp.l_pad;
95                         for (int kw = 0; kw < jcp.kw; ++kw) {
96                             int iw = iw_;
97                             for (int ow = 0; ow < jcp.ow; ++ow) {
98                                 if (iw < 0 || iw >= jcp.iw) {
99                                     iw += jcp.stride_w;
100                                     continue;
101                                 }
102
103                                 const size_t col_idx = kw * OHW + oh * jcp.ow
104                                     + ow;
105                                 const size_t im_idx = ih * jcp.iw + iw;
106
107                                 col_[col_idx] = im_[im_idx];
108                                 iw += jcp.stride_w;
109                             }
110                             iw_ += (1 + jcp.dilate_w);
111                         }
112                         ih += jcp.stride_h;
113                     }
114                     ih_ += (1 + jcp.dilate_h);
115                     col_ += jcp.kw * OHW;
116                 }
117             }
118             id += (1 + jcp.dilate_d);
119         }
120     });
121 }
122
123 /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
124 void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
125        float *__restrict col, int hs, int hb, int ws, int wb) {
126     const size_t im_step = jcp.is;
127     const size_t col_step = jcp.ks * hb * wb;
128     if (jcp.stride_w == 1) {
129         // Generated code is more optimized for stride_w == 1
130         // because innermost loop is by width
131         auto ker = [&](int ic, int kh, int kw, int oh) {
132             const float *__restrict im_ = im + ic * im_step;
133             float *__restrict col_
134                 = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
135
136             const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
137                 + kh * (1 + jcp.dilate_h);
138             if (ih < 0 || ih >= jcp.ih) {
139                 for (int ow = 0; ow < wb; ++ow)
140                     col_[ow] = 0.f;
141             } else {
142                 for (int ow = 0; ow < wb; ++ow) {
143                     const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w);
144                     if (iw < 0 || iw >= jcp.iw)
145                         col_[ow] = 0.f;
146                     else {
147                         const size_t im_idx = ih * jcp.iw + iw;
148                         col_[ow] = im_[im_idx];
149                     }
150                 }
151             }
152         };
153
154         if (jcp.outer_threading) {
155             for (int ic = 0; ic < jcp.ic; ic++)
156                 for (int kh = 0; kh < jcp.kh; kh++)
157                     for (int kw = 0; kw < jcp.kw; kw++)
158                         for (int oh = 0; oh < hb; oh++)
159                             ker(ic, kh, kw, oh);
160         }
161         else {
162             parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
163         }
164     } else if (jcp.ic == 1) {
165         parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
166             const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
167                     + kh * (1 + jcp.dilate_h);
168             if (ih < 0 || ih >= jcp.ih)
169                 for (int kw = 0; kw < jcp.kw; ++kw) {
170                     for (int ow = 0; ow < wb; ++ow) {
171                         const size_t col_idx
172                                 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
173                         col[col_idx] = 0;
174                     }
175                 }
176             else
177                 for (int kw = 0; kw < jcp.kw; ++kw) {
178                     for (int ow = 0; ow < wb; ++ow) {
179                         const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
180                                 + kw * (1 + jcp.dilate_w);
181                         const size_t col_idx
182                                 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
183                         const size_t im_idx = ih * jcp.iw + iw;
184                         if (iw < 0 || iw >= jcp.iw)
185                             col[col_idx] = 0;
186                         else
187                             col[col_idx] = im[im_idx];
188                     }
189                 }
190         });
191     } else {
192
193         parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
194             [&](int ic, int kh, int kw, int oh) {
195             const float *__restrict im_ = im + ic * im_step;
196             float *__restrict col_ = col + ic * col_step
197                 + ((kh * jcp.kw + kw) * hb + oh) * wb;
198
199             const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
200                 + kh * (1 + jcp.dilate_h);
201             if (ih < 0 || ih >= jcp.ih) {
202                 for (int ow = 0; ow < wb; ++ow)
203                     col_[ow] = 0.f;
204             } else {
205                 for (int ow = 0; ow < wb; ++ow) {
206                     const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
207                         + kw * (1 + jcp.dilate_w);
208                     const size_t im_idx = ih * jcp.iw + iw;
209                     if (iw < 0 || iw >= jcp.iw)
210                         col_[ow] = 0.f;
211                     else
212                         col_[ow] = im_[im_idx];
213                 }
214             }
215         });
216     }
217 }
218
219 /* col[oh][ow][kh][kw][ic] <-- im2col_u8(im[ih][iw][ic]) */
220 template <typename T>
221 void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
222         uint8_t *__restrict col) {
223     uint8_t shift = jcp.signed_input ? 128 : 0;
224     const int dh = 1 + jcp.dilate_h;
225     const int dw = 1 + jcp.dilate_w;
226     const int sh = jcp.stride_h;
227     const int sw = jcp.stride_w;
228     if (sh == 1 && sw == 1 && jcp.oh > 2 * mkldnn_get_max_threads()) {
229         const int ihp = jcp.ih + jcp.t_pad;
230         const int iwp = jcp.iw + jcp.l_pad;
231         const int col_kw_step = jcp.ic;
232         const int col_kh_step = jcp.kw * col_kw_step;
233         const int col_ow_step = jcp.kh * col_kh_step;
234         const int col_oh_step = jcp.ow * col_ow_step;
235         const int im_iw_step = jcp.ngroups * jcp.ic;
236         const int im_ih_step = jcp.iw * im_iw_step;
237
238         const int nb_ic = jcp.ic / 4;
239         const int ic_blocked = nb_ic * 4;
240
241         parallel_nd(jcp.oh, [&](int oh) {
242             const int kh_start = nstl::max(div_up(jcp.t_pad - oh, dh), 0);
243             const int kh_end = nstl::min(div_up(ihp - oh, dh), jcp.kh);
244             const int ih_start = oh - jcp.t_pad + kh_start * dh;
245             const int col_oh_idx = oh * col_oh_step;
246
247             for (int kh = kh_start, ih = ih_start; kh < kh_end; ++kh, ih += dh)
248             {
249                 const int col_kh_idx = col_oh_idx + kh * col_kh_step;
250                 const int im_kh_idx = ih * im_ih_step;
251
252                 for (int kw = 0; kw < jcp.kw; ++kw) {
253                     const int ow_start = nstl::max(jcp.l_pad - kw * dw, 0);
254                     const int ow_end = nstl::min(iwp - kw * dw, jcp.ow);
255                     const int iw_start = ow_start - jcp.l_pad + kw * dw;
256                     const int col_kw_idx = col_kh_idx + kw * col_kw_step;
257
258                     const int col_idx_start
259                             = col_kw_idx + ow_start * col_ow_step;
260                     const int im_idx_start = im_kh_idx + iw_start * im_iw_step;
261                     const int col_idx_end = col_kw_idx + ow_end * col_ow_step;
262
263                     // loop by iw and ow
264                     if (nb_ic > 0) {
265                         for (int col_idx = col_idx_start, im_idx = im_idx_start;
266                                 col_idx < col_idx_end;
267                                 col_idx += col_ow_step, im_idx += im_iw_step) {
268                             for (int icb = 0; icb < 4 * nb_ic; icb += 4) {
269                                 PRAGMA_OMP_SIMD()
270                                 for (int ic = 0; ic < 4; ++ic) {
271                                     col[col_idx + icb + ic]
272                                             = im[im_idx + icb + ic] + shift;
273                                 }
274                             }
275                         }
276                     }
277                     if (ic_blocked != jcp.ic) {
278                         for (int col_idx = col_idx_start, im_idx = im_idx_start;
279                                 col_idx < col_idx_end;
280                                 col_idx += col_ow_step, im_idx += im_iw_step) {
281                             PRAGMA_OMP_SIMD()
282                             for (int ic = ic_blocked; ic < jcp.ic; ++ic) {
283                                 col[col_idx + ic] = im[im_idx + ic] + shift;
284                             }
285                         }
286                     }
287                 }
288             }
289         });
290     }
291     else {
292         const size_t col_kh_step = jcp.kw * jcp.ic;
293         const size_t col_ow_step = jcp.kh * col_kh_step;
294         const size_t col_oh_step = jcp.ow * col_ow_step;
295         const size_t im_ih_step = jcp.iw * jcp.ngroups * jcp.ic;
296         const size_t im_iw_step = jcp.ngroups * jcp.ic;
297         const int ih_pad = jcp.ih + jcp.t_pad;
298         const int iw_pad = jcp.iw + jcp.l_pad;
299         parallel_nd(jcp.oh, jcp.ow, [&](int oh, int ow) {
300             const int ihs = oh * sh;
301             const int ihsp = jcp.t_pad - ihs;
302             const int kh_start = nstl::max(div_up(ihsp, dh), 0);
303             const int kh_end = nstl::min(div_up(ih_pad - ihs, dh), jcp.kh);
304             const int ih_start = kh_start * dh - ihsp;
305             const int iws = ow * sw;
306             const int iwsp = jcp.l_pad - iws;
307             const int kw_start = nstl::max(div_up(iwsp, dw), 0);
308             const int kw_end = nstl::min(div_up(iw_pad - iws, dw), jcp.kw);
309             const int iw_start = kw_start * dw - iwsp;
310
311             uint8_t *__restrict col_base
312                     = col + oh * col_oh_step + ow * col_ow_step;
313             for (int kh = kh_start, ih = ih_start; kh < kh_end;
314                     ++kh, ih += dh) {
315                 uint8_t *__restrict col_ = col_base + kh * col_kh_step;
316                 const T *__restrict im_ = im + ih * im_ih_step;
317
318                 for (int kw = kw_start, iw = iw_start; kw < kw_end;
319                     ++kw, iw += dw) {
320
321                     const size_t col_idx = kw * jcp.ic;
322                     const size_t im_idx = iw * im_iw_step;
323                     PRAGMA_OMP_SIMD()
324                         for (int ic = 0; ic < jcp.ic; ++ic) {
325                             col_[col_idx + ic] = im_[im_idx + ic] + shift;
326                         }
327                 }
328             }
329         });
330     }
331
332 }
333
334 template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp,
335         const int8_t *__restrict im, uint8_t *__restrict col);
336 template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp,
337         const uint8_t *__restrict im, uint8_t *__restrict col);
338
339 /* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
340 void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
341         int32_t *__restrict im)
342 {
343     parallel(0, [&](const int ithr, const int nthr) {
344         int h_nthr = nstl::min(jcp.ih, nthr);
345         int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
346         int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
347         if (ithr < h_nthr * w_nthr) {
348             h_ithr = ithr / w_nthr;
349             w_ithr = ithr % w_nthr;
350             balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
351             balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
352         } else {
353             h_ithr = w_ithr = -ithr;
354             h_s = h_e = w_s = w_e = -1;
355         }
356
357         for (int ih = h_s; ih < h_e; ++ih) {
358             for (int iw = w_s; iw < w_e; ++iw) {
359                 PRAGMA_OMP_SIMD()
360                 for (int ic = 0; ic < jcp.ic; ++ic) {
361                     im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
362                 }
363             }
364         }
365
366         // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
367         for (int oh = 0; oh < jcp.oh; ++oh) {
368             for (int ow = 0; ow < jcp.ow; ++ow) {
369                 for (int kh = 0; kh < jcp.kh; ++kh) {
370                     const int ih = oh * jcp.stride_h
371                         - jcp.t_pad + kh * (1 + jcp.dilate_h);
372                     if (ih < h_s || ih >= h_e) continue;
373
374                     for (int kw = 0; kw < jcp.kw; ++kw) {
375                         const int iw = ow * jcp.stride_w
376                             - jcp.l_pad + kw * (1 + jcp.dilate_w);
377                         if (iw < w_s || iw >= w_e) continue;
378
379                         const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
380                                 + kh) * jcp.kw + kw) * jcp.ic;
381                         const size_t im_idx
382                             = (ih * jcp.iw + iw) * jcp.ic;
383                         PRAGMA_OMP_SIMD()
384                         for (int ic = 0; ic < jcp.ic; ++ic) {
385                             im[im_idx + ic] += col[col_idx + ic];
386                         }
387                     }
388                 }
389             }
390         }
391     });
392 }
393
394 void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
395         int od)
396 {
397     parallel_nd(jcp.ic, [&](int ic) {
398         const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
399         float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
400
401         int id = od * jcp.stride_d - jcp.f_pad;
402         for (int kd = 0; kd < jcp.kd; ++kd) {
403             if (id < 0 || id >= jcp.id) {
404                 col_ += jcp.kh * jcp.kw * jcp.os;
405                 id += (1 + jcp.dilate_d);
406                 continue;
407             }
408
409             float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw;
410
411             for (int oh = 0; oh < jcp.oh; ++oh) {
412             for (int kh = 0; kh < jcp.kh; ++kh) {
413                 const int ih = oh * jcp.stride_h - jcp.t_pad
414                     + kh * (1 + jcp.dilate_h);
415                 if (ih < 0 || ih >= jcp.ih) continue;
416
417                 for (int ow = 0; ow < jcp.ow; ++ow) {
418                 for (int kw = 0; kw < jcp.kw; ++kw) {
419                     const int iw = ow * jcp.stride_w - jcp.l_pad
420                         + kw * (1 + jcp.dilate_w);
421                     if (iw < 0 || iw >= jcp.iw) continue;
422
423                     const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
424                     const size_t im_idx = ih*jcp.iw + iw;
425                     im_[im_idx] += col_[col_idx];
426                 }}
427             }}
428
429             col_ += jcp.kh * jcp.kw * jcp.os;
430             id += (1 + jcp.dilate_d);
431         }
432     });
433 }
434
435 void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
436     const size_t col_step = jcp.ks * jcp.os;
437     const size_t im_step = jcp.ih * jcp.iw;
438     const int iS = jcp.ih * jcp.iw;
439
440     parallel_nd(jcp.ic, [&](int ic) {
441         float *__restrict im_ = im + ic * im_step;
442         const float *__restrict col_ = col + ic * col_step;
443         PRAGMA_OMP_SIMD()
444         for (int is = 0; is < iS; ++is) im_[is] = 0.;
445
446         for (int kh = 0; kh < jcp.kh; ++kh) {
447         for (int oh = 0; oh < jcp.oh; ++oh) {
448             const int ih = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
449             if (ih < 0 || ih >= jcp.ih) continue;
450
451             for (int kw = 0; kw < jcp.kw; ++kw) {
452             for (int ow = 0; ow < jcp.ow; ++ow) {
453                 const int iw = ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
454                 if (iw < 0 || iw >= jcp.iw) continue;
455
456                 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
457                 const size_t im_idx = ih*jcp.iw + iw;
458                 im_[im_idx] += col_[col_idx];
459             }
460             }
461         }
462         }
463     });
464 }
465
466 status_t init_conf(jit_gemm_conv_conf_t &jcp,
467         memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
468         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
469         const memory_desc_wrapper &dst_d, int max_threads) {
470     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
471     const int ndims = src_d.ndims();
472     const int is_1d = ndims == 3;
473     const int is_3d = ndims == 5;
474
475     jcp.prop_kind = cd.prop_kind;
476
477     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
478     jcp.mb = src_d.dims()[0];
479
480     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
481     jcp.ic = src_d.dims()[1] / jcp.ngroups;
482     jcp.id = is_3d ? src_d.dims()[2] : 1;
483     jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
484     jcp.iw = src_d.dims()[ndims - 1];
485     jcp.od = is_3d ? dst_d.dims()[2] : 1;
486     jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
487     jcp.ow = dst_d.dims()[ndims - 1];
488
489     jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
490     jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
491     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
492
493     jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
494     jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
495     jcp.l_pad = cd.padding[0][ndims - 3];
496
497     jcp.stride_d = is_3d ? cd.strides[0] : 1;
498     jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
499     jcp.stride_w = cd.strides[ndims - 3];
500
501     jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
502     jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
503     jcp.dilate_w = cd.dilates[ndims - 3];
504
505     jcp.src_fmt = src_d.format();
506     jcp.with_bias = cd.bias_desc.format != memory_format::undef
507         || cd.diff_bias_desc.format != memory_format::undef;
508
509     jcp.is = jcp.ih * jcp.iw;
510     jcp.os = jcp.oh * jcp.ow;
511     jcp.ks = jcp.kh * jcp.kw * jcp.kd;
512
513     jcp.signed_input = src_d.data_type() == data_type::s8;
514     jcp.wei_adj_scale =
515         !jcp.signed_input || mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
516
517     jcp.im2col_sz = !everyone_is(true,
518             jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id,
519             jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1,
520             jcp.ks == 1, !jcp.signed_input)
521         ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0;
522
523     jcp.outer_threading = false;
524     jcp.oh_block = jcp.oh;
525     jcp.ow_block = jcp.ow;
526
527     bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8)
528         && weights_d.data_type() == s8;
529
530     const int vlen = mayiuse(avx512_common)
531         ? cpu_isa_traits<avx512_common>::vlen
532         : mayiuse(avx)
533             ? cpu_isa_traits<avx>::vlen
534             : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4;
535     const int simd_w = vlen / (is_int8_conv ? 1 : 4);
536
537     const bool is_bwd_d = jcp.prop_kind == backward_data;
538     const bool is_bwd_w = jcp.prop_kind == backward_weights;
539     const bool is_fwd = !is_bwd_d && !is_bwd_w;
540
541     using namespace memory_tracking::names;
542     //  For threading selection we do:
543     //  1. Rough estimation of efficiency for inner and outer threading.
544     //  2. Gemm size estimation in assumption that it does not work
545     //  so effectively for small sizes.
546     //  64K - this is heuristic gemm size per thread threshold.
547     const int gemm_threshold = 64 * 1024;
548     if (is_int8_conv) {
549         bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
550
551         const int bs = is_fwd ? jcp.os : jcp.is;
552         const int ls = is_fwd ? jcp.oc : jcp.ic;
553         const size_t outer_work_amount = jcp.ngroups * jcp.mb;
554         const float outer_thr_eff = (float)outer_work_amount
555                 / rnd_up(outer_work_amount, max_threads);
556         const size_t inner_work_amount
557                 = div_up(bs, simd_w) * div_up(ls, simd_w);
558         const float inner_thr_eff = (float)inner_work_amount
559                 / rnd_up(inner_work_amount, max_threads);
560         jcp.outer_threading = (is_depthwise
561                 || (bs  / max_threads < 64 && jcp.mb != 1))
562             && (outer_thr_eff / inner_thr_eff >= 1.f
563                    || (bs * jcp.ic * jcp.oc) / max_threads < gemm_threshold);
564         jcp.nthr = jcp.outer_threading ? max_threads : 1;
565
566         if (is_fwd) {
567             scratchpad.book(key_conv_gemm_col,
568                     sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
569             scratchpad.book(key_conv_int_dat_in_acc_dt,
570                     sizeof(int32_t) * jcp.nthr * jcp.os * jcp.oc);
571         } else if (is_bwd_d) {
572             scratchpad.book(key_conv_gemm_col,
573                     sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
574             scratchpad.book(key_conv_int_dat_in_acc_dt,
575                     sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
576         } else if (is_bwd_w) {
577             assert(!"unimplemented prop_kind");
578             return status::unimplemented;
579         }
580     } else {
581         if (is_fwd) {
582             const int L2 = get_cache_size(2, true) / sizeof(float);
583             const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
584
585             // It makes sense to try blocking for some special cases:
586             // when weights size is small and we have to do im2col
587             if (wei_size < L2/2 && jcp.im2col_sz && jcp.id == 1 && jcp.od == 1) {
588                 // looking for oh and ow blocking
589                 int h_block{ jcp.oh }, w_block{ jcp.ow };
590                 // 1. cache requirement
591                 // !!! used memory (assuming strides = 1 and dilate = 0 etc):
592                 const int row_size = jcp.ic * jcp.kh * jcp.kw * jcp.ow
593                     + 2 * jcp.ic * jcp.iw + 2 * jcp.oc * jcp.ow;
594                 h_block = nstl::max(
595                     1, nstl::min(jcp.oh, div_up(L2 - wei_size, row_size)));
596                 if (h_block == 1) {
597                     const int col_size = jcp.ic * jcp.kh * jcp.kw + 2 * jcp.ic
598                         + 2 * jcp.oc;
599                     w_block = nstl::max(
600                         1, nstl::min(jcp.ow, div_up(L2 - wei_size, col_size)));
601                 }
602
603                 // 2. threading requirement
604                 if (h_block != jcp.oh)
605                     h_block = nstl::max(1, rnd_dn(h_block, 4));
606                 if (w_block != jcp.ow)
607                     w_block = nstl::max(1, rnd_dn(w_block, simd_w));
608
609                 float thr_eff = 0.f;
610                 float thr_eff_treshold = 0.9f;
611                 if (w_block == jcp.ow) {
612                     do {
613                         int nb_oh = div_up(jcp.oh, h_block);
614                         size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_oh;
615                         float disb = (float)jcp.oh / rnd_up(jcp.oh, h_block);
616                         thr_eff = (float)work
617                             / rnd_up(work, max_threads);
618                         thr_eff = (thr_eff + disb) / 2.f;
619                         if (thr_eff >= thr_eff_treshold)
620                             break;
621                         h_block = rnd_dn(h_block - 4, 4);
622                     } while (h_block > 0);
623                 }
624                 if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
625                 {
626                     h_block = 1;
627                     int nb_oh = jcp.oh;
628                     do {
629                         int nb_ow = div_up(jcp.ow, w_block);
630                         size_t work_amount
631                             = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow;
632                         float disb = (float)jcp.ow / rnd_up(jcp.ow, w_block);
633                         thr_eff = (float)work_amount
634                             / rnd_up(work_amount, max_threads);
635                         thr_eff = (thr_eff + disb) / 2.f;
636                         if (thr_eff > thr_eff_treshold)
637                             break;
638                         w_block = rnd_dn(w_block - simd_w, simd_w);
639                     } while (w_block > 0);
640                 }
641                 const size_t inner_work_amount
642                     = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
643                 const float inner_thr_eff = (float)inner_work_amount
644                     / rnd_up(inner_work_amount, max_threads);
645                 if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
646                     jcp.oh_block = h_block;
647                     jcp.ow_block = w_block;
648                     jcp.outer_threading = true;
649                 }
650                 // updating jcp.im2col_sz
651                 if (jcp.oh_block != 1)
652                     jcp.ow_block = jcp.ow;
653                 jcp.im2col_sz
654                     = (ptrdiff_t)jcp.ic * jcp.ks * jcp.oh_block * jcp.ow_block;
655             } else {
656                 const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
657                 const float outer_thr_eff = (float)outer_work_amount
658                         / rnd_up(outer_work_amount, max_threads);
659                 const size_t inner_work_amount
660                         = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
661                 const float inner_thr_eff = (float)inner_work_amount
662                         / rnd_up(inner_work_amount, max_threads);
663                 jcp.outer_threading = jcp.os / max_threads < 512
664                     && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
665                     && (outer_thr_eff / inner_thr_eff >= 1.f
666                       || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_threshold);
667             }
668         } else if (is_bwd_d) {
669             const size_t outer_work_amount = jcp.ngroups * jcp.mb;
670             const float outer_thr_eff = (float)outer_work_amount
671                 / rnd_up(outer_work_amount, max_threads);
672             const size_t inner_work_amount
673                 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
674             const float inner_thr_eff = (float)inner_work_amount
675                 / rnd_up(inner_work_amount, max_threads);
676             jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
677                 && (jcp.mb != 1 || jcp.ngroups > 2)
678                 && (outer_thr_eff / inner_thr_eff >= 1.f
679                   || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_threshold);
680         } else if (is_bwd_w)
681             jcp.outer_threading = jcp.os / max_threads < 256
682                 && (jcp.mb != 1 || jcp.ngroups > 2);
683
684         jcp.nthr = jcp.outer_threading ? max_threads : 1;
685
686         scratchpad.book(key_conv_gemm_col,
687                 sizeof(float) * jcp.nthr * jcp.im2col_sz);
688
689         if (is_bwd_w) {
690             jcp.need_wei_reduction = mkldnn_thr_syncable()
691                 ? jcp.mb != 1 && jcp.nthr != 1 : false;
692
693             scratchpad.book(key_conv_wei_reduction,
694                     sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
695         }
696     }
697
698     return status::success;
699 }
700
701 void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
702         int &nthr_g, int &ithr_mb, int &nthr_mb) {
703     nthr_g = nstl::min(ngroups, nthr);
704     nthr_mb = nstl::min(mb, nthr / nthr_g);
705     if (ithr / nthr_mb >= ngroups) {
706         ithr_g = ithr_mb = -1;
707     } else {
708         ithr_g = ithr / nthr_mb;
709         ithr_mb = ithr % nthr_mb;
710     }
711 }
712
713 void bwd_weights_reduction_par(int ithr, int nthr,
714         const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
715         float *weights) {
716     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
717
718     size_t weights_start{0}, weights_end{0};
719     balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
720
721     for (int i = 0; i < nthr; ++i) {
722         const float *ws_i = weights_reduce_ws + i * weights_g_size;
723         for (size_t s = weights_start; s < weights_end; ++s)
724             weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];
725     }
726 }
727
728 };
729
730 }
731 }
732 }