Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / bnorm / bnorm.cpp
index 7a6c81c..0d47b9e 100644 (file)
@@ -371,15 +371,15 @@ static int compare(const prb_t *p, data_kind_t kind, const dnn_mem_t &fp_mem,
 int check_fwd_ws(const dnn_mem_t &data_dt, const dnn_mem_t &ws_dt, res_t *r) {
     /* so far we know ws is just bit-mask of whether value was negative or
      * positive */
-    const size_t nelems = data_dt.nelems();
+    const size_t nelems = data_dt.nelems(true);
     const float *d = (const float *)data_dt;
     const uint8_t *ws = (const uint8_t *)ws_dt;
 
     /* some internal knowledge: flags in ws are either stored as bytes (e.g.
      * for the ref implementation) or as bits (e.g. for the jitted one); in
-     * the first case the ws memory has fewer elements than the data memory */
+     * the latter case the ws memory has fewer elements than the data memory */
     enum { ws_byte, ws_bit } ws_type;
-    ws_type = ws_dt.nelems() < nelems ? ws_bit : ws_byte;
+    ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte;
 
     /* more internal knowledge: data_dt and ws_dt are expected to have exactly
      * the same data layout, and data_dt padded regions are expected to be
@@ -488,8 +488,9 @@ static int cvt_mask_to_ws(const prb_t *p, const dnn_mem_t &mask_fp,
         is_bnorm_3d(p) ? data_dims_3d : data_dims, mkldnn_f32, p->fmt);
     SAFE(data.reorder(mask_fp), WARN);
 
-    dnn_mem_t mean(1, &p->ic, mkldnn_f32, mkldnn_x);
-    dnn_mem_t var(1, &p->ic, mkldnn_f32, mkldnn_x);
+    ptrdiff_t ic = p->ic;
+    dnn_mem_t mean(1, &ic, mkldnn_f32, mkldnn_x);
+    dnn_mem_t var(1, &ic, mkldnn_f32, mkldnn_x);
     for (int c = 0; c < p->ic; ++c) ((float *)mean)[c] = 0.5;
     for (int c = 0; c < p->ic; ++c) ((float *)var)[c] = 1;
 
@@ -603,8 +604,7 @@ int doit(const prb_t *p, res_t *r) {
                 SAFE(compare(p, MEAN, mean_fp, mean_dt, r), WARN);
                 SAFE(compare(p, VAR, var_fp, var_dt, r), WARN);
             }
-            dnn_mem_t data(data_dt.md_, fp, src_format);
-            SAFE(data.reorder(data_dt), WARN);
+            dnn_mem_t data(data_dt, fp, src_format);
             SAFE(compare(p, DATA, data_fp, data, r), WARN);
             if ((p->flags & FUSE_BN_RELU) && !(p->dir & FLAG_INF))
                 SAFE(check_fwd_ws(data_dt, ws_dt, r), WARN);
@@ -652,9 +652,8 @@ int doit(const prb_t *p, res_t *r) {
                     ws_fp, d_data_fp, d_ss_fp);
             if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI))
                 SAFE(compare(p, SS, d_ss_fp, d_ss_dt, r), WARN);
-            dnn_mem_t d_data(d_data_dt.md_, fp,
+            dnn_mem_t d_data(d_data_dt, fp,
             is_bnorm_3d(p) ? mkldnn_ncdhw : mkldnn_nchw);
-            SAFE(d_data.reorder(d_data_dt), WARN);
             SAFE(compare(p, DATA, d_data_fp, d_data, r), WARN);
         }
     }