Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_binary_convolution.cpp
1     /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include "mkldnn_types.h"
19
20 #include "c_types_map.hpp"
21 #include "jit_uni_binary_convolution.hpp"
22 #include "utils.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
34
35 template <cpu_isa_t isa>
36 void jit_uni_binary_convolution_fwd_t<isa>::execute_forward() const {
37     auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
38     auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
39     auto dst_u8 = reinterpret_cast<uint8_t*>(this->memory());
40     auto dst_f32 = reinterpret_cast<float*>(this->memory());
41
42     const memory_desc_wrapper src_d(pd()->src_pd());
43     const memory_desc_wrapper dst_d(pd()->dst_pd());
44     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
45
46     const auto &jcp = kernel_->jcp;
47     const int MB = pd()->MB();
48
49     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
50     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
51
52     int nbits = 8;
53
54     auto ker = [&](const int ithr, const int nthr) {
55         size_t start{0}, end{0};
56         balance211(work_amount, nthr, ithr, start, end);
57
58         size_t n{0}, g{0}, ocbb{0}, oh{0};
59         nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
60         for (size_t iwork = start; iwork < end; ++iwork) {
61             int ocb = ocbb * jcp.nb_oc_blocking;
62             int ocb_num = jcp.nb_oc_blocking;
63
64             auto par_conv = jit_conv_call_s();
65
66             const int ij = oh * jcp.stride_h;
67             const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
68             const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
69                                                                                 jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
70
71             const size_t _oc = g * jcp.nb_oc + ocb;
72             const size_t _ic = g * jcp.nb_ic;
73
74             const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
75             par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0) / nbits];
76
77             if (jcp.with_binarization) {
78                 par_conv.dst = &dst_u8[dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0) / nbits];
79             } else {
80                 par_conv.dst = &dst_f32[dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0)];
81             }
82
83             const int wh = jcp.exclude_pad ? i_t_overflow : 0;
84             int widx = weights_d.blk_off(ocb, 0, wh, 0);
85             par_conv.filt = &weights[widx / nbits];
86
87             par_conv.oc_work = nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
88
89             par_conv.kw_padding = 0;
90             const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
91             par_conv.kh_padding = nstl::max(0, kh_padding);
92             par_conv.t_overflow = i_t_overflow;
93             par_conv.b_overflow = i_b_overflow;
94
95             par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
96
97             kernel_->jit_ker(&par_conv);
98
99             nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
100         }
101     };
102
103     parallel(0, ker);
104 }
105
106 template <cpu_isa_t isa>
107 void jit_uni_binary_convolution_fwd_t<isa>::execute_forward_with_dw_conv() const {
108     auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
109     auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
110     auto dst_u8 = reinterpret_cast<uint8_t*>(this->memory());
111     auto dst_f32 = reinterpret_cast<float*>(this->memory());
112
113     const memory_desc_wrapper src_d(pd()->src_pd());
114     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
115
116     const auto &jcp = kernel_->jcp;
117     const auto &jcp_dw_conv = dw_conv_kernel_->jcp;
118     const int MB = pd()->MB();
119
120     auto dw_conv_bias = jcp_dw_conv.conv_biases;
121     auto dw_conv_weights = reinterpret_cast<const float*>(jcp_dw_conv.conv_weights);
122
123     int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
124     const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
125
126     int nbits = 8;
127
128     auto ker = [&](const int ithr, const int nthr) {
129         auto compute_row_generic_conv = [&](float* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
130             for (int h = 0; h < num_rows; h++) {
131                 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
132                     for (int chb = ocb; chb < ocb + ocb_num; chb++) {
133                         memset(ws_p + (((oh + h) + 1) % jcp_dw_conv.kh) * jcp.ow * jcp.oc_block +
134                                (chb - ocb) * jcp_dw_conv.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
135                     }
136                 } else {
137                     auto par_conv = jit_conv_call_s();
138
139                     const int ij = (oh + h) * jcp.stride_h;
140                     const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
141                     const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
142                                                                                         jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
143
144                     const size_t _oc = g * jcp.nb_oc + ocb;
145                     const size_t _ic = g * jcp.nb_ic;
146
147                     const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
148                     par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0) / nbits];
149
150                     par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw_conv.kh) * jcp.ow * jcp.oc_block];
151
152                     const int wh = jcp.exclude_pad ? i_t_overflow : 0;
153                     int widx = weights_d.blk_off(ocb, 0, wh, 0);
154                     par_conv.filt = &weights[widx / nbits];
155
156                     par_conv.oc_work = nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
157
158                     par_conv.kw_padding = 0;
159                     const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
160                     par_conv.kh_padding = nstl::max(0, kh_padding);
161                     par_conv.t_overflow = i_t_overflow;
162                     par_conv.b_overflow = i_b_overflow;
163
164                     par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
165
166                     kernel_->jit_ker(&par_conv);
167                 }
168             }
169         };
170
171         auto compute_row_dw_conv = [&](const float* ws_p, int n, int ocb, int ocb_num, int dst_idx) {
172             for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
173                 auto par_conv_dw = jit_conv_call_s();
174
175                 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw_conv.kh) * jcp_dw_conv.iw * jcp_dw_conv.ch_block +
176                                              (chb - ocb) * jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block];
177                 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw_conv.kh) * jcp_dw_conv.iw * jcp_dw_conv.ch_block +
178                                              (chb - ocb) * jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block];
179                 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw_conv.kh) * jcp_dw_conv.iw * jcp_dw_conv.ch_block +
180                                              (chb - ocb) * jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block];
181
182                 if (jcp_dw_conv.with_binarization) {
183                     int nbits = 8;
184
185                     int didx = n*jcp_dw_conv.oc*jcp_dw_conv.oh*jcp_dw_conv.ow +
186                                dst_idx/jcp_dw_conv.stride_h*jcp_dw_conv.ow*jcp_dw_conv.oc + chb*jcp_dw_conv.ch_block;
187                     par_conv_dw.dst = &dst_u8[didx / nbits];
188                 } else {
189                     par_conv_dw.dst = &dst_f32[n*jcp_dw_conv.oc*jcp_dw_conv.oh*jcp_dw_conv.ow +
190                                                dst_idx/jcp_dw_conv.stride_h*jcp_dw_conv.ow*jcp_dw_conv.oc + chb*jcp_dw_conv.ch_block];
191                 }
192
193                 par_conv_dw.kh_padding = jcp_dw_conv.kh;
194                 par_conv_dw.filt = &dw_conv_weights[chb * jcp_dw_conv.kh * jcp_dw_conv.kw * jcp_dw_conv.ch_block];
195                 par_conv_dw.bias = &dw_conv_bias[chb * jcp_dw_conv.ch_block];
196                 par_conv_dw.ur_w = (size_t)(jcp_dw_conv.ow);
197                 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw_conv.ch_block, jcp_dw_conv.oc) - chb*jcp_dw_conv.ch_block;
198                 par_conv_dw.oc_off = chb * jcp_dw_conv.ch_block * sizeof(float);
199
200                 dw_conv_kernel_->jit_ker(&par_conv_dw);
201             }
202         };
203
204         size_t start{0}, end{0};
205         balance211(work_amount, nthr, ithr, start, end);
206         auto dw_conv_buffer_ = scratchpad().template get<float>(key_dw_conv_buffer);
207         size_t dw_conv_buffer_size_ = (size_t)jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block * jcp.nb_oc_blocking;
208         auto pbuf = dw_conv_buffer_ + ithr * dw_conv_buffer_size_;
209
210         size_t n{0}, g{0}, ocbb{0}, oh{0};
211         nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
212         for (size_t iwork = start; iwork < end; ++iwork) {
213             int ocb = ocbb * jcp.nb_oc_blocking;
214             int ocb_num = jcp.nb_oc_blocking;
215
216             if (iwork == start || oh == 0) {
217                 compute_row_generic_conv(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
218             } else {
219                 compute_row_generic_conv(pbuf, n, g, ocb, ocb_num, oh, 1);
220             }
221
222             if (iwork > start && ((oh - 1) % jcp_dw_conv.stride_h == 0) && oh > 0) {
223                 compute_row_dw_conv(pbuf, n, ocb, ocb_num, oh - 1);
224             }
225
226             if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw_conv.stride_h == 0)) {
227                 compute_row_generic_conv(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
228                 compute_row_dw_conv(pbuf, n, ocb, ocb_num, oh);
229             }
230
231             nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
232         }
233     };
234
235     if (jcp.oc != jcp.oc_padded) {
236         auto dw_conv_padded_bias = scratchpad().template get<float>(key_dw_conv_padded_bias);
237         utils::array_copy(dw_conv_padded_bias, dw_conv_bias, jcp.oc);
238         utils::array_set(dw_conv_padded_bias + jcp.oc, 0.f, jcp.oc_padded - jcp.oc);
239         dw_conv_bias = dw_conv_padded_bias;
240     }
241
242     parallel(0, ker);
243 }
244
245 template struct jit_uni_binary_convolution_fwd_t<avx512_common>;
246 template struct jit_uni_binary_convolution_fwd_t<avx2>;
247 template struct jit_uni_binary_convolution_fwd_t<sse42>;
248
249 }
250 }
251 }