Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / deconv.cpp
index ec0e0d0..034acfe 100644 (file)
@@ -33,15 +33,15 @@ using namespace conv;
 
 namespace deconv {
 
-inline static void swap(int &a, int &b)
-{
-    int temp = a;
+template <typename T>
+inline static void swap(T &a, T &b) {
+    T temp = a;
     a = b;
     b = temp;
 }
-inline bool is_deconv_3d(const prb_t *p)
-{
-    return (p->id > 1 || p->od > 1) ? 1 : 0;
+
+inline bool is_deconv_3d(const prb_t *p) {
+    return p->id > 1;
 }
 
 inline int transpose_data_wei(const prb_t *p, dnn_mem_t &wei, dnn_mem_t &wei_tr) {
@@ -61,43 +61,42 @@ inline int init_pd(const prb_t *p, mkldnn_deconvolution_desc_t &cd,
     int ndims = is_deconv_3d(p) ? 5 : 4;
 
     mkldnn_memory_desc_t src_d, wei_d, bia_d, dst_d;
-    mkldnn_dims_t src_dims = {p->mb, p->ic, p->ih, p->iw};
+    mkldnn_dims_t src_2d_dims = {p->mb, p->ic, p->ih, p->iw};
     mkldnn_dims_t src_3d_dims = {p->mb, p->ic, p->id, p->ih, p->iw};
-    mkldnn_dims_t wei_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
+    mkldnn_dims_t wei_2d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kh, p->kw};
     mkldnn_dims_t wei_3d_dims = {p->g, p->oc / p->g, p->ic / p->g, p->kd, p->kh, p->kw};
     mkldnn_dims_t bia_dims = {p->oc};
-    mkldnn_dims_t dst_dims = {p->mb, p->oc, p->oh, p->ow};
+    mkldnn_dims_t dst_2d_dims = {p->mb, p->oc, p->oh, p->ow};
     mkldnn_dims_t dst_3d_dims = {p->mb, p->oc, p->od, p->oh, p->ow};
-
     DNN_SAFE(mkldnn_memory_desc_init(&src_d, ndims,
-        is_deconv_3d(p) ? src_3d_dims : src_dims, p->cfg[SRC].dt, mkldnn_any), WARN);
-    DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + 1,
-        is_deconv_3d(p) ? wei_3d_dims : wei_dims, p->cfg[WEI].dt, mkldnn_any), WARN);
+        is_deconv_3d(p) ? src_3d_dims : src_2d_dims, p->cfg[SRC].dt, mkldnn_any), WARN);
+    DNN_SAFE(mkldnn_memory_desc_init(&wei_d, ndims + p->has_groups,
+        is_deconv_3d(p)
+        ? &wei_3d_dims[!p->has_groups]
+        : &wei_2d_dims[!p->has_groups],
+        p->cfg[WEI].dt, mkldnn_any), WARN);
     DNN_SAFE(mkldnn_memory_desc_init(&bia_d, 1, bia_dims, p->cfg[BIA].dt, mkldnn_any), WARN);
     DNN_SAFE(mkldnn_memory_desc_init(&dst_d, ndims,
-        is_deconv_3d(p) ? dst_3d_dims : dst_dims, p->cfg[DST].dt, mkldnn_any), WARN);
-    int strides_2d[] = {p->sh, p->sw};
-    int dilates_2d[] = {p->dh, p->dw};
-    int padding_2d[] = {p->ph, p->pw};
-    int strides_3d[] = {p->sd, p->sh, p->sw};
-    int dilates_3d[] = {p->dd, p->dh, p->dw};
-    int padding_3d[] = {p->pd, p->ph, p->pw};
+        is_deconv_3d(p) ? dst_3d_dims : dst_2d_dims, p->cfg[DST].dt, mkldnn_any), WARN);
+
+    ptrdiff_t strides_nd[] = {p->sd, p->sh, p->sw};
+    ptrdiff_t dilates_nd[] = {p->dd, p->dh, p->dw};
+    ptrdiff_t padding_nd[] = {p->pd, p->ph, p->pw};
 
     auto bph = [&](int ih, int oh, int kh, int sh, int ph, int dh) {
         return (oh - 1) * sh - ih + ((kh - 1) * (dh + 1) + 1) - ph;
     };
-    int padding_r_3d[] = {
+
+    ptrdiff_t padding_r_nd[] = {
         bph(p->od, p->id, p->kd, p->sd, p->pd, p->dd),
         bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
         bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
-    int padding_r_2d[] = {
-        bph(p->oh, p->ih, p->kh, p->sh, p->ph, p->dh),
-        bph(p->ow, p->iw, p->kw, p->sw, p->pw, p->dw)};
 
-    int *strides = is_deconv_3d(p) ? strides_3d : strides_2d;
-    int *dilates = is_deconv_3d(p) ? dilates_3d : dilates_2d;
-    int *padding = is_deconv_3d(p) ? padding_3d : padding_2d;
-    int *padding_r = is_deconv_3d(p) ? padding_r_3d : padding_r_2d;
+    ptrdiff_t *strides = strides_nd + (5 - ndims);
+    ptrdiff_t *dilates = dilates_nd + (5 - ndims);
+    ptrdiff_t *padding = padding_nd + (5 - ndims);
+    ptrdiff_t *padding_r = padding_r_nd + (5 - ndims);
+
     mkldnn_alg_kind_t alg = mkldnn_deconvolution_direct;
     if (p->alg == WINO) alg = mkldnn_deconvolution_winograd;
 
@@ -182,7 +181,7 @@ int doit(const prb_t *p, res_t *r) {
     *r = res_zero;
     bool with_groups = 1;
 
-    prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->merge, p->attr, p->mb);
+    prb_t p_tr((desc_t)*p, p->dir, p->cfg, p->alg, p->attr, p->mb, true);
     swap(p_tr.ic,  p_tr.oc);
     swap(p_tr.ih,  p_tr.oh);
     swap(p_tr.id,  p_tr.od);
@@ -210,8 +209,9 @@ int doit(const prb_t *p, res_t *r) {
         ? new dnn_mem_t(bia_dt_d, p->cfg[BIA].dt) : new dnn_mem_t();
     dnn_mem_t &bia_dt = *p_bia_dt;
 
-    auto src_format = is_deconv_3d(p) ? mkldnn_ncdhw : mkldnn_nchw;
-    auto wei_format = is_deconv_3d(p) ? mkldnn_goidhw : mkldnn_goihw;
+    auto src_format = get_default_format(src_dt.md_.ndims, DATA);
+    auto wei_format = get_default_format(wei_dt.md_.ndims,
+        p->has_groups ? GWEI : WEI);
 
     const auto fp = mkldnn_f32;
 
@@ -243,7 +243,6 @@ int doit(const prb_t *p, res_t *r) {
         if (bench_mode & CORR) {
             compute_ref_bwd_d(&p_tr, dst_fp, wei_tr_fp, bia_fp, src_fp);
             dnn_mem_t dst(dst_dt, fp, src_format);
-            SAFE(dst.reorder(dst_dt), WARN);
             SAFE(compare_dst(p, dst, dst_fp, r, true), WARN);
         }
     } else if (p->dir == BWD_D) {
@@ -254,7 +253,6 @@ int doit(const prb_t *p, res_t *r) {
         if (bench_mode & CORR) {
             compute_ref_fwd(&p_tr, dst_fp, wei_tr_fp, zero_fp, src_fp);
             dnn_mem_t src(src_dt, fp, src_format);
-            SAFE(src.reorder(src_dt), WARN);
             SAFE(compare_src(p, src, src_fp, r, true), WARN);
         }
     } else if (p->dir & FLAG_BWD && p->dir & FLAG_WEI) {
@@ -268,12 +266,10 @@ int doit(const prb_t *p, res_t *r) {
             compute_ref_bwd_weights(&p_tr, dst_fp, wei_tr_fp, src_fp);
             transpose_data_wei(&p_tr, wei_tr_fp, wei_fp);
             dnn_mem_t wei(wei_dt, fp, wei_format);
-            SAFE(wei.reorder(wei_dt), WARN);
             SAFE(compare_wei(&p_tr, wei, wei_fp, r, true), WARN);
             if (p->dir & FLAG_BIA) {
                 compute_ref_bwd_bias(p, bia_fp, dst_fp);
                 dnn_mem_t bia(bia_dt, fp, mkldnn_x);
-                SAFE(bia.reorder(bia_dt), WARN);
                 SAFE(compare_bia(p, bia, bia_fp, r, true), WARN);
             }
         }