1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "type_helpers.hpp"
22 #include "jit_avx512_core_x8s8s32x_convolution.hpp"
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
35 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
37 #define wht_blk_off(d, g, ...) \
38 (pd()->with_groups() \
39 ? (d).blk_off((g), __VA_ARGS__) \
40 : (d).blk_off(__VA_ARGS__))
42 template <data_type_t src_type, data_type_t dst_type>
43 void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
44 execute_forward() const
46 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
47 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
48 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
49 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
51 const memory_desc_wrapper src_d(pd()->src_pd());
52 const memory_desc_wrapper dst_d(pd()->dst_pd());
53 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
54 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
56 const size_t bia_dt_size = pd()->with_bias()
57 ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
59 const auto &jcp = pd()->jcp_;
60 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
61 assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
63 const float *oscales = pd()->attr()->output_scales_.scales_;
64 if (jcp.signed_input && jcp.ver != ver_vnni) {
65 auto local_scales = scratchpad().template get<float>(
66 key_conv_adjusted_scales);
67 size_t count = pd()->attr()->output_scales_.count_;
68 float factor = 1.f / pd()->jcp_.wei_adj_scale;
70 utils::array_set(local_scales, oscales[0] * factor, 16);
72 for (size_t c = 0; c < count; c++)
73 local_scales[c] = oscales[c] * factor;
75 oscales = local_scales;
78 size_t offset = weights_d.size() - weights_d.additional_buffer_size();
79 auto w = const_cast<wei_data_t *>(weights);
80 int32_t* compensation = (jcp.signed_input)
81 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
82 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
83 int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
84 int group_block = jcp.ch_block;
85 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
87 parallel(0, [&](const int ithr, const int nthr) {
90 balance211(work_amount, nthr, ithr, start, end);
92 auto p = jit_conv_call_s();
94 size_t src_h_stride = src_d.blk_off(0, 0, 1);
95 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
96 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
98 int n{ 0 }, gg{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
99 if (jcp.loop_order == loop_cwgn)
100 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
101 nb_groups, n, jcp.mb, oh_s, jcp.oh);
102 else if (jcp.loop_order == loop_gncw)
103 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
104 owb, jcp.nb_ow, oh_s, jcp.oh);
105 else if (jcp.loop_order == loop_ngcw)
106 nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
107 owb, jcp.nb_ow, oh_s, jcp.oh);
108 else if (jcp.loop_order == loop_nhwcg)
109 nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
110 occ, oc_chunks, gg, nb_groups);
112 assert(!"unsupported loop order");
113 while (start < end) {
114 int ocb = occ * jcp.nb_oc_blocking;
115 int gb = gg * jcp.nb_ch_blocking;
116 int g = gb * group_block;
117 int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
119 int g_ic = g * jcp.nb_ic * jcp.ic_block;
121 int work_rem = end - start;
122 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
123 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
124 if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; // step instead
125 int ow_s = owb * jcp.ow_block;
126 int iw_s = ow_s * jcp.stride_w;
129 ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
131 int32_t *compensation_w = (jcp.signed_input)
132 ? compensation + g_oc : 0;
134 auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s);
135 auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
136 auto wht_w = weights + wht_blk_off(weights_d, gb, ocb, 0);
138 auto scales = &oscales[jcp.is_oc_scale * g_oc];
140 for (int oj = oh_s, ij = ih_s; oj < oh_e;
141 ++oj, ij += jcp.stride_h) {
142 int dilate_h = jcp.dilate_h + 1;
143 int i_t_overflow = nstl::min(jcp.kh,
144 div_up(max(0, -ij), dilate_h));
145 int i_b_overflow = nstl::min(jcp.kh, div_up(
146 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
148 int kh_padding = nstl::max(0,
149 jcp.kh - i_t_overflow - i_b_overflow);
151 size_t wei_stride = (!jcp.signed_input)
152 ? i_t_overflow * wht_h_stride : 0;
153 p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
155 p.filt = wht_w + wei_stride;
157 p.compensation = compensation_w;
158 p.oc_blocks = jcp.is_depthwise ? gb : ocb;
159 p.kh_padding = kh_padding;
161 p.t_overflow = i_t_overflow;
162 p.b_overflow = i_b_overflow;
165 p.oc_off = g_oc * sizeof(float);
167 kernel_->jit_ker(&p);
169 src_w += src_h_stride * jcp.stride_h;
170 dst_w += dst_h_stride;
172 if (jcp.loop_order == loop_cwgn)
173 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
174 nb_groups, n, jcp.mb, oh_s, jcp.oh);
175 else if (jcp.loop_order == loop_gncw)
176 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
177 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
178 else if (jcp.loop_order == loop_ngcw)
179 nd_iterator_jump(start, end, n, jcp.mb, gg, nb_groups, occ,
180 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
181 else if (jcp.loop_order == loop_nhwcg) {
183 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
184 oc_chunks, gg, nb_groups);
187 assert(!"unsupported loop order");
192 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
193 data_type::s8, data_type::u8>;
194 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
195 data_type::u8, data_type::u8>;
196 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
197 data_type::s8, data_type::s8>;
198 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
199 data_type::u8, data_type::s8>;
200 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
201 data_type::s8, data_type::s32>;
202 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
203 data_type::u8, data_type::s32>;
204 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
205 data_type::s8, data_type::f32>;
206 template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
207 data_type::u8, data_type::f32>;
212 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s