Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "gemm_convolution.hpp"
21 #include "utils.hpp"
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "ref_eltwise.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 void gemm_convolution_fwd_t::execute_forward() const {
36     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
37     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
38     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
39     auto dst = reinterpret_cast<data_t*>(this->memory());
40
41     auto col = scratchpad().get<data_t>(key_conv_gemm_col);
42
43     const auto &jcp = this->pd()->jcp_;
44     const int MB = pd()->MB();
45
46     const memory_desc_wrapper src_d(pd()->src_pd());
47     const memory_desc_wrapper dst_d(pd()->dst_pd());
48
49     const int M = jcp.os * jcp.od;
50     const size_t src_step = (src_d.blk_off(1) - src_d.off_l(0)) / jcp.ngroups;
51     const size_t dst_step = (dst_d.blk_off(1) - dst_d.off_l(0)) / jcp.ngroups;
52     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
53     src += src_d.off_l(0);
54     dst += dst_d.off_l(0);
55
56     assert(IMPLICATION(
57             jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
58     assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
59
60     const int K = jcp.ic * jcp.ks;
61     const int N = jcp.oc;
62
63     if (jcp.im2col_sz && jcp.id != 1)
64         parallel_nd(jcp.im2col_sz * jcp.nthr,
65                 [&](ptrdiff_t i) { col[i] = (data_t)0; });
66
67     const int nb_oh = div_up(jcp.oh, jcp.oh_block);
68     const int nb_ow = div_up(jcp.ow, jcp.ow_block);
69     const size_t work_amount = jcp.ngroups * MB * jcp.od * nb_oh * nb_ow;
70     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
71         data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
72
73         int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
74         size_t start = 0, end = 0;
75
76         balance211(work_amount, nthr, ithr, start, end);
77         nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od, ohb,
78                 nb_oh, owb, nb_ow);
79         for (size_t iwork = start; iwork < end; ++iwork) {
80             int oh = ohb * jcp.oh_block;
81             int ow = owb * jcp.ow_block;
82             const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
83             const data_t *_weights = weights + g * weights_g_size;
84             data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
85             const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
86             const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
87             if (jcp.im2col_sz) {
88                 if (jcp.id == 1)
89                     jit_gemm_convolution_utils::im2col(
90                             jcp, _src, _col, oh, h_step, ow, w_step);
91                 else
92                     jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
93             }
94
95             const data_t one = 1.0;
96
97             const int m = h_step * w_step;
98             const int LDA = jcp.im2col_sz ? m : M;
99             data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
100
101             extended_sgemm("N", "N", &m, &N, &K, &one,
102                     jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
103                     &this->beta_, _dst, &M);
104
105             data_t *d = _dst;
106             const auto &p = pd()->attr()->post_ops_;
107             bool need_bias = jcp.with_bias;
108             if (use_fast_relu) {
109                 parallel_nd(jcp.oc, [&](const int oc) {
110                     data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
111                     data_t *d_ = d + oc * M;
112                     PRAGMA_OMP_SIMD()
113                     for (int oS = 0; oS < m; ++oS) {
114                         d_[oS] += b;
115                         if (d_[oS] < 0) d_[oS] *= fast_relu_ns;
116                     }
117                 });
118
119                 need_bias = false;
120             } else if (p.len_ > 0) {
121                 int eltwise_inj_idx = 0;
122                 int depthwise_inj_idx = 0;
123
124                 for (int i = 0; i < p.len_; i++) {
125                     auto& post_op = p.entry_[i];
126                     if (post_op.is_eltwise()) {
127                         parallel_nd(jcp.oc, [&](const int oc) {
128                             data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
129                             data_t *d_ = d + oc * M;
130                             PRAGMA_OMP_SIMD()
131                             for (int oS = 0; oS < m; ++oS) {
132                                 d_[oS] += b;
133                                 d_[oS] = eltwise_injectors[eltwise_inj_idx]->compute_scalar(d_[oS]);
134                             }
135                         });
136
137                         eltwise_inj_idx++;
138                         need_bias = false;
139                     } else if (post_op.is_depthwise()) {
140                         auto depthwise_weights = post_op.depthwise.weights_data;
141                         auto depthwise_bias = post_op.depthwise.biases_data;
142
143                         parallel_nd(jcp.oc, [&](const int oc) {
144                             data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
145                             data_t *d_ = d + oc * M;
146                             PRAGMA_OMP_SIMD()
147                             for (int oS = 0; oS < m; ++oS) {
148                                 d_[oS] += b;
149                                 d_[oS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[oS],
150                                                                   depthwise_weights + g * jcp.oc + oc,
151                                                                   depthwise_bias + g * jcp.oc + oc);
152                             }
153                         });
154
155                         depthwise_inj_idx++;
156                         need_bias = false;
157                     }
158                 }
159             }
160
161             if (need_bias) {
162                 parallel_nd(jcp.oc, [&](const int oc) {
163                     data_t b = bias[g * jcp.oc + oc];
164                     data_t *d_ = d + oc * M;
165                     PRAGMA_OMP_SIMD()
166                     for (int oS = 0; oS < m; ++oS) {
167                         d_[oS] += b;
168                     }
169                 });
170             }
171
172             nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od, ohb, nb_oh,
173                     owb, nb_ow);
174         }
175     });
176 }
177
178 void gemm_convolution_bwd_data_t::execute_backward_data() const {
179     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
180     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
181     auto diff_src = reinterpret_cast<data_t*>(this->memory());
182
183     auto col = scratchpad().get<data_t>(key_conv_gemm_col);
184
185     const auto &jcp = this->pd()->jcp_;
186     const int MB = pd()->MB();
187
188     const int M = jcp.os * jcp.od;
189     const size_t src_step_to_clean = jcp.ic * jcp.ih * jcp.iw * jcp.id;
190     const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
191     const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
192     const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups;
193     const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups;
194     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
195
196     const int m = jcp.os;
197     const int K = jcp.oc;
198     const int N = jcp.ic * jcp.ks;
199     const int LDC = jcp.im2col_sz ? m : M;
200
201     const size_t work_amount = (size_t)jcp.ngroups * MB;
202
203     if (jcp.id > 1) {
204         for (size_t j = 0; j < work_amount; j++) {
205             int j_step = src_step * j;
206             const ptrdiff_t diff_src_sz = (ptrdiff_t)(src_step_to_clean);
207             parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[j_step + i] = (data_t)0; });
208         }
209     }
210
211     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
212         data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
213
214         int g{0}, n{0};
215         size_t start = 0, end = 0;
216         balance211(work_amount, nthr, ithr, start, end);
217         nd_iterator_init(start, g, jcp.ngroups, n, MB);
218         for (size_t iwork = start; iwork < end; ++iwork) {
219
220             data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step;
221             const data_t *_weights = weights + g * weights_g_size;
222             for (int od = 0; od < jcp.od; ++od) {
223                 const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
224                     *dst_step + od * m;
225
226                 const data_t zero = 0.0, one = 1.0;
227                 extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M,
228                     _weights, &N, &zero,
229                     jcp.im2col_sz ? _col:_diff_src + od * m, &LDC);
230
231                 if (jcp.im2col_sz) {
232                     if (jcp.id == 1)
233                         jit_gemm_convolution_utils::col2im(jcp, _col,
234                             _diff_src);
235                     else
236                         jit_gemm_convolution_utils::col2im_3d(jcp, _col,
237                             _diff_src, od);
238                 }
239             }
240             nd_iterator_step(g, jcp.ngroups, n, MB);
241         }
242     });
243 }
244
245 void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
246     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
247     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
248     auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
249     auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
250
251     auto col = scratchpad().get<data_t>(key_conv_gemm_col);
252     auto wei_reduction = scratchpad().get<data_t>(key_conv_wei_reduction);
253
254     const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
255
256     const int K = jcp.os * jcp.od;
257     const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
258     const size_t dst_step = jcp.oc * K;
259     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
260
261     const int k = jcp.os;
262     const int N = jcp.oc;
263     const int M = jcp.ic * jcp.ks;
264     const int LDA = jcp.im2col_sz ? k : K;
265
266     parallel_nd(jcp.im2col_sz * jcp.nthr,
267             [&](ptrdiff_t i) { col[i] = (data_t)0; });
268
269     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
270         int ithr_g, nthr_g, ithr_mb, nthr_mb;
271         size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0};
272
273         const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
274         jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
275                 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
276
277         assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
278         const int need_reduction = nthr_mb != 1;
279
280         if (ithr_g != -1 && ithr_mb != -1) {
281             balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
282             balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
283
284             assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
285
286             data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
287             data_t *weights_reduce_base = wei_reduction
288                     + ithr_g * nthr_mb * weights_g_size;
289             data_t *weights_reduce = weights_reduce_base
290                     + ithr_mb * weights_g_size;
291
292             for (size_t g = g_start; g < g_end; ++g) {
293                 data_t *_diff_weights = need_reduction
294                         ? weights_reduce : (diff_weights + g * weights_g_size);
295                 for (size_t mb = mb_start; mb < mb_end; ++mb) {
296                     const data_t *_src = src + (mb*jcp.ngroups+g)*src_step;
297                     for (int od = 0; od < jcp.od; ++od) {
298                     const data_t *_diff_dst = diff_dst
299                             + (mb*jcp.ngroups+g)*dst_step + od * k;
300
301                     if (jcp.im2col_sz) {
302                         if (jcp.id == 1)
303                             jit_gemm_convolution_utils::im2col(
304                                     jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
305                         else
306                             jit_gemm_convolution_utils::im2col_3d(jcp, _src,
307                                 _col, od);
308                     }
309
310                     const data_t zero = 0.0, one = 1.0;
311                     extended_sgemm(
312                         "T", "N", &M, &N, &k, &one,
313                         jcp.im2col_sz ? _col : _src + od * k,
314                         &LDA, _diff_dst, &K,
315                         mb == mb_start && od == 0 ? &zero : &one,
316                         _diff_weights, &M);
317                     }
318                 }
319             }
320             if (need_reduction) {
321                 mkldnn_thr_barrier();
322                 data_t *weights_base = diff_weights + g_start * weights_g_size;
323                 jit_gemm_convolution_utils::bwd_weights_reduction_par(
324                     ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base);
325             }
326         } else
327             if (need_reduction) { mkldnn_thr_barrier(); }
328     });
329
330     if (jcp.with_bias) {
331         parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
332             data_t db = 0;
333             size_t offset_ = (size_t)g * dst_step + (size_t)oc * K;
334             for (int mb = 0; mb < jcp.mb; ++mb)
335             {
336                 size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step;
337                 for (int od = 0; od < jcp.od; ++od)
338                 for (int oh = 0; oh < jcp.oh; ++oh)
339                 PRAGMA_OMP_SIMD(reduction(+:db))
340                 for (int ow = 0; ow < jcp.ow; ++ow) {
341                     db += diff_dst[offset];
342                     offset++;
343                 }
344             }
345             diff_bias[g*jcp.oc+oc] = db;
346         });
347     }
348 }
349
350 }
351 }
352 }