updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "jit_avx512_core_x8s8s32x_convolution.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
32
33 using namespace nstl;
34
35 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
36
37 #define wht_blk_off(d, g, ...) \
38         (pd()->with_groups() \
39          ? (d).blk_off((g), __VA_ARGS__) \
40          : (d).blk_off(__VA_ARGS__))
41
42 template <data_type_t src_type, data_type_t dst_type>
43 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
44         dst_type>::execute_forward_1d() const {
45     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
46     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
47     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
48     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
49
50     const memory_desc_wrapper src_d(pd()->src_pd());
51     const memory_desc_wrapper dst_d(pd()->dst_pd());
52     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
53     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
54
55     const size_t bia_dt_size = pd()->with_bias()
56         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
57
58     const auto &jcp = pd()->jcp_;
59     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
60     assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
61
62     const float *oscales = pd()->attr()->output_scales_.scales_;
63     if (jcp.signed_input && jcp.ver != ver_vnni) {
64         auto local_scales = scratchpad().template get<float>(
65                 key_conv_adjusted_scales);
66         size_t count = pd()->attr()->output_scales_.count_;
67         float factor = 1.f / pd()->jcp_.wei_adj_scale;
68         if (count == 1) {
69             utils::array_set(local_scales, oscales[0] * factor, 16);
70         } else {
71             for (size_t c = 0; c < count; c++)
72                 local_scales[c] = oscales[c] * factor;
73         }
74         oscales = local_scales;
75     }
76
77     size_t offset = weights_d.size() - weights_d.additional_buffer_size();
78     auto w = const_cast<wei_data_t *>(weights);
79     int32_t* compensation = (jcp.signed_input)
80                                 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
81     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
82     int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
83     int group_block = jcp.ch_block;
84     int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
85
86     parallel(0, [&](const int ithr, const int nthr) {
87
88         int start{ 0 }, end{ 0 };
89         balance211(work_amount, nthr, ithr, start, end);
90
91         auto p = jit_conv_call_s();
92
93         int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 };
94         switch (jcp.loop_order) {
95         case loop_cwgn:
96             nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
97                     nb_groups, n, jcp.mb);
98             break;
99         case loop_gncw:
100             nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
101                     owb, jcp.nb_ow);
102             break;
103         case loop_ngcw:
104             nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
105                     owb, jcp.nb_ow);
106             break;
107         case loop_nwcg:
108             nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
109                     gg, nb_groups);
110             break;
111         default: assert(!"unsupported loop order");
112         }
113         while (start < end) {
114             int ocb = occ * jcp.nb_oc_blocking;
115             int gb = gg * jcp.nb_ch_blocking;
116             int g = gb * group_block;
117             int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
118             int g_ic = g * jcp.nb_ic * jcp.ic_block;
119             int ow_s = owb * jcp.ow_block;
120             int iw_s = ow_s * jcp.stride_w;
121
122             p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0;
123             p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
124             p.dst = dst + dst_d.blk_off(n, g_oc, ow_s);
125             p.src = src + src_d.blk_off(n, g_ic, iw_s);
126             p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0);
127             p.scales = &oscales[jcp.is_oc_scale * g_oc];
128             p.oc_blocks = jcp.is_depthwise ? gb : ocb;
129             p.kh_padding = jcp.kh;
130             p.t_overflow = 0;
131             p.b_overflow = 0;
132             p.owb = owb;
133
134             kernel_->jit_ker(&p);
135
136             ++start;
137             switch (jcp.loop_order) {
138             case loop_cwgn:
139                 nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups,
140                         n, jcp.mb);
141                 break;
142             case loop_gncw:
143                 nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb,
144                         jcp.nb_ow);
145                 break;
146             case loop_ngcw:
147                 nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb,
148                         jcp.nb_ow);
149                 break;
150             case loop_nwcg:
151                 nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg,
152                         nb_groups);
153                 break;
154             default: assert(!"unsupported loop order");
155             }
156         }
157     });
158 }
159
160 template <data_type_t src_type, data_type_t dst_type>
161 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
162         dst_type>::execute_forward_2d() const {
163     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
164     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
165     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
166     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
167
168     const memory_desc_wrapper src_d(pd()->src_pd());
169     const memory_desc_wrapper dst_d(pd()->dst_pd());
170     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
171     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
172
173     const size_t bia_dt_size = pd()->with_bias()
174         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
175
176     const auto &jcp = pd()->jcp_;
177     assert(jcp.ch_block == 1);
178     assert(jcp.nb_ch_blocking == 1);
179     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
180     assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
181
182     const float *oscales = pd()->attr()->output_scales_.scales_;
183     if (jcp.signed_input && jcp.ver != ver_vnni) {
184         auto local_scales = scratchpad().template get<float>(
185                 key_conv_adjusted_scales);
186         size_t count = pd()->attr()->output_scales_.count_;
187         float factor = 1.f / pd()->jcp_.wei_adj_scale;
188         if (count == 1) {
189             utils::array_set(local_scales, oscales[0] * factor, 16);
190         } else {
191             for (size_t c = 0; c < count; c++)
192                 local_scales[c] = oscales[c] * factor;
193         }
194         oscales = local_scales;
195     }
196
197     size_t offset = weights_d.size() - weights_d.additional_buffer_size();
198     auto w = const_cast<wei_data_t *>(weights);
199     int32_t* compensation = (jcp.signed_input)
200                                 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
201     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
202     int nb_groups = jcp.nb_ch;
203     int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
204
205     parallel(0, (size_t)work_amount, [&](const int ithr, const int nthr) {
206
207         int start{0}, end{0};
208         balance211(work_amount, nthr, ithr, start, end);
209
210         auto p = jit_conv_call_s();
211
212         size_t src_h_stride = src_d.blk_off(0, 0, 1);
213         size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
214         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
215
216         int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
217         switch (jcp.loop_order) {
218         case loop_cwgn:
219             nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
220                     nb_groups, n, jcp.mb, oh_s, jcp.oh);
221             break;
222         case loop_ngcw:
223             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
224                     owb, jcp.nb_ow, oh_s, jcp.oh);
225             break;
226         case loop_nhwcg:
227             nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
228                     occ, oc_chunks, g, nb_groups);
229             break;
230         default: assert(!"unsupported loop order");
231         }
232         while (start < end) {
233             for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
234                 occ1 += jcp.nb_oc_blocking) {
235                 int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
236                 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
237
238                 int g_ic = g * jcp.nb_ic * jcp.ic_block;
239
240                 int work_rem = end - start;
241                 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
242                 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
243                 if (jcp.loop_order == loop_nhwcg)
244                     oh_e = oh_s + 1; // step instead
245                 int ow_s = owb * jcp.ow_block;
246                 int iw_s = ow_s * jcp.stride_w;
247
248                 auto bias_w = bias
249                     ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
250                     : 0;
251                 int32_t *compensation_w = (jcp.signed_input)
252                                           ? compensation + g_oc : 0;
253
254                 auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s);
255                 auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
256                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
257
258                 auto scales = &oscales[jcp.is_oc_scale * g_oc];
259
260                 for (int oj = oh_s, ij = ih_s; oj < oh_e;
261                     ++oj, ij += jcp.stride_h) {
262                     int dilate_h = jcp.dilate_h + 1;
263                     int i_t_overflow = nstl::min(jcp.kh,
264                                                 div_up(max(0, -ij), dilate_h));
265                     int i_b_overflow = nstl::min(jcp.kh, div_up(
266                             max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
267                             dilate_h));
268                     int kh_padding = nstl::max(0,
269                         jcp.kh - i_t_overflow - i_b_overflow);
270
271                     size_t wei_stride = (!jcp.signed_input)
272                                             ? i_t_overflow * wht_h_stride : 0;
273                     p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
274                     p.dst = dst_w;
275                     p.filt = wht_w + wei_stride;
276                     p.bias = bias_w;
277                     p.compensation = compensation_w;
278                     p.oc_blocks = ocb;
279                     p.kh_padding = kh_padding;
280                     p.scales = scales;
281                     p.t_overflow = i_t_overflow;
282                     p.b_overflow = i_b_overflow;
283                     p.owb = owb;
284
285                     p.oc_off = g_oc * sizeof(float);
286
287                     kernel_->jit_ker(&p);
288                     src_w += src_h_stride * jcp.stride_h;
289                     dst_w += dst_h_stride;
290                 }
291             }
292             switch (jcp.loop_order) {
293             case loop_cwgn:
294                 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g,
295                         nb_groups, n, jcp.mb, oh_s, jcp.oh);
296                 break;
297             case loop_ngcw:
298                 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
299                         oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
300                 break;
301             case loop_nhwcg:
302                 ++start;
303                 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
304                         oc_chunks, g, nb_groups);
305                 break;
306             default: assert(!"unsupported loop order");
307             }
308         }
309     });
310 }
311
312 template <data_type_t src_type, data_type_t dst_type>
313 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
314         dst_type>::execute_forward_2d_dw() const {
315     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
316     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
317     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
318     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
319
320     const memory_desc_wrapper src_d(pd()->src_pd());
321     const memory_desc_wrapper dst_d(pd()->dst_pd());
322     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
323     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
324
325     const size_t bia_dt_size = pd()->with_bias()
326             ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
327
328     const auto &jcp = pd()->jcp_;
329     assert(jcp.ic_block == 1);
330     assert(jcp.oc_block == 1);
331     assert(jcp.nb_ic == 1);
332     assert(jcp.nb_oc == 1);
333     assert(jcp.nb_oc_blocking == 1);
334     assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
335
336     const float *oscales = pd()->attr()->output_scales_.scales_;
337     if (jcp.signed_input && jcp.ver != ver_vnni) {
338         auto local_scales = scratchpad().template get<float>(
339                 key_conv_adjusted_scales);
340         size_t count = pd()->attr()->output_scales_.count_;
341         float factor = 1.f / pd()->jcp_.wei_adj_scale;
342         if (count == 1) {
343             utils::array_set(local_scales, oscales[0] * factor, 16);
344         } else {
345             for (size_t c = 0; c < count; c++)
346                 local_scales[c] = oscales[c] * factor;
347         }
348         oscales = local_scales;
349     }
350
351     size_t offset = weights_d.size() - weights_d.additional_buffer_size();
352     auto w = const_cast<wei_data_t *>(weights);
353     int32_t* compensation = (jcp.signed_input)
354                                 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
355     int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
356     int group_block = jcp.ch_block;
357
358     parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, [&](int n, int oh_s, int owb, int gg) {
359
360         auto p = jit_conv_call_s();
361
362         size_t src_h_stride = src_d.blk_off(0, 0, 1);
363         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
364
365         int gb = gg * jcp.nb_ch_blocking;
366         int g = gb * group_block;
367
368         int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
369         int ow_s = owb * jcp.ow_block;
370         int iw_s = ow_s * jcp.stride_w;
371
372         auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0;
373         int32_t *compensation_w = jcp.signed_input ? compensation + g : 0;
374
375         auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s);
376         auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s);
377         auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
378
379         auto scales = &oscales[jcp.is_oc_scale * g];
380
381         int dilate_h = jcp.dilate_h + 1;
382         int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
383         int i_b_overflow = nstl::min(jcp.kh,
384                 div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
385                                              dilate_h));
386         int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
387
388         size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride;
389         p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
390         p.dst = dst_w;
391         p.filt = wht_w + wei_stride;
392         p.bias = bias_w;
393         p.compensation = compensation_w;
394         p.oc_blocks = gb;
395         p.kh_padding = kh_padding;
396         p.scales = scales;
397         p.t_overflow = i_t_overflow;
398         p.b_overflow = i_b_overflow;
399         p.owb = owb;
400
401         kernel_->jit_ker(&p);
402     });
403 }
404
405 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
406                                                 data_type::s8, data_type::u8>;
407 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
408                                                 data_type::u8, data_type::u8>;
409 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
410                                                 data_type::s8, data_type::s8>;
411 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
412                                                 data_type::u8, data_type::s8>;
413 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
414                                                 data_type::s8, data_type::s32>;
415 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
416                                                 data_type::u8, data_type::s32>;
417 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
418                                                 data_type::s8, data_type::f32>;
419 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
420                                                 data_type::u8, data_type::f32>;
421 }
422 }
423 }
424
425 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s