Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / ip / ip.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <float.h>
20 #include <math.h>
21
22 #include "mkldnn.h"
23
24 #include "src/common/mkldnn_thread.hpp"
25
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
28
29 #include "ip/ip.hpp"
30
31 namespace ip {
32 inline bool is_3d(const prb_t *p) {
33     return p->id > 1;
34 }
35
36 inline int init_pd(const prb_t *p, mkldnn_inner_product_desc_t &ipd,
37         mkldnn_primitive_desc_t &ippd, res_t *r) {
38     mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
39
40     int ndims = is_3d(p) ? 5 : 4;
41     mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
42     mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
43     mkldnn_dims_t wei_dims = {p->oc, p->ic, p->ih, p->iw};
44     mkldnn_dims_t wei_3d_dims = {p->oc, p->ic, p->id, p->ih, p->iw};
45     mkldnn_dims_t bia_dims = {p->oc};
46     mkldnn_dims_t dst_dims = {p->mb, p->oc};
47
48     DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims, is_3d(p) ? src_3d_dims : src_dims,
49             p->cfg[SRC].dt, mkldnn_any), WARN);
50     DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims, is_3d(p) ? wei_3d_dims : wei_dims,
51             p->cfg[WEI].dt, mkldnn_any), WARN);
52     DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
53     DNN_SAFE(mkldnn_memory_desc_init(&dst_d, 2, dst_dims, p->cfg[DST].dt, mkldnn_any), WARN);
54
55     switch (p->dir) {
56     case FWD_D: case FWD_B:
57         DNN_SAFE(mkldnn_inner_product_forward_desc_init(&ipd, mkldnn_forward,
58                     &src_d, &wei_d, p->dir == FWD_D ? NULL : &bia_d, &dst_d),
59                 WARN);
60         break;
61     case BWD_D:
62         DNN_SAFE(mkldnn_inner_product_backward_data_desc_init(&ipd, &src_d,
63                     &wei_d, &dst_d), WARN);
64         break;
65     case BWD_W: case BWD_WB:
66         DNN_SAFE(mkldnn_inner_product_backward_weights_desc_init(&ipd, &src_d,
67                     &wei_d, p->dir == BWD_W ? NULL : &bia_d, &dst_d), WARN);
68         break;
69     default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
70     }
71
72     DNN_SAFE(ipd.accum_data_type == p->cfg[ACC].dt
73             ? mkldnn_success : mkldnn_unimplemented, CRIT);
74
75     auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
76
77     mkldnn_status_t init_status = mkldnn_success;
78     init_status = mkldnn_primitive_desc_create_v2(&ippd, &ipd, mkldnn_attr,
79             engine, NULL);
80
81     mkldnn_primitive_attr_destroy(mkldnn_attr);
82
83     if (init_status == mkldnn_unimplemented)
84         return r->state = UNIMPLEMENTED, OK;
85     else
86         SAFE(init_status, WARN);
87
88     auto q = [=](mkldnn_query_t query, int index = 0) {
89         return *mkldnn_primitive_desc_query_memory_d(
90                 mkldnn_primitive_desc_query_pd(ippd, query, index));
91     };
92
93     if (p->dir == BWD_D)
94         ipd.diff_src_desc = q(mkldnn_query_diff_src_pd);
95     else
96         ipd.src_desc = q(mkldnn_query_src_pd);
97
98     if (p->dir & FLAG_WEI)
99         ipd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
100     else
101         ipd.weights_desc = q(mkldnn_query_weights_pd);
102
103     if (p->dir & FLAG_BIA) {
104         if (p->dir & FLAG_BWD)
105             ipd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
106         else
107             ipd.bias_desc = q(mkldnn_query_weights_pd, 1);
108     }
109
110     if (p->dir & FLAG_BWD)
111         ipd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
112     else
113         ipd.dst_desc = q(mkldnn_query_dst_pd);
114
115     return OK;
116 }
117
118 inline int compare_dat(const prb_t *p, data_kind_t kind, dnn_mem_t &mem_dt,
119         dnn_mem_t &mem_fp, res_t *r) {
120     size_t nelems = mem_dt.nelems();
121     int non_zero = 0;
122     const char *skind = data_kind2str(kind);
123
124     r->errors = 0;
125     r->total = nelems;
126
127     for (size_t i = 0; i < nelems; ++i) {
128         float dt = ((float*)mem_dt)[i];
129         float fp0 = ((float *)mem_fp)[i];
130
131         float fp = fp0;
132         if (p->cfg[kind].dt != mkldnn_f32) {
133             using R = attr_t::round_mode_t;
134             switch (p->attr.irmode) {
135                 case R::DOWN: fp = floorf(fp0); break;
136                 case R::NEAREST: fp = nearbyintf(fp0); break;
137                 default:
138                     return UNTESTED;
139             }
140         }
141
142         float diff = fabsf(fp - dt);
143         float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
144
145         bool ok = true;
146         if (fp < p->cfg[kind].min)
147             ok = dt == p->cfg[kind].min;
148         else if (fp > p->cfg[kind].max)
149             ok = dt == p->cfg[kind].max;
150         else
151             ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= p->cfg[kind].eps;
152
153         if (!ok) {
154             r->errors++;
155             if (r->errors < 10 || verbose >= 10) {
156                 print(0, "[%4lu][%s]"
157                          "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
158                         (unsigned long)i, skind, fp, fp0, dt, diff, rel_diff);
159             }
160         }
161         non_zero += fp != 0;
162     }
163
164     const double trust_nz = (double)non_zero / r->total;
165     bool no_trust = trust_nz < 0.1;
166     if (no_trust) {
167         r->state = MISTRUSTED;
168         const char *skind = data_kind2str(kind);
169         print(0, "@@@ [%s] test-bug: trust is too low."
170                  " Nonzeros in output: %.2f\n",
171                 skind, trust_nz);
172     }
173
174     if (r->errors)
175         r->state = FAILED;
176
177     if (r->state == UNTESTED)
178         r->state = PASSED; /* optimism */
179
180     return r->state == FAILED ? FAIL : OK;
181 }
182
183 int fill_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
184     dnn_mem_t mem_00(
185             mem_dt.md_, mkldnn_f32, is_3d(p) ? mkldnn_ncdhw : mkldnn_nchw);
186
187     const auto &c = p->cfg[SRC];
188     const int range = c.f_max - c.f_min + 1;
189
190     mkldnn::impl::parallel_nd(p->mb, p->ic, p->id, p->ih, p->iw,
191         [&](int mb, int ic, int id, int ih, int iw) {
192             const int gen
193                 = 5 * id + 17 * ih + 13 * iw + 13 * mb + 19 * ic + 1637;
194             const bool non_base = flip_coin(gen, c.f_sparsity);
195             const float value = non_base
196                 ?  c.f_min + gen * c.f_step % range : c.f_base;
197
198             ((float *)mem_00)[src_off_f(p, mb, ic, id, ih, iw)] = value;
199         }
200     );
201
202     SAFE(mem_dt.reorder(mem_00), WARN);
203     SAFE(mem_fp.reorder(mem_dt), WARN);
204     return OK;
205 }
206
207 int fill_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
208     dnn_mem_t mem_00(
209             mem_dt.md_, mkldnn_f32, is_3d(p) ? mkldnn_goihw : mkldnn_oihw);
210
211     const auto &c = p->cfg[WEI];
212     const int range = c.f_max - c.f_min + 1;
213
214     mkldnn::impl::parallel_nd(p->oc, p->ic, p->id, p->ih, p->iw,
215         [&](int oc, int ic, int id, int ih, int iw) {
216             const int gen = 5 * id + 17 * ih + 13 * iw + 13 * oc + 19 * ic + 38;
217             const bool non_base = flip_coin(gen, c.f_sparsity);
218             const float value = non_base
219                     ?  c.f_min + gen * c.f_step % range : c.f_base;
220
221             ((float *)mem_00)[wei_off_f(p, oc, ic, id, ih, iw)] = value;
222         }
223     );
224
225     SAFE(mem_dt.reorder(mem_00), WARN);
226     SAFE(mem_fp.reorder(mem_dt), WARN);
227     return OK;
228 }
229
230 int fill_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
231     dnn_mem_t mem_00(mem_dt.md_, mkldnn_f32, mkldnn_x);
232
233     const auto &c = p->cfg[BIA];
234     const int range = c.f_max - c.f_min + 1;
235
236     const size_t sz = mem_00.nelems();
237     for (size_t i = 0; i < sz; ++i) {
238         const int gen = (int)(19 * i);
239         const bool non_base = flip_coin(gen, c.f_sparsity);
240         const float value = non_base
241                 ? c.f_min + gen * c.f_step % range : c.f_base;
242
243         ((float *)mem_00)[i] = value;
244     }
245
246     SAFE(mem_dt.reorder(mem_00), WARN);
247     SAFE(mem_fp.reorder(mem_dt), WARN);
248     return OK;
249 }
250
251 int fill_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *r) {
252     dnn_mem_t mem_00(mem_dt.md_, mkldnn_f32, mkldnn_nc);
253
254     const auto &c = p->cfg[DST];
255     const int range = c.f_max - c.f_min + 1;
256
257     mkldnn::impl::parallel_nd(p->mb, p->oc, [&](int mb, int oc) {
258         const int gen = 17 * mb + 13 * oc + 12;
259         const bool non_base = flip_coin(gen, c.f_sparsity);
260         const float value = non_base
261                 ? c.f_min + gen * c.f_step % range : c.f_base;
262
263         ((float *)mem_00)[dst_off_f(p, mb, oc)] = value;
264     });
265
266     SAFE(mem_dt.reorder(mem_00), WARN);
267     SAFE(mem_fp.reorder(mem_dt), WARN);
268
269     return OK;
270 }
271
272 int doit(const prb_t *p, res_t *r) {
273     mkldnn_inner_product_desc_t ipd;
274     mkldnn_primitive_desc_t ippd;
275     mkldnn_primitive_t ip;
276
277     SAFE(init_pd(p, ipd, ippd, r), WARN);
278     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
279         return OK;
280
281     auto &src_dt_d = p->dir == BWD_D ? ipd.diff_src_desc : ipd.src_desc;
282     auto &wei_dt_d = p->dir & FLAG_WEI ? ipd.diff_weights_desc : ipd.weights_desc;
283     auto &bia_dt_d = p->dir & FLAG_BWD ? ipd.diff_bias_desc : ipd.bias_desc;
284     auto &dst_dt_d = p->dir & FLAG_BWD ? ipd.diff_dst_desc: ipd.dst_desc;
285
286     const auto fp = mkldnn_f32;
287     dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
288     dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
289     dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
290     dnn_mem_t bia_dt = p->dir & FLAG_BIA
291         ? dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : dnn_mem_t();
292
293     auto src_format = is_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
294     auto wei_format = is_3d(p) ? mkldnn_oidhw : mkldnn_oihw;
295     dnn_mem_t src_fp(src_dt_d, fp, src_format);
296     dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
297     dnn_mem_t dst_fp(dst_dt_d, fp, mkldnn_nc);
298     dnn_mem_t bia_fp = p->dir & FLAG_BIA
299         ? dnn_mem_t(bia_dt_d, fp, mkldnn_x) : dnn_mem_t();
300
301     SAFE(fill_src(p, src_dt, src_fp, r), WARN);
302     SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
303     SAFE(fill_dst(p, dst_dt, dst_fp, r), WARN);
304     if (p->dir & FLAG_BIA)
305         SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
306
307     if (p->dir & FLAG_FWD) {
308         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
309             {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
310         };
311         const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
312         DNN_SAFE(mkldnn_primitive_create(&ip, ippd, inputs, outputs), WARN);
313         SAFE(execute(ip), WARN);
314         if (bench_mode & CORR) {
315             compute_ref_fwd(p, src_fp, wei_fp, bia_fp, dst_fp);
316             dnn_mem_t dst(dst_dt, fp, mkldnn_nc);
317             SAFE(compare_dat(p, DST, dst, dst_fp, r), WARN);
318         }
319     } else if (p->dir == BWD_D) {
320         mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
321         const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
322         DNN_SAFE(mkldnn_primitive_create(&ip, ippd, inputs, outputs), WARN);
323         SAFE(execute(ip), WARN);
324         if (bench_mode & CORR) {
325             compute_ref_bwd_d(p, src_fp, wei_fp, dst_fp);
326             dnn_mem_t src(src_dt, fp, src_format);
327             SAFE(compare_dat(p, SRC, src, src_fp, r), WARN);
328         }
329     } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
330         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
331         const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
332             p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
333         };
334         DNN_SAFE(mkldnn_primitive_create(&ip, ippd, inputs, outputs), WARN);
335         SAFE(execute(ip), WARN);
336         if (bench_mode & CORR) {
337             compute_ref_bwd_w(p, src_fp, wei_fp, bia_fp, dst_fp);
338             dnn_mem_t wei(wei_dt, fp, wei_format);
339             if (compare_dat(p, WEI, wei, wei_fp, r) != OK) return FAIL;
340             if (p->dir & FLAG_BIA) {
341                 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
342                 SAFE(compare_dat(p, BIA, bia, bia_fp, r), WARN);
343             }
344         }
345     }
346
347     if (bench_mode & PERF) {
348         auto &t = r->timer;
349         t.reset();
350         while (true) {
351             SAFE(execute(ip), WARN);
352             t.stamp();
353             const bool stop = false
354                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
355                 || (!fix_times_per_prb
356                         && t.total_ms() >= max_ms_per_prb
357                         && t.times() >= min_times_per_prb);
358             if (stop) break;
359         }
360     }
361
362     DNN_SAFE(mkldnn_primitive_desc_destroy(ippd), CRIT);
363     DNN_SAFE(mkldnn_primitive_destroy(ip), CRIT);
364
365     return OK;
366 }
367
368 }