updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "jit_generator.hpp"
23
24 #include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 namespace {
36 template <typename T, typename U>
37 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
38     T nx, T &nx_start, T &nx_end, T nx_divider)
39 {
40     const T grp_size = utils::div_up(nthr, nx_divider);
41     const T grp_count = utils::div_up(nthr, grp_size);
42
43     T grp = ithr / grp_size;
44     T grp_ithr = ithr % grp_size;
45     T grp_nthr = grp_size;
46     T first_grps = nthr % grp_count;
47     if (first_grps > 0 && grp >= first_grps) {
48         ithr -= first_grps * grp_size;
49         grp_nthr--;
50         grp = ithr / grp_nthr + first_grps;
51         grp_ithr = ithr % grp_nthr;
52     }
53     balance211(nx, grp_count, grp, nx_start, nx_end);
54     balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
55 }
56 }
57
58 /* convolution forward */
59 template <data_type_t src_type, data_type_t dst_type>
60 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
61                               <src_type, dst_type>::execute_forward() const
62 {
63     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
64     auto weights =
65         reinterpret_cast<const wei_data_t *>(this->input_memory(1));
66     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
67     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
68
69     auto scratchpad = this->scratchpad();
70
71     if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
72         auto local_scales = scratchpad.template get<float>(
73                 key_conv_adjusted_scales);
74         auto scales = pd()->attr()->output_scales_.scales_;
75         size_t count = pd()->attr()->output_scales_.count_;
76         float factor = 1.f / pd()->jcp_.wei_adj_scale;
77         if (count == 1) {
78             utils::array_set(local_scales, scales[0] * factor, 16);
79         } else {
80             for (size_t c = 0; c < count; c++)
81                 local_scales[c] = scales[c] * factor;
82         }
83     }
84
85     const int work_amount = pd()->jcp_.mb * pd()->jcp_.ngroups * pd()->jcp_.nb_bcast * pd()->jcp_.nb_load;
86
87     parallel(kernel_->jcp.nthr, (size_t)work_amount, [&](const int ithr, const int nthr) {
88         execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
89     });
90 }
91
92 template <data_type_t src_type, data_type_t dst_type>
93 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
94 ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
95         const wei_data_t *weights, const char *bias, dst_data_t *dst,
96         const memory_tracking::grantor_t &scratchpad) const {
97     const memory_desc_wrapper src_d(pd()->src_pd());
98     const memory_desc_wrapper dst_d(pd()->dst_pd());
99     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
100
101     const size_t bia_dt_size = pd()->with_bias()
102         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
103
104     const auto &jcp = kernel_->jcp;
105     auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
106     auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
107
108     const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
109
110     const int stride_h = pd()->desc()->strides[0];
111     const int stride_w = pd()->desc()->strides[1];
112     const int pad_t = pd()->desc()->padding[0][0];
113     const int pad_l = pd()->desc()->padding[0][1];
114
115     const auto &oscales = pd()->attr()->output_scales_;
116
117     int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
118         * jcp.oc_block * jcp.ic_block;
119     wei_data_t *w = const_cast<wei_data_t *>(weights);
120     int32_t* compensation = (jcp.signed_input)
121         ? reinterpret_cast<int32_t *>(w + offset) : 0;
122
123     auto step = [](int default_step, int remaining, int tail_step) {
124         assert(default_step <= tail_step);
125         return remaining < tail_step ? remaining : default_step;
126     };
127
128     auto p = jit_1x1_conv_call_s();
129
130     auto rp = rtus_driver_t<avx512_common>::call_params_t();
131     const int nb_oc = jcp.nb_load;
132     const int os_block = jcp.bcast_block;
133
134     int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
135     balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
136         jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
137         jcp.load_grp_count);
138     if (jcp.nb_load_chunk > 1) {
139         ocb_start *= jcp.nb_load_chunk;
140         ocb_end *= jcp.nb_load_chunk;
141     }
142
143     auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
144             int &oh, int &ow, int &ih, int &iw)
145     {
146         int osb{0};
147         nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
148             jcp.nb_bcast);
149         bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
150                 jcp.nb_bcast_blocking_max);
151         bcast_step = nstl::min(bcast_step, bcast_end - iwork);
152
153         const int os = osb * os_block;
154         oh = os / jcp.ow;
155         ow = os % jcp.ow;
156
157         ih = nstl::max(oh * stride_h - pad_t, 0);
158         iw = nstl::max(ow * stride_w - pad_l, 0);
159         rp.iw_start = iw;
160
161         p.bcast_dim = this_block_size(os, jcp.os,
162             bcast_step * os_block);
163         rp.os = p.bcast_dim;
164     };
165
166     auto init_load = [&](int ocb, int &load_step)
167     {
168         load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
169             jcp.nb_load_blocking_max);
170         p.load_dim = this_block_size(ocb * jcp.oc_block,
171             ocb_end * jcp.oc_block, load_step * jcp.oc_block);
172
173         if (ocb + load_step >= nb_oc)
174             p.first_last_flag |= FLAG_OC_LAST;
175         else
176             p.first_last_flag &= ~FLAG_OC_LAST;
177
178     };
179
180     auto init_reduce = [&]()
181     {
182         p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic);
183         rp.icb = p.reduce_dim / jcp.reduce_block;
184     };
185
186     auto inner_ker = [&](int ocb, int n, int g, int oh, int ow,
187             int ih, int iw)
188     {
189         const int icb = 0; // Start from the first IC block
190         const int _ocb = g * nb_oc + ocb;
191         const int _icb = g;
192
193         const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
194
195         p.output_data = &dst[dst_off];
196         p.load_data = &weights[pd()->with_groups()
197             ? weights_d.blk_off(g, ocb, icb)
198             : weights_d.blk_off(ocb, icb)];
199         p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
200         p.compensation = (jcp.signed_input)
201             ? &compensation[_ocb * jcp.oc_block] : 0;
202         p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
203             ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
204             : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
205         if (pd()->rtus_.reduce_src_) {
206             rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
207                 + _icb * jcp.is * jcp.ic_block;
208             if (ocb == ocb_start) {
209                 rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
210                 rtus_driver_->ker_(&rp);
211             }
212             p.bcast_data = rp.ws;
213         } else
214             p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
215
216         p.oc_off = _ocb * jcp.oc_block * sizeof(float);
217
218         kernel_->jit_ker(&p);
219     };
220
221     if (jcp.loop_order == loop_rlb) {
222         init_reduce();
223         int ocb = ocb_start;
224         while (ocb < ocb_end) {
225             int load_step;
226             init_load(ocb, load_step);
227             int iwork = bcast_start;
228             while (iwork < bcast_end) {
229                 int n, g, bcast_step, oh, ow, ih, iw;
230                 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
231                 inner_ker(ocb, n, g, oh, ow, ih, iw);
232                 iwork += bcast_step;
233             }
234             ocb += load_step;
235         }
236     } else if (jcp.loop_order == loop_lbr) {
237         int ocb = ocb_start;
238         while (ocb < ocb_end) {
239             int load_step;
240             init_load(ocb, load_step);
241             int iwork = bcast_start;
242             while (iwork < bcast_end) {
243                 int n, g, bcast_step, oh, ow, ih, iw;
244                 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
245                 init_reduce();
246                 inner_ker(ocb, n, g, oh, ow, ih, iw);
247                 iwork += bcast_step;
248             }
249             ocb += load_step;
250         }
251     } else if (jcp.loop_order == loop_rbl) {
252         init_reduce();
253         int iwork = bcast_start;
254         while (iwork < bcast_end) {
255             int n, g, bcast_step, oh, ow, ih, iw;
256             init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
257             int ocb = ocb_start;
258             while (ocb < ocb_end) {
259                 int load_step;
260                 init_load(ocb, load_step);
261                 inner_ker(ocb, n, g, oh, ow, ih, iw);
262                 ocb += load_step;
263             }
264             iwork += bcast_step;
265         }
266     } else if (jcp.loop_order == loop_blr) {
267         int iwork = bcast_start;
268         while (iwork < bcast_end) {
269             int n, g, bcast_step, oh, ow, ih, iw;
270             init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
271             int ocb = ocb_start;
272             while (ocb < ocb_end) {
273                 int load_step;
274                 init_load(ocb, load_step);
275                 init_reduce();
276                 inner_ker(ocb, n, g, oh, ow, ih, iw);
277                 ocb += load_step;
278             }
279             iwork += bcast_step;
280         }
281     } else {
282         assert(!"unsupported loop order");
283     }
284 }
285
286 using namespace data_type;
287 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
288 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
289 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
290 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
291 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
292 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
293 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
294 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
295
296 }
297 }
298 }