Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / bnorm / bnorm.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18 #include <stddef.h>
19 #include <stdio.h>
20 #include <float.h>
21 #include <math.h>
22
23 #include "mkldnn.h"
24
25 #include "src/common/mkldnn_thread.hpp"
26
27 #include "mkldnn_common.hpp"
28 #include "mkldnn_memory.hpp"
29 #include "norm.hpp"
30
31 #include "bnorm/bnorm.hpp"
32
33 namespace bnorm {
34
35 static int prepare_fwd(const prb_t *p, dnn_mem_t &src, dnn_mem_t &mean,
36         dnn_mem_t &var, dnn_mem_t &ss) {
37     /** Idea: choose src[] values so that both mean and variance are computed
38      * exactly (independently of the order of the computations).
39      *
40      * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean.
41      *
42      * The variation in src is allowed in the last flex_bits bits.
43      * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean
44      * value is set to 0 and src is partially filled with zeros (according to
45      * density so that at least want_flex_bits is reserved for src variation.
46      * Once src is set, variance is computed.
47      *
48      * ALG_0: mean is set to 0
49      * ALG_1: mean is set to 2^p, where p \in {-2, -1, ..., 4}
50      * ALG_AUTO: choose between ALG_0 and ALG_1 automatically */
51     const int exact_bits = 24;
52     const int L = p->mb * p->id * p->ih * p->iw;
53     const int logL = (int)ceilf(log2f(L));
54
55     assert(logL <= 0 || (1<<(logL-1)) < L);
56     assert(L <= (1<<logL));
57
58     const int min_flex_bits = 3;
59     const int want_flex_bits = 6;
60
61     check_alg_t alg = p->check_alg;
62     if (alg == ALG_AUTO) /* choose appropriate checking algorithm */
63         alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0;
64
65     const int flex_bits = alg == ALG_0
66         ? want_flex_bits : ((exact_bits - logL) / 2 - 1);
67
68     if (flex_bits < min_flex_bits)
69         return FAIL;
70
71     const int flex_mask = (1 << flex_bits) - 1;
72
73     /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */
74     const float density = alg == ALG_0
75         ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L : 1.f;
76     assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits);
77
78     print(6, "check_alg: %s, density = %g, flex_bits = %d\n",
79             check_alg2str(alg), density, flex_bits);
80
81     mkldnn::impl::parallel_nd(p->ic, [&](int c) {
82         const float m = ((float *)mean)[c] =
83             alg == ALG_0 ? 0.f : 0.25f * (1 << (c % 7));
84         float v = 0; /* current variance */
85
86         for (int mb = 0; mb < p->mb; ++mb) {
87             size_t l_base = mb * p->id * p->ih * p->iw + c * 239 * 2; // l[0] must be even
88             float *s = (float *)src + data_off(p, mb, c, 0, 0, 0);
89
90             for (int d = 0; d < p->id; ++d)
91             for (int h = 0; h < p->ih; ++h)
92             for (int w = 0; w < p->iw; ++w) {
93
94                 const int sp = d * p->ih * p->iw + h * p->iw + w;
95                 const size_t l = l_base + sp;
96
97                 if (alg == ALG_0 && !flip_coin(l/2 * 257ULL, density)) {
98                     s[sp] = 0;
99                     continue;
100                 }
101
102                 const size_t gen = (l / 2 * 1637) & flex_mask;
103                 const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */
104                 const float f = 1.f * sgn * gen / (1 << flex_bits);
105
106                 s[sp] = alg == ALG_0 ? f : m * (1.f + f);
107                 if (L % 2 && (mb * p->id * p->ih * p->iw + sp == L - 1)) {
108                     s[sp] = m;
109                 }
110                 v += (s[sp] - m) * (s[sp] - m);
111             }
112         }
113
114         ((float *)var)[c] = v / (p->mb * p->id * p->ih * p->iw);
115
116         if (p->flags & USE_SCALESHIFT) {
117             ((float *)ss)[c] = 1.f / 8 * (1 << (c % 7));
118             ((float *)ss)[p->ic + c] = ((c % 3) - 1) * ((float *)ss)[c] / 64;
119         } else {
120             ((float *)ss)[c] = 1;
121             ((float *)ss)[p->ic + c] = 0;
122         }
123     });
124
125     return OK;
126 }
127
128 /** @brief L = 2^k * P, P % 2 != 0 */
129 static void decompose2(int L, int &k, int &P) {
130     P = L;
131     for (k = 0; P % 2 == 0; ++k)
132         P /= 2;
133 }
134
135 static int prepare_bwd(const prb_t *p, dnn_mem_t &src, dnn_mem_t &d_dst,
136         dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &ss, dnn_mem_t &mask) {
137     const int exact_bits = 24;
138
139     const int L = p->mb * p->id * p->ih * p->iw;
140     if (L < 2)
141         return FAIL;
142
143     /** Stabilization idea...
144      * Since
145      *      d_src = func(d_beta / L, d_gamma' / L, ...)
146      * try to make d_beta = L / 2^t_beta and d_gamma' = L / 2^t_gamma,
147      * where both t_beta and t_gamma are in {1, .., max_k}.
148      * Currently, with no obvious reason, max_k is set to 4 for
149      * reasonably small problems and to 8 for big problems.
150      *
151      * Here d_gamma' = d_gamma / sqrt(var + eps).
152      * We might hope that division by L would be exact in that case,
153      * but that might happen iff L is less than 2^exact_bits, hence
154      * restriction [r1]. */
155
156     int k, P;
157     decompose2(L, k, P);
158
159     int log2P = (int)ceilf(log2f(P));
160     if (log2P >= exact_bits)
161         return FAIL; /* [r1] */
162
163     const int max_k = L > (1<<20) ? 8 : 4;
164     if (k > max_k && exact_bits - log2P > max_k + 4) {
165         log2P += (k - max_k);
166         P <<= k - max_k;
167         k = max_k;
168     }
169
170     const int param_dd_p2 = 7;   // factor_dd <- 2^{0, .., -param_db_p2+1}
171     const int param_dd_gen = 32; // gen_dd <- {1, .., param_dd_gen}
172
173     const int param_f_p2 = 1;    // factor_f <- 2^{-param_dg_p2}
174     const int param_f_gen = 16;  // gen_f <- {2..param_s_gen}
175
176     const float ub_dg = param_dd_gen * param_f_gen / 2 * L;
177     const float ub_db = param_dd_gen * L;
178     const float density = MIN3(1.f, (1<<exact_bits) / ub_dg,
179             (1<<exact_bits) / ub_db);
180
181     print(5, "prep_bwd: k:%d, P:%d log2P:%d, density = %g\n",
182             k, P, log2P, density);
183
184     mkldnn::impl::parallel_nd(p->ic, [&](int c) {
185         const float m = ((float *)mean)[c] = c % 2;
186
187         /* var + eps \in {1/4, 1, 4} */
188         const float ve_denom = 4.f / (1 << 2 * (c % 3));
189         ((float *)var)[c] = ve_denom - p->eps;
190
191         const int dd_p2 = (c * 127 % param_dd_p2);
192         const float factor_dd = 1.f / (1 << dd_p2);
193
194         const int f_p2 = 1 + (c % param_f_p2);
195         const float factor_f = 1.f / (1 << f_p2);
196
197         const float target_db = factor_dd * P;
198         const float target_dg = ve_denom * 2 * target_db;
199
200         float dg = 0, db = 0; /* current d_beta and d_gamma */
201         for (int mb = 0; mb < p->mb; ++mb) {
202             const int l_base = mb * p->id * p->ih * p->iw;
203
204             const auto off = data_off(p, mb, c, 0, 0, 0);
205             float *s = (float *)src + off;
206             float *dd = (float *)d_dst + off;
207             float *rmask = (float *)mask + off;
208
209             for (int d = 0; d < p->id; ++d)
210             for (int h = 0; h < p->ih; ++h)
211             for (int w = 0; w < p->iw; ++w) {
212
213                 const int sp = d * p->ih * p->iw + h * p->iw + w;
214                 if (!flip_coin(l_base + sp, density) && l_base + sp + 100 < L) {
215                     dd[sp] = 0;
216                     s[sp] = m;
217                     rmask[sp] = 1;
218                     continue;
219                 }
220                 if (l_base + sp + 2 >= L) continue; /* last 2 are special */
221                 const int l = l_base + sp * 7 + c * 19 + mb * 13;
222
223                 int rmask_v = 1;
224                 if (p->flags & FUSE_BN_RELU)
225                     rmask[sp] = rmask_v = l % 5 != 1;
226
227                 const int sgn_dd = db < target_db ? 1 : -1;
228                 dd[sp] = sgn_dd * factor_dd * (1 + (l * 3 % param_dd_gen));
229                 if (rmask_v) db += dd[sp];
230
231                 const int sgn_f = dg < target_dg ? 1 : -1;
232                 const float f =
233                     sgn_f * factor_f * (2 + (l * 7 % (param_f_gen - 1)));
234
235                 if (rmask_v) dg += f * dd[sp];
236                 s[sp] = f + m;
237             }
238         }
239
240         if (1) {
241             /* the last 2 elements in src and d_dst are set, so that:
242              *      db == target_db
243              *      dg == target_dg
244              * For this we need to solve the system:
245              *      d_dst[l1]           + d_dst[l0]           = target_db - db
246              *      d_dst[l1] * src[l1] + d_dst[l0] * src[l0] = target_dg - dg
247              *
248              * Here l0 -- last index, l1 -- last but one.
249              * More over, let's assume src[l1] = 1 and src[l0] = -1. */
250             size_t l0 = data_off(p, p->mb - 1, c, p->id - 1, p->ih - 1,
251                 p->iw - 1);
252             size_t l1 = l0 - 1;
253             if (p->id == 1 && p->ih == 1 && p->iw == 1)
254                 l1 = data_off(p, p->mb - 2, c, p->id - 1, p->ih - 1, p->iw - 1);
255
256             ((float *)src)[l1] = 1.f;
257             ((float *)src)[l0] = -1.f;
258             if (p->flags & FUSE_BN_RELU)
259                 ((float *)mask)[l0] = ((float *)mask)[l1] = 1;
260
261             float f1 = ((target_db - db) + (target_dg - dg)) /2;
262             float f0 = ((target_db - db) - (target_dg - dg)) /2;
263
264             ((float *)d_dst)[l1] = f1 + m;
265             ((float *)d_dst)[l0] = f0 + m;
266         }
267
268         if (p->flags & USE_SCALESHIFT) {
269             ((float *)ss)[c] = 1.f / 2 * (1 << (c % 7));
270             ((float *)ss)[p->ic + c] = ((float *)ss)[c] / 64;
271         } else {
272             ((float *)ss)[c] = 1;
273             ((float *)ss)[p->ic + c] = 0;
274         }
275     });
276
277     return OK;
278 }
279
280 static int compare(const prb_t *p, data_kind_t kind, const dnn_mem_t &fp_mem,
281         const dnn_mem_t &dt_mem, res_t *r) {
282     const char *skind = data_kind2str(kind);
283     const float eps = p->dir & FLAG_FWD
284         ? (kind == DATA ? 5e-7 : 0)
285         : (kind == DATA ? 2e-7 : 0);
286
287     /* With all the stability tricks bwd_d is still pretty unstable.
288      * So let's rely on relative error in L1, L2, and L_inf norms.
289      * TODO: make computations for bwd_d more stable and use `L0` here. */
290     const bool rely_on_norm = false
291         || (kind == DATA && (p->dir & FLAG_BWD) && (p->flags | GLOB_STATS));
292
293     const size_t nelems = kind == DATA
294         ? (size_t)p->mb * p->ic * p->id * p->ih * p->iw
295         : (size_t)p->ic * (kind == SS ? 2 : 1);
296     r->total += rely_on_norm ? 1 : nelems;
297
298     diff_norm_t diff_norm;
299
300     for (size_t i = 0; i < nelems; ++i) {
301         const float fp = ((const float *)fp_mem)[i];
302         const float dt = ((const float *)dt_mem)[i];
303         diff_norm.update(fp, dt);
304
305         if (rely_on_norm)
306             continue;
307
308         const float diff = fabsf(fp - dt);
309         const float rel_diff = diff / (fabsf(fp) > FLT_MIN ? fabsf(fp) : 1);
310         const bool ok = (fabs(fp) > 1e-5 ? rel_diff : diff) <= eps;
311
312         r->errors += !ok;
313
314         bool dump = false
315             || (!ok && (r->errors < 10 || verbose >= 10))
316             || (verbose >= 50 && i < 30);
317         if (dump) {
318             const int ind_str_len = 32;
319             char ind_str[ind_str_len] = {'\0'};
320             if (kind == DATA) {
321                 int mb, c, d, h, w;
322                 inv_data_off(p, i, mb, c, d, h, w);
323                 snprintf(ind_str, ind_str_len, "%d,%d,%d,%d,%d", mb, c, d,h, w);
324             } else if (kind == SS) {
325                 snprintf(ind_str, ind_str_len, "%d,%d",
326                         (int)i / p->ic, (int)i % p->ic);
327             } else {
328                 snprintf(ind_str, ind_str_len, "%d", (int)i);
329             }
330
331             print(0, "[%lu][%s%s][%s] fp:%8g dt:%8g diff:%8g rdiff:%8g\n",
332                     (unsigned long)i, p->dir & FLAG_BWD ? "D_" : "", skind,
333                     ind_str, fp, dt, diff, rel_diff);
334         }
335     }
336
337     diff_norm.done();
338
339     if (rely_on_norm) {
340         r->errors += false
341             || diff_norm.rel_diff(norm_t::L1) > eps
342             || diff_norm.rel_diff(norm_t::L2) > eps
343             || diff_norm.rel_diff(norm_t::L8) > eps;
344     }
345
346     if (r->errors || verbose >= 5) {
347         const int vl = r->errors ? 0 : 2;
348         print(vl, "@@@ [%s%s] diff: l0(``%g``) "
349                 "l1:(%g,%g,%g,``%g``) "
350                 "l2:(%g,%g,%g,``%g``) "
351                 "l8:(%g,%g,%g,``%g``)\n",
352                 p->dir & FLAG_BWD ? "D_" : "", skind,
353                 diff_norm.rel_diff(norm_t::L0),
354                 diff_norm.a_[norm_t::L1], diff_norm.b_[norm_t::L1],
355                 diff_norm.diff_[norm_t::L1], diff_norm.rel_diff(norm_t::L1),
356                 diff_norm.a_[norm_t::L2], diff_norm.b_[norm_t::L2],
357                 diff_norm.diff_[norm_t::L2], diff_norm.rel_diff(norm_t::L2),
358                 diff_norm.a_[norm_t::L8], diff_norm.b_[norm_t::L8],
359                 diff_norm.diff_[norm_t::L8], diff_norm.rel_diff(norm_t::L8));
360     }
361
362     if (r->errors)
363         r->state = FAILED;
364
365     if (r->state == UNTESTED)
366         r->state = PASSED; /* optimism */
367
368     return r->state == FAILED ? FAIL : OK;
369 }
370
371 int check_fwd_ws(const dnn_mem_t &data_dt, const dnn_mem_t &ws_dt, res_t *r) {
372     /* so far we know ws is just bit-mask of whether value was negative or
373      * positive */
374     const size_t nelems = data_dt.nelems(true);
375     const float *d = (const float *)data_dt;
376     const uint8_t *ws = (const uint8_t *)ws_dt;
377
378     /* some internal knowledge: flags in ws are either stored as bytes (e.g.
379      * for the ref implementation) or as bits (e.g. for the jitted one); in
380      * the latter case the ws memory has fewer elements than the data memory */
381     enum { ws_byte, ws_bit } ws_type;
382     ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte;
383
384     /* more internal knowledge: data_dt and ws_dt are expected to have exactly
385      * the same data layout, and data_dt padded regions are expected to be
386      * zero, and the respective ws_dt elements should be set accordingly */
387     for (size_t i = 0; i < nelems; i += 8) {
388         for (size_t j = 0; j < MIN2(8, nelems - i); ++j) {
389             const bool want = *d > 0;
390             const bool bit_set = ws_type == ws_byte ? *ws : !!(*ws & (1<<j));
391
392             const bool ok = bit_set == want;
393             r->errors += !ok;
394
395             bool dump = false
396                 || (!ok && (r->errors < 10 || verbose >= 10))
397                 || (verbose >= 50 && i < 30);
398             if (dump) {
399                 print(0, "[%lu] ws exp:%d got:%d (data:%g:%a)\n",
400                         (unsigned long)(i + j), want, bit_set, *d, *d);
401             }
402
403             ++d;
404             if (ws_type == ws_byte) ++ws;
405         }
406         if (ws_type == ws_bit) ++ws;
407     }
408
409     if (r->errors)
410         r->state = FAILED;
411
412     if (r->state == UNTESTED)
413         r->state = PASSED; /* optimism */
414
415     return r->state == FAILED ? FAIL : OK;
416 }
417
418 static int init_pd(const prb_t *p, mkldnn_batch_normalization_desc_t &bd,
419         mkldnn_primitive_desc_t &bpd, res_t *r) {
420     mkldnn_memory_desc_t data_d;
421     mkldnn_dims_t data_dims = {p->mb, p->ic, p->ih, p->iw};
422     mkldnn_dims_t data_dims_3d = {p->mb, p->ic, p->id, p->ih, p->iw};
423     DNN_SAFE(mkldnn_memory_desc_init(&data_d, is_bnorm_3d(p) ? 5 : 4,
424         is_bnorm_3d(p) ? data_dims_3d : data_dims, p->dt, p->fmt), WARN);
425
426     auto flags = (mkldnn_batch_normalization_flag_t)p->flags;
427     if (p->dir & FLAG_FWD) {
428         auto prop = p->dir & FLAG_INF
429             ? mkldnn_forward_inference : mkldnn_forward_training;
430         DNN_SAFE(mkldnn_batch_normalization_forward_desc_init(&bd, prop,
431                     &data_d, p->eps, flags), WARN);
432
433     } else {
434         auto prop = p->dir & FLAG_WEI
435             ? mkldnn_backward : mkldnn_backward_data;
436         DNN_SAFE(mkldnn_batch_normalization_backward_desc_init(&bd, prop,
437                     &data_d, &data_d, p->eps, flags), WARN);
438     }
439
440     auto mkldnn_attr = create_mkldnn_attr(p->attr, 1, NULL);
441
442     mkldnn_primitive_desc_t hint_fwd_pd = NULL;
443     if (p->dir & FLAG_BWD) {
444         mkldnn_batch_normalization_desc_t bd_fwd;
445         DNN_SAFE(mkldnn_batch_normalization_forward_desc_init(&bd_fwd,
446                     mkldnn_forward_training, &data_d, p->eps, flags), WARN);
447         DNN_SAFE(mkldnn_primitive_desc_create_v2(&hint_fwd_pd, &bd_fwd, NULL,
448                     engine, NULL), WARN);
449     }
450     mkldnn_status_t init_status = mkldnn_primitive_desc_create_v2(&bpd, &bd,
451             mkldnn_attr, engine, hint_fwd_pd);
452
453     mkldnn_primitive_desc_destroy(hint_fwd_pd);
454     mkldnn_primitive_attr_destroy(mkldnn_attr);
455
456     if (init_status == mkldnn_unimplemented)
457         return r->state = UNIMPLEMENTED, OK;
458     else
459         SAFE(init_status, WARN);
460
461     const char *impl_str = query_impl_info(bpd);
462     if (maybe_skip(skip_impl, impl_str)) {
463         print(2, "SKIPPED: mkldnn implementation: %s\n", impl_str);
464         DNN_SAFE(mkldnn_primitive_desc_destroy(bpd), WARN);
465         return r->state = SKIPPED, OK;
466     } else {
467         print(5, "mkldnn implementation: %s\n", impl_str);
468         if (!strstr(impl_str, "jit")) {
469             print(1, "WARNING: %s",
470                     "accuracy of the implementation being tested "
471                     "depends on the compiler and might give false-positives.\n");
472             print(1, "         %s",
473                     "please consider recompiling the sources with"
474                     " `-prec-div -fp-model precise` for a reliable testing.\n");
475         }
476     }
477
478     return OK;
479 }
480
481 /** converts benchdnn-understandable mask of {0, 1} to workspace */
482 static int cvt_mask_to_ws(const prb_t *p, const dnn_mem_t &mask_fp,
483         dnn_mem_t &ws_dt) {
484     mkldnn_dims_t data_dims = {p->mb, p->ic, p->ih, p->iw};
485     mkldnn_dims_t data_dims_3d = {p->mb, p->ic, p->id, p->ih, p->iw};
486
487     dnn_mem_t data(is_bnorm_3d(p) ? 5 : 4,
488         is_bnorm_3d(p) ? data_dims_3d : data_dims, mkldnn_f32, p->fmt);
489     SAFE(data.reorder(mask_fp), WARN);
490
491     ptrdiff_t ic = p->ic;
492     dnn_mem_t mean(1, &ic, mkldnn_f32, mkldnn_x);
493     dnn_mem_t var(1, &ic, mkldnn_f32, mkldnn_x);
494     for (int c = 0; c < p->ic; ++c) ((float *)mean)[c] = 0.5;
495     for (int c = 0; c < p->ic; ++c) ((float *)var)[c] = 1;
496
497     mkldnn_batch_normalization_desc_t bd;
498     auto flags = (mkldnn_batch_normalization_flag_t)
499         (mkldnn_use_global_stats | mkldnn_fuse_bn_relu);
500     DNN_SAFE(mkldnn_batch_normalization_forward_desc_init(&bd,
501                 mkldnn_forward_training, &data.md_, 0, flags), WARN);
502
503     mkldnn_primitive_desc_t bpd;
504     DNN_SAFE(mkldnn_primitive_desc_create_v2(&bpd, &bd, NULL, engine, NULL),
505             WARN);
506
507     mkldnn_primitive_t b{};
508     mkldnn_primitive_at_t inputs[3] = {
509         {data.p_, 0}, {mean.p_, 0}, {var.p_, 0}};
510     const_mkldnn_primitive_t outputs[2] = {data.p_, ws_dt.p_};
511     DNN_SAFE(mkldnn_primitive_create(&b, bpd, inputs, outputs), WARN);
512     SAFE(execute(b), WARN);
513
514     DNN_SAFE(mkldnn_primitive_desc_destroy(bpd), CRIT);
515     DNN_SAFE(mkldnn_primitive_destroy(b), CRIT);
516
517     return OK;
518 }
519
520 int doit(const prb_t *p, res_t *r) {
521     res_t res_zero{};
522     *r = res_zero;
523
524     mkldnn_batch_normalization_desc_t bd;
525     mkldnn_primitive_desc_t bpd;
526     mkldnn_primitive_t b{};
527
528     SAFE(init_pd(p, bd, bpd, r), WARN);
529     if (r->state == SKIPPED || r->state == UNIMPLEMENTED)
530         return OK;
531
532     const auto fp = mkldnn_f32;
533     auto &data_dt_d = bd.data_desc;
534
535     const mkldnn_dims_t dims1d = {p->ic};
536     const mkldnn_dims_t dims2d = {2, p->ic};
537     const auto src_format = is_bnorm_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
538
539     dnn_mem_t data_fp(data_dt_d, fp, src_format),
540               data_dt(data_dt_d);
541     dnn_mem_t d_data_fp(data_dt_d, fp, src_format),
542               d_data_dt(data_dt_d);
543
544     dnn_mem_t mean_fp(1, dims1d, fp, mkldnn_x),
545               mean_dt(mean_fp.md_);
546     dnn_mem_t var_fp(1, dims1d, fp, mkldnn_x),
547               var_dt(var_fp.md_);
548
549     dnn_mem_t ss_fp(2, dims2d, fp, mkldnn_nc),
550               ss_dt(ss_fp.md_);
551     dnn_mem_t d_ss_fp(2, dims2d, fp, mkldnn_nc),
552               d_ss_dt(d_ss_fp.md_);
553
554     dnn_mem_t ws_fp(data_fp.md_);
555     dnn_mem_t *p_ws_dt = NULL;
556     if ((p->flags & FUSE_BN_RELU) && !(p->dir & FLAG_INF)) {
557         const auto ws_pd = mkldnn_primitive_desc_query_pd(bpd,
558                 mkldnn_query_workspace_pd, 0);
559         SAFE(ws_pd != NULL ? OK : FAIL, WARN);
560         p_ws_dt = new dnn_mem_t(*mkldnn_primitive_desc_query_memory_d(ws_pd));
561     } else {
562         p_ws_dt = new dnn_mem_t();
563     }
564     dnn_mem_t &ws_dt = *p_ws_dt;
565
566     if (p->dir & FLAG_FWD) {
567         if (prepare_fwd(p, data_fp, mean_fp, var_fp, ss_fp) != OK)
568             return r->state = MISTRUSTED, OK;
569
570         mkldnn_primitive_at_t inputs[4];
571         const_mkldnn_primitive_t outputs[4];
572
573         int idx = 0;
574
575         SAFE(data_dt.reorder(data_fp), WARN);
576         inputs[idx++] = {data_dt.p_, 0};
577
578         if (p->flags & GLOB_STATS) {
579             SAFE(mean_dt.reorder(mean_fp), WARN);
580             SAFE(var_dt.reorder(var_fp), WARN);
581             inputs[idx++] = {mean_dt.p_, 0};
582             inputs[idx++] = {var_dt.p_, 0};
583         }
584         if (p->flags & USE_SCALESHIFT) {
585             SAFE(ss_dt.reorder(ss_fp), WARN);
586             inputs[idx++] = {ss_dt.p_, 0};
587         }
588
589         idx = 0;
590         outputs[idx++] = data_dt.p_; /* always in-place so far... */
591         if (!(p->flags & GLOB_STATS)) {
592             outputs[idx++] = mean_dt.p_;
593             outputs[idx++] = var_dt.p_;
594         }
595
596         if (p->flags & FUSE_BN_RELU)
597             outputs[idx++] = ws_dt.p_;
598
599         DNN_SAFE(mkldnn_primitive_create(&b, bpd, inputs, outputs), WARN);
600         SAFE(execute(b), WARN);
601         if (bench_mode & CORR) {
602             compute_ref_fwd(p, data_fp, mean_fp, var_fp, ss_fp, data_fp);
603             if (!(p->flags & GLOB_STATS) && !(p->dir & FLAG_INF)) {
604                 SAFE(compare(p, MEAN, mean_fp, mean_dt, r), WARN);
605                 SAFE(compare(p, VAR, var_fp, var_dt, r), WARN);
606             }
607             dnn_mem_t data(data_dt, fp, src_format);
608             SAFE(compare(p, DATA, data_fp, data, r), WARN);
609             if ((p->flags & FUSE_BN_RELU) && !(p->dir & FLAG_INF))
610                 SAFE(check_fwd_ws(data_dt, ws_dt, r), WARN);
611         }
612     } else {
613         if (prepare_bwd(p, data_fp, d_data_fp, mean_fp, var_fp, ss_fp, ws_fp)
614                 != OK)
615             return r->state = MISTRUSTED, OK;
616
617         mkldnn_primitive_at_t inputs[6];
618         const_mkldnn_primitive_t outputs[2];
619
620         int idx = 0;
621
622         SAFE(data_dt.reorder(data_fp), WARN);
623         inputs[idx++] = {data_dt.p_, 0};
624
625         SAFE(mean_dt.reorder(mean_fp), WARN);
626         SAFE(var_dt.reorder(var_fp), WARN);
627         inputs[idx++] = {mean_dt.p_, 0};
628         inputs[idx++] = {var_dt.p_, 0};
629
630         SAFE(d_data_dt.reorder(d_data_fp), WARN);
631         inputs[idx++] = {d_data_dt.p_, 0};
632
633         if (p->flags & USE_SCALESHIFT) {
634             SAFE(ss_dt.reorder(ss_fp), WARN);
635             inputs[idx++] = {ss_dt.p_, 0};
636         }
637
638         if (p->flags & FUSE_BN_RELU) {
639             SAFE(cvt_mask_to_ws(p, ws_fp, ws_dt), WARN);
640             inputs[idx++] = {ws_dt.p_, 0};
641         }
642
643         idx = 0;
644         outputs[idx++] = d_data_dt.p_; /* always in-place so far... */
645         if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI))
646             outputs[idx++] = d_ss_dt.p_;
647
648         DNN_SAFE(mkldnn_primitive_create(&b, bpd, inputs, outputs), WARN);
649         SAFE(execute(b), WARN);
650         if (bench_mode & CORR) {
651             compute_ref_bwd(p, data_fp, mean_fp, var_fp, d_data_fp, ss_fp,
652                     ws_fp, d_data_fp, d_ss_fp);
653             if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI))
654                 SAFE(compare(p, SS, d_ss_fp, d_ss_dt, r), WARN);
655             dnn_mem_t d_data(d_data_dt, fp,
656             is_bnorm_3d(p) ? mkldnn_ncdhw : mkldnn_nchw);
657             SAFE(compare(p, DATA, d_data_fp, d_data, r), WARN);
658         }
659     }
660
661     if (bench_mode & PERF) {
662         auto &t = r->timer;
663         t.reset();
664         while (true) {
665             SAFE(execute(b), WARN);
666             t.stamp();
667             const bool stop = false
668                 || (fix_times_per_prb && t.times() >= fix_times_per_prb)
669                 || (!fix_times_per_prb
670                         && t.total_ms() >= max_ms_per_prb
671                         && t.times() >= min_times_per_prb);
672             if (stop) break;
673         }
674     }
675
676     delete p_ws_dt;
677     DNN_SAFE(mkldnn_primitive_desc_destroy(bpd), CRIT);
678     DNN_SAFE(mkldnn_primitive_destroy(b), CRIT);
679
680     return OK;
681 }
682
683 }