1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
20 #include "gemm_convolution.hpp"
22 #include "type_helpers.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "ref_eltwise.hpp"
30 using namespace mkldnn::impl::status;
31 using namespace mkldnn::impl::memory_format;
32 using namespace mkldnn::impl::memory_tracking::names;
33 using namespace mkldnn::impl::utils;
35 void gemm_convolution_fwd_t::execute_forward() const {
36 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
37 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
38 auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
39 auto dst = reinterpret_cast<data_t*>(this->memory());
41 auto col = scratchpad().get<data_t>(key_conv_gemm_col);
43 const auto &jcp = this->pd()->jcp_;
44 const int MB = pd()->MB();
46 const memory_desc_wrapper src_d(pd()->src_pd());
47 const memory_desc_wrapper dst_d(pd()->dst_pd());
49 const int M = jcp.os * jcp.od;
50 const size_t src_step = (src_d.blk_off(1) - src_d.off_l(0)) / jcp.ngroups;
51 const size_t dst_step = (dst_d.blk_off(1) - dst_d.off_l(0)) / jcp.ngroups;
52 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
53 src += src_d.off_l(0);
54 dst += dst_d.off_l(0);
57 jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
58 assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
60 const int K = jcp.ic * jcp.ks;
63 if (jcp.im2col_sz && jcp.id != 1)
64 parallel_nd(jcp.im2col_sz * jcp.nthr,
65 [&](ptrdiff_t i) { col[i] = (data_t)0; });
67 const int nb_oh = div_up(jcp.oh, jcp.oh_block);
68 const int nb_ow = div_up(jcp.ow, jcp.ow_block);
69 const size_t work_amount = jcp.ngroups * MB * jcp.od * nb_oh * nb_ow;
70 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
71 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
73 int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
74 size_t start = 0, end = 0;
76 balance211(work_amount, nthr, ithr, start, end);
77 nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od, ohb,
79 for (size_t iwork = start; iwork < end; ++iwork) {
80 int oh = ohb * jcp.oh_block;
81 int ow = owb * jcp.ow_block;
82 const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
83 const data_t *_weights = weights + g * weights_g_size;
84 data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
85 const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
86 const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
89 jit_gemm_convolution_utils::im2col(
90 jcp, _src, _col, oh, h_step, ow, w_step);
92 jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
95 const data_t one = 1.0;
97 const int m = h_step * w_step;
98 const int LDA = jcp.im2col_sz ? m : M;
99 data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
101 extended_sgemm("N", "N", &m, &N, &K, &one,
102 jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
103 &this->beta_, _dst, &M);
106 const auto &p = pd()->attr()->post_ops_;
107 bool need_bias = jcp.with_bias;
109 parallel_nd(jcp.oc, [&](const int oc) {
110 data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
111 data_t *d_ = d + oc * M;
113 for (int oS = 0; oS < m; ++oS) {
115 if (d_[oS] < 0) d_[oS] *= fast_relu_ns;
120 } else if (p.len_ > 0) {
121 int eltwise_inj_idx = 0;
122 int depthwise_inj_idx = 0;
124 for (int i = 0; i < p.len_; i++) {
125 auto& post_op = p.entry_[i];
126 if (post_op.is_eltwise()) {
127 parallel_nd(jcp.oc, [&](const int oc) {
128 data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
129 data_t *d_ = d + oc * M;
131 for (int oS = 0; oS < m; ++oS) {
133 d_[oS] = eltwise_injectors[eltwise_inj_idx]->compute_scalar(d_[oS]);
139 } else if (post_op.is_depthwise()) {
140 auto depthwise_weights = post_op.depthwise.weights_data;
141 auto depthwise_bias = post_op.depthwise.biases_data;
143 parallel_nd(jcp.oc, [&](const int oc) {
144 data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
145 data_t *d_ = d + oc * M;
147 for (int oS = 0; oS < m; ++oS) {
149 d_[oS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[oS],
150 depthwise_weights + g * jcp.oc + oc,
151 depthwise_bias + g * jcp.oc + oc);
162 parallel_nd(jcp.oc, [&](const int oc) {
163 data_t b = bias[g * jcp.oc + oc];
164 data_t *d_ = d + oc * M;
166 for (int oS = 0; oS < m; ++oS) {
172 nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od, ohb, nb_oh,
178 void gemm_convolution_bwd_data_t::execute_backward_data() const {
179 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
180 auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
181 auto diff_src = reinterpret_cast<data_t*>(this->memory());
183 auto col = scratchpad().get<data_t>(key_conv_gemm_col);
185 const auto &jcp = this->pd()->jcp_;
186 const int MB = pd()->MB();
188 const int M = jcp.os * jcp.od;
189 const size_t src_step_to_clean = jcp.ic * jcp.ih * jcp.iw * jcp.id;
190 const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
191 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
192 const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups;
193 const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups;
194 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
196 const int m = jcp.os;
197 const int K = jcp.oc;
198 const int N = jcp.ic * jcp.ks;
199 const int LDC = jcp.im2col_sz ? m : M;
201 const size_t work_amount = (size_t)jcp.ngroups * MB;
204 for (size_t j = 0; j < work_amount; j++) {
205 int j_step = src_step * j;
206 const ptrdiff_t diff_src_sz = (ptrdiff_t)(src_step_to_clean);
207 parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[j_step + i] = (data_t)0; });
211 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
212 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
215 size_t start = 0, end = 0;
216 balance211(work_amount, nthr, ithr, start, end);
217 nd_iterator_init(start, g, jcp.ngroups, n, MB);
218 for (size_t iwork = start; iwork < end; ++iwork) {
220 data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step;
221 const data_t *_weights = weights + g * weights_g_size;
222 for (int od = 0; od < jcp.od; ++od) {
223 const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
226 const data_t zero = 0.0, one = 1.0;
227 extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M,
229 jcp.im2col_sz ? _col:_diff_src + od * m, &LDC);
233 jit_gemm_convolution_utils::col2im(jcp, _col,
236 jit_gemm_convolution_utils::col2im_3d(jcp, _col,
240 nd_iterator_step(g, jcp.ngroups, n, MB);
245 void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
246 auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
247 auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
248 auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
249 auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
251 auto col = scratchpad().get<data_t>(key_conv_gemm_col);
252 auto wei_reduction = scratchpad().get<data_t>(key_conv_wei_reduction);
254 const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
256 const int K = jcp.os * jcp.od;
257 const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
258 const size_t dst_step = jcp.oc * K;
259 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
261 const int k = jcp.os;
262 const int N = jcp.oc;
263 const int M = jcp.ic * jcp.ks;
264 const int LDA = jcp.im2col_sz ? k : K;
266 parallel_nd(jcp.im2col_sz * jcp.nthr,
267 [&](ptrdiff_t i) { col[i] = (data_t)0; });
269 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
270 int ithr_g, nthr_g, ithr_mb, nthr_mb;
271 size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0};
273 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
274 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
275 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
277 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
278 const int need_reduction = nthr_mb != 1;
280 if (ithr_g != -1 && ithr_mb != -1) {
281 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
282 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
284 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
286 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
287 data_t *weights_reduce_base = wei_reduction
288 + ithr_g * nthr_mb * weights_g_size;
289 data_t *weights_reduce = weights_reduce_base
290 + ithr_mb * weights_g_size;
292 for (size_t g = g_start; g < g_end; ++g) {
293 data_t *_diff_weights = need_reduction
294 ? weights_reduce : (diff_weights + g * weights_g_size);
295 for (size_t mb = mb_start; mb < mb_end; ++mb) {
296 const data_t *_src = src + (mb*jcp.ngroups+g)*src_step;
297 for (int od = 0; od < jcp.od; ++od) {
298 const data_t *_diff_dst = diff_dst
299 + (mb*jcp.ngroups+g)*dst_step + od * k;
303 jit_gemm_convolution_utils::im2col(
304 jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
306 jit_gemm_convolution_utils::im2col_3d(jcp, _src,
310 const data_t zero = 0.0, one = 1.0;
312 "T", "N", &M, &N, &k, &one,
313 jcp.im2col_sz ? _col : _src + od * k,
315 mb == mb_start && od == 0 ? &zero : &one,
320 if (need_reduction) {
321 mkldnn_thr_barrier();
322 data_t *weights_base = diff_weights + g_start * weights_g_size;
323 jit_gemm_convolution_utils::bwd_weights_reduction_par(
324 ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base);
327 if (need_reduction) { mkldnn_thr_barrier(); }
331 parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
333 size_t offset_ = (size_t)g * dst_step + (size_t)oc * K;
334 for (int mb = 0; mb < jcp.mb; ++mb)
336 size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step;
337 for (int od = 0; od < jcp.od; ++od)
338 for (int oh = 0; oh < jcp.oh; ++oh)
339 PRAGMA_OMP_SIMD(reduction(+:db))
340 for (int ow = 0; ow < jcp.ow; ++ow) {
341 db += diff_dst[offset];
345 diff_bias[g*jcp.oc+oc] = db;