Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_1x1_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include "mkldnn_types.h"
19
20 #include "c_types_map.hpp"
21 #include "jit_sse42_1x1_convolution.hpp"
22 #include "utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 #define data_blk_off(f, n, c, h, w) \
31     ((ndims == 3) \
32     ? (f).blk_off(n, c, w) \
33     : (f).blk_off(n, c, h, w))
34
35 using namespace mkldnn::impl::status;
36 using namespace mkldnn::impl::memory_format;
37 using namespace mkldnn::impl::memory_tracking::names;
38 using namespace mkldnn::impl::utils;
39
40 void jit_sse42_1x1_convolution_fwd_t::execute_forward() const {
41     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
42     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
43     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
44     auto dst = reinterpret_cast<data_t *>(this->memory());
45
46     const memory_desc_wrapper src_d(pd()->src_pd());
47     const memory_desc_wrapper dst_d(pd()->dst_pd());
48     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
49
50     const int ndims = src_d.ndims();
51     const auto &jcp = kernel_->jcp;
52     int MB = pd()->MB();
53
54     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
55
56     if (pd()->wants_padded_bias()) {
57         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
58         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
59         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
60                 jcp.oc - jcp.oc_without_padding);
61         bias = padded_bias;
62     }
63
64     parallel(0, [&](const int ithr, const int nthr) {
65         // TODO (Roma): remove this restriction
66         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
67
68         auto par_conv = jit_1x1_conv_call_s();
69
70         const int nb_oc = jcp.nb_load;
71         const int nb_ic = jcp.nb_reduce;
72         const int nb_ic_blocking = jcp.nb_reduce_blocking;
73         const int os_block = jcp.bcast_block;
74
75         int start{0}, end{0};
76         balance211(work_amount, nthr, ithr, start, end);
77
78         int iwork = start;
79         while (iwork < end) {
80             int n{0}, g{0}, osb{0};
81             nd_iterator_init(iwork, n, MB, g, jcp.ngroups, osb,
82                     jcp.nb_bcast);
83
84             const int bcast_step_rem = jcp.nb_bcast - osb;
85             int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max
86                 ? bcast_step_rem : jcp.nb_bcast_blocking;
87             bcast_step = nstl::min<int>(bcast_step, end - iwork);
88
89             const int os = osb * os_block;
90             const int ow = os % jcp.ow;
91             const int oh = os / jcp.ow;
92             const int iw = nstl::max<int>(ow * jcp.stride_w - jcp.l_pad, 0);
93             const int ih = nstl::max<int>(oh * jcp.stride_h - jcp.t_pad, 0);
94
95             par_conv.bcast_dim = this_block_size(os, jcp.os,
96                     bcast_step * os_block);
97
98             int ocb = 0;
99             while (ocb < jcp.nb_load) {
100                 const int load_step_rem = jcp.nb_load - ocb;
101                 const int load_step = load_step_rem < jcp.nb_load_blocking_max
102                     ? load_step_rem : jcp.nb_load_blocking;
103
104                 const size_t _ocb = g * nb_oc + ocb;
105                 par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
106                         load_step * jcp.oc_block);
107
108                 const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
109                 par_conv.output_data = &dst[dst_off];
110
111                 par_conv.bias_data = &bias[_ocb * jcp.oc_block];
112
113                 for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
114                     par_conv.first_last_flag = 0
115                         | (icb == 0) * FLAG_REDUCE_FIRST
116                         | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST;
117
118                     par_conv.reduce_dim = this_block_size(icb * jcp.ic_block,
119                             jcp.ic, nb_ic_blocking * jcp.ic_block);
120
121                     const size_t _icb = g * nb_ic + icb;
122                     const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw);
123                     par_conv.bcast_data = &src[src_off];
124
125                     par_conv.load_data = &weights[pd()->with_groups()
126                         ? weights_d.blk_off(g, ocb, icb)
127                         : weights_d.blk_off(ocb, icb)];
128
129                     par_conv.oc_off = _ocb * jcp.oc_block * sizeof(float);
130
131                     kernel_->jit_ker(&par_conv);
132                 }
133
134                 ocb += load_step;
135             }
136
137             iwork += bcast_step;
138         }
139     });
140
141     if (pd()->wants_zero_pad_dst())
142         output_memory_primitive(0)->zero_pad();
143 }
144
145 void jit_sse42_1x1_convolution_fwd_t::execute_forward_with_dw_conv() const {
146     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
147     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
148     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
149     auto dst = reinterpret_cast<data_t *>(this->memory());
150
151     const memory_desc_wrapper src_d(pd()->src_pd());
152     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
153
154     const auto &jcp = kernel_->jcp;
155     const auto &jcp_dw = kernel_dw_->jcp;
156     int MB = pd()->MB();
157
158     auto dw_bias = jcp_dw.conv_biases;
159
160     int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1;
161     const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast;
162
163     auto step = [](int default_step, int remaining, int tail_step) {
164         assert(default_step <= tail_step);
165         return remaining < tail_step ? remaining : default_step;
166     };
167
168     auto ker = [&](const int ithr, const int nthr) {
169         // TODO (Roma): remove this restriction
170         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
171
172         auto compute_block_1x1 = [&](float* ws_p, int n, int g, int oh, int ow, int ih, int iw, int os, int os_block, int bcast_step, int ocb, int load_step,
173                                     int num_rows) {
174             auto p = jit_1x1_conv_call_s();
175
176             for (int h = 0; h < num_rows; h++) {
177                 ih = nstl::max((oh + h) * jcp.stride_h - jcp.t_pad, 0);
178
179                 if ((oh + h) < 0 || (oh + h) >= jcp.ih) {
180                     for (int chb = ocb; chb < ocb + load_step; chb++) {
181                         memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
182                                (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
183                     }
184                 } else {
185                     const int _ocb = g * jcp.nb_load + ocb;
186
187                     p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
188                     p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
189
190                     p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
191
192                     p.bias_data = &bias[_ocb * jcp.oc_block];
193
194                     for (int icb = 0; icb < jcp.nb_reduce; icb += jcp.nb_reduce_blocking) {
195                         p.first_last_flag = 0
196                                             | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
197                                             | (icb + jcp.nb_reduce_blocking >= jcp.nb_reduce
198                                                ? FLAG_REDUCE_LAST : 0);
199
200                         p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
201                                                        jcp.nb_reduce_blocking * jcp.ic_block);
202                         p.load_data = &weights[pd()->with_groups()
203                                                ? weights_d.blk_off(g, ocb, icb)
204                                                : weights_d.blk_off(ocb, icb)];
205
206                         const int _icb = g * jcp.nb_reduce + icb;
207                         p.bcast_data = src + src_d.blk_off(n, _icb, ih, iw);
208
209                         p.oc_off = _ocb * jcp.oc_block * sizeof(float);
210
211                         kernel_->jit_ker(&p);
212                     }
213                 }
214             }
215         };
216
217         auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) {
218             for (int chb = ocb; chb < ocb + load_step; chb++) {
219                 auto par_conv_dw = jit_conv_call_s();
220
221                 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
222                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
223                 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
224                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
225                 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw.kh) * jcp_dw.iw * jcp_dw.ch_block +
226                                              (chb - ocb) * jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block];
227
228                 par_conv_dw.dst = &dst[n*jcp_dw.oc*jcp_dw.oh*jcp_dw.ow + chb*jcp_dw.ch_block*jcp_dw.oh*jcp_dw.ow +
229                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
230
231                 par_conv_dw.kh_padding = jcp_dw.kh;
232                 par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
233                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
234                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
235                 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
236                 par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
237
238                 kernel_dw_->jit_ker(&par_conv_dw);
239             }
240         };
241
242         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
243
244         int start{0}, end{0};
245         balance211(work_amount, nthr, ithr, start, end);
246
247         auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
248         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
249         auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
250
251         const int os_block = jcp.iw;
252
253         int iwork = start;
254         while (iwork < end) {
255             int n{0}, g{0}, ocbb{0}, osb{0};
256             nd_iterator_init(iwork, n, MB, g, jcp.ngroups, ocbb, ocb_work, osb,
257                              jcp.nb_bcast);
258
259             int bcast_step = 1;
260
261             const int os = osb * os_block;
262             const int oh = os / jcp.ow;
263             const int ow = os % jcp.ow;
264
265             const int ih = nstl::max(oh * jcp.stride_h - jcp.t_pad, 0);
266             const int iw = nstl::max(ow * jcp.stride_w - jcp.l_pad, 0);
267
268             int ocb = ocbb * jcp.nb_load_blocking;
269
270             const int load_step = step(jcp.nb_load_blocking,
271                                        jcp.nb_load - ocb, jcp.nb_load_blocking_max);
272
273             if (iwork == start || oh == 0) {
274                 bcast_step = nstl::min(1, end - iwork);
275                 compute_block_1x1(pbuf, n, g, oh - 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step + 2);
276             } else {
277                 bcast_step = nstl::min(1, end - iwork);
278                 compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step);
279             }
280
281             if ((oh % jcp_dw.stride_h == 0)) {
282                 compute_row_dw(pbuf, n, ocb, load_step, oh);
283             }
284
285             iwork += bcast_step;
286         }
287     };
288
289     if (pd()->wants_padded_bias()) {
290         auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
291         utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
292         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
293                 jcp.oc - jcp.oc_without_padding);
294         bias = padded_bias;
295
296         auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
297         utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
298         utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
299                          jcp.oc - jcp.oc_without_padding);
300         dw_bias = dw_padded_bias;
301     }
302
303     parallel(0, ker);
304
305     if (pd()->wants_zero_pad_dst())
306         output_memory_primitive(0)->zero_pad();
307 }
308
309 }
310 }
311 }