1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_generator.hpp"
24 #include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
36 template <typename T, typename U>
37 void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
38 T nx, T &nx_start, T &nx_end, T nx_divider)
40 const T grp_size = utils::div_up(nthr, nx_divider);
41 const T grp_count = utils::div_up(nthr, grp_size);
43 T grp = ithr / grp_size;
44 T grp_ithr = ithr % grp_size;
45 T grp_nthr = grp_size;
46 T first_grps = nthr % grp_count;
47 if (first_grps > 0 && grp >= first_grps) {
48 ithr -= first_grps * grp_size;
50 grp = ithr / grp_nthr + first_grps;
51 grp_ithr = ithr % grp_nthr;
53 balance211(nx, grp_count, grp, nx_start, nx_end);
54 balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
58 /* convolution forward */
59 template <data_type_t src_type, data_type_t dst_type>
60 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
61 <src_type, dst_type>::execute_forward() const
63 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
65 reinterpret_cast<const wei_data_t *>(this->input_memory(1));
66 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
67 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
69 auto scratchpad = this->scratchpad();
71 if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
72 auto local_scales = scratchpad.template get<float>(
73 key_conv_adjusted_scales);
74 auto scales = pd()->attr()->output_scales_.scales_;
75 size_t count = pd()->attr()->output_scales_.count_;
76 float factor = 1.f / pd()->jcp_.wei_adj_scale;
78 utils::array_set(local_scales, scales[0] * factor, 16);
80 for (size_t c = 0; c < count; c++)
81 local_scales[c] = scales[c] * factor;
85 parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
86 execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
90 template <data_type_t src_type, data_type_t dst_type>
91 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
92 ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
93 const wei_data_t *weights, const char *bias, dst_data_t *dst,
94 const memory_tracking::grantor_t &scratchpad) const {
95 const memory_desc_wrapper src_d(pd()->src_pd());
96 const memory_desc_wrapper dst_d(pd()->dst_pd());
97 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
99 const size_t bia_dt_size = pd()->with_bias()
100 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
102 const auto &jcp = kernel_->jcp;
103 auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
104 auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
106 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
108 const int stride_h = pd()->desc()->strides[0];
109 const int stride_w = pd()->desc()->strides[1];
110 const int pad_t = pd()->desc()->padding[0][0];
111 const int pad_l = pd()->desc()->padding[0][1];
113 const auto &oscales = pd()->attr()->output_scales_;
115 int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
116 * jcp.oc_block * jcp.ic_block;
117 wei_data_t *w = const_cast<wei_data_t *>(weights);
118 int32_t* compensation = (jcp.signed_input)
119 ? reinterpret_cast<int32_t *>(w + offset) : 0;
121 auto step = [](int default_step, int remaining, int tail_step) {
122 assert(default_step <= tail_step);
123 return remaining < tail_step ? remaining : default_step;
126 auto p = jit_1x1_conv_call_s();
128 auto rp = rtus_driver_t<avx512_common>::call_params_t();
129 const int nb_oc = jcp.nb_load;
130 const int os_block = jcp.bcast_block;
133 int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
134 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
135 jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
137 auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
138 int &oh, int &ow, int &ih, int &iw)
141 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
143 bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
144 jcp.nb_bcast_blocking_max);
145 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
147 const int os = osb * os_block;
151 ih = nstl::max(oh * stride_h - pad_t, 0);
152 iw = nstl::max(ow * stride_w - pad_l, 0);
155 p.bcast_dim = this_block_size(os, jcp.os,
156 bcast_step * os_block);
160 auto init_load = [&](int ocb, int &load_step)
162 load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
163 jcp.nb_load_blocking_max);
164 p.load_dim = this_block_size(ocb * jcp.oc_block,
165 ocb_end * jcp.oc_block, load_step * jcp.oc_block);
167 if (ocb + load_step >= nb_oc)
168 p.first_last_flag |= FLAG_OC_LAST;
170 p.first_last_flag &= ~FLAG_OC_LAST;
174 auto init_reduce = [&]()
176 p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic);
177 rp.icb = p.reduce_dim / jcp.reduce_block;
180 auto inner_ker = [&](int ocb, int n, int g, int oh, int ow,
183 const int icb = 0; // Start from the first IC block
184 const int _ocb = g * nb_oc + ocb;
187 const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
189 p.output_data = &dst[dst_off];
190 p.load_data = &weights[pd()->with_groups()
191 ? weights_d.blk_off(g, ocb, icb)
192 : weights_d.blk_off(ocb, icb)];
193 p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
194 p.compensation = (jcp.signed_input)
195 ? &compensation[_ocb * jcp.oc_block] : 0;
196 p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
197 ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
198 : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
199 if (pd()->rtus_.reduce_src_) {
200 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
201 + _icb * jcp.is * jcp.ic_block;
202 if (ocb == ocb_start) {
203 rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
204 rtus_driver_->ker_(&rp);
206 p.bcast_data = rp.ws;
208 p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
210 p.oc_off = _ocb * jcp.oc_block * sizeof(float);
212 kernel_->jit_ker(&p);
215 if (jcp.loop_order == loop_rlb) {
218 while (ocb < ocb_end) {
220 init_load(ocb, load_step);
221 int iwork = bcast_start;
222 while (iwork < bcast_end) {
223 int n, g, bcast_step, oh, ow, ih, iw;
224 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
225 inner_ker(ocb, n, g, oh, ow, ih, iw);
230 } else if (jcp.loop_order == loop_lbr) {
232 while (ocb < ocb_end) {
234 init_load(ocb, load_step);
235 int iwork = bcast_start;
236 while (iwork < bcast_end) {
237 int n, g, bcast_step, oh, ow, ih, iw;
238 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
240 inner_ker(ocb, n, g, oh, ow, ih, iw);
245 } else if (jcp.loop_order == loop_rbl) {
247 int iwork = bcast_start;
248 while (iwork < bcast_end) {
249 int n, g, bcast_step, oh, ow, ih, iw;
250 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
252 while (ocb < ocb_end) {
254 init_load(ocb, load_step);
255 inner_ker(ocb, n, g, oh, ow, ih, iw);
260 } else if (jcp.loop_order == loop_blr) {
261 int iwork = bcast_start;
262 while (iwork < bcast_end) {
263 int n, g, bcast_step, oh, ow, ih, iw;
264 init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
266 while (ocb < ocb_end) {
268 init_load(ocb, load_step);
270 inner_ker(ocb, n, g, oh, ow, ih, iw);
276 assert(!"unsupported loop order");
280 using namespace data_type;
281 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
282 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
283 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
284 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
285 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
286 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
287 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
288 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;