1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
24 #include "src/common/mkldnn_thread.hpp"
26 #include "mkldnn_common.hpp"
27 #include "mkldnn_memory.hpp"
30 #include "conv/conv_common.hpp"
34 inline bool is_conv_3d(const prb_t *p) {
38 inline bool is_conv_1d(const prb_t *p) {
39 return !is_conv_3d(p) && p->ih == 1 && p->kh == 1
40 && p->cfg[SRC].dt != mkldnn_s8 // temporary workaround until
41 && p->cfg[SRC].dt != mkldnn_u8; // int8 jit supports 1d
44 double get_trust_nz_level(const prb_t *p, data_kind_t kind,
47 return p->cfg[kind].f_sparsity;
49 auto negative_to_zero = [&]() {
50 using pk = attr_t::post_ops_t::kind_t;
51 const auto &po = p->attr.post_ops;
53 for (int i = 0; i < po.len; ++i) {
54 auto k = po.entry[i].kind;
56 k == pk::RELU || k == pk::ELU || k == pk::SQRT || k == pk::BRELU;
61 double trust = 0.3; /* why? */
64 trust /= p->sd * p->sh * p->sw;
67 trust /= 1. * p->kd * p->kh * p->kw
68 / MIN3(p->kd * p->kh * p->kw, p->id * p->ih * p->iw
69 , p->od * p->oh * p->ow);
72 trust = 0.8 * p->cfg[DST].f_sparsity; /* why? */
75 trust /= negative_to_zero() == 0 ? 1 : 2;
82 inline bool post_ops_require_integral_check(const prb_t *p) {
83 if (p->attr.post_ops.len == 0) return false;
85 using pk = attr_t::post_ops_t::kind_t;
86 const auto &ops = p->attr.post_ops;
88 // assumptions: at most 1 eltwise, scale = 1.
89 for (int idx = 0; idx < ops.len; ++idx) {
90 const auto &e = ops.entry[idx];
91 if (e.kind == pk::SUM || e.kind == pk::ABS) continue;
92 if (e.kind == pk::RELU && e.eltwise.alpha == 0.f) continue;
99 inline double get_eps(const prb_t *p, const data_kind_t kind) {
100 // Winograd specifics
101 if (p->alg & WINO && p->dir & FLAG_WEI) {
102 /*This is an empirical equation derived by observing growth error
103 with increasing 'k' dimension in gemm of winograd*/
104 return p->cfg[kind].eps *
105 (MAX2(1, pow(10, 0.4 * log10(0.125 * p->mb * p->oh * p->ow))));
108 // post-ops specifics
109 if (post_ops_require_integral_check(p))
110 return MAX2(1e-5, p->cfg[kind].eps);
112 return p->cfg[kind].eps;
115 inline void get_result(const prb_t *p, const data_kind_t kind, res_t *r,
116 const diff_norm_t diff_norm) {
117 const float eps = get_eps(p, kind);
119 /* Ignoring element-wise errors for Winograd and in some cases of post-ops,
120 * since large relative error in few elements (which are anyways close
121 * to zero) results in false positive failures */
123 bool wino_test = (p->alg & WINO) && diff_norm.rel_diff(norm_t::L2) <= eps;
124 if (wino_test) r->errors = 0;
126 bool post_ops_test = post_ops_require_integral_check(p)
127 && diff_norm.rel_diff(norm_t::L2) <= eps;
128 if (post_ops_test) r->errors = 0;
130 if (r->errors) r->state = FAILED;
133 inline int compare_dat(const prb_t *p, data_kind_t kind, dnn_mem_t &mem_dt,
134 dnn_mem_t &mem_fp, res_t *r, bool final_compare = false) {
135 const bool dont_complain = false
137 || post_ops_require_integral_check(p);
139 size_t nelems = mem_dt.nelems();
141 const char *skind = data_kind2str(kind);
143 int in = 0, below = 0, above = 0;
144 int in_ok = 0, below_ok = 0, above_ok = 0;
147 diff_norm_t diff_norm;
152 for (size_t i = 0; i < nelems; ++i) {
153 const float dt = ((float*)mem_dt)[i];
154 const float fp0 = ((float*)mem_fp)[i];
157 if (p->cfg[kind].dt != mkldnn_f32) {
158 using R = attr_t::round_mode_t;
159 switch (p->attr.irmode) {
160 case R::DOWN: fp = floorf(fp0); break;
161 case R::NEAREST: fp = nearbyintf(fp0); break;
167 const float diff = fabsf(fp - dt);
168 const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
171 if (fp < p->cfg[kind].min) {
172 diff_norm.update(p->cfg[kind].min, dt);
173 ok = dt == p->cfg[kind].min;
176 } else if (fp > p->cfg[kind].max) {
177 diff_norm.update(p->cfg[kind].max, dt);
178 ok = dt == p->cfg[kind].max;
182 diff_norm.update(fp, dt);
183 ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= get_eps(p, kind);
189 if ((!dont_complain && r->errors < 10) || verbose >=10) {
190 int mb_or_g = 0, g_or_oc = 0, c = 0, d = 0, h = 0, w = 0;
192 case SRC: inv_src_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
193 case WEI: inv_wei_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
194 case BIA: inv_bia_off_f(p, i, mb_or_g, g_or_oc); break;
195 case DST: inv_dst_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
197 print(0, "[%4lu][%s%s][%d,%d,%d,%d,%d,%d] "
198 "fp:%8g fp0:%8g dt:%8g diff:%8g rdiff:%8g\n",
200 final_compare == false ? "REORDER " : "",
201 skind, mb_or_g, g_or_oc, c, d, h, w,
202 fp, fp0, dt, diff, rel_diff);
206 /* for debug purposes only: dump the output */
207 if (final_compare && verbose >= 50 && i < 30) {
208 int mb_or_g = 0, g_or_oc = 0, c = 0, d = 0, h = 0, w = 0;
210 case SRC: inv_src_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
211 case WEI: inv_wei_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
212 case BIA: inv_bia_off_f(p, i, mb_or_g, g_or_oc); break;
213 case DST: inv_dst_off_f(p, i, mb_or_g, g_or_oc, c, d, h, w); break;
216 print(0, "[%4lu][%s][%d,%d,%d,%d,%d,%d] fp:%8g fp0:%8g dt:%8g\n",
218 skind, mb_or_g, g_or_oc, c, d, h, w, fp, fp0, dt);
225 get_result(p, kind, r, diff_norm);
227 if (final_compare || r->errors) {
228 const int vl = r->errors ? 0 : 2;
229 print(vl, "@@@ [%s] %sdiff: err:%d, l0(``%g``) "
230 "l1:(%g,%g,%g,``%g``) "
231 "l2:(%g,%g,%g,``%g``) "
232 "l8:(%g,%g,%g,``%g``)\n",
233 skind, final_compare ? "final: " : "", (int)r->errors,
234 diff_norm.rel_diff(norm_t::L0),
235 diff_norm.a_[norm_t::L1], diff_norm.b_[norm_t::L1],
236 diff_norm.diff_[norm_t::L1], diff_norm.rel_diff(norm_t::L1),
237 diff_norm.a_[norm_t::L2], diff_norm.b_[norm_t::L2],
238 diff_norm.diff_[norm_t::L2], diff_norm.rel_diff(norm_t::L2),
239 diff_norm.a_[norm_t::L8], diff_norm.b_[norm_t::L8],
240 diff_norm.diff_[norm_t::L8], diff_norm.rel_diff(norm_t::L8));
243 const double trust_rg_level = 0.3;
244 const double trust_nz_level = get_trust_nz_level(p, kind, final_compare);
246 const double trust_rg = (double)in / r->total;
247 const double trust_nz = (double)non_zero / r->total;
249 const bool no_trust = true /* ...in the test ...at all */
251 && (trust_rg < trust_rg_level || trust_nz < trust_nz_level);
253 const bool dump = verbose >= 20
254 || (verbose >= 10 && (trust_rg < 1. || trust_nz < 1.));
256 print(0, "@@@ [%s] %strust range:%.2f nz:%.2f "
257 "(level range:%.2f nz:%.2f). "
258 "in:%d (ok:%d) below:%d (ok:%d) above:%d (ok:%d) nz:%d "
259 "total:%lu\n", skind, final_compare ? "final: " : "",
260 trust_rg, trust_nz, trust_rg_level, trust_nz_level, in, in_ok,
261 below, below_ok, above, above_ok, non_zero,
262 (unsigned long)r->total);
266 r->state = MISTRUSTED;
267 print(0, "@@@ [%s] test-bug: trust is too low. "
268 "range:%.2f (?<%.2f) nz:%.2f (?<%.2f) (nz: %d total: %lu)\n",
269 skind, trust_rg, trust_rg_level, trust_nz, trust_nz_level,
270 non_zero, (unsigned long)r->total);
273 if (final_compare && r->state == UNTESTED)
274 r->state = PASSED; /* optimism */
276 return r->state == FAILED ? FAIL : OK;
279 int compare_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
280 res_t *r, bool final_compare)
281 { return compare_dat(p, SRC, mem_dt, mem_fp, r, final_compare); }
282 int compare_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
283 res_t *r, bool final_compare)
284 { return compare_dat(p, WEI, mem_dt, mem_fp, r, final_compare); }
285 int compare_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
286 res_t *r, bool final_compare)
287 { return compare_dat(p, BIA, mem_dt, mem_fp, r, final_compare); }
288 int compare_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
289 res_t *r, bool final_compare)
290 { return compare_dat(p, DST, mem_dt, mem_fp, r, final_compare); }
292 int fill_src(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
294 const bool extra_mem = mem_dt.dt() != mem_fp.dt();
295 dnn_mem_t *p_mem_00 = extra_mem
296 ? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
297 get_default_format(mem_dt.md_.ndims, DATA))
299 dnn_mem_t &mem_00 = *p_mem_00;
301 const auto &c = p->cfg[SRC];
302 const int range = c.f_max - c.f_min + 1;
304 mkldnn::impl::parallel_nd(p->mb, p->ic, p->id, p->ih, p->iw,
305 [&](int mb, int ic, int id, int ih, int iw) {
306 const int gen = 5 * id + 17 * ih + 13 * iw + 13 * mb + 19 * ic + 1637;
307 const bool non_base = flip_coin(gen, c.f_sparsity);
309 non_base ? c.f_min + gen * c.f_step % range : c.f_base;
311 ((float*)mem_00)[src_off_f(p, mb, 0, ic, id, ih, iw)] = value;
314 SAFE(mem_dt.reorder(mem_00), WARN);
316 SAFE(mem_fp.reorder(mem_dt), WARN);
317 SAFE(compare_src(p, mem_fp, mem_00, r), WARN);
324 int fill_wei(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
326 const bool wino_s8 = p->alg == WINO && p->cfg[WEI].dt == mkldnn_s8;
327 const bool s8_s8 = p->cfg[WEI].dt == mkldnn_s8 && p->cfg[SRC].dt == mkldnn_s8;
328 const bool diff_data_type = mem_dt.dt() != mem_fp.dt();
329 const bool check_reorder = diff_data_type && !wino_s8 && !s8_s8;
331 dnn_mem_t *p_mem_00 = check_reorder
332 ? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
333 get_default_format(mem_dt.md_.ndims, p->has_groups ? GWEI : WEI))
335 dnn_mem_t &mem_00 = *p_mem_00;
337 const auto &c = p->cfg[WEI];
338 const int range = c.f_max - c.f_min + 1;
340 mkldnn::impl::parallel_nd(
341 p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw,
342 [&](int g, int oc, int ic, int kd, int kh, int kw) {
343 const int gen = 5 * kd + 17 * kh + 13 * kw + 13 * oc + 19 * ic + 38;
344 const bool non_base = flip_coin(gen, c.f_sparsity);
346 non_base ? c.f_min + gen * c.f_step % range : c.f_base;
348 ((float*)mem_00)[wei_off_f(p, g, oc, ic, kd, kh, kw)] = value;
351 SAFE(mem_dt.reorder(mem_00), WARN);
353 SAFE(mem_fp.reorder(mem_dt), WARN);
354 SAFE(compare_wei(p, mem_fp, mem_00, r), WARN);
361 int fill_bia(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
363 const bool extra_mem = mem_dt.dt() != mem_fp.dt();
364 dnn_mem_t *p_mem_00 = extra_mem
365 ? new dnn_mem_t(mem_dt.md_, mkldnn_f32, mkldnn_x)
367 dnn_mem_t &mem_00 = *p_mem_00;
369 const auto &c = p->cfg[BIA];
370 const int range = c.f_max - c.f_min + 1;
372 const size_t sz = mem_00.nelems();
373 for (size_t i = 0; i < sz; ++i) {
374 const int gen = (int)(19 * i);
375 const bool non_base = flip_coin(gen, c.f_sparsity);
377 non_base ? c.f_min + gen * c.f_step % range : c.f_base;
379 ((float*)mem_00)[i] = value;
382 SAFE(mem_dt.reorder(mem_00), WARN);
384 SAFE(mem_fp.reorder(mem_dt), WARN);
385 SAFE(compare_bia(p, mem_fp, mem_00, r), WARN);
392 int fill_dst(const prb_t *p, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
394 const bool extra_mem = mem_dt.dt() != mem_fp.dt();
395 dnn_mem_t *p_mem_00 = extra_mem
396 ? new dnn_mem_t(mem_dt.md_, mkldnn_f32,
397 get_default_format(mem_dt.md_.ndims, DATA))
399 dnn_mem_t &mem_00 = *p_mem_00;
401 const auto &c = p->cfg[DST];
402 const int range = c.f_max - c.f_min + 1;
404 mkldnn::impl::parallel_nd(p->mb, p->oc, p->od, p->oh, p->ow,
405 [&](int mb, int oc, int od, int oh, int ow) {
406 const int gen = 7 * od + 19 * oh + 17 * ow + 13 * mb + 13 * oc + 223;
407 const bool non_base = flip_coin(gen, c.f_sparsity);
409 non_base ? c.f_min + gen * c.f_step % range : c.f_base;
411 ((float*)mem_00)[dst_off_f(p, mb, 0, oc, od, oh, ow)] = value;
414 SAFE(mem_dt.reorder(mem_00), WARN);
416 SAFE(mem_fp.reorder(mem_dt), WARN);
417 SAFE(compare_dst(p, mem_fp, mem_00, r), WARN);
424 inline int init_pd(const prb_t *p, mkldnn_convolution_desc_t &cd,
425 mkldnn_primitive_desc_t &cpd, res_t *r) {
426 mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
428 int ndims = is_conv_3d(p) ? 5 : is_conv_1d(p) ? 3 : 4;
429 mkldnn_dims_t src_1d_dims = {p->mb, p->ic, p->iw};
430 mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
431 mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
433 mkldnn_dims_t wei_1d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kw};
434 mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
435 mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
437 mkldnn_dims_t bia_dims = {p->oc};
439 mkldnn_dims_t dst_1d_dims = {p->mb, p->oc, p->ow};
440 mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
441 mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
443 DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
444 is_conv_3d(p) ? src_3d_dims : is_conv_1d(p) ? src_1d_dims : src_2d_dims,
445 p->cfg[SRC].dt, mkldnn_any), WARN);
447 DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
449 ? &wei_3d_dims[!p->has_groups]
451 ? &wei_1d_dims[!p->has_groups]
452 : &wei_2d_dims[!p->has_groups],
453 p->cfg[WEI].dt, mkldnn_any), WARN);
455 DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt,
458 DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
459 is_conv_3d(p) ? dst_3d_dims : is_conv_1d(p) ? dst_1d_dims : dst_2d_dims,
460 p->cfg[DST].dt, mkldnn_any), WARN);
462 ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
463 ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
464 ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
466 auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
467 return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
469 ptrdiff_t padding_r_nd[] = {
470 bph(p->id, p->od, p->kd, p->sd, p->pd, p->dd),
471 bph(p->ih, p->oh, p->kh, p->sh, p->ph, p->dh),
472 bph(p->iw, p->ow, p->kw, p->sw, p->pw, p->dw)};
474 ptrdiff_t *strides = strides_nd + (5 - ndims);
475 ptrdiff_t *dilates = dilates_nd + (5 - ndims);
476 ptrdiff_t *padding = padding_nd + (5 - ndims);
477 ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
479 mkldnn_alg_kind_t alg = mkldnn_convolution_direct;
480 if (p->alg == WINO) alg = mkldnn_convolution_winograd;
481 if (p->alg == AUTO) alg = mkldnn_convolution_auto;
484 case FWD_D: case FWD_B: case FWD_I:
485 DNN_SAFE(mkldnn_dilated_convolution_forward_desc_init(&cd,
487 ? mkldnn_forward_inference
488 : mkldnn_forward_training,
490 p->dir == FWD_B ? &bia_d : NULL, &dst_d,
491 strides, dilates, padding, padding_r,
492 mkldnn_padding_zero), WARN);
495 DNN_SAFE(mkldnn_dilated_convolution_backward_data_desc_init(&cd, alg,
496 &src_d, &wei_d, &dst_d, strides, dilates, padding, padding_r,
497 mkldnn_padding_zero), WARN);
499 case BWD_W: case BWD_WB:
500 DNN_SAFE(mkldnn_dilated_convolution_backward_weights_desc_init(&cd,
501 alg, &src_d, &wei_d, p->dir == BWD_W ? NULL : &bia_d, &dst_d,
502 strides, dilates, padding, padding_r,
503 mkldnn_padding_zero), WARN);
505 default: DNN_SAFE(mkldnn_invalid_arguments, CRIT);
508 DNN_SAFE(cd.accum_data_type == p->cfg[ACC].dt
509 ? mkldnn_success : mkldnn_unimplemented, CRIT);
511 auto mkldnn_attr = create_mkldnn_attr(p->attr, p->oc, p->scales);
513 mkldnn_status_t init_status = mkldnn_success;
514 init_status = mkldnn_primitive_desc_create_v2(&cpd, &cd, mkldnn_attr,
517 mkldnn_primitive_attr_destroy(mkldnn_attr);
519 if (init_status == mkldnn_unimplemented)
520 return r->state = UNIMPLEMENTED, OK;
522 SAFE(init_status, WARN);
524 const char *impl_str = query_impl_info(cpd);
525 if (maybe_skip(skip_impl, impl_str)) {
526 print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
527 DNN_SAFE(mkldnn_primitive_desc_destroy(cpd), WARN);
528 return r->state = SKIPPED, OK;
530 print(5, "mkldnn implementation: %s\n", impl_str);
533 auto q = [=](mkldnn_query_t query, int index = 0) {
534 return *mkldnn_primitive_desc_query_memory_d(
535 mkldnn_primitive_desc_query_pd(cpd, query, index));
538 if (p->alg == AUTO) {
539 mkldnn_convolution_desc_t *temp_conv_desc = {0};
540 DNN_SAFE(mkldnn_primitive_desc_query(cpd,
541 mkldnn_query_convolution_d, 0, &temp_conv_desc), CRIT);
542 cd.alg_kind = temp_conv_desc->alg_kind;
546 cd.diff_src_desc = q(mkldnn_query_diff_src_pd);
548 cd.src_desc = q(mkldnn_query_src_pd);
550 if (p->dir & FLAG_WEI)
551 cd.diff_weights_desc = q(mkldnn_query_diff_weights_pd);
553 cd.weights_desc = q(mkldnn_query_weights_pd);
555 if (p->dir & FLAG_BIA) {
556 if (p->dir & FLAG_BWD)
557 cd.diff_bias_desc = q(mkldnn_query_diff_weights_pd, 1);
559 cd.bias_desc = q(mkldnn_query_weights_pd, 1);
562 if (p->dir & FLAG_BWD)
563 cd.diff_dst_desc = q(mkldnn_query_diff_dst_pd);
565 cd.dst_desc = q(mkldnn_query_dst_pd);
570 int doit(const prb_t *p, res_t *r) {
574 mkldnn_convolution_desc_t cd;
575 mkldnn_primitive_desc_t cpd;
576 mkldnn_primitive_t c{};
578 SAFE(init_pd(p, cd, cpd, r), WARN);
580 prb_t *p_temp = nullptr;
581 if (p->alg == AUTO || p->alg == WINO) {
582 p_temp = new prb_t((desc_t)*p, p->dir, p->cfg,
583 p->alg, p->attr, p->mb);
584 if (p->alg == AUTO) p_temp->alg = alg_kind2alg(cd.alg_kind);
585 p_temp->cfg = auto_cfg(p_temp->alg, p->cfg);
590 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
593 auto &src_dt_d = p->dir == BWD_D ? cd.diff_src_desc : cd.src_desc;
594 auto &wei_dt_d = p->dir & FLAG_WEI ? cd.diff_weights_desc : cd.weights_desc;
595 auto &bia_dt_d = p->dir & FLAG_BWD ? cd.diff_bias_desc : cd.bias_desc;
596 auto &dst_dt_d = p->dir & FLAG_BWD ? cd.diff_dst_desc: cd.dst_desc;
598 dnn_mem_t src_dt(src_dt_d, p->cfg[SRC].dt);
599 dnn_mem_t wei_dt(wei_dt_d, p->cfg[WEI].dt);
600 dnn_mem_t dst_dt(dst_dt_d, p->cfg[DST].dt);
601 dnn_mem_t *p_bia_dt = p->dir & FLAG_BIA
602 ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
603 dnn_mem_t &bia_dt = *p_bia_dt;
605 auto src_format = get_default_format(src_dt.md_.ndims, DATA);
606 auto wei_format = get_default_format(wei_dt.md_.ndims,
607 p->has_groups ? GWEI : WEI);
609 const auto fp = mkldnn_f32;
610 dnn_mem_t src_fp(src_dt_d, fp, src_format);
611 dnn_mem_t wei_fp(wei_dt_d, fp, wei_format);
612 dnn_mem_t dst_fp(dst_dt_d, fp, src_format);
613 dnn_mem_t *p_bia_fp = p->dir & FLAG_BIA
614 ? new dnn_mem_t(bia_dt_d, fp, mkldnn_x) : new dnn_mem_t();
615 dnn_mem_t &bia_fp = *p_bia_fp;
617 SAFE(fill_src(p, src_dt, src_fp, r), WARN);
618 SAFE(fill_wei(p, wei_dt, wei_fp, r), WARN);
619 SAFE(fill_dst(p, dst_dt, dst_fp, r), WARN);
620 if (p->dir & FLAG_BIA)
621 SAFE(fill_bia(p, bia_dt, bia_fp, r), WARN);
623 if (p->dir & FLAG_FWD) {
624 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {wei_dt.p_, 0},
625 {p->dir & FLAG_BIA ? bia_dt.p_ : NULL, 0}
627 const_mkldnn_primitive_t outputs[] = { dst_dt.p_ };
628 DNN_SAFE(mkldnn_primitive_create(&c, cpd, inputs, outputs), WARN);
629 SAFE(execute(c), WARN);
630 if (bench_mode & CORR) {
631 compute_ref_fwd(p, src_fp, wei_fp, bia_fp, dst_fp);
632 dnn_mem_t dst(dst_dt, fp, src_format);
633 SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
635 } else if (p->dir == BWD_D) {
636 mkldnn_primitive_at_t inputs[3] = { {dst_dt.p_, 0}, {wei_dt.p_, 0}, };
637 const_mkldnn_primitive_t outputs[] = { src_dt.p_ };
638 DNN_SAFE(mkldnn_primitive_create(&c, cpd, inputs, outputs), WARN);
639 SAFE(execute(c), WARN);
640 if (bench_mode & CORR) {
641 compute_ref_bwd_d(p, src_fp, wei_fp, bia_fp, dst_fp);
642 dnn_mem_t src(src_dt, fp, src_format);
643 SAFE(compare_src(p, src, src_fp, r, true), WARN);
645 } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
646 mkldnn_primitive_at_t inputs[3] = { {src_dt.p_, 0}, {dst_dt.p_, 0}, };
647 const_mkldnn_primitive_t outputs[] = { wei_dt.p_,
648 p->dir & FLAG_BIA ? bia_dt.p_ : NULL,
650 DNN_SAFE(mkldnn_primitive_create(&c, cpd, inputs, outputs), WARN);
651 SAFE(execute(c), WARN);
652 if (bench_mode & CORR) {
653 compute_ref_bwd_w(p, src_fp, wei_fp, bia_fp, dst_fp);
654 dnn_mem_t wei(wei_dt, fp, wei_format);
655 SAFE(compare_wei(p, wei, wei_fp, r, true), WARN);
656 if (p->dir & FLAG_BIA) {
657 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
658 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
667 if (bench_mode & PERF) {
671 SAFE(execute(c), WARN);
673 const bool stop = false
674 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
675 || (!fix_times_per_prb
676 && t.total_ms() >= max_ms_per_prb
677 && t.times() >= min_times_per_prb);
682 DNN_SAFE(mkldnn_primitive_desc_destroy(cpd), CRIT);
683 DNN_SAFE(mkldnn_primitive_destroy(c), CRIT);