1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
18 #include "mkldnn_types.h"
20 #include "c_types_map.hpp"
21 #include "jit_sse42_1x1_convolution.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
30 #define data_blk_off(f, n, c, h, w) \
32 ? (f).blk_off(n, c, w) \
33 : (f).blk_off(n, c, h, w))
35 using namespace mkldnn::impl::status;
36 using namespace mkldnn::impl::memory_format;
37 using namespace mkldnn::impl::memory_tracking::names;
38 using namespace mkldnn::impl::utils;
40 void jit_sse42_1x1_convolution_fwd_t::execute_forward() const {
41 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
42 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
43 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
44 auto dst = reinterpret_cast<data_t *>(this->memory());
46 const memory_desc_wrapper src_d(pd()->src_pd());
47 const memory_desc_wrapper dst_d(pd()->dst_pd());
48 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
50 const int ndims = src_d.ndims();
51 const auto &jcp = kernel_->jcp;
54 const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
56 if (pd()->wants_padded_bias()) {
57 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
58 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
59 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
60 jcp.oc - jcp.oc_without_padding);
64 parallel(0, [&](const int ithr, const int nthr) {
65 // TODO (Roma): remove this restriction
66 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
68 auto par_conv = jit_1x1_conv_call_s();
70 const int nb_oc = jcp.nb_load;
71 const int nb_ic = jcp.nb_reduce;
72 const int nb_ic_blocking = jcp.nb_reduce_blocking;
73 const int os_block = jcp.bcast_block;
76 balance211(work_amount, nthr, ithr, start, end);
80 int n{0}, g{0}, osb{0};
81 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
84 const int bcast_step_rem = jcp.nb_bcast - osb;
85 int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max
86 ? bcast_step_rem : jcp.nb_bcast_blocking;
87 bcast_step = nstl::min<int>(bcast_step, end - iwork);
89 const int os = osb * os_block;
90 const int ow = os % jcp.ow;
91 const int oh = os / jcp.ow;
92 const int iw = nstl::max<int>(ow * jcp.stride_w - jcp.l_pad, 0);
93 const int ih = nstl::max<int>(oh * jcp.stride_h - jcp.t_pad, 0);
95 par_conv.bcast_dim = this_block_size(os, jcp.os,
96 bcast_step * os_block);
99 while (ocb < jcp.nb_load) {
100 const int load_step_rem = jcp.nb_load - ocb;
101 const int load_step = load_step_rem < jcp.nb_load_blocking_max
102 ? load_step_rem : jcp.nb_load_blocking;
104 const size_t _ocb = g * nb_oc + ocb;
105 par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
106 load_step * jcp.oc_block);
108 const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
109 par_conv.output_data = &dst[dst_off];
111 par_conv.bias_data = &bias[_ocb * jcp.oc_block];
113 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
114 par_conv.first_last_flag = 0
115 | (icb == 0) * FLAG_REDUCE_FIRST
116 | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST;
118 par_conv.reduce_dim = this_block_size(icb * jcp.ic_block,
119 jcp.ic, nb_ic_blocking * jcp.ic_block);
121 const size_t _icb = g * nb_ic + icb;
122 const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw);
123 par_conv.bcast_data = &src[src_off];
125 par_conv.load_data = &weights[pd()->with_groups()
126 ? weights_d.blk_off(g, ocb, icb)
127 : weights_d.blk_off(ocb, icb)];
129 par_conv.oc_off = _ocb * jcp.oc_block * sizeof(float);
131 kernel_->jit_ker(&par_conv);
141 if (pd()->wants_zero_pad_dst())
142 output_memory_primitive(0)->zero_pad();
145 void jit_sse42_1x1_convolution_fwd_t::execute_forward_with_dw_conv() const {
146 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
147 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
148 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
149 auto dst = reinterpret_cast<data_t *>(this->memory());
151 const memory_desc_wrapper src_d(pd()->src_pd());
152 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
154 const auto &jcp = kernel_->jcp;
155 const auto &jcp_dw = kernel_dw_->jcp;
158 auto dw_bias = jcp_dw.conv_biases;
160 int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1;
161 const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast;
163 auto step = [](int default_step, int remaining, int tail_step) {
164 assert(default_step <= tail_step);
165 return remaining < tail_step ? remaining : default_step;
168 auto ker = [&](const int ithr, const int nthr) {
169 // TODO (Roma): remove this restriction
170 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
172 auto compute_block_1x1 = [&](float* ws_p, int n, int g, int oh, int ow, int ih, int iw, int os, int os_block, int bcast_step, int ocb, int load_step,
174 auto p = jit_1x1_conv_call_s();
176 for (int h = 0; h < num_rows; h++) {
177 ih = nstl::max((oh + h) * jcp.stride_h - jcp.t_pad, 0);
179 if ((oh + h) < 0 || (oh + h) >= jcp.ih) {
180 for (int chb = ocb; chb < ocb + load_step; chb++) {
181 memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
182 (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
185 const int _ocb = g * jcp.nb_load + ocb;
187 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
188 p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
190 p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
192 p.bias_data = &bias[_ocb * jcp.oc_block];
194 for (int icb = 0; icb < jcp.nb_reduce; icb += jcp.nb_reduce_blocking) {
195 p.first_last_flag = 0
196 | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
197 | (icb + jcp.nb_reduce_blocking >= jcp.nb_reduce
198 ? FLAG_REDUCE_LAST : 0);
200 p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
201 jcp.nb_reduce_blocking * jcp.ic_block);
202 p.load_data = &weights[pd()->with_groups()
203 ? weights_d.blk_off(g, ocb, icb)
204 : weights_d.blk_off(ocb, icb)];
206 const int _icb = g * jcp.nb_reduce + icb;
207 p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw);
209 p.oc_off = _ocb * jcp.oc_block * sizeof(float);
211 kernel_->jit_ker(&p);
217 auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) {
218 for (int chb = ocb; chb < ocb + load_step; chb++) {
219 auto par_conv_dw = jit_conv_call_s();
221 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
222 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
223 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
224 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
225 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
226 (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
228 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
229 dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
231 par_conv_dw.kh_padding = jcp_dw.kh;
232 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
233 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
234 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
235 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
236 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
238 kernel_dw_->jit_ker(&par_conv_dw);
242 assert(jcp.stride_w == 1 && jcp.stride_h == 1);
244 int start{0}, end{0};
245 balance211(work_amount, nthr, ithr, start, end);
247 auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
248 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
249 auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
251 const int os_block = jcp.iw;
254 while (iwork < end) {
255 int n{0}, g{0}, ocbb{0}, osb{0};
256 nd_iterator_init(iwork, n, MB, g, jcp.ngroups, ocbb, ocb_work, osb,
261 const int os = osb * os_block;
262 const int oh = os / jcp.ow;
263 const int ow = os % jcp.ow;
265 const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0);
266 const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0);
268 int ocb = ocbb * jcp.nb_load_blocking;
270 const int load_step = step(jcp.nb_load_blocking,
271 jcp.nb_load - ocb, jcp.nb_load_blocking_max);
273 if (iwork == start || oh == 0) {
274 bcast_step = nstl::min(1, end - iwork);
275 compute_block_1x1(pbuf, n, g, oh - 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step + 2);
277 bcast_step = nstl::min(1, end - iwork);
278 compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step);
281 if ((oh % jcp_dw.stride_h == 0)) {
282 compute_row_dw(pbuf, n, ocb, load_step, oh);
289 if (pd()->wants_padded_bias()) {
290 auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
291 utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
292 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
293 jcp.oc - jcp.oc_without_padding);
296 auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
297 utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
298 utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
299 jcp.oc - jcp.oc_without_padding);
300 dw_bias = dw_padded_bias;
305 if (pd()->wants_zero_pad_dst())
306 output_memory_primitive(0)->zero_pad();