1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
36 template <typename T, typename U>
37 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
38 T nx, T &nx_start, T &nx_end, T nx_divider)
40 const T grp_size = utils::div_up(nthr, nx_divider);
41 const T grp_count = utils::div_up(nthr, grp_size);
43 T grp = ithr / grp_size;
44 T grp_ithr = ithr % grp_size;
45 T grp_nthr = grp_size;
46 T first_grps = nthr % grp_count;
47 if (first_grps > 0 && grp >= first_grps) {
48 ithr -= first_grps * grp_size;
50 grp = ithr / grp_nthr + first_grps;
51 grp_ithr = ithr % grp_nthr;
53 balance211(nx, grp_count, grp, nx_start, nx_end);
54 balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
58 /* convolution forward */
59 template <data_type_t src_type, data_type_t dst_type>
60 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
61 <src_type, dst_type>::execute_forward() const
63 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
65 reinterpret_cast<const wei_data_t *>(this->input_memory(1));
66 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
67 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
69 auto scratchpad = this->scratchpad();
71 if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
72 auto local_scales = scratchpad.template get<float>(
73 key_conv_adjusted_scales);
74 auto scales = pd()->attr()->output_scales_.scales_;
75 size_t count = pd()->attr()->output_scales_.count_;
76 float factor = 1.f / pd()->jcp_.wei_adj_scale;
78 utils::array_set(local_scales, scales[0] * factor, 16);
80 for (size_t c = 0; c < count; c++)
81 local_scales[c] = scales[c] * factor;
85 const int work_amount = pd()->jcp_.mb * pd()->jcp_.ngroups * pd()->jcp_.nb_bcast * pd()->jcp_.nb_load;
87 parallel(kernel_->jcp.nthr, (size_t)work_amount, [&](const int ithr, const int nthr) {
88 execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
92 template <data_type_t src_type, data_type_t dst_type>
93 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
94 ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
95 const wei_data_t *weights, const char *bias, dst_data_t *dst,
96 const memory_tracking::grantor_t &scratchpad) const {
97 const memory_desc_wrapper src_d(pd()->src_pd());
98 const memory_desc_wrapper dst_d(pd()->dst_pd());
99 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
101 const size_t bia_dt_size = pd()->with_bias()
102 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
104 const auto &jcp = kernel_->jcp;
105 auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
106 auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
108 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
110 const int stride_h = pd()->desc()->strides[0];
111 const int stride_w = pd()->desc()->strides[1];
112 const int pad_t = pd()->desc()->padding[0][0];
113 const int pad_l = pd()->desc()->padding[0][1];
115 const auto &oscales = pd()->attr()->output_scales_;
117 int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
118 * jcp.oc_block * jcp.ic_block;
119 wei_data_t *w = const_cast<wei_data_t *>(weights);
120 int32_t* compensation = (jcp.signed_input)
121 ? reinterpret_cast<int32_t *>(w + offset) : 0;
123 auto step = [](int default_step, int remaining, int tail_step) {
124 assert(default_step <= tail_step);
125 return remaining < tail_step ? remaining : default_step;
128 auto p = jit_1x1_conv_call_s();
130 auto rp = rtus_driver_t<avx512_common>::call_params_t();
131 const int nb_oc = jcp.nb_load;
132 const int os_block = jcp.bcast_block;
134 int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
135 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
136 jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
138 if (jcp.nb_load_chunk > 1) {
139 ocb_start *= jcp.nb_load_chunk;
140 ocb_end *= jcp.nb_load_chunk;
143 auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
144 int &oh, int &ow, int &ih, int &iw)
147 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
149 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
150 jcp.nb_bcast_blocking_max);
151 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
153 const int os = osb * os_block;
157 ih = nstl::max(oh * stride_h - pad_t, 0);
158 iw = nstl::max(ow * stride_w - pad_l, 0);
161 p.bcast_dim = this_block_size(os, jcp.os,
162 bcast_step * os_block);
166 auto init_load = [&](int ocb, int &load_step)
168 load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
169 jcp.nb_load_blocking_max);
170 p.load_dim = this_block_size(ocb * jcp.oc_block,
171 ocb_end * jcp.oc_block, load_step * jcp.oc_block);
173 if (ocb + load_step >= nb_oc)
174 p.first_last_flag |= FLAG_OC_LAST;
176 p.first_last_flag &= ~FLAG_OC_LAST;
180 auto init_reduce = [&]()
182 p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic);
183 rp.icb = p.reduce_dim / jcp.reduce_block;
186 auto inner_ker = [&](int ocb, int n, int g, int oh, int ow,
189 const int icb = 0; // Start from the first IC block
190 const int _ocb = g * nb_oc + ocb;
193 const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
195 p.output_data = &dst[dst_off];
196 p.load_data = &weights[pd()->with_groups()
197 ? weights_d.blk_off(g, ocb, icb)
198 : weights_d.blk_off(ocb, icb)];
199 p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
200 p.compensation = (jcp.signed_input)
201 ? &compensation[_ocb * jcp.oc_block] : 0;
202 p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
203 ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
204 : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
205 if (pd()->rtus_.reduce_src_) {
206 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
207 + _icb * jcp.is * jcp.ic_block;
208 if (ocb == ocb_start) {
209 rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
210 rtus_driver_->ker_(&rp);
212 p.bcast_data = rp.ws;
214 p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
216 p.oc_off = _ocb * jcp.oc_block * sizeof(float);
218 kernel_->jit_ker(&p);
221 if (jcp.loop_order == loop_rlb) {
224 while (ocb < ocb_end) {
226 init_load(ocb, load_step);
227 int iwork = bcast_start;
228 while (iwork < bcast_end) {
229 int n, g, bcast_step, oh, ow, ih, iw;
230 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
231 inner_ker(ocb, n, g, oh, ow, ih, iw);
236 } else if (jcp.loop_order == loop_lbr) {
238 while (ocb < ocb_end) {
240 init_load(ocb, load_step);
241 int iwork = bcast_start;
242 while (iwork < bcast_end) {
243 int n, g, bcast_step, oh, ow, ih, iw;
244 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
246 inner_ker(ocb, n, g, oh, ow, ih, iw);
251 } else if (jcp.loop_order == loop_rbl) {
253 int iwork = bcast_start;
254 while (iwork < bcast_end) {
255 int n, g, bcast_step, oh, ow, ih, iw;
256 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
258 while (ocb < ocb_end) {
260 init_load(ocb, load_step);
261 inner_ker(ocb, n, g, oh, ow, ih, iw);
266 } else if (jcp.loop_order == loop_blr) {
267 int iwork = bcast_start;
268 while (iwork < bcast_end) {
269 int n, g, bcast_step, oh, ow, ih, iw;
270 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
272 while (ocb < ocb_end) {
274 init_load(ocb, load_step);
276 inner_ker(ocb, n, g, oh, ow, ih, iw);
282 assert(!"unsupported loop order");
286 using namespace data_type;
287 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
288 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
289 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
290 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
291 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
292 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
293 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
294 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;