1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_avx512_core_x8s8s32x_convolution.hpp"
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
35 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
37 #define wht_blk_off(d, g, ...) \
38 (pd()->with_groups() \
39 ? (d).blk_off((g), __VA_ARGS__) \
40 : (d).blk_off(__VA_ARGS__))
42 template <data_type_t src_type, data_type_t dst_type>
43 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
44 dst_type>::execute_forward_1d() const {
45 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
46 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
47 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
48 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
50 const memory_desc_wrapper src_d(pd()->src_pd());
51 const memory_desc_wrapper dst_d(pd()->dst_pd());
52 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
53 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
55 const size_t bia_dt_size = pd()->with_bias()
56 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
58 const auto &jcp = pd()->jcp_;
59 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
60 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
62 const float *oscales = pd()->attr()->output_scales_.scales_;
63 if (jcp.signed_input && jcp.ver != ver_vnni) {
64 auto local_scales = scratchpad().template get<float>(
65 key_conv_adjusted_scales);
66 size_t count = pd()->attr()->output_scales_.count_;
67 float factor = 1.f / pd()->jcp_.wei_adj_scale;
69 utils::array_set(local_scales, oscales[0] * factor, 16);
71 for (size_t c = 0; c < count; c++)
72 local_scales[c] = oscales[c] * factor;
74 oscales = local_scales;
77 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
78 auto w = const_cast<wei_data_t *>(weights);
79 int32_t* compensation = (jcp.signed_input)
80 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
81 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
82 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
83 int group_block = jcp.ch_block;
84 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
86 parallel(0, [&](const int ithr, const int nthr) {
88 int start{ 0 }, end{ 0 };
89 balance211(work_amount, nthr, ithr, start, end);
91 auto p = jit_conv_call_s();
93 int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 };
94 switch (jcp.loop_order) {
96 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
97 nb_groups, n, jcp.mb);
100 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
104 nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
108 nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
111 default: assert(!"unsupported loop order");
113 while (start < end) {
114 int ocb = occ * jcp.nb_oc_blocking;
115 int gb = gg * jcp.nb_ch_blocking;
116 int g = gb * group_block;
117 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
118 int g_ic = g * jcp.nb_ic * jcp.ic_block;
119 int ow_s = owb * jcp.ow_block;
120 int iw_s = ow_s * jcp.stride_w;
122 p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0;
123 p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
124 p.dst = dst + dst_d.blk_off(n, g_oc, ow_s);
125 p.src = src + src_d.blk_off(n, g_ic, iw_s);
126 p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0);
127 p.scales = &oscales[jcp.is_oc_scale * g_oc];
128 p.oc_blocks = jcp.is_depthwise ? gb : ocb;
129 p.kh_padding = jcp.kh;
134 kernel_->jit_ker(&p);
137 switch (jcp.loop_order) {
139 nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups,
143 nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb,
147 nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb,
151 nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg,
154 default: assert(!"unsupported loop order");
160 template <data_type_t src_type, data_type_t dst_type>
161 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
162 dst_type>::execute_forward_2d() const {
163 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
164 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
165 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
166 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
168 const memory_desc_wrapper src_d(pd()->src_pd());
169 const memory_desc_wrapper dst_d(pd()->dst_pd());
170 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
171 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
173 const size_t bia_dt_size = pd()->with_bias()
174 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
176 const auto &jcp = pd()->jcp_;
177 assert(jcp.ch_block == 1);
178 assert(jcp.nb_ch_blocking == 1);
179 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
180 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
182 const float *oscales = pd()->attr()->output_scales_.scales_;
183 if (jcp.signed_input && jcp.ver != ver_vnni) {
184 auto local_scales = scratchpad().template get<float>(
185 key_conv_adjusted_scales);
186 size_t count = pd()->attr()->output_scales_.count_;
187 float factor = 1.f / pd()->jcp_.wei_adj_scale;
189 utils::array_set(local_scales, oscales[0] * factor, 16);
191 for (size_t c = 0; c < count; c++)
192 local_scales[c] = oscales[c] * factor;
194 oscales = local_scales;
197 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
198 auto w = const_cast<wei_data_t *>(weights);
199 int32_t* compensation = (jcp.signed_input)
200 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
201 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
202 int nb_groups = jcp.nb_ch;
203 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
205 parallel(0, (size_t)work_amount, [&](const int ithr, const int nthr) {
207 int start{0}, end{0};
208 balance211(work_amount, nthr, ithr, start, end);
210 auto p = jit_conv_call_s();
212 size_t src_h_stride = src_d.blk_off(0, 0, 1);
213 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
214 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
216 int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
217 switch (jcp.loop_order) {
219 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
220 nb_groups, n, jcp.mb, oh_s, jcp.oh);
223 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
224 owb, jcp.nb_ow, oh_s, jcp.oh);
227 nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
228 occ, oc_chunks, g, nb_groups);
230 default: assert(!"unsupported loop order");
232 while (start < end) {
233 for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
234 occ1 += jcp.nb_oc_blocking) {
235 int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
236 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
238 int g_ic = g * jcp.nb_ic * jcp.ic_block;
240 int work_rem = end - start;
241 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
242 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
243 if (jcp.loop_order == loop_nhwcg)
244 oh_e = oh_s + 1; // step instead
245 int ow_s = owb * jcp.ow_block;
246 int iw_s = ow_s * jcp.stride_w;
249 ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
251 int32_t *compensation_w = (jcp.signed_input)
252 ? compensation + g_oc : 0;
254 auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s);
255 auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
256 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
258 auto scales = &oscales[jcp.is_oc_scale * g_oc];
260 for (int oj = oh_s, ij = ih_s; oj < oh_e;
261 ++oj, ij += jcp.stride_h) {
262 int dilate_h = jcp.dilate_h + 1;
263 int i_t_overflow = nstl::min(jcp.kh,
264 div_up(max(0, -ij), dilate_h));
265 int i_b_overflow = nstl::min(jcp.kh, div_up(
266 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
268 int kh_padding = nstl::max(0,
269 jcp.kh - i_t_overflow - i_b_overflow);
271 size_t wei_stride = (!jcp.signed_input)
272 ? i_t_overflow * wht_h_stride : 0;
273 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
275 p.filt = wht_w + wei_stride;
277 p.compensation = compensation_w;
279 p.kh_padding = kh_padding;
281 p.t_overflow = i_t_overflow;
282 p.b_overflow = i_b_overflow;
285 p.oc_off = g_oc * sizeof(float);
287 kernel_->jit_ker(&p);
288 src_w += src_h_stride * jcp.stride_h;
289 dst_w += dst_h_stride;
292 switch (jcp.loop_order) {
294 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g,
295 nb_groups, n, jcp.mb, oh_s, jcp.oh);
298 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
299 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
303 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
304 oc_chunks, g, nb_groups);
306 default: assert(!"unsupported loop order");
312 template <data_type_t src_type, data_type_t dst_type>
313 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
314 dst_type>::execute_forward_2d_dw() const {
315 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
316 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
317 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
318 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
320 const memory_desc_wrapper src_d(pd()->src_pd());
321 const memory_desc_wrapper dst_d(pd()->dst_pd());
322 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
323 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
325 const size_t bia_dt_size = pd()->with_bias()
326 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
328 const auto &jcp = pd()->jcp_;
329 assert(jcp.ic_block == 1);
330 assert(jcp.oc_block == 1);
331 assert(jcp.nb_ic == 1);
332 assert(jcp.nb_oc == 1);
333 assert(jcp.nb_oc_blocking == 1);
334 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
336 const float *oscales = pd()->attr()->output_scales_.scales_;
337 if (jcp.signed_input && jcp.ver != ver_vnni) {
338 auto local_scales = scratchpad().template get<float>(
339 key_conv_adjusted_scales);
340 size_t count = pd()->attr()->output_scales_.count_;
341 float factor = 1.f / pd()->jcp_.wei_adj_scale;
343 utils::array_set(local_scales, oscales[0] * factor, 16);
345 for (size_t c = 0; c < count; c++)
346 local_scales[c] = oscales[c] * factor;
348 oscales = local_scales;
351 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
352 auto w = const_cast<wei_data_t *>(weights);
353 int32_t* compensation = (jcp.signed_input)
354 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
355 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
356 int group_block = jcp.ch_block;
358 parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups, [&](int n, int oh_s, int owb, int gg) {
360 auto p = jit_conv_call_s();
362 size_t src_h_stride = src_d.blk_off(0, 0, 1);
363 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
365 int gb = gg * jcp.nb_ch_blocking;
366 int g = gb * group_block;
368 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
369 int ow_s = owb * jcp.ow_block;
370 int iw_s = ow_s * jcp.stride_w;
372 auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0;
373 int32_t *compensation_w = jcp.signed_input ? compensation + g : 0;
375 auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s);
376 auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s);
377 auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
379 auto scales = &oscales[jcp.is_oc_scale * g];
381 int dilate_h = jcp.dilate_h + 1;
382 int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
383 int i_b_overflow = nstl::min(jcp.kh,
384 div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
386 int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
388 size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride;
389 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
391 p.filt = wht_w + wei_stride;
393 p.compensation = compensation_w;
395 p.kh_padding = kh_padding;
397 p.t_overflow = i_t_overflow;
398 p.b_overflow = i_b_overflow;
401 kernel_->jit_ker(&p);
405 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
406 data_type::s8, data_type::u8>;
407 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
408 data_type::u8, data_type::u8>;
409 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
410 data_type::s8, data_type::s8>;
411 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
412 data_type::u8, data_type::s8>;
413 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
414 data_type::s8, data_type::s32>;
415 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
416 data_type::u8, data_type::s32>;
417 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
418 data_type::s8, data_type::f32>;
419 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
420 data_type::u8, data_type::f32>;
425 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s