1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
18 #include "mkldnn_types.h"
20 #include "c_types_map.hpp"
21 #include "jit_uni_binary_convolution.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
35 template <cpu_isa_t isa>
36 void jit_uni_binary_convolution_fwd_t<isa>::execute_forward() const {
37 auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
38 auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
39 auto dst_u8 = reinterpret_cast<uint8_t*>(this->memory());
40 auto dst_f32 = reinterpret_cast<float*>(this->memory());
42 const memory_desc_wrapper src_d(pd()->src_pd());
43 const memory_desc_wrapper dst_d(pd()->dst_pd());
44 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
46 const auto &jcp = kernel_->jcp;
47 const int MB = pd()->MB();
49 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
50 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
54 auto ker = [&](const int ithr, const int nthr) {
55 size_t start{0}, end{0};
56 balance211(work_amount, nthr, ithr, start, end);
58 size_t n{0}, g{0}, ocbb{0}, oh{0};
59 nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
60 for (size_t iwork = start; iwork < end; ++iwork) {
61 int ocb = ocbb * jcp.nb_oc_blocking;
62 int ocb_num = jcp.nb_oc_blocking;
64 auto par_conv = jit_conv_call_s();
66 const int ij = oh * jcp.stride_h;
67 const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
68 const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
69 jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
71 const size_t _oc = g * jcp.nb_oc + ocb;
72 const size_t _ic = g * jcp.nb_ic;
74 const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
75 par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0) / nbits];
77 if (jcp.with_binarization) {
78 par_conv.dst = &dst_u8[dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0) / nbits];
80 par_conv.dst = &dst_f32[dst_d.blk_off(n, _oc*jcp.oc_block, oh, 0)];
83 const int wh = jcp.exclude_pad ? i_t_overflow : 0;
84 int widx = weights_d.blk_off(ocb, 0, wh, 0);
85 par_conv.filt = &weights[widx / nbits];
87 par_conv.oc_work = nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
89 par_conv.kw_padding = 0;
90 const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
91 par_conv.kh_padding = nstl::max(0, kh_padding);
92 par_conv.t_overflow = i_t_overflow;
93 par_conv.b_overflow = i_b_overflow;
95 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
97 kernel_->jit_ker(&par_conv);
99 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
106 template <cpu_isa_t isa>
107 void jit_uni_binary_convolution_fwd_t<isa>::execute_forward_with_dw_conv() const {
108 auto src = reinterpret_cast<const uint8_t*>(this->input_memory(0));
109 auto weights = reinterpret_cast<const uint8_t*>(this->input_memory(1));
110 auto dst_u8 = reinterpret_cast<uint8_t*>(this->memory());
111 auto dst_f32 = reinterpret_cast<float*>(this->memory());
113 const memory_desc_wrapper src_d(pd()->src_pd());
114 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
116 const auto &jcp = kernel_->jcp;
117 const auto &jcp_dw_conv = dw_conv_kernel_->jcp;
118 const int MB = pd()->MB();
120 auto dw_conv_bias = jcp_dw_conv.conv_biases;
121 auto dw_conv_weights = reinterpret_cast<const float*>(jcp_dw_conv.conv_weights);
123 int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
124 const size_t work_amount = MB * jcp.ngroups * ocb_work * jcp.oh;
128 auto ker = [&](const int ithr, const int nthr) {
129 auto compute_row_generic_conv = [&](float* ws_p, int n, int g, int ocb, int ocb_num, int oh, int num_rows) {
130 for (int h = 0; h < num_rows; h++) {
131 if ((oh + h) < 0 || (oh + h) >= jcp.oh) {
132 for (int chb = ocb; chb < ocb + ocb_num; chb++) {
133 memset(ws_p + (((oh + h) + 1) % jcp_dw_conv.kh) * jcp.ow * jcp.oc_block +
134 (chb - ocb) * jcp_dw_conv.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
137 auto par_conv = jit_conv_call_s();
139 const int ij = (oh + h) * jcp.stride_h;
140 const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1)));
141 const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) -
142 jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1)));
144 const size_t _oc = g * jcp.nb_oc + ocb;
145 const size_t _ic = g * jcp.nb_ic;
147 const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0);
148 par_conv.src = &src[src_d.blk_off(n, _ic*jcp.ic_block, ih, 0) / nbits];
150 par_conv.dst = &ws_p[(((oh + h) + 1) % jcp_dw_conv.kh) * jcp.ow * jcp.oc_block];
152 const int wh = jcp.exclude_pad ? i_t_overflow : 0;
153 int widx = weights_d.blk_off(ocb, 0, wh, 0);
154 par_conv.filt = &weights[widx / nbits];
156 par_conv.oc_work = nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block;
158 par_conv.kw_padding = 0;
159 const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow;
160 par_conv.kh_padding = nstl::max(0, kh_padding);
161 par_conv.t_overflow = i_t_overflow;
162 par_conv.b_overflow = i_b_overflow;
164 par_conv.oc_off = _oc * jcp.oc_block * sizeof(float);
166 kernel_->jit_ker(&par_conv);
171 auto compute_row_dw_conv = [&](const float* ws_p, int n, int ocb, int ocb_num, int dst_idx) {
172 for (int chb = ocb; chb < nstl::min(ocb + ocb_num, jcp.nb_oc); chb++) {
173 auto par_conv_dw = jit_conv_call_s();
175 par_conv_dw.src_row0 = &ws_p[(((dst_idx+1) - 1) % jcp_dw_conv.kh) * jcp_dw_conv.iw * jcp_dw_conv.ch_block +
176 (chb - ocb) * jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block];
177 par_conv_dw.src_row1 = &ws_p[(((dst_idx+1) - 0) % jcp_dw_conv.kh) * jcp_dw_conv.iw * jcp_dw_conv.ch_block +
178 (chb - ocb) * jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block];
179 par_conv_dw.src_row2 = &ws_p[(((dst_idx+1) + 1) % jcp_dw_conv.kh) * jcp_dw_conv.iw * jcp_dw_conv.ch_block +
180 (chb - ocb) * jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block];
182 if (jcp_dw_conv.with_binarization) {
185 int didx = n*jcp_dw_conv.oc*jcp_dw_conv.oh*jcp_dw_conv.ow +
186 dst_idx/jcp_dw_conv.stride_h*jcp_dw_conv.ow*jcp_dw_conv.oc + chb*jcp_dw_conv.ch_block;
187 par_conv_dw.dst = &dst_u8[didx / nbits];
189 par_conv_dw.dst = &dst_f32[n*jcp_dw_conv.oc*jcp_dw_conv.oh*jcp_dw_conv.ow +
190 dst_idx/jcp_dw_conv.stride_h*jcp_dw_conv.ow*jcp_dw_conv.oc + chb*jcp_dw_conv.ch_block];
193 par_conv_dw.kh_padding = jcp_dw_conv.kh;
194 par_conv_dw.filt = &dw_conv_weights[chb * jcp_dw_conv.kh * jcp_dw_conv.kw * jcp_dw_conv.ch_block];
195 par_conv_dw.bias = &dw_conv_bias[chb * jcp_dw_conv.ch_block];
196 par_conv_dw.ur_w = (size_t)(jcp_dw_conv.ow);
197 par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw_conv.ch_block, jcp_dw_conv.oc) - chb*jcp_dw_conv.ch_block;
198 par_conv_dw.oc_off = chb * jcp_dw_conv.ch_block * sizeof(float);
200 dw_conv_kernel_->jit_ker(&par_conv_dw);
204 size_t start{0}, end{0};
205 balance211(work_amount, nthr, ithr, start, end);
206 auto dw_conv_buffer_ = scratchpad().template get<float>(key_dw_conv_buffer);
207 size_t dw_conv_buffer_size_ = (size_t)jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block * jcp.nb_oc_blocking;
208 auto pbuf = dw_conv_buffer_ + ithr * dw_conv_buffer_size_;
210 size_t n{0}, g{0}, ocbb{0}, oh{0};
211 nd_iterator_init(start, n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
212 for (size_t iwork = start; iwork < end; ++iwork) {
213 int ocb = ocbb * jcp.nb_oc_blocking;
214 int ocb_num = jcp.nb_oc_blocking;
216 if (iwork == start || oh == 0) {
217 compute_row_generic_conv(pbuf, n, g, ocb, ocb_num, oh - 1, 2);
219 compute_row_generic_conv(pbuf, n, g, ocb, ocb_num, oh, 1);
222 if (iwork > start && ((oh - 1) % jcp_dw_conv.stride_h == 0) && oh > 0) {
223 compute_row_dw_conv(pbuf, n, ocb, ocb_num, oh - 1);
226 if ((iwork == end - 1 || (int) oh == jcp.oh - 1) && ((oh) % jcp_dw_conv.stride_h == 0)) {
227 compute_row_generic_conv(pbuf, n, g, ocb, ocb_num, oh + 1, 1);
228 compute_row_dw_conv(pbuf, n, ocb, ocb_num, oh);
231 nd_iterator_step(n, MB, g, jcp.ngroups, ocbb, ocb_work, oh, jcp.oh);
235 if (jcp.oc != jcp.oc_padded) {
236 auto dw_conv_padded_bias = scratchpad().template get<float>(key_dw_conv_padded_bias);
237 utils::array_copy(dw_conv_padded_bias, dw_conv_bias, jcp.oc);
238 utils::array_set(dw_conv_padded_bias + jcp.oc, 0.f, jcp.oc_padded - jcp.oc);
239 dw_conv_bias = dw_conv_padded_bias;
245 template struct jit_uni_binary_convolution_fwd_t<avx512_common>;
246 template struct jit_uni_binary_convolution_fwd_t<avx2>;
247 template struct jit_uni_binary_convolution_fwd_t<sse42>;