Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / conv.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <float.h>
20 #include <math.h>
21
22 #include "mkldnn.h"
23
24 #include "src/common/mkldnn_thread.hpp"
25
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
28 #include "norm.hpp"
29
30 #include "conv/conv_common.hpp"
31
32 namespace conv {
33
34 inline bool is_conv_3d(const prb_t *p) {
35     return p->id > 1;
36 }
37
38 inline bool is_conv_1d(const prb_t *p) {
39     return !is_conv_3d(p) && p->ih == 1 && p->kh == 1
40                    && p->cfg[SRC].dt != mkldnn_s8 // temporary workaround until
41                    && p->cfg[SRC].dt != mkldnn_u8; // int8 jit supports 1d
42 }
43
44 double get_trust_nz_level(const prb_t *p, data_kind_t kind,
45         bool final_compare) {
46     if (!final_compare)
47         return p->cfg[kind].f_sparsity;
48
49     auto negative_to_zero = [&]() {
50         using pk = attr_t::post_ops_t::kind_t;
51         const auto &po = p->attr.post_ops;
52         int count = 0;
53         for (int i = 0; i < po.len; ++i) {
54             auto k = po.entry[i].kind;
55             count +=
56                 k == pk::RELU || k == pk::ELU || k == pk::SQRT || k == pk::BRELU;
57         }
58         return !!count;
59     };
60
61     double trust = 0.3; /* why? */
62     switch (kind) {
63         case SRC:
64             trust /= p->sd * p->sh * p->sw;
65             break;
66         case WEI:
67             trust /= 1. * p->kd * p->kh * p->kw
68                 / MIN3(p->kd * p->kh * p->kw, p->id * p->ih * p->iw
69                 , p->od * p->oh * p->ow);
70             break;
71         case BIA:
72             trust = 0.8 * p->cfg[DST].f_sparsity; /* why? */
73             break;
74         case DST:
75             trust /= negative_to_zero() == 0 ? 1 : 2;
76             break;
77     }
78
79     return trust;
80 }
81
82 inline bool post_ops_require_integral_check(const prb_t *p) {
83     if (p->attr.post_ops.len == 0) return false;
84
85     using pk = attr_t::post_ops_t::kind_t;
86     const auto &ops = p->attr.post_ops;
87
88     // assumptions: at most 1 eltwise, scale = 1.
89     for (int idx = 0; idx < ops.len; ++idx) {
90         const auto &e = ops.entry[idx];
91         if (e.kind == pk::SUM || e.kind == pk::ABS) continue;
92         if (e.kind == pk::RELU && e.eltwise.alpha == 0.f) continue;
93         return true;
94     }
95
96     return false;
97 }
98
99 inline double get_eps(const prb_t *p, const data_kind_t kind) {
100     // Winograd specifics
101     if (p->alg & WINO && p->dir & FLAG_WEI) {
102         /*This is an empirical equation derived by observing growth error
103           with increasing 'k' dimension in gemm of winograd*/
104         return p->cfg[kind].eps *
105             (MAX2(1, pow(10, 0.4 * log10(0.125 * p->mb * p->oh * p->ow))));
106     }
107
108     // post-ops specifics
109     if (post_ops_require_integral_check(p))
110         return MAX2(1e-5, p->cfg[kind].eps);
111
112     return p->cfg[kind].eps;
113 }
114
115 inline void get_result(const prb_t *p, const data_kind_t kind, res_t *r,
116         const diff_norm_t diff_norm) {
117     const float eps = get_eps(p, kind);
118
119     /* Ignoring element-wise errors for Winograd and in some cases of post-ops,
120      * since large relative error in few elements (which are anyways close
121      * to zero) results in false positive failures */
122
123     bool wino_test = (p->alg & WINO) && diff_norm.rel_diff(norm_t::L2) <= eps;
124     if (wino_test) r->errors = 0;
125
126     bool post_ops_test = post_ops_require_integral_check(p)
127         && diff_norm.rel_diff(norm_t::L2) <= eps;
128     if (post_ops_test) r->errors = 0;
129
130     if (r->errors) r->state = FAILED;
131 }
132
133 inline int compare_dat(const prb_t *p, data_kind_t kind, dnn_mem_t &mem_dt,
134         dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
135     const bool dont_complain = false
136         || (p->alg & WINO)
137         || post_ops_require_integral_check(p);
138
139     size_t nelems = mem_dt.nelems();
140
141     const char *skind = data_kind2str(kind);
142
143     int in = 0, below = 0, above = 0;
144     int in_ok = 0, below_ok = 0, above_ok = 0;
145     int non_zero = 0;
146
147     diff_norm_t diff_norm;
148
149     r->errors = 0;
150     r->total = nelems;
151
152     for (size_t i = 0; i < nelems; ++i) {
153         const float dt = ((float*)mem_dt)[i];
154         const float fp0 = ((float*)mem_fp)[i];
155
156         float fp = fp0;
157         if (p->cfg[kind].dt != mkldnn_f32) {
158             using R = attr_t::round_mode_t;
159             switch (p->attr.irmode) {
160                 case R::DOWN: fp = floorf(fp0); break;
161                 case R::NEAREST: fp = nearbyintf(fp0); break;
162                 default:
163                     return UNTESTED;
164             }
165         }
166
167         const float diff = fabsf(fp - dt);
168         const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
169
170         bool ok = true;
171         if (fp < p->cfg[kind].min) {
172             diff_norm.update(p->cfg[kind].min, dt);
173             ok = dt == p->cfg[kind].min;
174             below += 1;
175             below_ok += ok;
176         } else if (fp > p->cfg[kind].max) {
177             diff_norm.update(p->cfg[kind].max, dt);
178             ok = dt == p->cfg[kind].max;
179             above += 1;
180             above_ok += ok;
181         } else {
182             diff_norm.update(fp, dt);
183             ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= get_eps(p, kind);
184             in += 1;
185             in_ok += ok;
186         }
187         if (!ok) {
188             r->errors++;
189             if ((!dont_complain && r->errors < 10) || verbose >=10) {
190                 int mb_or_g = 0, g_or_oc = 0, c = 0, d = 0, h = 0, w = 0;
191                 switch (kind) {
192                 case SRC: inv_src_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
193                 case WEI: inv_wei_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
194                 case BIA: inv_bia_off_f(p, i, mb_or_g, g_or_oc); break;
195                 case DST: inv_dst_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
196                 }
197                 print(0, "[%4lu][%s%s][%d,%d,%d,%d,%d,%d] "
198                         "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
199                         (unsigned long)i,
200                         final_compare == false ? "REORDER " : "",
201                         skind, mb_or_g, g_or_oc, c, d, h, w,
202                         fp, fp0, dt, diff, rel_diff);
203             }
204         }
205
206         /* for debug purposes only: dump the output */
207         if (final_compare && verbose >= 50 && i < 30) {
208             int mb_or_g = 0, g_or_oc = 0, c = 0, d = 0, h = 0, w = 0;
209             switch (kind) {
210             case SRC: inv_src_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
211             case WEI: inv_wei_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
212             case BIA: inv_bia_off_f(p, i, mb_or_g, g_or_oc); break;
213             case DST: inv_dst_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
214             }
215
216             print(0, "[%4lu][%s][%d,%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
217                     (unsigned long)i,
218                     skind, mb_or_g, g_or_oc, c, d, h, w, fp, fp0, dt);
219         }
220
221         non_zero += fp != 0;
222     }
223
224     diff_norm.done();
225     get_result(p, kind, r, diff_norm);
226
227     if (final_compare || r->errors) {
228         const int vl = r->errors ? 0 : 2;
229         print(vl, "@@@ [%s] %sdiff: err:%d, l0(``%g``) "
230                 "l1:(%g,%g,%g,``%g``) "
231                 "l2:(%g,%g,%g,``%g``) "
232                 "l8:(%g,%g,%g,``%g``)\n",
233                 skind, final_compare ? "final: " : "", (int)r->errors,
234                 diff_norm.rel_diff(norm_t::L0),
235                 diff_norm.a_[norm_t::L1], diff_norm.b_[norm_t::L1],
236                 diff_norm.diff_[norm_t::L1], diff_norm.rel_diff(norm_t::L1),
237                 diff_norm.a_[norm_t::L2], diff_norm.b_[norm_t::L2],
238                 diff_norm.diff_[norm_t::L2], diff_norm.rel_diff(norm_t::L2),
239                 diff_norm.a_[norm_t::L8], diff_norm.b_[norm_t::L8],
240                 diff_norm.diff_[norm_t::L8], diff_norm.rel_diff(norm_t::L8));
241     }
242
243     const double trust_rg_level = 0.3;
244     const double trust_nz_level = get_trust_nz_level(p, kind, final_compare);
245
246     const double trust_rg = (double)in / r->total;
247     const double trust_nz = (double)non_zero / r->total;
248
249     const bool no_trust = true /* ...in the test ...at all */
250         && final_compare
251         && (trust_rg < trust_rg_level || trust_nz < trust_nz_level);
252
253     const bool dump = verbose >= 20
254         || (verbose >= 10 && (trust_rg < 1. || trust_nz < 1.));
255     if (dump) {
256         print(0, "@@@ [%s] %strust range:%.2f nz:%.2f "
257                 "(level range:%.2f nz:%.2f). "
258                 "in:%d (ok:%d) below:%d (ok:%d) above:%d (ok:%d) nz:%d "
259                 "total:%lu\n", skind, final_compare ? "final: " : "",
260                 trust_rg, trust_nz, trust_rg_level, trust_nz_level, in, in_ok,
261                 below, below_ok, above, above_ok, non_zero,
262                 (unsigned long)r->total);
263     }
264
265     if (no_trust) {
266         r->state = MISTRUSTED;
267         print(0, "@@@ [%s] test-bug: trust is too low. "
268                 "range:%.2f (?<%.2f) nz:%.2f (?<%.2f) (nz: %d total: %lu)\n",
269                 skind, trust_rg, trust_rg_level, trust_nz, trust_nz_level,
270                 non_zero, (unsigned long)r->total);
271     }
272
273     if (final_compare && r->state == UNTESTED)
274         r->state = PASSED; /* optimism */
275
276     return r->state == FAILED ? FAIL : OK;
277 }
278
279 int compare_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
280         res_t *r, bool final_compare)
281 { return compare_dat(p, SRC, mem_dt, mem_fp, r, final_compare); }
282 int compare_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
283         res_t *r, bool final_compare)
284 { return compare_dat(p, WEI, mem_dt, mem_fp, r, final_compare); }
285 int compare_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
286         res_t *r, bool final_compare)
287 { return compare_dat(p, BIA, mem_dt, mem_fp, r, final_compare); }
288 int compare_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
289         res_t *r, bool final_compare)
290 { return compare_dat(p, DST, mem_dt, mem_fp, r, final_compare); }
291
292 int fill_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
293         res_t *r) {
294     const bool extra_mem = mem_dt.dt() != mem_fp.dt();
295     dnn_mem_t *p_mem_00 = extra_mem
296         ? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
297             get_default_format(mem_dt.md_.ndims, DATA))
298         : &mem_fp;
299     dnn_mem_t &mem_00 = *p_mem_00;
300
301     const auto &c = p->cfg[SRC];
302     const int range = c.f_max - c.f_min + 1;
303
304     mkldnn::impl::parallel_nd(p->mb, p->ic, p->id, p->ih, p->iw,
305         [&](int mb, int ic, int id, int ih, int iw) {
306         const int gen = 5 * id + 17 * ih + 13 * iw + 13 * mb + 19 * ic + 1637;
307         const bool non_base = flip_coin(gen, c.f_sparsity);
308         const float value =
309             non_base ? c.f_min + gen * c.f_step % range : c.f_base;
310
311         ((float*)mem_00)[src_off_f(p, mb, 0, ic, id, ih, iw)] = value;
312     });
313
314     SAFE(mem_dt.reorder(mem_00), WARN);
315     if (extra_mem) {
316         SAFE(mem_fp.reorder(mem_dt), WARN);
317         SAFE(compare_src(p, mem_fp, mem_00, r), WARN);
318         delete &mem_00;
319     }
320
321     return OK;
322 }
323
324 int fill_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
325     res_t *r) {
326     const bool wino_s8 = p->alg == WINO && p->cfg[WEI].dt == mkldnn_s8;
327     const bool s8_s8 = p->cfg[WEI].dt == mkldnn_s8 && p->cfg[SRC].dt == mkldnn_s8;
328     const bool diff_data_type = mem_dt.dt() != mem_fp.dt();
329     const bool check_reorder = diff_data_type && !wino_s8 && !s8_s8;
330
331     dnn_mem_t *p_mem_00 = check_reorder
332         ? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
333             get_default_format(mem_dt.md_.ndims, p->has_groups ? GWEI : WEI))
334         : &mem_fp;
335     dnn_mem_t &mem_00 = *p_mem_00;
336
337     const auto &c = p->cfg[WEI];
338     const int range = c.f_max - c.f_min + 1;
339
340     mkldnn::impl::parallel_nd(
341         p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
342         [&](int g, int oc, int ic, int kd, int kh, int kw) {
343         const int gen = 5 * kd + 17 * kh + 13 * kw + 13 * oc + 19 * ic + 38;
344         const bool non_base = flip_coin(gen, c.f_sparsity);
345         const float value =
346             non_base ? c.f_min + gen * c.f_step % range : c.f_base;
347
348         ((float*)mem_00)[wei_off_f(p, g, oc, ic, kd, kh, kw)] = value;
349     });
350
351     SAFE(mem_dt.reorder(mem_00), WARN);
352     if (check_reorder) {
353         SAFE(mem_fp.reorder(mem_dt), WARN);
354         SAFE(compare_wei(p, mem_fp, mem_00, r), WARN);
355         delete &mem_00;
356     }
357
358     return OK;
359 }
360
361 int fill_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
362         res_t *r) {
363     const bool extra_mem = mem_dt.dt() != mem_fp.dt();
364     dnn_mem_t *p_mem_00 = extra_mem
365         ? new dnn_mem_t(mem_dt.md_, mkldnn_f32, mkldnn_x)
366         : &mem_fp;
367     dnn_mem_t &mem_00 = *p_mem_00;
368
369     const auto &c = p->cfg[BIA];
370     const int range = c.f_max - c.f_min + 1;
371
372     const size_t sz = mem_00.nelems();
373     for (size_t i = 0; i < sz; ++i) {
374         const int gen = (int)(19 * i);
375         const bool non_base = flip_coin(gen, c.f_sparsity);
376         const float value =
377             non_base ? c.f_min + gen * c.f_step % range : c.f_base;
378
379         ((float*)mem_00)[i] = value;
380     }
381
382     SAFE(mem_dt.reorder(mem_00), WARN);
383     if (extra_mem) {
384         SAFE(mem_fp.reorder(mem_dt), WARN);
385         SAFE(compare_bia(p, mem_fp, mem_00, r), WARN);
386         delete &mem_00;
387     }
388
389     return OK;
390 }
391
392 int fill_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
393         res_t *r) {
394     const bool extra_mem = mem_dt.dt() != mem_fp.dt();
395     dnn_mem_t *p_mem_00 = extra_mem
396         ? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
397             get_default_format(mem_dt.md_.ndims, DATA))
398         : &mem_fp;
399     dnn_mem_t &mem_00 = *p_mem_00;
400
401     const auto &c = p->cfg[DST];
402     const int range = c.f_max - c.f_min + 1;
403
404     mkldnn::impl::parallel_nd(p->mb, p->oc, p->od, p->oh, p->ow,
405         [&](int mb, int oc, int od, int oh, int ow) {
406         const int gen = 7 * od + 19 * oh + 17 * ow + 13 * mb + 13 * oc + 223;
407         const bool non_base = flip_coin(gen, c.f_sparsity);
408         const float value =
409             non_base ? c.f_min + gen * c.f_step % range : c.f_base;
410
411         ((float*)mem_00)[dst_off_f(p, mb, 0, oc, od, oh, ow)] = value;
412     });
413
414     SAFE(mem_dt.reorder(mem_00), WARN);
415     if (extra_mem) {
416         SAFE(mem_fp.reorder(mem_dt), WARN);
417         SAFE(compare_dst(p, mem_fp, mem_00, r), WARN);
418         delete &mem_00;
419     }
420
421     return OK;
422 }
423
424 inline int init_pd(const prb_t *p, mkldnn_convolution_desc_t &cd,
425         mkldnn_primitive_desc_t &cpd, res_t *r) {
426     mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
427
428     int ndims = is_conv_3d(p) ? 5 : is_conv_1d(p) ? 3 : 4;
429     mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
430     mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
431     mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
432
433     mkldnn_dims_t wei_1d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kw};
434     mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
435     mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
436
437     mkldnn_dims_t bia_dims = {p->oc};
438
439     mkldnn_dims_t dst_1d_dims = {p->mb, p->oc, p->ow};
440     mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
441     mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
442
443     DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
444         is_conv_3d(p) ? src_3d_dims : is_conv_1d(p) ? src_1d_dims : src_2d_dims,
445         p->cfg[SRC].dt, mkldnn_any), WARN);
446
447     DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
448         is_conv_3d(p)
449         ? &wei_3d_dims[!p->has_groups]
450         : is_conv_1d(p)
451         ? &wei_1d_dims[!p->has_groups]
452         : &wei_2d_dims[!p->has_groups],
453         p->cfg[WEI].dt, mkldnn_any), WARN);
454
455     DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt,
456         mkldnn_any), WARN);
457
458     DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
459         is_conv_3d(p) ? dst_3d_dims : is_conv_1d(p) ? dst_1d_dims : dst_2d_dims,
460         p->cfg[DST].dt, mkldnn_any), WARN);
461
462     ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
463     ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
464     ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
465
466     auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
467         return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
468     };
469     ptrdiff_t padding_r_nd[] = {
470         bph(p->id, p->od, p->kd, p->sd, p->pd, p->dd),
471         bph(p->ih, p->oh, p->kh, p->sh, p->ph, p->dh),
472         bph(p->iw, p->ow, p->kw, p->sw, p->pw, p->dw)};
473
474     ptrdiff_t *strides = strides_nd + (5 - ndims);
475     ptrdiff_t *dilates = dilates_nd + (5 - ndims);
476     ptrdiff_t *padding = padding_nd + (5 - ndims);
477     ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
478
479     mkldnn_alg_kind_t alg = mkldnn_convolution_direct;
480     if (p->alg == WINO) alg = mkldnn_convolution_winograd;
481     if (p->alg == AUTO) alg = mkldnn_convolution_auto;
482
483     switch (p->dir) {
484     case FWD_D: case FWD_B: case FWD_I:
485         DNN_SAFE(mkldnn_dilated_convolution_forward_desc_init(&cd,
486                     p->dir == FWD_I
487                         ? mkldnn_forward_inference
488                         : mkldnn_forward_training,
489                     alg, &src_d, &wei_d,
490                     p->dir == FWD_B ? &bia_d : NULL, &dst_d,
491                     strides, dilates, padding, padding_r,
492                     mkldnn_padding_zero), WARN);
493         break;
494     case BWD_D:
495         DNN_SAFE(mkldnn_dilated_convolution_backward_data_desc_init(&cd, alg,
496                     &src_d, &wei_d, &dst_d, strides, dilates, padding, padding_r,
497                     mkldnn_padding_zero), WARN);
498         break;
499     case BWD_W: case BWD_WB:
500         DNN_SAFE(mkldnn_dilated_convolution_backward_weights_desc_init(&cd,
501                     alg, &src_d, &wei_d, p->dir == BWD_W ? NULL : &bia_d, &dst_d,
502                     strides, dilates, padding, padding_r,
503                     mkldnn_padding_zero), WARN);
504         break;
505     default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
506     }
507
508     DNN_SAFE(cd.accum_data_type == p->cfg[ACC].dt
509             ? mkldnn_success : mkldnn_unimplemented, CRIT);
510
511     auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
512
513     mkldnn_status_t init_status = mkldnn_success;
514     init_status = mkldnn_primitive_desc_create_v2(&cpd, &cd, mkldnn_attr,
515                 engine, NULL);
516
517     mkldnn_primitive_attr_destroy(mkldnn_attr);
518
519     if (init_status == mkldnn_unimplemented)
520         return r->state = UNIMPLEMENTED, OK;
521     else
522         SAFE(init_status, WARN);
523
524     const char *impl_str = query_impl_info(cpd);
525     if (maybe_skip(skip_impl, impl_str)) {
526         print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
527         DNN_SAFE(mkldnn_primitive_desc_destroy(cpd), WARN);
528         return r->state = SKIPPED, OK;
529     } else {
530         print(5, "mkldnn implementation: %s\n", impl_str);
531     }
532
533     auto q = [=](mkldnn_query_t query, int index = 0) {
534         return *mkldnn_primitive_desc_query_memory_d(
535                 mkldnn_primitive_desc_query_pd(cpd, query, index));
536     };
537
538     if (p->alg == AUTO) {
539         mkldnn_convolution_desc_t *temp_conv_desc = {0};
540         DNN_SAFE(mkldnn_primitive_desc_query(cpd,
541                 mkldnn_query_convolution_d, 0, &temp_conv_desc), CRIT);
542         cd.alg_kind = temp_conv_desc->alg_kind;
543     }
544
545     if (p->dir == BWD_D)
546         cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
547     else
548         cd.src_desc = q(mkldnn_query_src_pd);
549
550     if (p->dir & FLAG_WEI)
551         cd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
552     else
553         cd.weights_desc = q(mkldnn_query_weights_pd);
554
555     if (p->dir & FLAG_BIA) {
556         if (p->dir & FLAG_BWD)
557             cd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
558         else
559             cd.bias_desc = q(mkldnn_query_weights_pd, 1);
560     }
561
562     if (p->dir & FLAG_BWD)
563         cd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
564     else
565         cd.dst_desc = q(mkldnn_query_dst_pd);
566
567     return OK;
568 }
569
570 int doit(const prb_t *p, res_t *r) {
571     res_t res_zero{};
572     *r = res_zero;
573
574     mkldnn_convolution_desc_t cd;
575     mkldnn_primitive_desc_t cpd;
576     mkldnn_primitive_t c{};
577
578     SAFE(init_pd(p, cd, cpd, r), WARN);
579
580     prb_t *p_temp = nullptr;
581     if (p->alg == AUTO || p->alg == WINO) {
582         p_temp = new prb_t((desc_t)*p, p->dir, p->cfg,
583                     p->alg, p->attr, p->mb);
584         if (p->alg == AUTO) p_temp->alg = alg_kind2alg(cd.alg_kind);
585         p_temp->cfg = auto_cfg(p_temp->alg, p->cfg);
586         p = p_temp;
587     }
588
589
590     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
591         return OK;
592
593     auto &src_dt_d = p->dir == BWD_D ? cd.diff_src_desc : cd.src_desc;
594     auto &wei_dt_d = p->dir & FLAG_WEI ? cd.diff_weights_desc : cd.weights_desc;
595     auto &bia_dt_d = p->dir & FLAG_BWD ? cd.diff_bias_desc : cd.bias_desc;
596     auto &dst_dt_d = p->dir & FLAG_BWD ? cd.diff_dst_desc: cd.dst_desc;
597
598     dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
599     dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
600     dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
601     dnn_mem_t *p_bia_dt = p->dir & FLAG_BIA
602         ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
603     dnn_mem_t &bia_dt = *p_bia_dt;
604
605     auto src_format = get_default_format(src_dt.md_.ndims, DATA);
606     auto wei_format = get_default_format(wei_dt.md_.ndims,
607         p->has_groups ? GWEI : WEI);
608
609     const auto fp = mkldnn_f32;
610     dnn_mem_t src_fp(src_dt_d, fp, src_format);
611     dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
612     dnn_mem_t dst_fp(dst_dt_d, fp, src_format);
613     dnn_mem_t *p_bia_fp = p->dir & FLAG_BIA
614         ? new dnn_mem_t(bia_dt_d, fp, mkldnn_x) : new dnn_mem_t();
615     dnn_mem_t &bia_fp = *p_bia_fp;
616
617     SAFE(fill_src(p, src_dt, src_fp, r), WARN);
618     SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
619     SAFE(fill_dst(p, dst_dt, dst_fp, r), WARN);
620     if (p->dir & FLAG_BIA)
621         SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
622
623     if (p->dir & FLAG_FWD) {
624         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
625             {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
626         };
627         const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
628         DNN_SAFE(mkldnn_primitive_create(&c, cpd, inputs, outputs), WARN);
629         SAFE(execute(c), WARN);
630         if (bench_mode & CORR) {
631             compute_ref_fwd(p, src_fp, wei_fp, bia_fp, dst_fp);
632             dnn_mem_t dst(dst_dt, fp, src_format);
633             SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
634         }
635     } else if (p->dir == BWD_D) {
636         mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
637         const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
638         DNN_SAFE(mkldnn_primitive_create(&c, cpd, inputs, outputs), WARN);
639         SAFE(execute(c), WARN);
640         if (bench_mode & CORR) {
641             compute_ref_bwd_d(p, src_fp, wei_fp, bia_fp, dst_fp);
642             dnn_mem_t src(src_dt, fp, src_format);
643             SAFE(compare_src(p, src, src_fp, r, true), WARN);
644         }
645     } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
646         mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
647         const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
648             p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
649         };
650         DNN_SAFE(mkldnn_primitive_create(&c, cpd, inputs, outputs), WARN);
651         SAFE(execute(c), WARN);
652         if (bench_mode & CORR) {
653             compute_ref_bwd_w(p, src_fp, wei_fp, bia_fp, dst_fp);
654             dnn_mem_t wei(wei_dt, fp, wei_format);
655             SAFE(compare_wei(p, wei, wei_fp, r, true), WARN);
656             if (p->dir & FLAG_BIA) {
657                 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
658                 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
659             }
660         }
661     } else {
662         delete p_bia_dt;
663         delete p_bia_fp;
664         SAFE(FAIL, CRIT);
665     }
666
667     if (bench_mode & PERF) {
668         auto &t = r->timer;
669         t.reset();
670         while (true) {
671             SAFE(execute(c), WARN);
672             t.stamp();
673             const bool stop = false
674                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
675                 || (!fix_times_per_prb
676                         && t.total_ms() >= max_ms_per_prb
677                         && t.times() >= min_times_per_prb);
678             if (stop) break;
679         }
680     }
681
682     DNN_SAFE(mkldnn_primitive_desc_destroy(cpd), CRIT);
683     DNN_SAFE(mkldnn_primitive_destroy(c), CRIT);
684
685     delete p_bia_dt;
686     delete p_bia_fp;
687     delete p_temp;
688
689     return OK;
690 }
691
692 }