1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "math_utils.hpp"
25 #include "simple_q10n.hpp"
27 #include "gemm_u8s8s32x_convolution.hpp"
33 using namespace mkldnn::impl::utils;
34 using namespace mkldnn::impl::math;
36 template <bool with_relu, data_type_t dst_type>
37 void _gemm_u8s8s32x_convolution_fwd_t<with_relu, dst_type>::execute_forward() {
38 auto src_base = reinterpret_cast<const src_data_t *>(this->input_memory(0));
39 auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
40 auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
41 auto dst_base = reinterpret_cast<dst_data_t *>(this->memory());
43 jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
45 char *scratchpad = (char *)this->scratchpad_->get();
46 src_data_t *col = (src_data_t *)scratchpad;
47 parallel_nd(jcp.im2col_sz * jcp.nthr,
48 [&](ptrdiff_t i) { col[i] = (src_data_t)0; });
50 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
51 execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base,
52 dst_base, scratchpad);
56 template <bool with_relu, data_type_t dst_type>
57 void _gemm_u8s8s32x_convolution_fwd_t<with_relu, dst_type>
58 ::execute_forward_thr(const int ithr, const int nthr,
59 const src_data_t *src_base, const wei_data_t *wei_base,
60 const char *bia_base, dst_data_t *dst_base, char *scratchpad) {
62 jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
64 const auto src_md = memory_desc_wrapper(conf_.src_pd());
65 const size_t src_mb_stride = src_md.blk_off(1);
66 const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
68 const auto wei_md = memory_desc_wrapper(conf_.weights_pd(0));
69 const size_t wei_g_stride = conf_.with_groups() ? wei_md.blk_off(1) : 0;
71 const auto dst_md = memory_desc_wrapper(conf_.dst_pd());
72 const size_t dst_mb_stride = dst_md.blk_off(1);
73 const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
74 const size_t dst_os_stride = dst_md.blk_off(0, 0, 0, 1);
76 auto get_bias = [=, &bia_base](size_t off) -> acc_data_t {
77 # define CASE(dt) case dt: return (acc_data_t)\
78 (*((const prec_traits<dt>::type *)bia_base + off))
79 switch (conf_.cdesc()->bias_desc.data_type) {
84 default: assert(!"unimplemented");
90 /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
91 const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
92 const float *scales = conf_.attr()->output_scales_.scales_;
94 const auto rmode = conf_.attr()->round_mode_;
96 const bool use_fast_path = true
97 && scale_idx_mult == 0
100 const float fast_path_alpha = scales[0];
102 const auto &post_ops = conf_.attr()->post_ops_;
103 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
104 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
106 float nslope = jcp.with_relu ? jcp.relu_negative_slope : 0;
108 for (int idx = 0; idx < post_ops.len_; ++idx) {
109 const auto &e = post_ops.entry_[idx];
110 if (e.is_relu(true, false)) {
112 nslope = e.eltwise.alpha;
116 const bool do_relu = jcp.with_relu || (entry_idx >= 0);
118 src_data_t *_col = (src_data_t *)scratchpad;
119 ptrdiff_t offset = (ptrdiff_t)jcp.im2col_sz
120 * sizeof(src_data_t) * jcp.nthr;
121 acc_data_t *_acc = (acc_data_t *)(scratchpad + offset);
123 src_data_t *col = _col + (ptrdiff_t)ithr * jcp.im2col_sz;
124 acc_data_t *acc = _acc + (ptrdiff_t)ithr * jcp.os * jcp.oc;
127 size_t start = 0, end = 0;
129 const size_t work_amount = jcp.ngroups * jcp.mb;
130 balance211(work_amount, nthr, ithr, start, end);
131 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
133 for (size_t iwork = start; iwork < end; ++iwork) {
134 const src_data_t *src = src_base + n * src_mb_stride
136 const wei_data_t *wei = wei_base + g * wei_g_stride;
137 dst_data_t *dst = dst_base + n * dst_mb_stride + g * dst_g_stride;
140 jit_gemm_convolution_utils::im2col_u8(jcp, src, col);
142 const int M = jcp.oc;
143 const int K = jcp.ks * jcp.ic;
144 const int N = jcp.os;
145 const int8_t off_a = 0, off_b = 0;
146 const int32_t off_c = 0;
148 cblas_gemm_s8u8s32(CblasColMajor, CblasNoTrans, CblasNoTrans,
149 CblasFixOffset, M, N, K, 1., wei, M * jcp.ngroups, off_a,
150 jcp.im2col_sz ? col : src, K, off_b, 0., acc, M, &off_c);
153 auto body = [&](int o) {
154 float d = fast_path_alpha * acc[o] + sum_scale * dst[o];
155 if (do_relu && d < 0) d *= nslope;
156 dst[o] = qz_a1b0<float, dst_data_t>()(d, rmode);
159 # if _OPENMP >= 201307
160 # pragma omp parallel for simd
161 for (int o = 0; o < jcp.os * jcp.oc; ++o) body(o);
163 parallel_nd(jcp.os * jcp.oc, body);
166 parallel_nd(jcp.os, jcp.oc, [&](const int os, const int oc) {
167 const size_t acc_off = os * jcp.oc + oc;
168 float d = (float)acc[acc_off];
171 d += get_bias(g * jcp.oc + oc);
173 d *= scales[(g * jcp.oc + oc) * scale_idx_mult];
175 const size_t dst_off = os * dst_os_stride + oc;
176 if (do_sum) d += sum_scale * dst[dst_off];
177 if (do_relu && d < 0) d *= nslope;
178 dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode);
181 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
186 template <data_type_t dst_type>
187 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data() {
188 auto diff_dst_base = reinterpret_cast<const diff_dst_data_t *>
189 (this->input_memory(0));
190 auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
191 auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
192 auto diff_src_base = reinterpret_cast<diff_src_data_t *>(this->memory());
194 jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
195 char *scratchpad = (char *)this->scratchpad_->get();
197 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
198 execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
199 bia_base, diff_src_base, scratchpad);
203 template <data_type_t dst_type>
204 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>
205 ::execute_backward_data_thr(const int ithr, const int nthr,
206 const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
207 const char *bia_base, diff_src_data_t *diff_src_base, char *scratchpad)
210 jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
212 const auto diff_dst_md = memory_desc_wrapper(conf_.diff_dst_pd());
213 const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
214 const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
216 const auto wei_md = memory_desc_wrapper(conf_.weights_pd(0));
217 const size_t wei_g_stride = conf_.with_groups() ? wei_md.blk_off(1) : 0;
219 const auto diff_src_md = memory_desc_wrapper(conf_.diff_src_pd());
220 const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
221 const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
222 const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
224 auto get_bias = [=, &bia_base](size_t off) -> acc_data_t {
225 # define CASE(dt) case dt: return (acc_data_t)\
226 (*((const prec_traits<dt>::type *)bia_base + off))
227 switch (conf_.desc()->bias_desc.data_type) {
230 CASE(data_type::s32);
231 CASE(data_type::f32);
232 default: assert(!"unimplemented");
238 /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
239 const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
240 const float *scales = conf_.attr()->output_scales_.scales_;
241 const auto rmode = conf_.attr()->round_mode_;
242 const size_t work_amount = jcp.ngroups * jcp.mb;
244 acc_data_t *_col = (acc_data_t *)scratchpad;
245 ptrdiff_t offset = (ptrdiff_t)jcp.im2col_sz
246 * sizeof(acc_data_t) * jcp.nthr;
247 acc_data_t *_acc = (acc_data_t *)(scratchpad + offset);
249 acc_data_t *col = _col + (ptrdiff_t)ithr * jcp.im2col_sz;
250 acc_data_t *acc = _acc + (ptrdiff_t)ithr * jcp.is * jcp.ic;
253 size_t start = 0, end = 0;
255 balance211(work_amount, nthr, ithr, start, end);
256 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
258 for (size_t iwork = start; iwork < end; ++iwork) {
259 const diff_dst_data_t *diff_dst = diff_dst_base
260 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
261 const wei_data_t *wei = wei_base + g * wei_g_stride;
262 diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
263 + g * diff_src_g_stride;
265 const int M = jcp.ks * jcp.ic;
266 const int N = jcp.os;
267 const int K = jcp.oc;
268 const int8_t off_a = 0, off_b = 0;
269 const int32_t off_c = 0;
271 cblas_gemm_s8u8s32(CblasColMajor, CblasTrans, CblasNoTrans,
272 CblasFixOffset, M, N, K, 1., wei, K * jcp.ngroups, off_a,
273 diff_dst, K * jcp.ngroups, off_b, 0., jcp.im2col_sz ? col
277 jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
279 parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
280 float d = (float)acc[is * jcp.ic + ic];
282 d += get_bias(g * jcp.ic + ic);
283 d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
284 const size_t diff_src_off = is * diff_src_os_stride + ic;
285 diff_src[diff_src_off] =
286 qz_a1b0<float, diff_src_data_t>()(d, rmode);
288 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
293 using namespace data_type;
295 template struct _gemm_u8s8s32x_convolution_fwd_t<true, f32>;
296 template struct _gemm_u8s8s32x_convolution_fwd_t<true, s32>;
297 template struct _gemm_u8s8s32x_convolution_fwd_t<true, s8>;
298 template struct _gemm_u8s8s32x_convolution_fwd_t<true, u8>;
299 template struct _gemm_u8s8s32x_convolution_fwd_t<false, f32>;
300 template struct _gemm_u8s8s32x_convolution_fwd_t<false, s32>;
301 template struct _gemm_u8s8s32x_convolution_fwd_t<false, s8>;
302 template struct _gemm_u8s8s32x_convolution_fwd_t<false, u8>;
304 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
305 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
306 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
307 template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;