1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
18 #include "c_types_map.hpp"
19 #include "jit_uni_x8s8s32x_convolution.hpp"
21 #include "mkldnn_thread.hpp"
22 #include "type_helpers.hpp"
29 using namespace mkldnn::impl::status;
30 using namespace mkldnn::impl::memory_format;
31 using namespace mkldnn::impl::memory_tracking::names;
32 using namespace mkldnn::impl::utils;
34 template <cpu_isa_t isa, impl::data_type_t src_type, data_type_t dst_type>
35 void _jit_uni_x8s8s32x_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
36 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
37 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
38 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
39 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
41 const memory_desc_wrapper src_d(pd()->src_pd());
42 const memory_desc_wrapper dst_d(pd()->dst_pd());
43 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
44 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
46 const auto &jcp = kernel_->jcp;
48 size_t offset = (size_t)jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw;
49 auto w = const_cast<wei_data_t *>(weights);
50 int32_t* compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
52 if (bias && jcp.oc != jcp.oc_padded) {
53 auto padded_bias = this->scratchpad().template get<bia_data_t>(key_conv_padded_bias);
54 utils::array_copy(padded_bias, (bia_data_t*)bias, jcp.oc);
55 utils::array_set(padded_bias + jcp.oc, 0, jcp.oc_padded - jcp.oc);
56 bias = (char *)padded_bias;
59 const float *oscales = pd()->attr()->output_scales_.scales_;
60 if (jcp.signed_input) {
61 auto local_scales = scratchpad().template get<float>(key_conv_adjusted_scales);
62 size_t count = pd()->attr()->output_scales_.count_;
63 float factor = 1.f / jcp.wei_adj_scale;
65 utils::array_set(local_scales, oscales[0] * factor, 8);
67 for (size_t c = 0; c < count; c++)
68 local_scales[c] = oscales[c] * factor;
70 oscales = local_scales;
72 if (jcp.oc != jcp.oc_padded) {
73 auto padded_compensation = this->scratchpad().template get<int32_t>(key_conv_padded_compensation);
74 utils::array_copy(padded_compensation, compensation, jcp.oc);
75 utils::array_set(padded_compensation + jcp.oc, 0, jcp.oc_padded - jcp.oc);
76 compensation = padded_compensation;
80 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
81 const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh;
83 auto ker = [&](const int ithr, const int nthr) {
84 size_t start{0}, end{0};
85 balance211(work_amount, nthr, ithr, start, end);
87 size_t n{0}, g{0}, ocbb{0}, oh{0};
88 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
90 for (size_t iwork = start; iwork < end; ++iwork) {
91 int ocb = ocbb * jcp.nb_oc_blocking;
92 int ocb_num = jcp.nb_oc_blocking;
94 auto par_conv = jit_conv_call_s();
96 const int ij = oh * jcp.stride_h;
97 const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
98 const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
99 jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
101 const size_t _oc = g * jcp.nb_oc + ocb;
102 const size_t _ic = g * jcp.nb_ic;
104 const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
105 par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0)];
107 size_t dst_off = dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0);
108 par_conv.dst = &dst[dst_off];
110 const int wh = (!jcp.signed_input) ? i_t_overflow : 0;
111 par_conv.filt = &weights[pd()->with_groups()
112 ? weights_d.blk_off(g, ocb, 0, wh, 0)
113 : weights_d.blk_off(ocb, 0, wh, 0)];
116 par_conv.bias = &bias[bias_d.blk_off(_oc * jcp.oc_block*jcp.typesize_bia)];
119 nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
121 par_conv.kw_padding = 0;
122 const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
123 par_conv.kh_padding = nstl::max(0, kh_padding);
125 par_conv.scales = &oscales[jcp.is_oc_scale * _oc * jcp.oc_block];
127 par_conv.compensation = (jcp.signed_input) ? compensation + _oc * jcp.oc_block : 0;
128 par_conv.t_overflow = i_t_overflow;
129 par_conv.b_overflow = i_b_overflow;
131 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
133 kernel_->jit_ker(&par_conv);
134 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
141 template <cpu_isa_t isa, impl::data_type_t src_type, data_type_t dst_type>
142 void _jit_uni_x8s8s32x_convolution_fwd_t<isa, src_type, dst_type>::execute_forward_with_dw_conv() const {
143 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
144 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
145 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
146 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
148 const memory_desc_wrapper src_d(pd()->src_pd());
149 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
150 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
152 const auto &jcp = kernel_->jcp;
153 const auto &jcp_dw = kernel_dw_->jcp;
154 const int MB = pd()->MB();
156 size_t offset = (size_t)jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block) * jcp.kh * jcp.kw;
157 auto w = const_cast<wei_data_t *>(weights);
158 int32_t* compensation = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
160 auto dw_bias = jcp_dw.conv_biases;
161 auto dw_weights = reinterpret_cast<const wei_data_t *>(jcp_dw.conv_weights);
163 if (jcp.oc != jcp.oc_padded) {
164 auto padded_bias = this->scratchpad().template get<bia_data_t>(key_conv_padded_bias);
165 utils::array_copy(padded_bias, (bia_data_t*)bias, jcp.oc);
166 utils::array_set(padded_bias + jcp.oc, 0, jcp.oc_padded - jcp.oc);
167 bias = (char *)padded_bias;
169 auto dw_padded_bias = this->scratchpad().template get<bia_data_t>(key_dw_conv_padded_bias);
170 utils::array_copy(dw_padded_bias, dw_bias, jcp.oc);
171 utils::array_set(dw_padded_bias + jcp.oc, 0.f, jcp.oc_padded - jcp.oc);
172 dw_bias = dw_padded_bias;
175 const float *oscales = pd()->attr()->output_scales_.scales_;
176 if (jcp.signed_input) {
177 auto local_scales = scratchpad().template get<float>(key_conv_adjusted_scales);
178 size_t count = pd()->attr()->output_scales_.count_;
179 float factor = 1.f / jcp.wei_adj_scale;
181 utils::array_set(local_scales, oscales[0] * factor, 8);
183 for (size_t c = 0; c < count; c++)
184 local_scales[c] = oscales[c] * factor;
186 oscales = local_scales;
188 if (jcp.oc != jcp.oc_padded) {
189 auto padded_compensation = this->scratchpad().template get<int32_t>(key_conv_padded_compensation);
190 utils::array_copy(padded_compensation, compensation, jcp.oc);
191 utils::array_set(padded_compensation + jcp.oc, 0, jcp.oc_padded - jcp.oc);
192 compensation = padded_compensation;
196 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
197 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
199 auto ker = [&](const int ithr, const int nthr) {
200 auto compute_row_gen = [&](dst_data_t* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
201 for (int h = 0; h < num_rows; h++) {
202 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
203 for (int chb = ocb; chb < ocb + ocb_num; chb++) {
204 memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
205 (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(dst_data_t));
208 auto par_conv = jit_conv_call_s();
210 const int ij = (oh + h) * jcp.stride_h;
211 const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
212 const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
213 jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
215 const size_t _oc = g * jcp.nb_oc + ocb;
216 const size_t _ic = g * jcp.nb_ic;
218 const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
219 par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0)];
221 par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
223 const int wh = (!jcp.signed_input) ? i_t_overflow : 0;
224 par_conv.filt = &weights[pd()->with_groups()
225 ? weights_d.blk_off(g, ocb, 0, wh, 0)
226 : weights_d.blk_off(ocb, 0, wh, 0)];
229 par_conv.bias = &bias[bias_d.blk_off(_oc * jcp.oc_block*jcp.typesize_bia)];
232 nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
234 par_conv.kw_padding = 0;
235 const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
236 par_conv.kh_padding = nstl::max(0, kh_padding);
238 par_conv.scales = &oscales[jcp.is_oc_scale * _oc * jcp.oc_block];
239 par_conv.compensation = (jcp.signed_input) ? compensation + _oc * jcp.oc_block : 0;
240 par_conv.t_overflow = i_t_overflow;
241 par_conv.b_overflow = i_b_overflow;
243 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
245 kernel_->jit_ker(&par_conv);
250 auto compute_row_dw = [&](const dst_data_t* ws_p, int n, int ocb, int ocb_num, int dst_idx) {
251 for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
252 auto par_conv_dw = jit_conv_call_s();
254 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
255 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
256 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
257 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
258 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
259 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
261 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.oc + chb*jcp_dw.ch_block];
263 par_conv_dw.kh_padding = jcp_dw.kh;
264 par_conv_dw.filt = &dw_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
265 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
266 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
267 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
268 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
270 kernel_dw_->jit_ker(&par_conv_dw);
274 size_t start{0}, end{0};
275 balance211(work_amount, nthr, ithr, start, end);
277 auto dw_conv_buffer = scratchpad().template get<dst_data_t>(key_dw_conv_buffer);
278 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
279 auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
281 size_t n{0}, g{0}, ocbb{0}, oh{0};
282 nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
283 for (size_t iwork = start; iwork < end; ++iwork) {
284 int ocb = ocbb * jcp.nb_oc_blocking;
285 int ocb_num = jcp.nb_oc_blocking;
287 if (iwork == start || oh == 0) {
288 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
290 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh, 1);
293 if (iwork > start && ((oh - 1) % jcp_dw.stride_h == 0) && oh > 0) {
294 compute_row_dw(pbuf, n, ocb, ocb_num, oh - 1);
297 if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw.stride_h == 0)) {
298 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
299 compute_row_dw(pbuf, n, ocb, ocb_num, oh);
302 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
309 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
310 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
311 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
312 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
314 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
315 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
316 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
317 template struct _jit_uni_x8s8s32x_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
319 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
320 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
321 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
322 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
324 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
325 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
326 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
327 template struct _jit_uni_x8s8s32x_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;