1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
25 #include "src/common/mkldnn_thread.hpp"
27 #include "mkldnn_common.hpp"
28 #include "mkldnn_memory.hpp"
31 #include "bnorm/bnorm.hpp"
35 static int prepare_fwd(const prb_t *p, dnn_mem_t &src, dnn_mem_t &mean,
36 dnn_mem_t &var, dnn_mem_t &ss) {
37 /** Idea: choose src[] values so that both mean and variance are computed
38 * exactly (independently of the order of the computations).
40 * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean.
42 * The variation in src is allowed in the last flex_bits bits.
43 * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean
44 * value is set to 0 and src is partially filled with zeros (according to
45 * density so that at least want_flex_bits is reserved for src variation.
46 * Once src is set, variance is computed.
48 * ALG_0: mean is set to 0
49 * ALG_1: mean is set to 2^p, where p \in {-2, -1, ..., 4}
50 * ALG_AUTO: choose between ALG_0 and ALG_1 automatically */
51 const int exact_bits = 24;
52 const int L = p->mb * p->id * p->ih * p->iw;
53 const int logL = (int)ceilf(log2f(L));
55 assert(logL <= 0 || (1<<(logL-1)) < L);
56 assert(L <= (1<<logL));
58 const int min_flex_bits = 3;
59 const int want_flex_bits = 6;
61 check_alg_t alg = p->check_alg;
62 if (alg == ALG_AUTO) /* choose appropriate checking algorithm */
63 alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0;
65 const int flex_bits = alg == ALG_0
66 ? want_flex_bits : ((exact_bits - logL) / 2 - 1);
68 if (flex_bits < min_flex_bits)
71 const int flex_mask = (1 << flex_bits) - 1;
73 /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */
74 const float density = alg == ALG_0
75 ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L : 1.f;
76 assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits);
78 print(6, "check_alg: %s, density = %g, flex_bits = %d\n",
79 check_alg2str(alg), density, flex_bits);
81 mkldnn::impl::parallel_nd(p->ic, [&](int c) {
82 const float m = ((float *)mean)[c] =
83 alg == ALG_0 ? 0.f : 0.25f * (1 << (c % 7));
84 float v = 0; /* current variance */
86 for (int mb = 0; mb < p->mb; ++mb) {
87 size_t l_base = mb * p->id * p->ih * p->iw + c * 239 * 2; // l[0] must be even
88 float *s = (float *)src + data_off(p, mb, c, 0, 0, 0);
90 for (int d = 0; d < p->id; ++d)
91 for (int h = 0; h < p->ih; ++h)
92 for (int w = 0; w < p->iw; ++w) {
94 const int sp = d * p->ih * p->iw + h * p->iw + w;
95 const size_t l = l_base + sp;
97 if (alg == ALG_0 && !flip_coin(l/2 * 257ULL, density)) {
102 const size_t gen = (l / 2 * 1637) & flex_mask;
103 const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */
104 const float f = 1.f * sgn * gen / (1 << flex_bits);
106 s[sp] = alg == ALG_0 ? f : m * (1.f + f);
107 if (L % 2 && (mb * p->id * p->ih * p->iw + sp == L - 1)) {
110 v += (s[sp] - m) * (s[sp] - m);
114 ((float *)var)[c] = v / (p->mb * p->id * p->ih * p->iw);
116 if (p->flags & USE_SCALESHIFT) {
117 ((float *)ss)[c] = 1.f / 8 * (1 << (c % 7));
118 ((float *)ss)[p->ic + c] = ((c % 3) - 1) * ((float *)ss)[c] / 64;
120 ((float *)ss)[c] = 1;
121 ((float *)ss)[p->ic + c] = 0;
128 /** @brief L = 2^k * P, P % 2 != 0 */
129 static void decompose2(int L, int &k, int &P) {
131 for (k = 0; P % 2 == 0; ++k)
135 static int prepare_bwd(const prb_t *p, dnn_mem_t &src, dnn_mem_t &d_dst,
136 dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &ss, dnn_mem_t &mask) {
137 const int exact_bits = 24;
139 const int L = p->mb * p->id * p->ih * p->iw;
143 /** Stabilization idea...
145 * d_src = func(d_beta / L, d_gamma' / L, ...)
146 * try to make d_beta = L / 2^t_beta and d_gamma' = L / 2^t_gamma,
147 * where both t_beta and t_gamma are in {1, .., max_k}.
148 * Currently, with no obvious reason, max_k is set to 4 for
149 * reasonably small problems and to 8 for big problems.
151 * Here d_gamma' = d_gamma / sqrt(var + eps).
152 * We might hope that division by L would be exact in that case,
153 * but that might happen iff L is less than 2^exact_bits, hence
154 * restriction [r1]. */
159 int log2P = (int)ceilf(log2f(P));
160 if (log2P >= exact_bits)
161 return FAIL; /* [r1] */
163 const int max_k = L > (1<<20) ? 8 : 4;
164 if (k > max_k && exact_bits - log2P > max_k + 4) {
165 log2P += (k - max_k);
170 const int param_dd_p2 = 7; // factor_dd <- 2^{0, .., -param_db_p2+1}
171 const int param_dd_gen = 32; // gen_dd <- {1, .., param_dd_gen}
173 const int param_f_p2 = 1; // factor_f <- 2^{-param_dg_p2}
174 const int param_f_gen = 16; // gen_f <- {2..param_s_gen}
176 const float ub_dg = param_dd_gen * param_f_gen / 2 * L;
177 const float ub_db = param_dd_gen * L;
178 const float density = MIN3(1.f, (1<<exact_bits) / ub_dg,
179 (1<<exact_bits) / ub_db);
181 print(5, "prep_bwd: k:%d, P:%d log2P:%d, density = %g\n",
182 k, P, log2P, density);
184 mkldnn::impl::parallel_nd(p->ic, [&](int c) {
185 const float m = ((float *)mean)[c] = c % 2;
187 /* var + eps \in {1/4, 1, 4} */
188 const float ve_denom = 4.f / (1 << 2 * (c % 3));
189 ((float *)var)[c] = ve_denom - p->eps;
191 const int dd_p2 = (c * 127 % param_dd_p2);
192 const float factor_dd = 1.f / (1 << dd_p2);
194 const int f_p2 = 1 + (c % param_f_p2);
195 const float factor_f = 1.f / (1 << f_p2);
197 const float target_db = factor_dd * P;
198 const float target_dg = ve_denom * 2 * target_db;
200 float dg = 0, db = 0; /* current d_beta and d_gamma */
201 for (int mb = 0; mb < p->mb; ++mb) {
202 const int l_base = mb * p->id * p->ih * p->iw;
204 const auto off = data_off(p, mb, c, 0, 0, 0);
205 float *s = (float *)src + off;
206 float *dd = (float *)d_dst + off;
207 float *rmask = (float *)mask + off;
209 for (int d = 0; d < p->id; ++d)
210 for (int h = 0; h < p->ih; ++h)
211 for (int w = 0; w < p->iw; ++w) {
213 const int sp = d * p->ih * p->iw + h * p->iw + w;
214 if (!flip_coin(l_base + sp, density) && l_base + sp + 100 < L) {
220 if (l_base + sp + 2 >= L) continue; /* last 2 are special */
221 const int l = l_base + sp * 7 + c * 19 + mb * 13;
224 if (p->flags & FUSE_BN_RELU)
225 rmask[sp] = rmask_v = l % 5 != 1;
227 const int sgn_dd = db < target_db ? 1 : -1;
228 dd[sp] = sgn_dd * factor_dd * (1 + (l * 3 % param_dd_gen));
229 if (rmask_v) db += dd[sp];
231 const int sgn_f = dg < target_dg ? 1 : -1;
233 sgn_f * factor_f * (2 + (l * 7 % (param_f_gen - 1)));
235 if (rmask_v) dg += f * dd[sp];
241 /* the last 2 elements in src and d_dst are set, so that:
244 * For this we need to solve the system:
245 * d_dst[l1] + d_dst[l0] = target_db - db
246 * d_dst[l1] * src[l1] + d_dst[l0] * src[l0] = target_dg - dg
248 * Here l0 -- last index, l1 -- last but one.
249 * More over, let's assume src[l1] = 1 and src[l0] = -1. */
250 size_t l0 = data_off(p, p->mb - 1, c, p->id - 1, p->ih - 1,
253 if (p->id == 1 && p->ih == 1 && p->iw == 1)
254 l1 = data_off(p, p->mb - 2, c, p->id - 1, p->ih - 1, p->iw - 1);
256 ((float *)src)[l1] = 1.f;
257 ((float *)src)[l0] = -1.f;
258 if (p->flags & FUSE_BN_RELU)
259 ((float *)mask)[l0] = ((float *)mask)[l1] = 1;
261 float f1 = ((target_db - db) + (target_dg - dg)) /2;
262 float f0 = ((target_db - db) - (target_dg - dg)) /2;
264 ((float *)d_dst)[l1] = f1 + m;
265 ((float *)d_dst)[l0] = f0 + m;
268 if (p->flags & USE_SCALESHIFT) {
269 ((float *)ss)[c] = 1.f / 2 * (1 << (c % 7));
270 ((float *)ss)[p->ic + c] = ((float *)ss)[c] / 64;
272 ((float *)ss)[c] = 1;
273 ((float *)ss)[p->ic + c] = 0;
280 static int compare(const prb_t *p, data_kind_t kind, const dnn_mem_t &fp_mem,
281 const dnn_mem_t &dt_mem, res_t *r) {
282 const char *skind = data_kind2str(kind);
283 const float eps = p->dir & FLAG_FWD
284 ? (kind == DATA ? 5e-7 : 0)
285 : (kind == DATA ? 2e-7 : 0);
287 /* With all the stability tricks bwd_d is still pretty unstable.
288 * So let's rely on relative error in L1, L2, and L_inf norms.
289 * TODO: make computations for bwd_d more stable and use `L0` here. */
290 const bool rely_on_norm = false
291 || (kind == DATA && (p->dir & FLAG_BWD) && (p->flags | GLOB_STATS));
293 const size_t nelems = kind == DATA
294 ? (size_t)p->mb * p->ic * p->id * p->ih * p->iw
295 : (size_t)p->ic * (kind == SS ? 2 : 1);
296 r->total += rely_on_norm ? 1 : nelems;
298 diff_norm_t diff_norm;
300 for (size_t i = 0; i < nelems; ++i) {
301 const float fp = ((const float *)fp_mem)[i];
302 const float dt = ((const float *)dt_mem)[i];
303 diff_norm.update(fp, dt);
308 const float diff = fabsf(fp - dt);
309 const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
310 const bool ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= eps;
315 || (!ok && (r->errors < 10 || verbose >= 10))
316 || (verbose >= 50 && i < 30);
318 const int ind_str_len = 32;
319 char ind_str[ind_str_len] = {'\0'};
322 inv_data_off(p, i, mb, c, d, h, w);
323 snprintf(ind_str, ind_str_len, "%d,%d,%d,%d,%d", mb, c, d,h, w);
324 } else if (kind == SS) {
325 snprintf(ind_str, ind_str_len, "%d,%d",
326 (int)i / p->ic, (int)i % p->ic);
328 snprintf(ind_str, ind_str_len, "%d", (int)i);
331 print(0, "[%lu][%s%s][%s] fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
332 (unsigned long)i, p->dir & FLAG_BWD ? "D_" : "", skind,
333 ind_str, fp, dt, diff, rel_diff);
341 || diff_norm.rel_diff(norm_t::L1) > eps
342 || diff_norm.rel_diff(norm_t::L2) > eps
343 || diff_norm.rel_diff(norm_t::L8) > eps;
346 if (r->errors || verbose >= 5) {
347 const int vl = r->errors ? 0 : 2;
348 print(vl, "@@@ [%s%s] diff: l0(``%g``) "
349 "l1:(%g,%g,%g,``%g``) "
350 "l2:(%g,%g,%g,``%g``) "
351 "l8:(%g,%g,%g,``%g``)\n",
352 p->dir & FLAG_BWD ? "D_" : "", skind,
353 diff_norm.rel_diff(norm_t::L0),
354 diff_norm.a_[norm_t::L1], diff_norm.b_[norm_t::L1],
355 diff_norm.diff_[norm_t::L1], diff_norm.rel_diff(norm_t::L1),
356 diff_norm.a_[norm_t::L2], diff_norm.b_[norm_t::L2],
357 diff_norm.diff_[norm_t::L2], diff_norm.rel_diff(norm_t::L2),
358 diff_norm.a_[norm_t::L8], diff_norm.b_[norm_t::L8],
359 diff_norm.diff_[norm_t::L8], diff_norm.rel_diff(norm_t::L8));
365 if (r->state == UNTESTED)
366 r->state = PASSED; /* optimism */
368 return r->state == FAILED ? FAIL : OK;
371 int check_fwd_ws(const dnn_mem_t &data_dt, const dnn_mem_t &ws_dt, res_t *r) {
372 /* so far we know ws is just bit-mask of whether value was negative or
374 const size_t nelems = data_dt.nelems(true);
375 const float *d = (const float *)data_dt;
376 const uint8_t *ws = (const uint8_t *)ws_dt;
378 /* some internal knowledge: flags in ws are either stored as bytes (e.g.
379 * for the ref implementation) or as bits (e.g. for the jitted one); in
380 * the latter case the ws memory has fewer elements than the data memory */
381 enum { ws_byte, ws_bit } ws_type;
382 ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte;
384 /* more internal knowledge: data_dt and ws_dt are expected to have exactly
385 * the same data layout, and data_dt padded regions are expected to be
386 * zero, and the respective ws_dt elements should be set accordingly */
387 for (size_t i = 0; i < nelems; i += 8) {
388 for (size_t j = 0; j < MIN2(8, nelems - i); ++j) {
389 const bool want = *d > 0;
390 const bool bit_set = ws_type == ws_byte ? *ws : !!(*ws & (1<<j));
392 const bool ok = bit_set == want;
396 || (!ok && (r->errors < 10 || verbose >= 10))
397 || (verbose >= 50 && i < 30);
399 print(0, "[%lu] ws exp:%d got:%d (data:%g:%a)\n",
400 (unsigned long)(i + j), want, bit_set, *d, *d);
404 if (ws_type == ws_byte) ++ws;
406 if (ws_type == ws_bit) ++ws;
412 if (r->state == UNTESTED)
413 r->state = PASSED; /* optimism */
415 return r->state == FAILED ? FAIL : OK;
418 static int init_pd(const prb_t *p, mkldnn_batch_normalization_desc_t &bd,
419 mkldnn_primitive_desc_t &bpd, res_t *r) {
420 mkldnn_memory_desc_t data_d;
421 mkldnn_dims_t data_dims = {p->mb, p->ic, p->ih, p->iw};
422 mkldnn_dims_t data_dims_3d = {p->mb, p->ic, p->id, p->ih, p->iw};
423 DNN_SAFE(mkldnn_memory_desc_init(&data_d, is_bnorm_3d(p) ? 5 : 4,
424 is_bnorm_3d(p) ? data_dims_3d : data_dims, p->dt, p->fmt), WARN);
426 auto flags = (mkldnn_batch_normalization_flag_t)p->flags;
427 if (p->dir & FLAG_FWD) {
428 auto prop = p->dir & FLAG_INF
429 ? mkldnn_forward_inference : mkldnn_forward_training;
430 DNN_SAFE(mkldnn_batch_normalization_forward_desc_init(&bd, prop,
431 &data_d, p->eps, flags), WARN);
434 auto prop = p->dir & FLAG_WEI
435 ? mkldnn_backward : mkldnn_backward_data;
436 DNN_SAFE(mkldnn_batch_normalization_backward_desc_init(&bd, prop,
437 &data_d, &data_d, p->eps, flags), WARN);
440 auto mkldnn_attr = create_mkldnn_attr(p->attr, 1, NULL);
442 mkldnn_primitive_desc_t hint_fwd_pd = NULL;
443 if (p->dir & FLAG_BWD) {
444 mkldnn_batch_normalization_desc_t bd_fwd;
445 DNN_SAFE(mkldnn_batch_normalization_forward_desc_init(&bd_fwd,
446 mkldnn_forward_training, &data_d, p->eps, flags), WARN);
447 DNN_SAFE(mkldnn_primitive_desc_create_v2(&hint_fwd_pd, &bd_fwd, NULL,
448 engine, NULL), WARN);
450 mkldnn_status_t init_status = mkldnn_primitive_desc_create_v2(&bpd, &bd,
451 mkldnn_attr, engine, hint_fwd_pd);
453 mkldnn_primitive_desc_destroy(hint_fwd_pd);
454 mkldnn_primitive_attr_destroy(mkldnn_attr);
456 if (init_status == mkldnn_unimplemented)
457 return r->state = UNIMPLEMENTED, OK;
459 SAFE(init_status, WARN);
461 const char *impl_str = query_impl_info(bpd);
462 if (maybe_skip(skip_impl, impl_str)) {
463 print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
464 DNN_SAFE(mkldnn_primitive_desc_destroy(bpd), WARN);
465 return r->state = SKIPPED, OK;
467 print(5, "mkldnn implementation: %s\n", impl_str);
468 if (!strstr(impl_str, "jit")) {
469 print(1, "WARNING: %s",
470 "accuracy of the implementation being tested "
471 "depends on the compiler and might give false-positives.\n");
473 "please consider recompiling the sources with"
474 " `-prec-div -fp-model precise` for a reliable testing.\n");
481 /** converts benchdnn-understandable mask of {0, 1} to workspace */
482 static int cvt_mask_to_ws(const prb_t *p, const dnn_mem_t &mask_fp,
484 mkldnn_dims_t data_dims = {p->mb, p->ic, p->ih, p->iw};
485 mkldnn_dims_t data_dims_3d = {p->mb, p->ic, p->id, p->ih, p->iw};
487 dnn_mem_t data(is_bnorm_3d(p) ? 5 : 4,
488 is_bnorm_3d(p) ? data_dims_3d : data_dims, mkldnn_f32, p->fmt);
489 SAFE(data.reorder(mask_fp), WARN);
491 ptrdiff_t ic = p->ic;
492 dnn_mem_t mean(1, &ic, mkldnn_f32, mkldnn_x);
493 dnn_mem_t var(1, &ic, mkldnn_f32, mkldnn_x);
494 for (int c = 0; c < p->ic; ++c) ((float *)mean)[c] = 0.5;
495 for (int c = 0; c < p->ic; ++c) ((float *)var)[c] = 1;
497 mkldnn_batch_normalization_desc_t bd;
498 auto flags = (mkldnn_batch_normalization_flag_t)
499 (mkldnn_use_global_stats | mkldnn_fuse_bn_relu);
500 DNN_SAFE(mkldnn_batch_normalization_forward_desc_init(&bd,
501 mkldnn_forward_training, &data.md_, 0, flags), WARN);
503 mkldnn_primitive_desc_t bpd;
504 DNN_SAFE(mkldnn_primitive_desc_create_v2(&bpd, &bd, NULL, engine, NULL),
507 mkldnn_primitive_t b{};
508 mkldnn_primitive_at_t inputs[3] = {
509 {data.p_, 0}, {mean.p_, 0}, {var.p_, 0}};
510 const_mkldnn_primitive_t outputs[2] = {data.p_, ws_dt.p_};
511 DNN_SAFE(mkldnn_primitive_create(&b, bpd, inputs, outputs), WARN);
512 SAFE(execute(b), WARN);
514 DNN_SAFE(mkldnn_primitive_desc_destroy(bpd), CRIT);
515 DNN_SAFE(mkldnn_primitive_destroy(b), CRIT);
520 int doit(const prb_t *p, res_t *r) {
524 mkldnn_batch_normalization_desc_t bd;
525 mkldnn_primitive_desc_t bpd;
526 mkldnn_primitive_t b{};
528 SAFE(init_pd(p, bd, bpd, r), WARN);
529 if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
532 const auto fp = mkldnn_f32;
533 auto &data_dt_d = bd.data_desc;
535 const mkldnn_dims_t dims1d = {p->ic};
536 const mkldnn_dims_t dims2d = {2, p->ic};
537 const auto src_format = is_bnorm_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
539 dnn_mem_t data_fp(data_dt_d, fp, src_format),
541 dnn_mem_t d_data_fp(data_dt_d, fp, src_format),
542 d_data_dt(data_dt_d);
544 dnn_mem_t mean_fp(1, dims1d, fp, mkldnn_x),
545 mean_dt(mean_fp.md_);
546 dnn_mem_t var_fp(1, dims1d, fp, mkldnn_x),
549 dnn_mem_t ss_fp(2, dims2d, fp, mkldnn_nc),
551 dnn_mem_t d_ss_fp(2, dims2d, fp, mkldnn_nc),
552 d_ss_dt(d_ss_fp.md_);
554 dnn_mem_t ws_fp(data_fp.md_);
555 dnn_mem_t *p_ws_dt = NULL;
556 if ((p->flags & FUSE_BN_RELU) && !(p->dir & FLAG_INF)) {
557 const auto ws_pd = mkldnn_primitive_desc_query_pd(bpd,
558 mkldnn_query_workspace_pd, 0);
559 SAFE(ws_pd != NULL ? OK : FAIL, WARN);
560 p_ws_dt = new dnn_mem_t(*mkldnn_primitive_desc_query_memory_d(ws_pd));
562 p_ws_dt = new dnn_mem_t();
564 dnn_mem_t &ws_dt = *p_ws_dt;
566 if (p->dir & FLAG_FWD) {
567 if (prepare_fwd(p, data_fp, mean_fp, var_fp, ss_fp) != OK)
568 return r->state = MISTRUSTED, OK;
570 mkldnn_primitive_at_t inputs[4];
571 const_mkldnn_primitive_t outputs[4];
575 SAFE(data_dt.reorder(data_fp), WARN);
576 inputs[idx++] = {data_dt.p_, 0};
578 if (p->flags & GLOB_STATS) {
579 SAFE(mean_dt.reorder(mean_fp), WARN);
580 SAFE(var_dt.reorder(var_fp), WARN);
581 inputs[idx++] = {mean_dt.p_, 0};
582 inputs[idx++] = {var_dt.p_, 0};
584 if (p->flags & USE_SCALESHIFT) {
585 SAFE(ss_dt.reorder(ss_fp), WARN);
586 inputs[idx++] = {ss_dt.p_, 0};
590 outputs[idx++] = data_dt.p_; /* always in-place so far... */
591 if (!(p->flags & GLOB_STATS)) {
592 outputs[idx++] = mean_dt.p_;
593 outputs[idx++] = var_dt.p_;
596 if (p->flags & FUSE_BN_RELU)
597 outputs[idx++] = ws_dt.p_;
599 DNN_SAFE(mkldnn_primitive_create(&b, bpd, inputs, outputs), WARN);
600 SAFE(execute(b), WARN);
601 if (bench_mode & CORR) {
602 compute_ref_fwd(p, data_fp, mean_fp, var_fp, ss_fp, data_fp);
603 if (!(p->flags & GLOB_STATS) && !(p->dir & FLAG_INF)) {
604 SAFE(compare(p, MEAN, mean_fp, mean_dt, r), WARN);
605 SAFE(compare(p, VAR, var_fp, var_dt, r), WARN);
607 dnn_mem_t data(data_dt, fp, src_format);
608 SAFE(compare(p, DATA, data_fp, data, r), WARN);
609 if ((p->flags & FUSE_BN_RELU) && !(p->dir & FLAG_INF))
610 SAFE(check_fwd_ws(data_dt, ws_dt, r), WARN);
613 if (prepare_bwd(p, data_fp, d_data_fp, mean_fp, var_fp, ss_fp, ws_fp)
615 return r->state = MISTRUSTED, OK;
617 mkldnn_primitive_at_t inputs[6];
618 const_mkldnn_primitive_t outputs[2];
622 SAFE(data_dt.reorder(data_fp), WARN);
623 inputs[idx++] = {data_dt.p_, 0};
625 SAFE(mean_dt.reorder(mean_fp), WARN);
626 SAFE(var_dt.reorder(var_fp), WARN);
627 inputs[idx++] = {mean_dt.p_, 0};
628 inputs[idx++] = {var_dt.p_, 0};
630 SAFE(d_data_dt.reorder(d_data_fp), WARN);
631 inputs[idx++] = {d_data_dt.p_, 0};
633 if (p->flags & USE_SCALESHIFT) {
634 SAFE(ss_dt.reorder(ss_fp), WARN);
635 inputs[idx++] = {ss_dt.p_, 0};
638 if (p->flags & FUSE_BN_RELU) {
639 SAFE(cvt_mask_to_ws(p, ws_fp, ws_dt), WARN);
640 inputs[idx++] = {ws_dt.p_, 0};
644 outputs[idx++] = d_data_dt.p_; /* always in-place so far... */
645 if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI))
646 outputs[idx++] = d_ss_dt.p_;
648 DNN_SAFE(mkldnn_primitive_create(&b, bpd, inputs, outputs), WARN);
649 SAFE(execute(b), WARN);
650 if (bench_mode & CORR) {
651 compute_ref_bwd(p, data_fp, mean_fp, var_fp, d_data_fp, ss_fp,
652 ws_fp, d_data_fp, d_ss_fp);
653 if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI))
654 SAFE(compare(p, SS, d_ss_fp, d_ss_dt, r), WARN);
655 dnn_mem_t d_data(d_data_dt, fp,
656 is_bnorm_3d(p) ? mkldnn_ncdhw : mkldnn_nchw);
657 SAFE(compare(p, DATA, d_data_fp, d_data, r), WARN);
661 if (bench_mode & PERF) {
665 SAFE(execute(b), WARN);
667 const bool stop = false
668 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
669 || (!fix_times_per_prb
670 && t.total_ms() >= max_ms_per_prb
671 && t.times() >= min_times_per_prb);
677 DNN_SAFE(mkldnn_primitive_desc_destroy(bpd), CRIT);
678 DNN_SAFE(mkldnn_primitive_destroy(b), CRIT);