Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include "mkldnn_types.h"
19
20 #include "c_types_map.hpp"
21 #include "jit_sse42_convolution.hpp"
22 #include "mkldnn_thread.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
32
33 #define src_blk_off(f, n, c, h, w) \
34     (pd()->ndims() == 3) \
35     ? (f).blk_off(n, c, w) \
36     : (f).blk_off(n, c, h, w)
37
38 #define wht_blk_off_(f, g, ...) \
39     pd()->with_groups() \
40     ? (f).blk_off(g, __VA_ARGS__) \
41     : (f).blk_off(__VA_ARGS__)
42 #define wht_blk_off(f, g, oc, ic, kh, kw) \
43         pd()->ndims() == 3 \
44         ? wht_blk_off_(f, g, oc, ic, kw) \
45         : wht_blk_off_(f, g, oc, ic, kh, kw)
46
47 void jit_sse42_convolution_fwd_t::execute_forward() const {
48     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
49     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
50     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
51     auto dst = reinterpret_cast<data_t *>(this->memory());
52
53     const memory_desc_wrapper src_d(pd()->src_pd());
54     const memory_desc_wrapper dst_d(pd()->dst_pd());
55     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
56     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
57
58     const auto &jcp = kernel_->jcp;
59     int MB = pd()->MB();
60
61     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
62     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
63
64     if (pd()->wants_padded_bias()) {
65         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
66         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
67         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
68                 jcp.oc - jcp.oc_without_padding);
69         bias = padded_bias;
70     }
71
72     parallel(0, [&](const int ithr, const int nthr) {
73         size_t start{ 0 }, end{ 0 };
74         balance211(work_amount, nthr, ithr, start, end);
75
76         int icbb = 0;
77         while (icbb < jcp.nb_ic) {
78             int icb_step = jcp.nb_ic_blocking;
79             int icb_step_rem = jcp.nb_ic - icbb;
80             if (icb_step_rem < jcp.nb_ic_blocking_max)
81                 icb_step = icb_step_rem;
82
83             size_t n{0}, g{0}, ocbb{0}, oh{0};
84             nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
85                              oh, jcp.oh);
86             for (size_t iwork = start; iwork < end; ++iwork) {
87                 int ocb = ocbb * jcp.nb_oc_blocking;
88                 int ocb_num = jcp.nb_oc_blocking;
89
90                 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
91                     auto par_conv = jit_conv_call_s();
92
93                     const int ij = oh * jcp.stride_h;
94                     const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
95                     const int i_b_overflow = nstl::max(jcp.ih, ij
96                         + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
97
98                     const size_t _oc = g * jcp.nb_oc + ocb;
99                     const size_t _ic = g * jcp.nb_ic + icb;
100
101                     const int ih = nstl::max(ij - jcp.t_pad
102                         + div_up(i_t_overflow,
103                                  (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
104                     par_conv.src = &src[src_blk_off(src_d, n,
105                         jcp.ic == 3 ? 0 : _ic, ih, 0)];
106
107                     par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)];
108
109                     const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
110                     par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
111                         jcp.ic == 3 ? 0 : icb, wh, 0)];
112
113                     if (icb == 0) {
114                         if (bias)
115                             par_conv.bias =
116                                     &bias[bias_d.blk_off(_oc * jcp.oc_block)];
117                         par_conv.flags |= FLAG_IC_FIRST;
118                     }
119
120                     if (icb + 1 == jcp.nb_ic) {
121                         par_conv.flags |= FLAG_IC_LAST;
122                     }
123
124                     par_conv.oc_blocks =
125                             nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
126
127                     par_conv.kw_padding = 0;
128                     const int kh_padding = jcp.kh
129                         - div_up(i_t_overflow, (jcp.dilate_h + 1))
130                         - div_up(i_b_overflow, (jcp.dilate_h + 1));
131                     par_conv.kh_padding = nstl::max(0, kh_padding);
132
133                     par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
134
135                     kernel_->jit_ker(&par_conv);
136                 }
137                 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
138                                  oh, jcp.oh);
139             }
140             icbb += icb_step;
141         }
142     });
143
144     if (pd()->wants_zero_pad_dst())
145         output_memory_primitive(0)->zero_pad();
146 }
147
148 void jit_sse42_convolution_fwd_t::execute_forward_with_dw_conv() const {
149     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
150     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
151     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
152     auto dst = reinterpret_cast<data_t *>(this->memory());
153
154     const memory_desc_wrapper src_d(pd()->src_pd());
155     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
156     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
157
158     const auto &jcp = kernel_->jcp;
159     const auto &jcp_dw = kernel_dw_->jcp;
160     int MB = pd()->MB();
161
162     auto dw_bias = jcp_dw.conv_biases;
163
164     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
165     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
166
167     auto ker = [&](const int ithr, const int nthr) {
168         auto compute_row_gen = [&](float* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
169             for (int h = 0; h < num_rows; h++) {
170                 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
171                     for (int chb = ocb; chb < ocb + ocb_num; chb++) {
172                         memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
173                                (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
174                     }
175                 } else {
176                     for (int icb = 0; icb < jcp.nb_ic; ++icb) {
177                         auto par_conv = jit_conv_call_s();
178
179                         const int ij = (oh + h) * jcp.stride_h;
180                         const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
181                         const int i_b_overflow = nstl::max(jcp.ih, ij
182                                                                    + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad +
183                                                                    1) - jcp.ih;
184
185                         const size_t _oc = g * jcp.nb_oc + ocb;
186                         const size_t _ic = g * jcp.nb_ic + icb;
187
188                         const int ih = nstl::max(ij - jcp.t_pad
189                                                  + div_up(i_t_overflow,
190                                                           (jcp.dilate_h + 1)) * (jcp.dilate_h + 1), 0);
191                         par_conv.src = &src[src_d.blk_off(n,
192                                                           jcp.ic == 3 ? 0 : _ic, ih, 0)];
193
194                         par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow *
195                                              jcp.oc_block];
196
197                         const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
198                         par_conv.filt = &weights[pd()->with_groups()
199                                                  ? weights_d.blk_off(g, ocb,
200                                                                      jcp.ic == 3 ? 0 : icb, wh, 0)
201                                                  : weights_d.blk_off(ocb,
202                                                                      jcp.ic == 3 ? 0 : icb, wh, 0)];
203
204                         if (icb == 0) {
205                             if (bias)
206                                 par_conv.bias =
207                                         &bias[bias_d.blk_off(_oc * jcp.oc_block)];
208                             par_conv.flags |= FLAG_IC_FIRST;
209                         }
210
211                         if (icb + 1 == jcp.nb_ic) {
212                             par_conv.flags |= FLAG_IC_LAST;
213                         }
214
215                         par_conv.oc_blocks =
216                                 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
217
218                         par_conv.kw_padding = 0;
219                         const int kh_padding = jcp.kh
220                                                - div_up(i_t_overflow, (jcp.dilate_h + 1))
221                                                - div_up(i_b_overflow, (jcp.dilate_h + 1));
222                         par_conv.kh_padding = nstl::max(0, kh_padding);
223
224                         par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
225
226                         kernel_->jit_ker(&par_conv);
227                     }
228                 }
229             }
230         };
231
232         auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int ocb_num,
233                                   int dst_idx) {
234             for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
235                 auto par_conv_dw = jit_conv_call_s();
236
237                 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
238                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
239                 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
240                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
241                 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
242                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
243
244                 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
245                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
246
247                 par_conv_dw.kh_padding = jcp_dw.kh;
248                 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
249                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
250                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
251                 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
252                 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
253
254                 kernel_dw_->jit_ker(&par_conv_dw);
255             }
256         };
257
258         size_t start{0}, end{0};
259         balance211(work_amount, nthr, ithr, start, end);
260
261         auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
262         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
263         auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
264
265         size_t n{0}, g{0}, ocbb{0}, oh{0};
266         nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
267                          oh, jcp.oh);
268         for (size_t iwork = start; iwork < end; ++iwork) {
269             int ocb = ocbb * jcp.nb_oc_blocking;
270             int ocb_num = jcp.nb_oc_blocking;
271
272             if (iwork == start || oh == 0) {
273                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
274             } else {
275                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh, 1);
276             }
277
278             if (iwork > start && ((oh - 1) % jcp_dw.stride_h == 0) && oh > 0) {
279                 compute_row_dw(pbuf, n, ocb, ocb_num, oh - 1);
280             }
281
282             if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw.stride_h == 0)) {
283                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
284                 compute_row_dw(pbuf, n, ocb, ocb_num, oh);
285             }
286
287             nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
288                              oh, jcp.oh);
289         }
290     };
291
292     if (pd()->wants_padded_bias()) {
293         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
294         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
295         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
296                 jcp.oc - jcp.oc_without_padding);
297         bias = padded_bias;
298
299         auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
300         utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
301         utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
302                          jcp.oc - jcp.oc_without_padding);
303         dw_bias = dw_padded_bias;
304     }
305
306     parallel(0, ker);
307
308     if (pd()->wants_zero_pad_dst())
309         output_memory_primitive(0)->zero_pad();
310 }
311
312 }
313 }
314 }