6c8a4daf63da1ae24b481d21864d722d3099da48
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution_utils.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "type_helpers.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "utils.hpp"
23
24 #include "gemm_convolution_utils.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::utils;
33 using namespace prop_kind;
34 using namespace data_type;
35
36 namespace jit_gemm_convolution_utils {
37
38 void im2col_3d(jit_gemm_conv_conf_t &jcp, const float *im, float *col, int od) {
39     const size_t OHW = jcp.oh * jcp.ow;
40     const size_t im_step = jcp.ih * jcp.iw * jcp.id;
41     const size_t col_step = jcp.ks * OHW;
42
43     parallel_nd(jcp.ic, [&](int ic) {
44         const float *im_loc = im + ic * im_step;
45         float *col_loc = col + ic * col_step;
46         int id = od * jcp.stride_d - jcp.f_pad;
47         for (int kd = 0; kd < jcp.kd; ++kd) {
48             float *col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
49             if (id < 0 || id >= jcp.id) {
50                 int ih_ = -jcp.t_pad;
51                 for (int kh = 0; kh < jcp.kh; ++kh) {
52                     int ih = ih_;
53                     for (int oh = 0; oh < jcp.oh; ++oh) {
54                         if (ih < 0 || ih >= jcp.ih) {
55                             ih += jcp.stride_h;
56                             continue;
57                         }
58                         int iw_ = -jcp.l_pad;
59                         for (int kw = 0; kw < jcp.kw; ++kw) {
60                             int iw = iw_;
61                             for (int ow = 0; ow < jcp.ow; ++ow) {
62                                 if (iw < 0 || iw >= jcp.iw) {
63                                     iw += jcp.stride_w;
64                                     continue;
65                                 }
66
67                                 const size_t col_idx = kw * OHW + oh * jcp.ow
68                                     + ow;
69
70                                 col_[col_idx] = 0;
71                                 iw += jcp.stride_w;
72                             }
73                             iw_ += (1 + jcp.dilate_w);
74                         }
75                         ih += jcp.stride_h;
76                     }
77                     ih_ += (1 + jcp.dilate_h);
78                     col_ += jcp.kw * OHW;
79                 }
80             } else {
81                 const float *im_ = im_loc + id * jcp.ih * jcp.iw;
82                 int ih_ = -jcp.t_pad;
83                 for (int kh = 0; kh < jcp.kh; ++kh) {
84                     int ih = ih_;
85                     for (int oh = 0; oh < jcp.oh; ++oh) {
86                         if (ih < 0 || ih >= jcp.ih) {
87                             ih += jcp.stride_h;
88                             continue;
89                         }
90                         int iw_ = -jcp.l_pad;
91                         for (int kw = 0; kw < jcp.kw; ++kw) {
92                             int iw = iw_;
93                             for (int ow = 0; ow < jcp.ow; ++ow) {
94                                 if (iw < 0 || iw >= jcp.iw) {
95                                     iw += jcp.stride_w;
96                                     continue;
97                                 }
98
99                                 const size_t col_idx = kw * OHW + oh * jcp.ow
100                                     + ow;
101                                 const size_t im_idx = ih * jcp.iw + iw;
102
103                                 col_[col_idx] = im_[im_idx];
104                                 iw += jcp.stride_w;
105                             }
106                             iw_ += (1 + jcp.dilate_w);
107                         }
108                         ih += jcp.stride_h;
109                     }
110                     ih_ += (1 + jcp.dilate_h);
111                     col_ += jcp.kw * OHW;
112                 }
113             }
114             id += (1 + jcp.dilate_d);
115         }
116     });
117 }
118
119 void im2col(jit_gemm_conv_conf_t &jcp, const float *im, float *col) {
120     if (jcp.ic == 1) {
121         parallel_nd(jcp.kh, jcp.oh, [&](int kh, int oh) {
122             const int ih = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
123             if (ih < 0 || ih >= jcp.ih) return;
124
125             for (int kw = 0; kw < jcp.kw; ++kw) {
126             for (int ow = 0; ow < jcp.ow; ++ow) {
127                 const int iw = ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
128                 if (iw < 0 || iw >= jcp.iw) continue;
129
130                 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
131                 const size_t im_idx = ih*jcp.iw + iw;
132                 col[col_idx] = im[im_idx];
133             }}
134         });
135     } else {
136         const size_t im_step = jcp.ih * jcp.iw;
137         const size_t col_step = jcp.ks * jcp.os;
138
139         parallel_nd(jcp.ic, [&](int ic) {
140             const float *im_ = im + ic * im_step;
141             float *col_ = col + ic * col_step;
142
143             for (int kh = 0; kh < jcp.kh; ++kh) {
144             for (int oh = 0; oh < jcp.oh; ++oh) {
145                 const int ih = oh * jcp.stride_h
146                                - jcp.t_pad + kh * (1 + jcp.dilate_h);
147                 if (ih < 0 || ih >= jcp.ih) continue;
148
149                 for (int kw = 0; kw < jcp.kw; ++kw) {
150                 for (int ow = 0; ow < jcp.ow; ++ow) {
151                     const int iw = ow * jcp.stride_w
152                                    - jcp.l_pad + kw * (1 + jcp.dilate_w);
153                     if (iw < 0 || iw >= jcp.iw) continue;
154
155                     const size_t col_idx = ((kh * jcp.kw + kw) * jcp.oh+oh)
156                                            * jcp.ow + ow;
157                     const size_t im_idx = ih*jcp.iw + iw;
158                     col_[col_idx] = im_[im_idx];
159                 }}
160             }}
161         });
162     }
163 }
164
165 /* col[oh][ow][kh][kw][ic] <-- im2col_u8(im[ih][iw][ic]) */
166 void im2col_u8(jit_gemm_conv_conf_t &jcp, const uint8_t *im, uint8_t *col) {
167     parallel_nd(jcp.oh, jcp.ow, [&](int oh, int ow) {
168             for (int kh = 0; kh < jcp.kh; ++kh) {
169                 const int ih = oh * jcp.stride_h
170                     - jcp.t_pad + kh * (1 + jcp.dilate_h);
171                 if (ih < 0 || ih >= jcp.ih) continue;
172
173                 for (int kw = 0; kw < jcp.kw; ++kw) {
174                     const int iw = ow * jcp.stride_w
175                         - jcp.l_pad + kw * (1 + jcp.dilate_w);
176                     if (iw < 0 || iw >= jcp.iw) continue;
177
178                     const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh + kh)
179                             * jcp.kw + kw) * jcp.ic;
180                     const size_t im_idx
181                         = (ih * jcp.iw + iw) * jcp.ngroups * jcp.ic;
182                     PRAGMA_OMP_SIMD()
183                     for (int ic = 0; ic < jcp.ic; ++ic) {
184                         col[col_idx + ic] = im[im_idx + ic];
185                     }
186                 }
187             }
188         }
189     );
190 }
191
192 /* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
193 void col2im_s32(jit_gemm_conv_conf_t &jcp, const int32_t *col, int32_t *im) {
194     parallel(0, [&](const int ithr, const int nthr) {
195         int h_nthr = nstl::min(jcp.ih, nthr);
196         int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
197         int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
198         if (ithr < h_nthr * w_nthr) {
199             h_ithr = ithr / w_nthr;
200             w_ithr = ithr % w_nthr;
201             balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
202             balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
203         } else {
204             h_ithr = w_ithr = -ithr;
205             h_s = h_e = w_s = w_e = -1;
206         }
207
208         for (int ih = h_s; ih < h_e; ++ih) {
209             for (int iw = w_s; iw < w_e; ++iw) {
210                 PRAGMA_OMP_SIMD()
211                 for (int ic = 0; ic < jcp.ic; ++ic) {
212                     im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
213                 }
214             }
215         }
216
217         // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
218         for (int oh = 0; oh < jcp.oh; ++oh) {
219             for (int ow = 0; ow < jcp.ow; ++ow) {
220                 for (int kh = 0; kh < jcp.kh; ++kh) {
221                     const int ih = oh * jcp.stride_h
222                         - jcp.t_pad + kh * (1 + jcp.dilate_h);
223                     if (ih < h_s || ih >= h_e) continue;
224
225                     for (int kw = 0; kw < jcp.kw; ++kw) {
226                         const int iw = ow * jcp.stride_w
227                             - jcp.l_pad + kw * (1 + jcp.dilate_w);
228                         if (iw < w_s || iw >= w_e) continue;
229
230                         const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
231                                 + kh) * jcp.kw + kw) * jcp.ic;
232                         const size_t im_idx
233                             = (ih * jcp.iw + iw) * jcp.ic;
234                         PRAGMA_OMP_SIMD()
235                         for (int ic = 0; ic < jcp.ic; ++ic) {
236                             im[im_idx + ic] += col[col_idx + ic];
237                         }
238                     }
239                 }
240             }
241         }
242     });
243 }
244
245 void col2im_3d(jit_gemm_conv_conf_t &jcp, const float *col, float *im, int od) {
246     parallel_nd(jcp.ic, [&](int ic) {
247         const float *col_ = col + (size_t)ic * jcp.ks * jcp.os;
248         float *im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
249
250         int id = od * jcp.stride_d - jcp.f_pad;
251         for (int kd = 0; kd < jcp.kd; ++kd) {
252             if (id < 0 || id >= jcp.id) {
253                 col_ += jcp.kh * jcp.kw * jcp.os;
254                 id += (1 + jcp.dilate_d);
255                 continue;
256             }
257
258             float *im_ = im_ic + id * jcp.ih * jcp.iw;
259
260             for (int oh = 0; oh < jcp.oh; ++oh) {
261             for (int kh = 0; kh < jcp.kh; ++kh) {
262                 const int ih = oh * jcp.stride_h - jcp.t_pad
263                     + kh * (1 + jcp.dilate_h);
264                 if (ih < 0 || ih >= jcp.ih) continue;
265
266                 for (int ow = 0; ow < jcp.ow; ++ow) {
267                 for (int kw = 0; kw < jcp.kw; ++kw) {
268                     const int iw = ow * jcp.stride_w - jcp.l_pad
269                         + kw * (1 + jcp.dilate_w);
270                     if (iw < 0 || iw >= jcp.iw) continue;
271
272                     const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
273                     const size_t im_idx = ih*jcp.iw + iw;
274                     im_[im_idx] += col_[col_idx];
275                 }}
276             }}
277
278             col_ += jcp.kh * jcp.kw * jcp.os;
279             id += (1 + jcp.dilate_d);
280         }
281     });
282 }
283
284 void col2im(
285     jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
286
287     const size_t col_step = jcp.ks * jcp.os;
288     const size_t im_step = jcp.ih * jcp.iw;
289     const int iS = jcp.ih * jcp.iw;
290
291     parallel_nd(jcp.ic, [&](int ic) {
292         float *im_ = im + ic * im_step;
293         const float *col_ = col + ic * col_step;
294         PRAGMA_OMP_SIMD()
295         for (int is = 0; is < iS; ++is) im_[is] = 0.;
296
297         for (int kh = 0; kh < jcp.kh; ++kh) {
298         for (int oh = 0; oh < jcp.oh; ++oh) {
299             const int ih = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
300             if (ih < 0 || ih >= jcp.ih) continue;
301
302             for (int kw = 0; kw < jcp.kw; ++kw) {
303             for (int ow = 0; ow < jcp.ow; ++ow) {
304                 const int iw = ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
305                 if (iw < 0 || iw >= jcp.iw) continue;
306
307                 const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
308                 const size_t im_idx = ih*jcp.iw + iw;
309                 im_[im_idx] += col_[col_idx];
310             }
311             }
312         }
313         }
314     });
315 }
316
317 void init_conf(
318     jit_gemm_conv_conf_t &jcp, const convolution_desc_t &cd,
319     const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
320     const memory_desc_wrapper &dst_d, int max_threads,
321     bool with_relu, float relu_negative_slope) {
322
323     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
324     jcp.prop_kind = cd.prop_kind;
325     const int ndims = src_d.ndims();
326
327     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
328     jcp.mb = src_d.dims()[0];
329
330     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
331     jcp.ic = src_d.dims()[1] / jcp.ngroups;
332
333     jcp.id = (ndims == 4) ? 1 : src_d.dims()[2];
334     jcp.ih = src_d.dims()[ndims - 2];
335     jcp.iw = src_d.dims()[ndims - 1];
336     jcp.od = (ndims == 4) ? 1 : dst_d.dims()[2];
337     jcp.oh = dst_d.dims()[ndims - 2];
338     jcp.ow = dst_d.dims()[ndims - 1];
339
340     jcp.kd = (ndims == 4) ? 1 : weights_d.dims()[with_groups + 2];
341     jcp.kh = weights_d.dims()[with_groups + ndims - 2];
342     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
343
344     jcp.f_pad = (ndims == 4) ? 0 : cd.padding[0][0];
345     jcp.t_pad = cd.padding[0][ndims - 4];
346     jcp.l_pad = cd.padding[0][ndims - 3];
347
348     jcp.stride_d = (ndims == 4) ? 1 : cd.strides[0];
349     jcp.stride_h = cd.strides[ndims - 4];
350     jcp.stride_w = cd.strides[ndims - 3];
351
352     jcp.dilate_d = (ndims == 4) ? 0 : cd.dilates[0];
353     jcp.dilate_h = cd.dilates[ndims - 4];
354     jcp.dilate_w = cd.dilates[ndims - 3];
355
356     jcp.src_fmt = src_d.format();
357     jcp.with_bias
358         = cd.bias_desc.format != memory_format::undef
359         || cd.diff_bias_desc.format != memory_format::undef;
360     jcp.with_relu = with_relu;
361     jcp.relu_negative_slope = relu_negative_slope;
362
363     jcp.is = jcp.ih * jcp.iw;
364     jcp.os = jcp.oh * jcp.ow;
365     jcp.ks = jcp.kh * jcp.kw * jcp.kd;
366     jcp.im2col_sz = !(jcp.oh == jcp.ih && jcp.ow == jcp.iw
367                             && jcp.od == jcp.id && jcp.ks == 1)
368         ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os
369         : 0;
370
371     bool do_outer_threading = false;
372     bool is_int8_conv = (cd.src_desc.data_type == u8
373             && cd.weights_desc.data_type == s8);
374     if (is_int8_conv) {
375         bool is_depthwise =
376                 utils::everyone_is(1, jcp.ic, jcp.oc) && jcp.ngroups != 1;
377         do_outer_threading
378                 = (is_depthwise || (jcp.os / max_threads < 64 && jcp.mb != 1));
379     } else {
380         if (utils::one_of(jcp.prop_kind, forward_training, forward_inference))
381             do_outer_threading = jcp.os / max_threads < 512
382                 && utils::implication(jcp.od == 1, (jcp.mb != 1 || jcp.ngroups > 2));
383         else if (jcp.prop_kind == backward_data)
384             do_outer_threading = (jcp.mb != 1 || jcp.ngroups > 2);
385         else //(jcp.prop_kind == backward_weights)
386             do_outer_threading = jcp.os / max_threads < 256
387                        && (jcp.mb != 1 || jcp.ngroups > 2);
388     }
389     jcp.nthr = do_outer_threading ? max_threads : 1;
390     jcp.need_wei_reduction = mkldnn_thr_syncable()
391         ? (jcp.mb != 1 && jcp.nthr != 1) : false;
392 }
393
394 status_t prepare_scratchpad(jit_gemm_conv_conf_t &jcp,
395                 scratchpad_t **scratchpad_, size_t size, const int nthr) {
396     if (size > 0) {
397         *scratchpad_ = create_scratchpad(nthr * size);
398         if (*scratchpad_ == nullptr) return status::out_of_memory;
399     } else {
400         *scratchpad_ = nullptr;
401     }
402     return status::success;
403 }
404
405 void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
406         int &nthr_g, int &ithr_mb, int &nthr_mb) {
407     nthr_g = nstl::min(ngroups, nthr);
408     nthr_mb = nstl::min(mb, nthr / nthr_g);
409     if (ithr / nthr_mb >= ngroups) {
410         ithr_g = ithr_mb = -1;
411     } else {
412         ithr_g = ithr / nthr_mb;
413         ithr_mb = ithr % nthr_mb;
414     }
415 }
416
417 void bwd_weights_reduction_par(int ithr, int nthr, const jit_gemm_conv_conf_t &jcp,
418         const float *weights_reduce_ws, float *weights) {
419     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
420
421     size_t weights_start{0}, weights_end{0};
422     balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
423
424     for (int i = 0; i < nthr; ++i) {
425         const float *ws_i = weights_reduce_ws + i * weights_g_size;
426         for (size_t s = weights_start; s < weights_end; ++s)
427             weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];
428     }
429 }
430
431 };
432
433 }
434 }
435 }