updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution_utils.cpp
1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "utils.hpp"
23 #include "cpu_isa_traits.hpp"
24
25 #include "gemm_convolution_utils.hpp"
26 #include "jit_generator.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::status;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35 using namespace prop_kind;
36 using namespace data_type;
37
38 namespace jit_gemm_convolution_utils {
39
40 template <typename data_type_t>
41 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const data_type_t *im,
42         data_type_t *col, int od)
43 {
44     const size_t OHW = jcp.oh * jcp.ow;
45     const size_t im_step = jcp.ih * jcp.iw * jcp.id;
46     const size_t col_step = jcp.ks * OHW;
47
48     parallel_nd(jcp.ic, [&](int ic) {
49         const data_type_t *__restrict im_loc = im + ic * im_step;
50         data_type_t *__restrict col_loc = col + ic * col_step;
51         int id = od * jcp.stride_d - jcp.f_pad;
52         for (int kd = 0; kd < jcp.kd; ++kd) {
53             data_type_t *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
54             if (id < 0 || id >= jcp.id) {
55                 int ih_ = -jcp.t_pad;
56                 for (int kh = 0; kh < jcp.kh; ++kh) {
57                     int ih = ih_;
58                     for (int oh = 0; oh < jcp.oh; ++oh) {
59                         if (ih < 0 || ih >= jcp.ih) {
60                             ih += jcp.stride_h;
61                             continue;
62                         }
63                         int iw_ = -jcp.l_pad;
64                         for (int kw = 0; kw < jcp.kw; ++kw) {
65                             int iw = iw_;
66                             for (int ow = 0; ow < jcp.ow; ++ow) {
67                                 if (iw < 0 || iw >= jcp.iw) {
68                                     iw += jcp.stride_w;
69                                     continue;
70                                 }
71
72                                 const size_t col_idx = kw * OHW + oh * jcp.ow
73                                     + ow;
74
75                                 col_[col_idx] = 0;
76                                 iw += jcp.stride_w;
77                             }
78                             iw_ += (1 + jcp.dilate_w);
79                         }
80                         ih += jcp.stride_h;
81                     }
82                     ih_ += (1 + jcp.dilate_h);
83                     col_ += jcp.kw * OHW;
84                 }
85             } else {
86                 const data_type_t *__restrict im_ =
87                     im_loc + id * jcp.ih * jcp.iw;
88                 int ih_ = -jcp.t_pad;
89                 for (int kh = 0; kh < jcp.kh; ++kh) {
90                     int ih = ih_;
91                     for (int oh = 0; oh < jcp.oh; ++oh) {
92                         if (ih < 0 || ih >= jcp.ih) {
93                             ih += jcp.stride_h;
94                             continue;
95                         }
96                         int iw_ = -jcp.l_pad;
97                         for (int kw = 0; kw < jcp.kw; ++kw) {
98                             int iw = iw_;
99                             for (int ow = 0; ow < jcp.ow; ++ow) {
100                                 if (iw < 0 || iw >= jcp.iw) {
101                                     iw += jcp.stride_w;
102                                     continue;
103                                 }
104
105                                 const size_t col_idx = kw * OHW + oh * jcp.ow
106                                     + ow;
107                                 const size_t im_idx = ih * jcp.iw + iw;
108
109                                 col_[col_idx] = im_[im_idx];
110                                 iw += jcp.stride_w;
111                             }
112                             iw_ += (1 + jcp.dilate_w);
113                         }
114                         ih += jcp.stride_h;
115                     }
116                     ih_ += (1 + jcp.dilate_h);
117                     col_ += jcp.kw * OHW;
118                 }
119             }
120             id += (1 + jcp.dilate_d);
121         }
122     });
123 }
124
125 template
126 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
127         int od);
128
129 template
130 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const mkldnn_bfloat16_t *im,
131          mkldnn_bfloat16_t *col, int od);
132
133 /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
134 template <typename data_type_t>
135 void im2col(const jit_gemm_conv_conf_t &jcp, const data_type_t *__restrict im,
136        data_type_t *__restrict col, int hs, int hb, int ws, int wb) {
137     const size_t im_step = jcp.is;
138     const size_t col_step = jcp.ks * hb * wb;
139     if (jcp.stride_w == 1) {
140         // Generated code is more optimized for stride_w == 1
141         // because innermost loop is by width
142         auto ker = [&](int ic, int kh, int kw, int oh) {
143             const data_type_t *__restrict im_ = im + ic * im_step;
144             data_type_t *__restrict col_
145                 = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
146
147             const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
148                 + kh * (1 + jcp.dilate_h);
149             if (ih < 0 || ih >= jcp.ih) {
150                 for (int ow = 0; ow < wb; ++ow)
151                     col_[ow] = (data_type_t)0;
152             } else {
153                 for (int ow = 0; ow < wb; ++ow) {
154                     const int iw = ow + ws - jcp.l_pad
155                         + kw * (1 + jcp.dilate_w);
156                     if (iw < 0 || iw >= jcp.iw)
157                         col_[ow] = (data_type_t)0;
158                     else {
159                         const size_t im_idx = ih * jcp.iw + iw;
160                         col_[ow] = im_[im_idx];
161                     }
162                 }
163             }
164         };
165
166         if (jcp.outer_threading) {
167             for (int ic = 0; ic < jcp.ic; ic++)
168                 for (int kh = 0; kh < jcp.kh; kh++)
169                     for (int kw = 0; kw < jcp.kw; kw++)
170                         for (int oh = 0; oh < hb; oh++)
171                             ker(ic, kh, kw, oh);
172         }
173         else {
174             parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
175         }
176     } else if (jcp.ic == 1) {
177         parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
178             const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
179                     + kh * (1 + jcp.dilate_h);
180             if (ih < 0 || ih >= jcp.ih)
181                 for (int kw = 0; kw < jcp.kw; ++kw) {
182                     for (int ow = 0; ow < wb; ++ow) {
183                         const size_t col_idx
184                                 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
185                         col[col_idx] = (data_type_t)0;
186                     }
187                 }
188             else
189                 for (int kw = 0; kw < jcp.kw; ++kw) {
190                     for (int ow = 0; ow < wb; ++ow) {
191                         const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
192                                 + kw * (1 + jcp.dilate_w);
193                         const size_t col_idx
194                                 = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
195                         const size_t im_idx = ih * jcp.iw + iw;
196                         if (iw < 0 || iw >= jcp.iw)
197                             col[col_idx] = (data_type_t)0;
198                         else
199                             col[col_idx] = im[im_idx];
200                     }
201                 }
202         });
203     } else {
204
205         parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
206             [&](int ic, int kh, int kw, int oh) {
207             const data_type_t *__restrict im_ = im + ic * im_step;
208             data_type_t *__restrict col_ = col + ic * col_step
209                 + ((kh * jcp.kw + kw) * hb + oh) * wb;
210
211             const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
212                 + kh * (1 + jcp.dilate_h);
213             if (ih < 0 || ih >= jcp.ih) {
214                 for (int ow = 0; ow < wb; ++ow)
215                     col_[ow] = (data_type_t)0;
216             } else {
217                 for (int ow = 0; ow < wb; ++ow) {
218                     const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
219                         + kw * (1 + jcp.dilate_w);
220                     const size_t im_idx = ih * jcp.iw + iw;
221                     if (iw < 0 || iw >= jcp.iw)
222                         col_[ow] = (data_type_t)0;
223                     else
224                         col_[ow] = im_[im_idx];
225                 }
226             }
227         });
228     }
229 }
230
231 template
232 void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
233        float *__restrict col, int hs, int hb, int ws, int wb);
234
235 template
236 void im2col(const jit_gemm_conv_conf_t &jcp,
237        const mkldnn_bfloat16_t *__restrict im,
238        mkldnn_bfloat16_t *__restrict col, int hs, int hb, int ws, int wb);
239
240 inline int limit(int low, int upper, int value) {
241     return nstl::max(low, nstl::min(upper, value));
242 }
243
244 /* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */
245 template <typename T>
246 void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
247         T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws,
248         int wb) {
249     uint8_t shift = jcp.signed_input ? 128 : 0;
250     const int dh = 1 + jcp.dilate_h;
251     const int dw = 1 + jcp.dilate_w;
252     const int sh = jcp.stride_h;
253     const int sw = jcp.stride_w;
254     const int im_iw_stride = jcp.ic * jcp.ngroups;
255     const int im_ih_stride = jcp.iw * im_iw_stride;
256     const int tp = jcp.t_pad;
257     const int lp = jcp.l_pad;
258
259     if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) {
260         /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */
261         const int hp = hs - tp;
262         const int wp = ws - lp;
263         const int ih_start = limit(0, jcp.ih, hp);
264         const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh);
265         const int iw_start = limit(0, jcp.iw, wp);
266         const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw);
267
268         const int ihb = ih_end - ih_start;
269         const int iwb = iw_end - iw_start;
270
271         const int imtr_ic_stride = ihb * iwb;
272         const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start;
273         for (int ic = 0; ic < jcp.ic; ic++) {
274             const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift;
275             for (int ih = ih_start; ih < ih_end; ih++) {
276                 const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride;
277                 const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb;
278                 for (int iw = iw_start; iw < iw_end; iw++)
279                     imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride];
280             }
281         }
282
283         const int col_ic_str = hb * wb;
284         const int col_kw_stride = jcp.ic * col_ic_str;
285         const int col_kh_stride = jcp.kw * col_kw_stride;
286
287         const int oh_init = ih_start - hp;
288         const int ow_init = iw_start - wp;
289         for (int kh = 0; kh < jcp.kh; kh++) {
290             const ptrdiff_t col_idx_kh = kh * col_kh_stride;
291             const int oh_kh = oh_init - kh;
292             const int oh_start = limit(0, hb, oh_kh);
293             const int oh_end = limit(0, hb, oh_kh + ihb);
294             for (int kw = 0; kw < jcp.kw; kw++) {
295                 const ptrdiff_t col_idx_kw
296                         = col_idx_kh + kw * jcp.ic * col_ic_str;
297                 const int ow_kw = ow_init - kw;
298                 const int imtr_shift = oh_kh * iwb + ow_kw;
299                 const int ow_start = limit(0, wb, ow_kw);
300                 const int ow_end = limit(0, wb, ow_kw + iwb);
301                 for (int ic = 0; ic < jcp.ic; ic++) {
302                     const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str;
303                     const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift;
304                     for (int oh = 0; oh < oh_start; oh++) {
305                         const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
306                         for (int ow = 0; ow < wb; ++ow)
307                             col[col_idx_oh + ow] = shift;
308                     }
309                     for (int oh = oh_start; oh < oh_end; oh++) {
310                         const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
311                         const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb;
312                         for (int ow = 0; ow < ow_start; ++ow)
313                             col[col_idx_oh + ow] = shift;
314                         for (int ow = ow_start; ow < ow_end; ++ow)
315                             col[col_idx_oh + ow]
316                                     = imtr[imtr_idx_oh + ow] + shift;
317                         for (int ow = ow_end; ow < wb; ++ow)
318                             col[col_idx_oh + ow] = shift;
319                     }
320                     for (int oh = oh_end; oh < hb; oh++) {
321                         const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
322                         for (int ow = 0; ow < wb; ++ow)
323                             col[col_idx_oh + ow] = shift;
324                     }
325                 }
326             }
327         }
328     } else {
329         parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
330             [&](int kh, int kw, int ic, int oh) {
331                 const int hp = tp - kh * dh;
332                 const int ih = (oh + hs) * sh - hp;
333                 const ptrdiff_t col_idx_base
334                         = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb;
335                 if (ih < 0 || ih >= jcp.ih)
336                     for (int ow = 0; ow < wb; ow++)
337                         col[col_idx_base + ow] = shift;
338                 else {
339                     const int wp = lp - kw * dw;
340                     const int ow_start = limit(0, wb, div_up(wp, sw) - ws);
341                     const int ow_end
342                             = limit(0, wb, div_up(jcp.iw + wp, sw) - ws);
343                     for (int ow = 0; ow < ow_start; ow++)
344                         col[col_idx_base + ow] = shift;
345                     const int iw_base = ws * sw - wp;
346                     const ptrdiff_t im_idx_base = ih * im_ih_stride + ic;
347                     for (int ow = ow_start; ow < ow_end; ow++) {
348                         const int iw = iw_base + ow * sw;
349                         const ptrdiff_t im_idx
350                                 = im_idx_base + iw * im_iw_stride;
351                         col[col_idx_base + ow] = im[im_idx] + shift;
352                     }
353                     for (int ow = ow_end; ow < wb; ow++)
354                         col[col_idx_base + ow] = shift;
355                 }
356             });
357     }
358 }
359
360 template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp,
361         const int8_t *__restrict im, int8_t *__restrict imtr,
362         uint8_t *__restrict col, int hs, int hb, int ws, int wb);
363 template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp,
364         const uint8_t *__restrict im, uint8_t *__restrict imtr,
365         uint8_t *__restrict col, int hs, int hb, int ws, int wb);
366
367 template <typename T>
368 void im2col_u8_3d(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
369                   uint8_t *__restrict col, int od) {
370     uint8_t shift = jcp.signed_input ? 128 : 0;
371     const int dh = 1 + jcp.dilate_h;
372     const int dw = 1 + jcp.dilate_w;
373     const int dd = 1 + jcp.dilate_d;
374     const int sh = jcp.stride_h;
375     const int sw = jcp.stride_w;
376     const int sd = jcp.stride_d;
377     const int im_iw_stride = jcp.ic * jcp.ngroups;
378     const int im_ih_stride = jcp.iw * im_iw_stride;
379     const int im_id_stride = jcp.ih * im_ih_stride;
380     const int tp = jcp.t_pad;
381     const int lp = jcp.l_pad;
382     const int fp = jcp.f_pad;
383
384     const T* im_loc = im + od * sd * im_id_stride;
385
386     parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, jcp.oh, jcp.ow,
387                 [&](int kd, int kh, int kw, int ic, int oh, int ow) {
388                     int im_idx = (kd * dd - fp) * im_id_stride
389                                  + (kh * dh - tp + oh * sh) * im_ih_stride
390                                  + (kw * dw - lp + ow * sw) * im_iw_stride
391                                  + ic;
392
393                     int col_idx = kd * jcp.kh * jcp.kw * jcp.ic * jcp.oh * jcp.ow
394                                   + kh * jcp.kw * jcp.ic * jcp.oh * jcp.ow
395                                   + kw * jcp.ic * jcp.oh * jcp.ow
396                                   + ic * jcp.oh * jcp.ow
397                                   + oh * jcp.ow
398                                   + ow;
399
400                     int id = od * sd + kd * dd - fp;
401                     int ih = oh * sh + kh * dh - tp;
402                     int iw = ow * sw + kw * dw - lp;
403
404                     if (id < 0 || id >= jcp.id || ih < 0 || ih >= jcp.ih || iw < 0 || iw >= jcp.iw)
405                         col[col_idx] = shift;
406                     else
407                         col[col_idx] = im_loc[im_idx] + shift;
408     });
409 }
410
411 template void im2col_u8_3d<int8_t>(const jit_gemm_conv_conf_t &jcp, const int8_t *__restrict im,
412                                    uint8_t *__restrict col, int od);
413
414 template void im2col_u8_3d<uint8_t>(const jit_gemm_conv_conf_t &jcp, const uint8_t *__restrict im,
415                                     uint8_t *__restrict col, int od);
416
417 /* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
418 void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
419         int32_t *__restrict im)
420 {
421     parallel(0, (size_t)mkldnn_get_max_threads(), [&](const int ithr, const int nthr) {
422         int h_nthr = nstl::min(jcp.ih, nthr);
423         int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
424         int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
425         if (ithr < h_nthr * w_nthr) {
426             h_ithr = ithr / w_nthr;
427             w_ithr = ithr % w_nthr;
428             balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
429             balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
430         } else {
431             h_ithr = w_ithr = -ithr;
432             h_s = h_e = w_s = w_e = -1;
433         }
434
435         for (int ih = h_s; ih < h_e; ++ih) {
436             for (int iw = w_s; iw < w_e; ++iw) {
437                 PRAGMA_OMP_SIMD()
438                 for (int ic = 0; ic < jcp.ic; ++ic) {
439                     im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
440                 }
441             }
442         }
443
444         // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
445         for (int oh = 0; oh < jcp.oh; ++oh) {
446             for (int ow = 0; ow < jcp.ow; ++ow) {
447                 for (int kh = 0; kh < jcp.kh; ++kh) {
448                     const int ih = oh * jcp.stride_h
449                         - jcp.t_pad + kh * (1 + jcp.dilate_h);
450                     if (ih < h_s || ih >= h_e) continue;
451
452                     for (int kw = 0; kw < jcp.kw; ++kw) {
453                         const int iw = ow * jcp.stride_w
454                             - jcp.l_pad + kw * (1 + jcp.dilate_w);
455                         if (iw < w_s || iw >= w_e) continue;
456
457                         const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
458                                 + kh) * jcp.kw + kw) * jcp.ic;
459                         const size_t im_idx
460                             = (ih * jcp.iw + iw) * jcp.ic;
461                         PRAGMA_OMP_SIMD()
462                         for (int ic = 0; ic < jcp.ic; ++ic) {
463                             im[im_idx + ic] += col[col_idx + ic];
464                         }
465                     }
466                 }
467             }
468         }
469     });
470 }
471
472 void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
473         int od)
474 {
475     parallel_nd(jcp.ic, [&](int ic) {
476         const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
477         float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
478
479         int id = od * jcp.stride_d - jcp.f_pad;
480         for (int kd = 0; kd < jcp.kd; ++kd) {
481             if (id < 0 || id >= jcp.id) {
482                 col_ += jcp.kh * jcp.kw * jcp.os;
483                 id += (1 + jcp.dilate_d);
484                 continue;
485             }
486
487             float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw;
488
489             for (int oh = 0; oh < jcp.oh; ++oh) {
490             for (int kh = 0; kh < jcp.kh; ++kh) {
491                 const int ih = oh * jcp.stride_h - jcp.t_pad
492                     + kh * (1 + jcp.dilate_h);
493                 if (ih < 0 || ih >= jcp.ih) continue;
494
495                 for (int ow = 0; ow < jcp.ow; ++ow) {
496                 for (int kw = 0; kw < jcp.kw; ++kw) {
497                     const int iw = ow * jcp.stride_w - jcp.l_pad
498                         + kw * (1 + jcp.dilate_w);
499                     if (iw < 0 || iw >= jcp.iw) continue;
500
501                     const size_t col_idx =
502                         ((kh * jcp.kw + kw) * jcp.oh + oh) * jcp.ow + ow;
503                     const size_t im_idx = ih*jcp.iw + iw;
504                     im_[im_idx] += col_[col_idx];
505                 }}
506             }}
507
508             col_ += jcp.kh * jcp.kw * jcp.os;
509             id += (1 + jcp.dilate_d);
510         }
511     });
512 }
513
514 void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
515     const size_t col_step = jcp.ks * jcp.os;
516     const size_t im_step = jcp.ih * jcp.iw;
517     const int iS = jcp.ih * jcp.iw;
518
519     parallel_nd(jcp.ic, [&](int ic) {
520         float *__restrict im_ = im + ic * im_step;
521         const float *__restrict col_ = col + ic * col_step;
522         PRAGMA_OMP_SIMD()
523         for (int is = 0; is < iS; ++is) im_[is] = 0.;
524
525         for (int kh = 0; kh < jcp.kh; ++kh) {
526         for (int oh = 0; oh < jcp.oh; ++oh) {
527             const int ih =
528                     oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
529             if (ih < 0 || ih >= jcp.ih) continue;
530
531             for (int kw = 0; kw < jcp.kw; ++kw) {
532             for (int ow = 0; ow < jcp.ow; ++ow) {
533                 const int iw =
534                         ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
535                 if (iw < 0 || iw >= jcp.iw) continue;
536
537                 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
538                 const size_t im_idx = ih*jcp.iw + iw;
539                 im_[im_idx] += col_[col_idx];
540             }
541             }
542         }
543         }
544     });
545 }
546
547 status_t init_conf(jit_gemm_conv_conf_t &jcp,
548         memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
549         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
550         const memory_desc_wrapper &dst_d, int max_threads) {
551     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
552     const int ndims = src_d.ndims();
553     const int is_1d = ndims == 3;
554     const int is_3d = ndims == 5;
555
556     jcp.prop_kind = cd.prop_kind;
557
558     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
559     jcp.mb = src_d.dims()[0];
560
561     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
562     jcp.ic = src_d.dims()[1] / jcp.ngroups;
563     jcp.id = is_3d ? src_d.dims()[2] : 1;
564     jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
565     jcp.iw = src_d.dims()[ndims - 1];
566     jcp.od = is_3d ? dst_d.dims()[2] : 1;
567     jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
568     jcp.ow = dst_d.dims()[ndims - 1];
569
570     jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
571     jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
572     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
573
574     jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
575     jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
576     jcp.l_pad = cd.padding[0][ndims - 3];
577
578     jcp.stride_d = is_3d ? cd.strides[0] : 1;
579     jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
580     jcp.stride_w = cd.strides[ndims - 3];
581
582     jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
583     jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
584     jcp.dilate_w = cd.dilates[ndims - 3];
585
586     jcp.src_fmt = src_d.format();
587     jcp.with_bias = cd.bias_desc.format != memory_format::undef
588         || cd.diff_bias_desc.format != memory_format::undef;
589
590     jcp.is = jcp.ih * jcp.iw;
591     jcp.os = jcp.oh * jcp.ow;
592     jcp.ks = jcp.kh * jcp.kw * jcp.kd;
593
594     jcp.signed_input = src_d.data_type() == data_type::s8;
595     jcp.wei_adj_scale =
596         !jcp.signed_input || mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
597
598     jcp.im2col_sz = !everyone_is(true,
599             jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id,
600             jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1,
601             jcp.ks == 1, !jcp.signed_input)
602         ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0;
603
604     jcp.outer_threading = false;
605
606     bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8)
607         && weights_d.data_type() == s8;
608
609     const bool is_bwd_d = jcp.prop_kind == backward_data;
610     const bool is_bwd_w = jcp.prop_kind == backward_weights;
611     const bool is_fwd = !is_bwd_d && !is_bwd_w;
612
613     bool is_bf16_conv = false
614         || (is_fwd && utils::everyone_is(bf16,
615                 src_d.data_type(), weights_d.data_type()))
616         || (is_bwd_d && utils::everyone_is(bf16,
617                 dst_d.data_type(), weights_d.data_type()))
618         || (is_bwd_w && utils::everyone_is(bf16,
619                 src_d.data_type(), dst_d.data_type()));
620     if (is_bf16_conv && !mayiuse(avx512_core))
621         return status::unimplemented;
622
623     bool is_bf16_to_bf16_conv = is_bf16_conv
624         && ((is_fwd && bf16 == dst_d.data_type())
625                 || (is_bwd_d && bf16 == src_d.data_type())
626                 || (is_bwd_w && bf16 == weights_d.data_type()));
627
628     const int vlen = mayiuse(avx512_common)
629         ? cpu_isa_traits<avx512_common>::vlen
630         : mayiuse(avx)
631             ? cpu_isa_traits<avx>::vlen
632             : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4;
633     const int data_size = (is_int8_conv ? 1 : (is_bf16_conv ? 2 : 4));
634     const int simd_w = vlen / data_size;
635
636     jcp.oh_block = is_fwd ? jcp.oh : jcp.ih;
637     jcp.ow_block = is_fwd ? jcp.ow : jcp.iw;
638
639     using namespace memory_tracking::names;
640     bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
641
642     // TODO: maybe mitigate blocking restriction
643     const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
644     const int L2 = get_cache_size(2, true)
645           / data_size;
646     bool is_blocking_applicable = true
647             && is_fwd && jcp.im2col_sz
648             && jcp.id == 1 && jcp.od == 1
649 // This condition was relaxed to support old behaviour
650 //            && jcp.dilate_h == 0 && jcp.dilate_w == 0
651             && !is_depthwise
652             && wei_size < L2/2;
653     if (is_blocking_applicable) {
654         // looking for oh and ow blocking
655         int h_block{ jcp.oh_block }, w_block{ jcp.ow_block };
656         const int ic = jcp.ic;
657         const int oc = jcp.oc;
658         const int iw = jcp.iw;
659         const int ow = jcp.ow;
660         const int oh = jcp.oh;
661         const int os = oh * ow;
662
663         // 1. cache requirement
664         int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
665         if (is_int8_conv) {
666             // Heuristic rule: gemm needed a lot of memory for internal usage
667             row_size *= 5;
668             // memory for accumulators
669             row_size += oc * ow * sizeof(uint32_t);
670             // memory for transposition
671             row_size += ic * iw;
672         }
673
674         h_block = nstl::max(1, nstl::min(oh, div_up(L2 - wei_size, row_size)));
675         if (h_block == 1) {
676             int col_size = ic * jcp.ks + 2 * (ic + oc);
677             if (is_int8_conv) {
678                 col_size *= 5;
679                 col_size += oc * sizeof(uint32_t);
680                 col_size += ic;
681             }
682             w_block = nstl::max(1, nstl::min(ow, div_up(L2 - wei_size, col_size)));
683         }
684
685         // 2. threading requirement
686         if (h_block != oh)
687             h_block = nstl::max(1, rnd_dn(h_block, 4));
688         if (w_block != ow)
689             w_block = nstl::max(1, rnd_dn(w_block, simd_w));
690
691         float thr_eff = 0.f;
692         float thr_eff_treshold = 0.9f;
693         if (w_block == ow) {
694             do {
695                 int nb_h = div_up(oh, h_block);
696                 size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
697                 float disb = (float)oh / rnd_up(oh, h_block);
698                 thr_eff = (float)work / rnd_up(work, max_threads);
699                 thr_eff = (thr_eff + disb) / 2.f;
700                 if (thr_eff >= thr_eff_treshold)
701                     break;
702                 h_block = rnd_dn(h_block - 4, 4);
703             } while (h_block > 0);
704         }
705         if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
706         {
707             h_block = 1;
708             int nb_h = oh;
709             do {
710                 int nb_w = div_up(ow, w_block);
711                 size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
712                 float disb = (float)ow / rnd_up(ow, w_block);
713                 thr_eff = (float)work_amount / rnd_up(work_amount, max_threads);
714                 thr_eff = (thr_eff + disb) / 2.f;
715                 if (thr_eff > thr_eff_treshold)
716                     break;
717                 w_block = rnd_dn(w_block - simd_w, simd_w);
718             } while (w_block > 0);
719         }
720         h_block = nstl::max(1, h_block);
721         w_block = nstl::max(1, w_block);
722         const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w);
723         const float inner_thr_eff
724                 = (float)inner_work / rnd_up(inner_work, max_threads);
725         if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
726             jcp.oh_block = h_block;
727             jcp.ow_block = w_block;
728             jcp.outer_threading = true;
729         }
730         // updating jcp.im2col_sz
731         if (jcp.oh_block != 1)
732             jcp.ow_block = ow;
733         jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
734     }
735     //  For threading selection in bwd_d we do:
736     //  1. Rough estimation of efficiency for inner and outer threading.
737     //  2. Gemm size estimation in assumption that it does not work
738     //  so effectively for small sizes.
739     //  64K - this is heuristic gemm size per thread threshold.
740     const int gemm_thrld = 64 * 1024;
741
742     if (is_int8_conv) {
743         if (is_fwd) {
744             if (!jcp.outer_threading) {
745                 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1
746                    && jcp.ngroups != 1;
747                 const size_t outer_work = jcp.ngroups * jcp.mb;
748                 const float outer_thr_eff
749                     = (float)outer_work / rnd_up(outer_work, max_threads);
750                 const size_t inner_work
751                     = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
752                 const float inner_thr_eff
753                     = (float)inner_work / rnd_up(inner_work, max_threads);
754                 jcp.outer_threading = (is_depthwise
755                  || (jcp.is / max_threads < 64 && jcp.mb != 1))
756                  && (outer_thr_eff / inner_thr_eff >= 1.f
757                      || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
758             }
759             jcp.nthr = jcp.outer_threading ? max_threads : 1;
760             scratchpad.book(key_conv_gemm_col,
761                 sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
762             scratchpad.book(key_conv_int_dat_in_acc_dt,
763                 sizeof(int32_t) * jcp.nthr * jcp.oh_block
764                     * jcp.ow_block * jcp.oc);
765             scratchpad.book(key_conv_gemm_imtr,
766                 sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic);
767         } else if (is_bwd_d) {
768             bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
769             const size_t outer_work = jcp.ngroups * jcp.mb;
770             const float outer_thr_eff
771                     = (float)outer_work / rnd_up(outer_work, max_threads);
772             const size_t inner_work
773                     = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
774             const float inner_thr_eff
775                     = (float)inner_work / rnd_up(inner_work, max_threads);
776             jcp.outer_threading = (is_depthwise
777                  || (jcp.is / max_threads < 64 && jcp.mb != 1))
778                  && (outer_thr_eff / inner_thr_eff >= 1.f
779                      || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
780
781             jcp.nthr = jcp.outer_threading ? max_threads : 1;
782             scratchpad.book(key_conv_gemm_col,
783                 sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
784             scratchpad.book(key_conv_int_dat_in_acc_dt,
785                 sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
786         } else if (is_bwd_w) {
787             assert(!"unimplemented prop_kind");
788             return status::unimplemented;
789         }
790     } else {
791         if (is_fwd) {
792             if (!jcp.outer_threading) {
793                 const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
794                 const float outer_thr_eff = (float)outer_work_amount
795                         / rnd_up(outer_work_amount, max_threads);
796                 const size_t inner_work_amount
797                         = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
798                 const float inner_thr_eff = (float)inner_work_amount
799                         / rnd_up(inner_work_amount, max_threads);
800                 jcp.outer_threading = jcp.os / max_threads < 512
801                     && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
802                     && (outer_thr_eff / inner_thr_eff >= 1.f
803                       || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
804             }
805         } else if (is_bwd_d) {
806             const size_t outer_work_amount = jcp.ngroups * jcp.mb;
807             const float outer_thr_eff = (float)outer_work_amount
808                 / rnd_up(outer_work_amount, max_threads);
809             const size_t inner_work
810                 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
811             const float inner_thr_eff = (float)inner_work
812                 / rnd_up(inner_work, max_threads);
813             jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
814                 && (jcp.mb != 1 || jcp.ngroups > 2)
815                 && (outer_thr_eff / inner_thr_eff >= 1.f
816                     || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
817         } else if (is_bwd_w)
818             jcp.outer_threading = jcp.os / max_threads < 256
819                 && (jcp.mb != 1 || jcp.ngroups > 2);
820
821         jcp.nthr = jcp.outer_threading ? max_threads : 1;
822         const size_t gemm_col_datatype_size = is_bf16_conv && !is_bwd_d
823             ? sizeof(mkldnn_bfloat16_t)
824             : sizeof(float);
825         scratchpad.book(key_conv_gemm_col,
826                 gemm_col_datatype_size * jcp.nthr * jcp.im2col_sz);
827
828         const int sizeof_cacheline_float = 16;
829         if (is_bwd_w) {
830             jcp.need_wei_reduction = mkldnn_thr_syncable()
831                 ? jcp.mb != 1 && jcp.nthr != 1 : false;
832             scratchpad.book(key_conv_wei_reduction,
833                     sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
834
835             if (is_bf16_conv && jcp.with_bias) {
836                 const size_t ws_size = sizeof(float)
837                     * max_threads * rnd_up(jcp.ow, sizeof_cacheline_float);
838                 scratchpad.book(key_conv_dst_bf16_convert_wsp, ws_size);
839             }
840         }
841
842         if (is_bf16_to_bf16_conv) {
843             size_t conv_acc_buffer_size = 0;
844             if (is_fwd)
845                 conv_acc_buffer_size = sizeof(float) * jcp.nthr
846                     * rnd_up(jcp.oc * jcp.oh_block * jcp.ow_block,
847                           sizeof_cacheline_float);
848             else if (is_bwd_d)
849                 conv_acc_buffer_size = sizeof(float) * jcp.nthr
850                     * rnd_up(jcp.ic * jcp.ih * jcp.iw * jcp.id,
851                           sizeof_cacheline_float);
852             else if (is_bwd_w)
853                 conv_acc_buffer_size = sizeof(float) * weights_d.size();
854             scratchpad.book(key_conv_int_dat_in_acc_dt, conv_acc_buffer_size);
855         }
856     }
857
858     return status::success;
859 }
860
861 void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
862         int &nthr_g, int &ithr_mb, int &nthr_mb) {
863     nthr_g = nstl::min(ngroups, nthr);
864     nthr_mb = nstl::min(mb, nthr / nthr_g);
865     if (ithr / nthr_mb >= ngroups) {
866         ithr_g = ithr_mb = -1;
867     } else {
868         ithr_g = ithr / nthr_mb;
869         ithr_mb = ithr % nthr_mb;
870     }
871 }
872
873 void bwd_weights_reduction_par(int ithr, int nthr,
874         const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
875         float *weights) {
876     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
877
878     size_t weights_start{0}, weights_end{0};
879     balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
880
881     for (int i = 0; i < nthr; ++i) {
882         const float *ws_i = weights_reduce_ws + i * weights_g_size;
883         for (size_t s = weights_start; s < weights_end; ++s)
884             weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];
885     }
886 }
887
888 };
889
890 }
891 }
892 }