Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "jit_generator.hpp"
23
24 #include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 namespace {
36 template <typename T, typename U>
37 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
38     T nx, T &nx_start, T &nx_end, T nx_divider)
39 {
40     const T grp_size = utils::div_up(nthr, nx_divider);
41     const T grp_count = utils::div_up(nthr, grp_size);
42
43     T grp = ithr / grp_size;
44     T grp_ithr = ithr % grp_size;
45     T grp_nthr = grp_size;
46     T first_grps = nthr % grp_count;
47     if (first_grps > 0 && grp >= first_grps) {
48         ithr -= first_grps * grp_size;
49         grp_nthr--;
50         grp = ithr / grp_nthr + first_grps;
51         grp_ithr = ithr % grp_nthr;
52     }
53     balance211(nx, grp_count, grp, nx_start, nx_end);
54     balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
55 }
56 }
57
58 /* convolution forward */
59 template <data_type_t src_type, data_type_t dst_type>
60 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
61                               <src_type, dst_type>::execute_forward() const
62 {
63     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
64     auto weights =
65         reinterpret_cast<const wei_data_t *>(this->input_memory(1));
66     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
67     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
68
69     auto scratchpad = this->scratchpad();
70
71     if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
72         auto local_scales = scratchpad.template get<float>(
73                 key_conv_adjusted_scales);
74         auto scales = pd()->attr()->output_scales_.scales_;
75         size_t count = pd()->attr()->output_scales_.count_;
76         float factor = 1.f / pd()->jcp_.wei_adj_scale;
77         if (count == 1) {
78             utils::array_set(local_scales, scales[0] * factor, 16);
79         } else {
80             for (size_t c = 0; c < count; c++)
81                 local_scales[c] = scales[c] * factor;
82         }
83     }
84
85     parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
86         execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
87     });
88 }
89
90 template <data_type_t src_type, data_type_t dst_type>
91 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
92 ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
93         const wei_data_t *weights, const char *bias, dst_data_t *dst,
94         const memory_tracking::grantor_t &scratchpad) const {
95     const memory_desc_wrapper src_d(pd()->src_pd());
96     const memory_desc_wrapper dst_d(pd()->dst_pd());
97     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
98
99     const size_t bia_dt_size = pd()->with_bias()
100         ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
101
102     const auto &jcp = kernel_->jcp;
103     auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
104     auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
105
106     const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
107
108     const int stride_h = pd()->desc()->strides[0];
109     const int stride_w = pd()->desc()->strides[1];
110     const int pad_t = pd()->desc()->padding[0][0];
111     const int pad_l = pd()->desc()->padding[0][1];
112
113     const auto &oscales = pd()->attr()->output_scales_;
114
115     int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
116         * jcp.oc_block * jcp.ic_block;
117     wei_data_t *w = const_cast<wei_data_t *>(weights);
118     int32_t* compensation = (jcp.signed_input)
119         ? reinterpret_cast<int32_t *>(w + offset) : 0;
120
121     auto step = [](int default_step, int remaining, int tail_step) {
122         assert(default_step <= tail_step);
123         return remaining < tail_step ? remaining : default_step;
124     };
125
126     auto p = jit_1x1_conv_call_s();
127
128     auto rp = rtus_driver_t<avx512_common>::call_params_t();
129     const int nb_oc = jcp.nb_load;
130     const int os_block = jcp.bcast_block;
131
132
133     int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
134     balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
135         jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
136
137     auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
138             int &oh, int &ow, int &ih, int &iw)
139     {
140         int osb{0};
141         nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
142             jcp.nb_bcast);
143         bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
144                 jcp.nb_bcast_blocking_max);
145         bcast_step = nstl::min(bcast_step, bcast_end - iwork);
146
147         const int os = osb * os_block;
148         oh = os / jcp.ow;
149         ow = os % jcp.ow;
150
151         ih = nstl::max(oh * stride_h - pad_t, 0);
152         iw = nstl::max(ow * stride_w - pad_l, 0);
153         rp.iw_start = iw;
154
155         p.bcast_dim = this_block_size(os, jcp.os,
156             bcast_step * os_block);
157         rp.os = p.bcast_dim;
158     };
159
160     auto init_load = [&](int ocb, int &load_step)
161     {
162         load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
163             jcp.nb_load_blocking_max);
164         p.load_dim = this_block_size(ocb * jcp.oc_block,
165             ocb_end * jcp.oc_block, load_step * jcp.oc_block);
166
167         if (ocb + load_step >= nb_oc)
168             p.first_last_flag |= FLAG_OC_LAST;
169         else
170             p.first_last_flag &= ~FLAG_OC_LAST;
171
172     };
173
174     auto init_reduce = [&]()
175     {
176         p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic);
177         rp.icb = p.reduce_dim / jcp.reduce_block;
178     };
179
180     auto inner_ker = [&](int ocb, int n, int g, int oh, int ow,
181             int ih, int iw)
182     {
183         const int icb = 0; // Start from the first IC block
184         const int _ocb = g * nb_oc + ocb;
185         const int _icb = g;
186
187         const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
188
189         p.output_data = &dst[dst_off];
190         p.load_data = &weights[pd()->with_groups()
191             ? weights_d.blk_off(g, ocb, icb)
192             : weights_d.blk_off(ocb, icb)];
193         p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
194         p.compensation = (jcp.signed_input)
195             ? &compensation[_ocb * jcp.oc_block] : 0;
196         p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
197             ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
198             : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
199         if (pd()->rtus_.reduce_src_) {
200             rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
201                 + _icb * jcp.is * jcp.ic_block;
202             if (ocb == ocb_start) {
203                 rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
204                 rtus_driver_->ker_(&rp);
205             }
206             p.bcast_data = rp.ws;
207         } else
208             p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
209
210         p.oc_off = _ocb * jcp.oc_block * sizeof(float);
211
212         kernel_->jit_ker(&p);
213     };
214
215     if (jcp.loop_order == loop_rlb) {
216         init_reduce();
217         int ocb = ocb_start;
218         while (ocb < ocb_end) {
219             int load_step;
220             init_load(ocb, load_step);
221             int iwork = bcast_start;
222             while (iwork < bcast_end) {
223                 int n, g, bcast_step, oh, ow, ih, iw;
224                 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
225                 inner_ker(ocb, n, g, oh, ow, ih, iw);
226                 iwork += bcast_step;
227             }
228             ocb += load_step;
229         }
230     } else if (jcp.loop_order == loop_lbr) {
231         int ocb = ocb_start;
232         while (ocb < ocb_end) {
233             int load_step;
234             init_load(ocb, load_step);
235             int iwork = bcast_start;
236             while (iwork < bcast_end) {
237                 int n, g, bcast_step, oh, ow, ih, iw;
238                 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
239                 init_reduce();
240                 inner_ker(ocb, n, g, oh, ow, ih, iw);
241                 iwork += bcast_step;
242             }
243             ocb += load_step;
244         }
245     } else if (jcp.loop_order == loop_rbl) {
246         init_reduce();
247         int iwork = bcast_start;
248         while (iwork < bcast_end) {
249             int n, g, bcast_step, oh, ow, ih, iw;
250             init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
251             int ocb = ocb_start;
252             while (ocb < ocb_end) {
253                 int load_step;
254                 init_load(ocb, load_step);
255                 inner_ker(ocb, n, g, oh, ow, ih, iw);
256                 ocb += load_step;
257             }
258             iwork += bcast_step;
259         }
260     } else if (jcp.loop_order == loop_blr) {
261         int iwork = bcast_start;
262         while (iwork < bcast_end) {
263             int n, g, bcast_step, oh, ow, ih, iw;
264             init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
265             int ocb = ocb_start;
266             while (ocb < ocb_end) {
267                 int load_step;
268                 init_load(ocb, load_step);
269                 init_reduce();
270                 inner_ker(ocb, n, g, oh, ow, ih, iw);
271                 ocb += load_step;
272             }
273             iwork += bcast_step;
274         }
275     } else {
276         assert(!"unsupported loop order");
277     }
278 }
279
280 using namespace data_type;
281 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
282 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
283 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
284 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
285 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
286 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
287 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
288 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
289
290 }
291 }
292 }