21fcf34a567a3a741bcc34421c394d8ebbc99269
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_u8s8s32x_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "utils.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "math_utils.hpp"
24
25 #include "simple_q10n.hpp"
26
27 #include "gemm_u8s8s32x_convolution.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::utils;
34 using namespace mkldnn::impl::math;
35
36 template <bool with_relu, data_type_t dst_type>
37 void _gemm_u8s8s32x_convolution_fwd_t<with_relu, dst_type>::execute_forward() {
38     auto src_base = reinterpret_cast<const src_data_t *>(this->input_memory(0));
39     auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
40     auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
41     auto dst_base = reinterpret_cast<dst_data_t *>(this->memory());
42
43     jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
44
45     char *scratchpad = (char *)this->scratchpad_->get();
46     src_data_t *col = (src_data_t *)scratchpad;
47     parallel_nd(jcp.im2col_sz * jcp.nthr,
48             [&](ptrdiff_t i) { col[i] = (src_data_t)0; });
49
50     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
51         execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base,
52                 dst_base, scratchpad);
53     });
54 }
55
56 template <bool with_relu, data_type_t dst_type>
57 void _gemm_u8s8s32x_convolution_fwd_t<with_relu, dst_type>
58 ::execute_forward_thr(const int ithr, const int nthr,
59         const src_data_t *src_base, const wei_data_t *wei_base,
60         const char *bia_base, dst_data_t *dst_base, char *scratchpad) {
61 #if USE_MKL_IGEMM
62     jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
63
64     const auto src_md = memory_desc_wrapper(conf_.src_pd());
65     const size_t src_mb_stride = src_md.blk_off(1);
66     const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
67
68     const auto wei_md = memory_desc_wrapper(conf_.weights_pd(0));
69     const size_t wei_g_stride = conf_.with_groups() ? wei_md.blk_off(1) : 0;
70
71     const auto dst_md = memory_desc_wrapper(conf_.dst_pd());
72     const size_t dst_mb_stride = dst_md.blk_off(1);
73     const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
74     const size_t dst_os_stride = dst_md.blk_off(0, 0, 0, 1);
75
76     auto get_bias = [=, &bia_base](size_t off) -> acc_data_t {
77 #       define CASE(dt) case dt: return (acc_data_t)\
78         (*((const prec_traits<dt>::type *)bia_base + off))
79         switch (conf_.cdesc()->bias_desc.data_type) {
80         CASE(data_type::s8);
81         CASE(data_type::u8);
82         CASE(data_type::s32);
83         CASE(data_type::f32);
84         default: assert(!"unimplemented");
85         }
86 #       undef CASE
87         return 0;
88     };
89
90     /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
91     const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
92     const float *scales = conf_.attr()->output_scales_.scales_;
93
94     const auto rmode = conf_.attr()->round_mode_;
95
96     const bool use_fast_path = true
97         && scale_idx_mult == 0
98         && jcp.ngroups == 1
99         && !jcp.with_bias;
100     const float fast_path_alpha = scales[0];
101
102     const auto &post_ops = conf_.attr()->post_ops_;
103     const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
104     const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
105
106     float nslope = jcp.with_relu ? jcp.relu_negative_slope : 0;
107     int entry_idx = -1;
108     for (int idx = 0; idx < post_ops.len_; ++idx) {
109         const auto &e = post_ops.entry_[idx];
110         if (e.is_relu(true, false)) {
111             entry_idx = idx;
112             nslope = e.eltwise.alpha;
113             break;
114         }
115     }
116     const bool do_relu = jcp.with_relu || (entry_idx >= 0);
117
118     src_data_t *_col = (src_data_t *)scratchpad;
119     ptrdiff_t offset = (ptrdiff_t)jcp.im2col_sz
120                                    * sizeof(src_data_t) * jcp.nthr;
121     acc_data_t *_acc = (acc_data_t *)(scratchpad + offset);
122
123     src_data_t *col = _col + (ptrdiff_t)ithr * jcp.im2col_sz;
124     acc_data_t *acc = _acc + (ptrdiff_t)ithr * jcp.os * jcp.oc;
125
126     int n{0}, g{0};
127     size_t start = 0, end = 0;
128
129     const size_t work_amount = jcp.ngroups * jcp.mb;
130     balance211(work_amount, nthr, ithr, start, end);
131     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
132
133     for (size_t iwork = start; iwork < end; ++iwork) {
134         const src_data_t *src = src_base + n * src_mb_stride
135             + g * src_g_stride;
136         const wei_data_t *wei = wei_base + g * wei_g_stride;
137         dst_data_t *dst = dst_base + n * dst_mb_stride + g * dst_g_stride;
138
139         if (jcp.im2col_sz)
140             jit_gemm_convolution_utils::im2col_u8(jcp, src, col);
141
142         const int M = jcp.oc;
143         const int K = jcp.ks * jcp.ic;
144         const int N = jcp.os;
145         const int8_t off_a = 0, off_b = 0;
146         const int32_t off_c = 0;
147
148         cblas_gemm_s8u8s32(CblasColMajor, CblasNoTrans, CblasNoTrans,
149                 CblasFixOffset, M, N, K, 1., wei, M * jcp.ngroups, off_a,
150                 jcp.im2col_sz ? col : src, K, off_b, 0., acc, M, &off_c);
151
152         if (use_fast_path) {
153             auto body = [&](int o) {
154                 float d = fast_path_alpha * acc[o] + sum_scale * dst[o];
155                 if (do_relu && d < 0) d *= nslope;
156                 dst[o] = qz_a1b0<float, dst_data_t>()(d, rmode);
157             };
158
159 #           if _OPENMP >= 201307
160 #           pragma omp parallel for simd
161             for (int o = 0; o < jcp.os * jcp.oc; ++o) body(o);
162 #           else
163             parallel_nd(jcp.os * jcp.oc, body);
164 #           endif
165         } else {
166             parallel_nd(jcp.os, jcp.oc, [&](const int os, const int oc) {
167                 const size_t acc_off = os * jcp.oc + oc;
168                 float d = (float)acc[acc_off];
169
170                 if (jcp.with_bias)
171                     d += get_bias(g * jcp.oc + oc);
172
173                 d *= scales[(g * jcp.oc + oc) * scale_idx_mult];
174
175                 const size_t dst_off = os * dst_os_stride + oc;
176                 if (do_sum) d += sum_scale * dst[dst_off];
177                 if (do_relu && d < 0) d *= nslope;
178                 dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode);
179             });
180         }
181         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
182     }
183 #endif
184 }
185
186 template <data_type_t dst_type>
187 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data() {
188     auto diff_dst_base = reinterpret_cast<const diff_dst_data_t *>
189             (this->input_memory(0));
190     auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
191     auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
192     auto diff_src_base = reinterpret_cast<diff_src_data_t *>(this->memory());
193
194     jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
195     char *scratchpad = (char *)this->scratchpad_->get();
196
197     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
198         execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
199                 bia_base, diff_src_base, scratchpad);
200     });
201 }
202
203 template <data_type_t dst_type>
204 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>
205 ::execute_backward_data_thr(const int ithr, const int nthr,
206         const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
207         const char *bia_base, diff_src_data_t *diff_src_base, char *scratchpad)
208 {
209 #if USE_MKL_IGEMM
210     jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
211
212     const auto diff_dst_md = memory_desc_wrapper(conf_.diff_dst_pd());
213     const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
214     const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
215
216     const auto wei_md = memory_desc_wrapper(conf_.weights_pd(0));
217     const size_t wei_g_stride = conf_.with_groups() ? wei_md.blk_off(1) : 0;
218
219     const auto diff_src_md = memory_desc_wrapper(conf_.diff_src_pd());
220     const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
221     const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
222     const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
223
224     auto get_bias = [=, &bia_base](size_t off) -> acc_data_t {
225 #       define CASE(dt) case dt: return (acc_data_t)\
226         (*((const prec_traits<dt>::type *)bia_base + off))
227         switch (conf_.desc()->bias_desc.data_type) {
228         CASE(data_type::s8);
229         CASE(data_type::u8);
230         CASE(data_type::s32);
231         CASE(data_type::f32);
232         default: assert(!"unimplemented");
233         }
234 #       undef CASE
235         return 0;
236     };
237
238     /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
239     const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
240     const float *scales = conf_.attr()->output_scales_.scales_;
241     const auto rmode = conf_.attr()->round_mode_;
242     const size_t work_amount = jcp.ngroups * jcp.mb;
243
244     acc_data_t *_col = (acc_data_t *)scratchpad;
245     ptrdiff_t offset = (ptrdiff_t)jcp.im2col_sz
246                                     * sizeof(acc_data_t) * jcp.nthr;
247     acc_data_t *_acc = (acc_data_t *)(scratchpad + offset);
248
249     acc_data_t *col = _col + (ptrdiff_t)ithr * jcp.im2col_sz;
250     acc_data_t *acc = _acc + (ptrdiff_t)ithr * jcp.is * jcp.ic;
251
252     int n{0}, g{0};
253     size_t start = 0, end = 0;
254
255     balance211(work_amount, nthr, ithr, start, end);
256     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
257
258     for (size_t iwork = start; iwork < end; ++iwork) {
259         const diff_dst_data_t *diff_dst = diff_dst_base
260             + n * diff_dst_mb_stride + g * diff_dst_g_stride;
261         const wei_data_t *wei = wei_base + g * wei_g_stride;
262         diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
263             + g * diff_src_g_stride;
264
265         const int M = jcp.ks * jcp.ic;
266         const int N = jcp.os;
267         const int K = jcp.oc;
268         const int8_t off_a = 0, off_b = 0;
269         const int32_t off_c = 0;
270
271         cblas_gemm_s8u8s32(CblasColMajor, CblasTrans, CblasNoTrans,
272                 CblasFixOffset, M, N, K, 1., wei, K * jcp.ngroups, off_a,
273                 diff_dst, K * jcp.ngroups, off_b, 0., jcp.im2col_sz ? col
274                 : acc, M, &off_c);
275
276         if (jcp.im2col_sz)
277             jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
278
279         parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
280             float d = (float)acc[is * jcp.ic + ic];
281             if (jcp.with_bias)
282                 d += get_bias(g * jcp.ic + ic);
283             d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
284             const size_t diff_src_off = is * diff_src_os_stride + ic;
285             diff_src[diff_src_off] =
286                 qz_a1b0<float, diff_src_data_t>()(d, rmode);
287         });
288         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
289     }
290 #endif
291 }
292
293 using namespace data_type;
294
295 template struct _gemm_u8s8s32x_convolution_fwd_t<true, f32>;
296 template struct _gemm_u8s8s32x_convolution_fwd_t<true, s32>;
297 template struct _gemm_u8s8s32x_convolution_fwd_t<true, s8>;
298 template struct _gemm_u8s8s32x_convolution_fwd_t<true, u8>;
299 template struct _gemm_u8s8s32x_convolution_fwd_t<false, f32>;
300 template struct _gemm_u8s8s32x_convolution_fwd_t<false, s32>;
301 template struct _gemm_u8s8s32x_convolution_fwd_t<false, s8>;
302 template struct _gemm_u8s8s32x_convolution_fwd_t<false, u8>;
303
304 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
305 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
306 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
307 template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;
308 }
309 }
310 }