1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
18 #include "mkldnn_types.h"
20 #include "c_types_map.hpp"
21 #include "jit_sse42_convolution.hpp"
22 #include "mkldnn_thread.hpp"
28 using namespace mkldnn::impl::status;
29 using namespace mkldnn::impl::memory_format;
30 using namespace mkldnn::impl::memory_tracking::names;
31 using namespace mkldnn::impl::utils;
33 #define src_blk_off(f, n, c, h, w) \
34 (pd()->ndims() == 3) \
35 ? (f).blk_off(n, c, w) \
36 : (f).blk_off(n, c, h, w)
38 #define wht_blk_off_(f, g, ...) \
40 ? (f).blk_off(g, __VA_ARGS__) \
41 : (f).blk_off(__VA_ARGS__)
42 #define wht_blk_off(f, g, oc, ic, kh, kw) \
44 ? wht_blk_off_(f, g, oc, ic, kw) \
45 : wht_blk_off_(f, g, oc, ic, kh, kw)
47 void jit_sse42_convolution_fwd_t::execute_forward() const {
48 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
49 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
50 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
51 auto dst = reinterpret_cast<data_t *>(this->memory());
53 const memory_desc_wrapper src_d(pd()->src_pd());
54 const memory_desc_wrapper dst_d(pd()->dst_pd());
55 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
56 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
58 const auto &jcp = kernel_->jcp;
61 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
62 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
64 if (pd()->wants_padded_bias()) {
65 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
66 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
67 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
68 jcp.oc - jcp.oc_without_padding);
72 parallel(0, [&](const int ithr, const int nthr) {
73 size_t start{ 0 }, end{ 0 };
74 balance211(work_amount, nthr, ithr, start, end);
77 while (icbb < jcp.nb_ic) {
78 int icb_step = jcp.nb_ic_blocking;
79 int icb_step_rem = jcp.nb_ic - icbb;
80 if (icb_step_rem < jcp.nb_ic_blocking_max)
81 icb_step = icb_step_rem;
83 size_t n{0}, g{0}, ocbb{0}, oh{0};
84 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
86 for (size_t iwork = start; iwork < end; ++iwork) {
87 int ocb = ocbb * jcp.nb_oc_blocking;
88 int ocb_num = jcp.nb_oc_blocking;
90 for (int icb = icbb; icb < icbb + icb_step; ++icb) {
91 auto par_conv = jit_conv_call_s();
93 const int ij = oh * jcp.stride_h;
94 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
95 const int i_b_overflow = nstl::max(jcp.ih, ij
96 + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
98 const size_t _oc = g * jcp.nb_oc + ocb;
99 const size_t _ic = g * jcp.nb_ic + icb;
101 const int ih = nstl::max(ij - jcp.t_pad
102 + div_up(i_t_overflow,
103 (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
104 par_conv.src = &src[src_blk_off(src_d, n,
105 jcp.ic == 3 ? 0 : _ic, ih, 0)];
107 par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)];
109 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
110 par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
111 jcp.ic == 3 ? 0 : icb, wh, 0)];
116 &bias[bias_d.blk_off(_oc * jcp.oc_block)];
117 par_conv.flags |= FLAG_IC_FIRST;
120 if (icb + 1 == jcp.nb_ic) {
121 par_conv.flags |= FLAG_IC_LAST;
125 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
127 par_conv.kw_padding = 0;
128 const int kh_padding = jcp.kh
129 - div_up(i_t_overflow, (jcp.dilate_h + 1))
130 - div_up(i_b_overflow, (jcp.dilate_h + 1));
131 par_conv.kh_padding = nstl::max(0, kh_padding);
133 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
135 kernel_->jit_ker(&par_conv);
137 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
144 if (pd()->wants_zero_pad_dst())
145 output_memory_primitive(0)->zero_pad();
148 void jit_sse42_convolution_fwd_t::execute_forward_with_dw_conv() const {
149 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
150 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
151 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
152 auto dst = reinterpret_cast<data_t *>(this->memory());
154 const memory_desc_wrapper src_d(pd()->src_pd());
155 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
156 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
158 const auto &jcp = kernel_->jcp;
159 const auto &jcp_dw = kernel_dw_->jcp;
162 auto dw_bias = jcp_dw.conv_biases;
164 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
165 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
167 auto ker = [&](const int ithr, const int nthr) {
168 auto compute_row_gen = [&](float* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
169 for (int h = 0; h < num_rows; h++) {
170 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
171 for (int chb = ocb; chb < ocb + ocb_num; chb++) {
172 memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
173 (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
176 for (int icb = 0; icb < jcp.nb_ic; ++icb) {
177 auto par_conv = jit_conv_call_s();
179 const int ij = (oh + h) * jcp.stride_h;
180 const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
181 const int i_b_overflow = nstl::max(jcp.ih, ij
182 + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad +
185 const size_t _oc = g * jcp.nb_oc + ocb;
186 const size_t _ic = g * jcp.nb_ic + icb;
188 const int ih = nstl::max(ij - jcp.t_pad
189 + div_up(i_t_overflow,
190 (jcp.dilate_h + 1)) * (jcp.dilate_h + 1), 0);
191 par_conv.src = &src[src_d.blk_off(n,
192 jcp.ic == 3 ? 0 : _ic, ih, 0)];
194 par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow *
197 const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
198 par_conv.filt = &weights[pd()->with_groups()
199 ? weights_d.blk_off(g, ocb,
200 jcp.ic == 3 ? 0 : icb, wh, 0)
201 : weights_d.blk_off(ocb,
202 jcp.ic == 3 ? 0 : icb, wh, 0)];
207 &bias[bias_d.blk_off(_oc * jcp.oc_block)];
208 par_conv.flags |= FLAG_IC_FIRST;
211 if (icb + 1 == jcp.nb_ic) {
212 par_conv.flags |= FLAG_IC_LAST;
216 nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
218 par_conv.kw_padding = 0;
219 const int kh_padding = jcp.kh
220 - div_up(i_t_overflow, (jcp.dilate_h + 1))
221 - div_up(i_b_overflow, (jcp.dilate_h + 1));
222 par_conv.kh_padding = nstl::max(0, kh_padding);
224 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
226 kernel_->jit_ker(&par_conv);
232 auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int ocb_num,
234 for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
235 auto par_conv_dw = jit_conv_call_s();
237 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
238 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
239 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
240 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
241 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
242 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
244 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
245 dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
247 par_conv_dw.kh_padding = jcp_dw.kh;
248 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
249 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
250 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
251 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
252 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
254 kernel_dw_->jit_ker(&par_conv_dw);
258 size_t start{0}, end{0};
259 balance211(work_amount, nthr, ithr, start, end);
261 auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
262 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
263 auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
265 size_t n{0}, g{0}, ocbb{0}, oh{0};
266 nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work,
268 for (size_t iwork = start; iwork < end; ++iwork) {
269 int ocb = ocbb * jcp.nb_oc_blocking;
270 int ocb_num = jcp.nb_oc_blocking;
272 if (iwork == start || oh == 0) {
273 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
275 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh, 1);
278 if (iwork > start && ((oh - 1) % jcp_dw.stride_h == 0) && oh > 0) {
279 compute_row_dw(pbuf, n, ocb, ocb_num, oh - 1);
282 if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw.stride_h == 0)) {
283 compute_row_gen(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
284 compute_row_dw(pbuf, n, ocb, ocb_num, oh);
287 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work,
292 if (pd()->wants_padded_bias()) {
293 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
294 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
295 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
296 jcp.oc - jcp.oc_without_padding);
299 auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
300 utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
301 utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
302 jcp.oc - jcp.oc_without_padding);
303 dw_bias = dw_padded_bias;
308 if (pd()->wants_zero_pad_dst())
309 output_memory_primitive(0)->zero_pad();