int check_fwd_ws(const dnn_mem_t &data_dt, const dnn_mem_t &ws_dt, res_t *r) {
/* so far we know ws is just bit-mask of whether value was negative or
* positive */
- const size_t nelems = data_dt.nelems();
+ const size_t nelems = data_dt.nelems(true);
const float *d = (const float *)data_dt;
const uint8_t *ws = (const uint8_t *)ws_dt;
/* some internal knowledge: flags in ws are either stored as bytes (e.g.
* for the ref implementation) or as bits (e.g. for the jitted one); in
- * the first case the ws memory has fewer elements than the data memory */
+ * the latter case the ws memory has fewer elements than the data memory */
enum { ws_byte, ws_bit } ws_type;
- ws_type = ws_dt.nelems() < nelems ? ws_bit : ws_byte;
+ ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte;
/* more internal knowledge: data_dt and ws_dt are expected to have exactly
* the same data layout, and data_dt padded regions are expected to be
is_bnorm_3d(p) ? data_dims_3d : data_dims, mkldnn_f32, p->fmt);
SAFE(data.reorder(mask_fp), WARN);
- dnn_mem_t mean(1, &p->ic, mkldnn_f32, mkldnn_x);
- dnn_mem_t var(1, &p->ic, mkldnn_f32, mkldnn_x);
+ ptrdiff_t ic = p->ic;
+ dnn_mem_t mean(1, &ic, mkldnn_f32, mkldnn_x);
+ dnn_mem_t var(1, &ic, mkldnn_f32, mkldnn_x);
for (int c = 0; c < p->ic; ++c) ((float *)mean)[c] = 0.5;
for (int c = 0; c < p->ic; ++c) ((float *)var)[c] = 1;
SAFE(compare(p, MEAN, mean_fp, mean_dt, r), WARN);
SAFE(compare(p, VAR, var_fp, var_dt, r), WARN);
}
- dnn_mem_t data(data_dt.md_, fp, src_format);
- SAFE(data.reorder(data_dt), WARN);
+ dnn_mem_t data(data_dt, fp, src_format);
SAFE(compare(p, DATA, data_fp, data, r), WARN);
if ((p->flags & FUSE_BN_RELU) && !(p->dir & FLAG_INF))
SAFE(check_fwd_ws(data_dt, ws_dt, r), WARN);
ws_fp, d_data_fp, d_ss_fp);
if ((p->flags & USE_SCALESHIFT) && (p->dir & FLAG_WEI))
SAFE(compare(p, SS, d_ss_fp, d_ss_dt, r), WARN);
- dnn_mem_t d_data(d_data_dt.md_, fp,
+ dnn_mem_t d_data(d_data_dt, fp,
is_bnorm_3d(p) ? mkldnn_ncdhw : mkldnn_nchw);
- SAFE(d_data.reorder(d_data_dt), WARN);
SAFE(compare(p, DATA, d_data_fp, d_data, r), WARN);
}
}