1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "src/common/mkldnn_thread.hpp"
18 #include "src/common/math_utils.hpp"
20 #include "conv/conv_common.hpp"
24 void compute_ref_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
25 dnn_mem_t &bia_m, dnn_mem_t &dst_m) {
26 if (p->alg == WINO && p->cfg[SRC].dt == mkldnn_f32) {
27 compute_wino_ref_fwd(p, src_m, wei_m, bia_m, dst_m);
29 compute_ref_direct_fwd(p, src_m, wei_m, bia_m, dst_m);
33 void compute_ref_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m, dnn_mem_t &wei_m,
34 dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m) {
35 if (p->alg == WINO && p->cfg[SRC].dt == mkldnn_f32) {
36 compute_wino_ref_bwd_d(p, diff_src_m, wei_m, bia_m, diff_dst_m);
38 compute_ref_direct_bwd_d(p, diff_src_m, wei_m, bia_m, diff_dst_m);
42 void compute_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &diff_wei_m,
43 dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m) {
44 if (p->alg == WINO && p->cfg[SRC].dt == mkldnn_f32) {
45 compute_wino_ref_bwd_w(p, src_m, diff_wei_m, diff_bia_m, diff_dst_m);
47 compute_ref_direct_bwd_w(p, src_m, diff_wei_m, diff_bia_m, diff_dst_m);
51 void compute_ref_direct_fwd(const prb_t *p, dnn_mem_t &src_m,
52 dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &dst_m) {
53 auto ker = [&](float &d, int g, int mb, int oc, int od, int oh, int ow) {
54 for (int ic = 0; ic < p->ic/p->g; ++ic) {
55 for (int kd = 0; kd < p->kd; ++kd) {
56 const int id = od * p->sd - p->pd + kd * (p->dd + 1);
57 if (id < 0 || id >= p->id) continue;
58 for (int kh = 0; kh < p->kh; ++kh) {
59 const int ih = oh * p->sh - p->ph + kh * (p->dh + 1);
60 if (ih < 0 || ih >= p->ih) continue;
62 for (int kw = 0; kw < p->kw; ++kw) {
63 const int iw = ow * p->sw - p->pw + kw * (p->dw + 1);
64 if (iw < 0 || iw >= p->iw) continue;
66 size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
67 size_t wei_off = wei_off_f(p, g, oc, ic, kd, kh, kw);
68 d += ((float*)src_m)[src_off]
69 * ((float*)wei_m)[wei_off];
76 auto maybe_scale = [&](float &d, int oc) {
77 if (!p->attr.oscale.is_def()) {
78 using policy_t = attr_t::scale_t::policy_t;
79 const auto &s = p->attr.oscale;
80 if (s.policy == policy_t::COMMON) {
88 auto maybe_post_ops = [&](float &conv_res, float dst) {
89 using namespace mkldnn::impl::math;
91 const auto &ops = p->attr.post_ops;
92 for (int idx = 0; idx < ops.len; ++idx) {
93 using pk = attr_t::post_ops_t::kind_t;
94 const auto &e = ops.entry[idx];
96 const auto &s = e.eltwise.scale;
97 const auto &a = e.eltwise.alpha;
98 const auto &b = e.eltwise.beta;
101 case pk::SUM: conv_res += e.sum.scale * dst; break;
102 case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
103 case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
104 case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
105 case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
106 case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
107 case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
108 case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
109 case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
110 case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
111 case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
113 assert(!"unknown attr::post_ops::kind");
118 mkldnn::impl::parallel_nd(p->g, p->mb, p->oc / p->g, p->od, p->oh, p->ow,
119 [&](int g, int mb, int oc, int od, int oh, int ow) {
120 const size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
121 float &dst = ((float*)dst_m)[dst_off];
124 ker(conv_res, g, mb, oc, od, oh, ow);
126 if (p->dir & FLAG_BIA) {
127 const size_t bia_off = bia_off_f(p, g, oc);
128 conv_res += ((float*)bia_m)[bia_off];
131 maybe_scale(conv_res, g * p->oc / p->g + oc);
132 maybe_post_ops(conv_res, dst);
139 void compute_ref_direct_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m,
140 dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m) {
141 enum { precompute_size = 16 };
142 const bool fast = MAX2(p->kh, p->kw) <= precompute_size;
144 /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */
145 auto precompute_ok = [](int i, int O, int K, int S, int P, int D,
146 int &num, int *_o, int *_k) {
147 assert(K <= precompute_size);
149 for (int k = 0; k < K; ++k) {
150 int o = i - k * (D + 1) + P;
151 if (o < 0 || o % S) continue;
153 if (o >= O) continue;
160 auto ker_fast = [&](float &ds, int g, int mb, int ic, int id, int ih, int iw) {
161 int kd[precompute_size], od[precompute_size], num_d;
162 int kh[precompute_size], oh[precompute_size], num_h;
163 int kw[precompute_size], ow[precompute_size], num_w;
164 precompute_ok(id, p->od, p->kd, p->sd, p->pd, p->dd, num_d, od, kd);
165 precompute_ok(ih, p->oh, p->kh, p->sh, p->ph, p->dh, num_h, oh, kh);
166 precompute_ok(iw, p->ow, p->kw, p->sw, p->pw, p->dw, num_w, ow, kw);
168 for (int oc = 0; oc < p->oc/p->g; ++oc) {
169 for (int d = 0; d < num_d; ++d) {
170 for (int h = 0; h < num_h; ++h) {
171 for (int w = 0; w < num_w; ++w) {
173 size_t dst_off = dst_off_f(p, mb, g, oc, od[d], oh[h], ow[w]);
174 size_t wei_off = wei_off_f(p, g, oc, ic, kd[d], kh[h], kw[w]);
175 ds += ((float*)diff_dst_m)[dst_off]
176 * ((float*)wei_m)[wei_off];
183 auto ker = [&](float &ds, int g, int mb, int ic, int id, int ih, int iw) {
184 for (int oc = 0; oc < p->oc/p->g; ++oc) {
185 for (int kd = 0; kd < p->kd; ++kd) {
186 int od = id - kd * (p->dd + 1) + p->pd;
187 if (od < 0 || od % p->sd) continue;
189 if (od >= p->od) continue;
190 for (int kh = 0; kh < p->kh; ++kh) {
191 int oh = ih - kh * (p->dh + 1) + p->ph;
192 if (oh < 0 || oh % p->sh) continue;
194 if (oh >= p->oh) continue;
196 for (int kw = 0; kw < p->kw; ++kw) {
197 int ow = iw - kw * (p->dw + 1) + p->pw;
198 if (ow < 0 || ow % p->sw) continue;
200 if (ow >= p->ow) continue;
202 size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
203 size_t wei_off = wei_off_f(p, g, oc, ic, kd, kh, kw);
204 ds += ((float*)diff_dst_m)[dst_off]
205 * ((float*)wei_m)[wei_off];
212 auto maybe_scale = [&](float &ds, int ic) {
213 if (!p->attr.oscale.is_def()) {
214 using policy_t = attr_t::scale_t::policy_t;
215 const auto &s = p->attr.oscale;
216 if (s.policy == policy_t::COMMON) {
224 /* Used for Deconv FWD */
225 auto maybe_post_ops = [&](float &conv_res, float dst) {
226 using namespace mkldnn::impl::math;
228 const auto &ops = p->attr.post_ops;
229 for (int idx = 0; idx < ops.len; ++idx) {
230 using pk = attr_t::post_ops_t::kind_t;
231 const auto &e = ops.entry[idx];
233 const auto &s = e.eltwise.scale;
234 const auto &a = e.eltwise.alpha;
235 const auto &b = e.eltwise.beta;
238 case pk::SUM: conv_res += e.sum.scale * dst; break;
239 case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
240 case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
241 case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
242 case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
243 case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
244 case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
245 case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
246 case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
247 case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
248 case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
250 assert(!"unknown attr::post_ops::kind");
255 mkldnn::impl::parallel_nd(p->g, p->mb, p->ic / p->g, p->id, p->ih, p->iw,
256 [&](int g, int mb, int ic, int id, int ih, int iw) {
257 size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
258 float &ds = ((float*)diff_src_m)[src_off];
261 ker_fast(conv_res, g, mb, ic, id, ih, iw);
263 ker(conv_res, g, mb, ic, id, ih, iw);
265 if (p->dir & FLAG_BIA) {
266 const size_t bia_off = (size_t)g * p->ic / p->g + ic;
267 conv_res += ((float*)bia_m)[bia_off];
269 maybe_scale(conv_res, g * p->ic / p->g + ic);
270 maybe_post_ops(conv_res, ds);
277 void compute_ref_bwd_weights(const prb_t *p, dnn_mem_t &src_m,
278 dnn_mem_t &diff_wei_m, dnn_mem_t &diff_dst_m) {
279 auto compute_bounds = [](int I, int O, int k, int S, int P, int D,
280 int &o_s, int &o_e) {
281 const float tmp = P - k * (D + 1);
282 o_s = MAX2(0, ceilf(tmp / S));
283 o_e = MIN2(O, ceilf((I + tmp) / S));
286 auto ker = [&](float &dw, int g, int oc, int ic, int kd, int kh, int kw) {
287 int od_s, od_e, oh_s, oh_e, ow_s, ow_e;
288 compute_bounds(p->id, p->od, kd, p->sd, p->pd, p->dd, od_s, od_e);
289 compute_bounds(p->ih, p->oh, kh, p->sh, p->ph, p->dh, oh_s, oh_e);
290 compute_bounds(p->iw, p->ow, kw, p->sw, p->pw, p->dw, ow_s, ow_e);
292 for (int mb = 0; mb < p->mb; ++mb) {
293 for (int od = od_s; od < od_e; ++od) {
294 for (int oh = oh_s; oh < oh_e; ++oh) {
295 for (int ow = ow_s; ow < ow_e; ++ow) {
296 const int id = od * p->sd - p->pd + kd * (p->dd + 1);
297 const int ih = oh * p->sh - p->ph + kh * (p->dh + 1);
298 const int iw = ow * p->sw - p->pw + kw * (p->dw + 1);
300 size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
301 size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
302 dw += ((float*)diff_dst_m)[dst_off]
303 * ((float*)src_m)[src_off];
310 mkldnn::impl::parallel_nd(
311 p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
312 [&](int g, int oc, int ic, int kd, int kh, int kw) {
313 size_t wei_off = wei_off_f(p, g, oc, ic, kd, kh, kw);
314 float &dw = ((float*)diff_wei_m)[wei_off];
316 ker(dw, g, oc, ic, kd, kh, kw);
321 void compute_ref_bwd_bias(const prb_t *p, dnn_mem_t &diff_bia_m,
322 dnn_mem_t &diff_dst_m) {
323 mkldnn::impl::parallel_nd(p->g, p->oc / p->g, [&](int g, int oc) {
324 size_t bia_off = bia_off_f(p, g, oc);
327 for (int mb = 0; mb < p->mb; ++mb)
328 for (int od = 0; od < p->od; ++od)
329 for (int oh = 0; oh < p->oh; ++oh)
330 for (int ow = 0; ow < p->ow; ++ow)
332 size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
333 sum += ((float*)diff_dst_m)[dst_off];
335 ((float *)diff_bia_m)[bia_off] = (float)sum;
339 void compute_ref_direct_bwd_w(const prb_t *p, dnn_mem_t &src_m,
340 dnn_mem_t &diff_wei_m, dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m) {
341 compute_ref_bwd_weights(p, src_m, diff_wei_m, diff_dst_m);
342 if (!(p->dir & FLAG_BIA)) return;
343 compute_ref_bwd_bias(p, diff_bia_m, diff_dst_m);