Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / ref_conv.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "src/common/mkldnn_thread.hpp"
18 #include "src/common/math_utils.hpp"
19
20 #include "conv/conv_common.hpp"
21
22 namespace conv {
23
24 void compute_ref_fwd(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &wei_m,
25         dnn_mem_t &bia_m, dnn_mem_t &dst_m) {
26     if (p->alg == WINO && p->cfg[SRC].dt == mkldnn_f32) {
27         compute_wino_ref_fwd(p, src_m, wei_m, bia_m, dst_m);
28     } else {
29         compute_ref_direct_fwd(p, src_m, wei_m, bia_m, dst_m);
30     }
31 }
32
33 void compute_ref_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m, dnn_mem_t &wei_m,
34         dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m) {
35     if (p->alg == WINO && p->cfg[SRC].dt == mkldnn_f32) {
36         compute_wino_ref_bwd_d(p, diff_src_m, wei_m, bia_m, diff_dst_m);
37     } else {
38         compute_ref_direct_bwd_d(p, diff_src_m, wei_m, bia_m, diff_dst_m);
39     }
40 }
41
42 void compute_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &diff_wei_m,
43         dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m) {
44     if (p->alg == WINO && p->cfg[SRC].dt == mkldnn_f32) {
45         compute_wino_ref_bwd_w(p, src_m, diff_wei_m, diff_bia_m, diff_dst_m);
46     } else {
47         compute_ref_direct_bwd_w(p, src_m, diff_wei_m, diff_bia_m, diff_dst_m);
48     }
49 }
50
51 void compute_ref_direct_fwd(const prb_t *p, dnn_mem_t &src_m,
52         dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &dst_m) {
53     auto ker = [&](float &d, int g, int mb, int oc, int od, int oh, int ow) {
54         for (int ic = 0; ic < p->ic/p->g; ++ic) {
55             for (int kd = 0; kd < p->kd; ++kd) {
56                 const int id = od * p->sd - p->pd + kd * (p->dd + 1);
57                 if (id < 0 || id >= p->id) continue;
58                 for (int kh = 0; kh < p->kh; ++kh) {
59                     const int ih = oh * p->sh - p->ph + kh * (p->dh + 1);
60                     if (ih < 0 || ih >= p->ih) continue;
61
62                     for (int kw = 0; kw < p->kw; ++kw) {
63                         const int iw = ow * p->sw - p->pw + kw * (p->dw + 1);
64                         if (iw < 0 || iw >= p->iw) continue;
65
66                         size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
67                         size_t wei_off = wei_off_f(p, g, oc, ic, kd, kh, kw);
68                         d += ((float*)src_m)[src_off]
69                             * ((float*)wei_m)[wei_off];
70                     }
71                 }
72             }
73         }
74     };
75
76     auto maybe_scale = [&](float &d, int oc) {
77         if (!p->attr.oscale.is_def()) {
78             using policy_t = attr_t::scale_t::policy_t;
79             const auto &s = p->attr.oscale;
80             if (s.policy == policy_t::COMMON) {
81                 d *= s.scale;
82             } else {
83                 d *= p->scales[oc];
84             }
85         }
86     };
87
88     auto maybe_post_ops = [&](float &conv_res, float dst) {
89         using namespace mkldnn::impl::math;
90
91         const auto &ops = p->attr.post_ops;
92         for (int idx = 0; idx < ops.len; ++idx) {
93             using pk = attr_t::post_ops_t::kind_t;
94             const auto &e = ops.entry[idx];
95
96             const auto &s = e.eltwise.scale;
97             const auto &a = e.eltwise.alpha;
98             const auto &b = e.eltwise.beta;
99
100             switch (e.kind) {
101             case pk::SUM: conv_res += e.sum.scale * dst; break;
102             case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
103             case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
104             case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
105             case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
106             case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
107             case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
108             case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
109             case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
110             case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
111             case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
112             default:
113                 assert(!"unknown attr::post_ops::kind");
114             }
115         }
116     };
117
118     mkldnn::impl::parallel_nd(p->g, p->mb, p->oc / p->g, p->od, p->oh, p->ow,
119         [&](int g, int mb, int oc, int od, int oh, int ow) {
120             const size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
121             float &dst = ((float*)dst_m)[dst_off];
122
123             float conv_res = 0;
124             ker(conv_res, g, mb, oc, od, oh, ow);
125
126             if (p->dir & FLAG_BIA) {
127                 const size_t bia_off = bia_off_f(p, g, oc);
128                 conv_res += ((float*)bia_m)[bia_off];
129             }
130
131             maybe_scale(conv_res, g * p->oc / p->g + oc);
132             maybe_post_ops(conv_res, dst);
133
134             dst = conv_res;
135         }
136     );
137 }
138
139 void compute_ref_direct_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m,
140         dnn_mem_t &wei_m, dnn_mem_t &bia_m, dnn_mem_t &diff_dst_m) {
141     enum { precompute_size = 16 };
142     const bool fast = MAX2(p->kh, p->kw) <= precompute_size;
143
144     /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */
145     auto precompute_ok = [](int i, int O, int K, int S, int P, int D,
146             int &num, int *_o, int *_k) {
147         assert(K <= precompute_size);
148         num = 0;
149         for (int k = 0; k < K; ++k) {
150             int o = i - k * (D + 1) + P;
151             if (o < 0 || o % S) continue;
152             o /= S;
153             if (o >= O) continue;
154             _k[num] = k;
155             _o[num] = o;
156             ++num;
157         }
158     };
159
160     auto ker_fast = [&](float &ds, int g, int mb, int ic, int id, int ih, int iw) {
161         int kd[precompute_size], od[precompute_size], num_d;
162         int kh[precompute_size], oh[precompute_size], num_h;
163         int kw[precompute_size], ow[precompute_size], num_w;
164         precompute_ok(id, p->od, p->kd, p->sd, p->pd, p->dd, num_d, od, kd);
165         precompute_ok(ih, p->oh, p->kh, p->sh, p->ph, p->dh, num_h, oh, kh);
166         precompute_ok(iw, p->ow, p->kw, p->sw, p->pw, p->dw, num_w, ow, kw);
167
168         for (int oc = 0; oc < p->oc/p->g; ++oc) {
169             for (int d = 0; d < num_d; ++d) {
170                 for (int h = 0; h < num_h; ++h) {
171                     for (int w = 0; w < num_w; ++w) {
172
173                         size_t dst_off = dst_off_f(p, mb, g, oc, od[d], oh[h], ow[w]);
174                         size_t wei_off = wei_off_f(p, g, oc, ic, kd[d], kh[h], kw[w]);
175                         ds += ((float*)diff_dst_m)[dst_off]
176                         * ((float*)wei_m)[wei_off];
177                     }
178                 }
179             }
180         }
181     };
182
183     auto ker = [&](float &ds, int g, int mb, int ic, int id, int ih, int iw) {
184         for (int oc = 0; oc < p->oc/p->g; ++oc) {
185             for (int kd = 0; kd < p->kd; ++kd) {
186                 int od = id - kd * (p->dd + 1) + p->pd;
187                 if (od < 0 || od % p->sd) continue;
188                 od /= p->sd;
189                 if (od >= p->od) continue;
190                 for (int kh = 0; kh < p->kh; ++kh) {
191                     int oh = ih - kh * (p->dh + 1) + p->ph;
192                     if (oh < 0 || oh % p->sh) continue;
193                     oh /= p->sh;
194                     if (oh >= p->oh) continue;
195
196                     for (int kw = 0; kw < p->kw; ++kw) {
197                         int ow = iw - kw * (p->dw + 1) + p->pw;
198                         if (ow < 0 || ow % p->sw) continue;
199                         ow /= p->sw;
200                         if (ow >= p->ow) continue;
201
202                         size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
203                         size_t wei_off = wei_off_f(p, g, oc, ic, kd, kh, kw);
204                         ds += ((float*)diff_dst_m)[dst_off]
205                         * ((float*)wei_m)[wei_off];
206                     }
207                 }
208             }
209         }
210     };
211
212     auto maybe_scale = [&](float &ds, int ic) {
213         if (!p->attr.oscale.is_def()) {
214             using policy_t = attr_t::scale_t::policy_t;
215             const auto &s = p->attr.oscale;
216             if (s.policy == policy_t::COMMON) {
217                 ds *= s.scale;
218             } else {
219                 ds *= p->scales[ic];
220             }
221         }
222     };
223
224     /* Used for Deconv FWD */
225     auto maybe_post_ops = [&](float &conv_res, float dst) {
226         using namespace mkldnn::impl::math;
227
228         const auto &ops = p->attr.post_ops;
229         for (int idx = 0; idx < ops.len; ++idx) {
230             using pk = attr_t::post_ops_t::kind_t;
231             const auto &e = ops.entry[idx];
232
233             const auto &s = e.eltwise.scale;
234             const auto &a = e.eltwise.alpha;
235             const auto &b = e.eltwise.beta;
236
237             switch (e.kind) {
238             case pk::SUM: conv_res += e.sum.scale * dst; break;
239             case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
240             case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
241             case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
242             case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
243             case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
244             case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
245             case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
246             case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
247             case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
248             case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
249             default:
250                 assert(!"unknown attr::post_ops::kind");
251             }
252         }
253     };
254
255     mkldnn::impl::parallel_nd(p->g, p->mb, p->ic / p->g, p->id, p->ih, p->iw,
256         [&](int g, int mb, int ic, int id, int ih, int iw) {
257             size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
258             float &ds = ((float*)diff_src_m)[src_off];
259             float conv_res = 0;
260             if (fast)
261                 ker_fast(conv_res, g, mb, ic, id, ih, iw);
262             else
263                 ker(conv_res, g, mb, ic, id, ih, iw);
264
265             if (p->dir & FLAG_BIA) {
266                 const size_t bia_off = (size_t)g * p->ic / p->g + ic;
267                 conv_res += ((float*)bia_m)[bia_off];
268             }
269             maybe_scale(conv_res, g * p->ic / p->g + ic);
270             maybe_post_ops(conv_res, ds);
271
272             ds = conv_res;
273         }
274     );
275 }
276
277 void compute_ref_bwd_weights(const prb_t *p, dnn_mem_t &src_m,
278         dnn_mem_t &diff_wei_m, dnn_mem_t &diff_dst_m) {
279     auto compute_bounds = [](int I, int O, int k, int S, int P, int D,
280             int &o_s, int &o_e) {
281         const float tmp = P - k * (D + 1);
282         o_s = MAX2(0, ceilf(tmp / S));
283         o_e = MIN2(O, ceilf((I + tmp) / S));
284     };
285
286     auto ker = [&](float &dw, int g, int oc, int ic, int kd, int kh, int kw) {
287         int od_s, od_e, oh_s, oh_e, ow_s, ow_e;
288         compute_bounds(p->id, p->od, kd, p->sd, p->pd, p->dd, od_s, od_e);
289         compute_bounds(p->ih, p->oh, kh, p->sh, p->ph, p->dh, oh_s, oh_e);
290         compute_bounds(p->iw, p->ow, kw, p->sw, p->pw, p->dw, ow_s, ow_e);
291
292         for (int mb = 0; mb < p->mb; ++mb) {
293             for (int od = od_s; od < od_e; ++od) {
294             for (int oh = oh_s; oh < oh_e; ++oh) {
295             for (int ow = ow_s; ow < ow_e; ++ow) {
296                 const int id = od * p->sd - p->pd + kd * (p->dd + 1);
297                 const int ih = oh * p->sh - p->ph + kh * (p->dh + 1);
298                 const int iw = ow * p->sw - p->pw + kw * (p->dw + 1);
299
300                 size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
301                 size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
302                 dw += ((float*)diff_dst_m)[dst_off]
303                     * ((float*)src_m)[src_off];
304             }
305             }
306             }
307         }
308     };
309
310     mkldnn::impl::parallel_nd(
311         p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
312         [&](int g, int oc, int ic, int kd, int kh, int kw) {
313                 size_t wei_off = wei_off_f(p, g, oc, ic, kd, kh, kw);
314                 float &dw = ((float*)diff_wei_m)[wei_off];
315                 dw = 0;
316                 ker(dw, g, oc, ic, kd, kh, kw);
317         }
318     );
319 }
320
321 void compute_ref_bwd_bias(const prb_t *p, dnn_mem_t &diff_bia_m,
322     dnn_mem_t &diff_dst_m) {
323     mkldnn::impl::parallel_nd(p->g, p->oc / p->g, [&](int g, int oc) {
324        size_t bia_off = bia_off_f(p, g, oc);
325        double sum = 0;
326
327        for (int mb = 0; mb < p->mb; ++mb)
328        for (int od = 0; od < p->od; ++od)
329        for (int oh = 0; oh < p->oh; ++oh)
330        for (int ow = 0; ow < p->ow; ++ow)
331        {
332            size_t dst_off = dst_off_f(p, mb, g, oc, od, oh, ow);
333            sum += ((float*)diff_dst_m)[dst_off];
334        }
335        ((float *)diff_bia_m)[bia_off] = (float)sum;
336     });
337 }
338
339 void compute_ref_direct_bwd_w(const prb_t *p, dnn_mem_t &src_m,
340         dnn_mem_t &diff_wei_m, dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m) {
341     compute_ref_bwd_weights(p, src_m, diff_wei_m, diff_dst_m);
342     if (!(p->dir & FLAG_BIA)) return;
343     compute_ref_bwd_bias(p, diff_bia_m, diff_dst_m);
344 }
345
346 }