Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_convolution.hpp"
20 #include "utils.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
23 #include <cstring>
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 using namespace mkldnn::impl::status;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::memory_tracking::names;
32 using namespace mkldnn::impl::utils;
33
34 template <cpu_isa_t isa, impl::data_type_t src_type, data_type_t dst_type>
35 void _jit_uni_x8s8s32x_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
36     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
40
41     const memory_desc_wrapper src_d(pd()->src_pd());
42     const memory_desc_wrapper dst_d(pd()->dst_pd());
43     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
45
46     const auto &jcp = kernel_->jcp;
47
48     size_t offset = (size_t)jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw;
49     auto w = const_cast<wei_data_t *>(weights);
50     int32_t* compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
51
52     if (bias && jcp.oc != jcp.oc_padded) {
53         auto padded_bias = this->scratchpad().template get<bia_data_t>(key_conv_padded_bias);
54         utils::array_copy(padded_bias, (bia_data_t*)bias, jcp.oc);
55         utils::array_set(padded_bias + jcp.oc, 0, jcp.oc_padded - jcp.oc);
56         bias = (char *)padded_bias;
57     }
58
59     const float *oscales = pd()->attr()->output_scales_.scales_;
60     if (jcp.signed_input) {
61         auto local_scales = scratchpad().template get<float>(key_conv_adjusted_scales);
62         size_t count = pd()->attr()->output_scales_.count_;
63         float factor = 1.f / jcp.wei_adj_scale;
64         if (count == 1) {
65             utils::array_set(local_scales, oscales[0] * factor, 8);
66         } else {
67             for (size_t c = 0; c < count; c++)
68                 local_scales[c] = oscales[c] * factor;
69         }
70         oscales = local_scales;
71
72         if (jcp.oc != jcp.oc_padded) {
73             auto padded_compensation = this->scratchpad().template get<int32_t>(key_conv_padded_compensation);
74             utils::array_copy(padded_compensation, compensation, jcp.oc);
75             utils::array_set(padded_compensation + jcp.oc, 0, jcp.oc_padded - jcp.oc);
76             compensation = padded_compensation;
77         }
78     }
79
80     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
81     const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh;
82
83     auto ker = [&](const int ithr, const int nthr) {
84         size_t start{0}, end{0};
85         balance211(work_amount, nthr, ithr, start, end);
86
87         size_t n{0}, g{0}, ocbb{0}, oh{0};
88         nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
89                          oh, jcp.oh);
90         for (size_t iwork = start; iwork < end; ++iwork) {
91             int ocb = ocbb * jcp.nb_oc_blocking;
92             int ocb_num = jcp.nb_oc_blocking;
93
94             auto par_conv = jit_conv_call_s();
95
96             const int ij = oh * jcp.stride_h;
97             const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
98             const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
99                                                jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
100
101             const size_t _oc = g * jcp.nb_oc + ocb;
102             const size_t _ic = g * jcp.nb_ic;
103
104             const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
105             par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0)];
106
107             size_t dst_off = dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0);
108             par_conv.dst = &dst[dst_off];
109
110             const int wh = (!jcp.signed_input) ? i_t_overflow : 0;
111             par_conv.filt = &weights[pd()->with_groups()
112                                 ? weights_d.blk_off(g, ocb, 0, wh, 0)
113                                 : weights_d.blk_off(ocb, 0, wh, 0)];
114
115             if (bias)
116                 par_conv.bias = &bias[bias_d.blk_off(_oc * jcp.oc_block*jcp.typesize_bia)];
117
118             par_conv.oc_work =
119                     nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
120
121             par_conv.kw_padding = 0;
122             const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
123             par_conv.kh_padding = nstl::max(0, kh_padding);
124
125             par_conv.scales = &oscales[jcp.is_oc_scale * _oc * jcp.oc_block];
126
127             par_conv.compensation = (jcp.signed_input) ? compensation + _oc * jcp.oc_block : 0;
128             par_conv.t_overflow = i_t_overflow;
129             par_conv.b_overflow = i_b_overflow;
130
131             par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
132
133             kernel_->jit_ker(&par_conv);
134             nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
135         }
136     };
137
138     parallel(0, ker);
139 }
140
141 template <cpu_isa_t isa, impl::data_type_t src_type, data_type_t dst_type>
142 void _jit_uni_x8s8s32x_convolution_fwd_t<isa, src_type, dst_type>::execute_forward_with_dw_conv() const {
143     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
144     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
145     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
146     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
147
148     const memory_desc_wrapper src_d(pd()->src_pd());
149     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
150     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
151
152     const auto &jcp = kernel_->jcp;
153     const auto &jcp_dw = kernel_dw_->jcp;
154     const int MB = pd()->MB();
155
156     size_t offset = (size_t)jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw;
157     auto w = const_cast<wei_data_t *>(weights);
158     int32_t* compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
159
160     auto dw_bias = jcp_dw.conv_biases;
161     auto dw_weights = reinterpret_cast<const wei_data_t *>(jcp_dw.conv_weights);
162
163     if (jcp.oc != jcp.oc_padded) {
164         auto padded_bias = this->scratchpad().template get<bia_data_t>(key_conv_padded_bias);
165         utils::array_copy(padded_bias, (bia_data_t*)bias, jcp.oc);
166         utils::array_set(padded_bias + jcp.oc, 0, jcp.oc_padded - jcp.oc);
167         bias = (char *)padded_bias;
168
169         auto dw_padded_bias = this->scratchpad().template get<bia_data_t>(key_dw_conv_padded_bias);
170         utils::array_copy(dw_padded_bias, dw_bias, jcp.oc);
171         utils::array_set(dw_padded_bias + jcp.oc, 0.f, jcp.oc_padded - jcp.oc);
172         dw_bias = dw_padded_bias;
173     }
174
175     const float *oscales = pd()->attr()->output_scales_.scales_;
176     if (jcp.signed_input) {
177         auto local_scales = scratchpad().template get<float>(key_conv_adjusted_scales);
178         size_t count = pd()->attr()->output_scales_.count_;
179         float factor = 1.f / jcp.wei_adj_scale;
180         if (count == 1) {
181             utils::array_set(local_scales, oscales[0] * factor, 8);
182         } else {
183             for (size_t c = 0; c < count; c++)
184                 local_scales[c] = oscales[c] * factor;
185         }
186         oscales = local_scales;
187
188         if (jcp.oc != jcp.oc_padded) {
189             auto padded_compensation = this->scratchpad().template get<int32_t>(key_conv_padded_compensation);
190             utils::array_copy(padded_compensation, compensation, jcp.oc);
191             utils::array_set(padded_compensation + jcp.oc, 0, jcp.oc_padded - jcp.oc);
192             compensation = padded_compensation;
193         }
194     }
195
196     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
197     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
198
199     auto ker = [&](const int ithr, const int nthr) {
200         auto compute_row_gen = [&](dst_data_t* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
201             for (int h = 0; h < num_rows; h++) {
202                 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
203                     for (int chb = ocb; chb < ocb + ocb_num; chb++) {
204                         memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
205                                (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(dst_data_t));
206                     }
207                 } else {
208                     auto par_conv = jit_conv_call_s();
209
210                     const int ij = (oh + h) * jcp.stride_h;
211                     const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
212                     const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
213                                                        jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
214
215                     const size_t _oc = g * jcp.nb_oc + ocb;
216                     const size_t _ic = g * jcp.nb_ic;
217
218                     const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
219                     par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0)];
220
221                     par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
222
223                     const int wh = (!jcp.signed_input) ? i_t_overflow : 0;
224                     par_conv.filt = &weights[pd()->with_groups()
225                                         ? weights_d.blk_off(g, ocb, 0, wh, 0)
226                                         : weights_d.blk_off(ocb, 0, wh, 0)];
227
228                     if (bias)
229                         par_conv.bias = &bias[bias_d.blk_off(_oc * jcp.oc_block*jcp.typesize_bia)];
230
231                     par_conv.oc_work =
232                             nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
233
234                     par_conv.kw_padding = 0;
235                     const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
236                     par_conv.kh_padding = nstl::max(0, kh_padding);
237
238                     par_conv.scales = &oscales[jcp.is_oc_scale * _oc * jcp.oc_block];
239                     par_conv.compensation = (jcp.signed_input) ? compensation + _oc * jcp.oc_block : 0;
240                     par_conv.t_overflow = i_t_overflow;
241                     par_conv.b_overflow = i_b_overflow;
242
243                     par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
244
245                     kernel_->jit_ker(&par_conv);
246                 }
247             }
248         };
249
250         auto compute_row_dw = [&](const dst_data_t* ws_p, int n, int ocb, int ocb_num, int dst_idx) {
251             for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
252                 auto par_conv_dw = jit_conv_call_s();
253
254                 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
255                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
256                 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
257                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
258                 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
259                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
260
261                 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.oc + chb*jcp_dw.ch_block];
262
263                 par_conv_dw.kh_padding = jcp_dw.kh;
264                 par_conv_dw.filt = &dw_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
265                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
266                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
267                 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
268                 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
269
270                 kernel_dw_->jit_ker(&par_conv_dw);
271             }
272         };
273
274         size_t start{0}, end{0};
275         balance211(work_amount, nthr, ithr, start, end);
276
277         auto dw_conv_buffer = scratchpad().template get<dst_data_t>(key_dw_conv_buffer);
278         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
279         auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
280
281         size_t n{0}, g{0}, ocbb{0}, oh{0};
282         nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
283         for (size_t iwork = start; iwork < end; ++iwork) {
284             int ocb = ocbb * jcp.nb_oc_blocking;
285             int ocb_num = jcp.nb_oc_blocking;
286
287             if (iwork == start || oh == 0) {
288                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
289             } else {
290                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh, 1);
291             }
292
293             if (iwork > start && ((oh - 1) % jcp_dw.stride_h == 0) && oh > 0) {
294                 compute_row_dw(pbuf, n, ocb, ocb_num, oh - 1);
295             }
296
297             if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw.stride_h == 0)) {
298                 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
299                 compute_row_dw(pbuf, n, ocb, ocb_num, oh);
300             }
301
302             nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
303         }
304     };
305
306     parallel(0, ker);
307 }
308
309 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
310 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
311 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
312 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
313
314 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
315 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
316 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
317 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
318
319 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
320 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
321 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
322 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
323
324 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
325 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
326 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
327 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;
328
329 }
330 }
331 }