Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / conv_aux.cpp
index 8301e87..44a504d 100644 (file)
@@ -30,6 +30,7 @@ namespace conv {
 
 alg_t str2alg(const char *str) {
 #define CASE(_alg) if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
+    CASE(AUTO);
     CASE(DIRECT);
     CASE(WINO);
 #undef CASE
@@ -38,26 +39,19 @@ alg_t str2alg(const char *str) {
 }
 
 const char *alg2str(alg_t alg) {
+    if (alg == AUTO) return "auto";
     if (alg == DIRECT) return "direct";
     if (alg == WINO) return "wino";
     assert(!"unknown algorithm");
     return "unknown algorithm";
 }
 
-merge_t str2merge(const char *str) {
-#define CASE(_mrg) if (!strcasecmp(STRINGIFY(_mrg), str)) return _mrg
-    CASE(NONE);
-    CASE(RELU);
-#undef CASE
-    assert(!"unknown merge");
-    return NONE;
-}
-
-const char *merge2str(merge_t merge) {
-    if (merge == NONE) return "none";
-    if (merge == RELU) return "relu";
-    assert(!"unknown merge");
-    return "unknown merge";
+alg_t alg_kind2alg(mkldnn_alg_kind_t alg) {
+    if (alg == mkldnn_convolution_auto) return AUTO;
+    if (alg == mkldnn_convolution_direct) return DIRECT;
+    if (alg == mkldnn_convolution_winograd) return WINO;
+    assert(!"unknown algorithm");
+    return DIRECT;
 }
 
 int str2desc(desc_t *desc, const char *str, bool is_deconv) {
@@ -78,7 +72,9 @@ int str2desc(desc_t *desc, const char *str, bool is_deconv) {
      *  - if padding is undefined => compute trivial padding
      */
 
-    d.g = 1; d.mb = 2; d.sd = d.sh = d.sw = 1; d.dd = d.dh = d.dw = 0; d.name = "\"wip\"";
+    d.g = 1; d.mb = 2; d.sd = d.sh = d.sw = 1; d.dd = d.dh = d.dw = 0;
+    d.has_groups = false, d.name = "\"wip\"";
+    d.pw = -1; d.ph = -1; d.pd = -1;
 
     const char *s = str;
     assert(s);
@@ -87,6 +83,7 @@ int str2desc(desc_t *desc, const char *str, bool is_deconv) {
         if (!strncmp(p, s, strlen(p))) { \
             ok = 1; s += strlen(p); \
             char *end_s; d. c = strtol(s, &end_s, 10); s += (end_s - s); \
+            if (!strncmp(p, "g", 1)) d.has_groups = true; \
             /* printf("@@@debug: %s: %d\n", p, d. c); */ \
         } \
     } while (0)
@@ -123,34 +120,35 @@ int str2desc(desc_t *desc, const char *str, bool is_deconv) {
             return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
     };
 
-    const bool no_d = (d.id | d.kd | d.od | d.pd | d.dd) == 0 && d.sd == 1;
-    const bool no_h = (d.ih | d.kh | d.oh | d.ph | d.dh) == 0 && d.sh == 1;
-    const bool no_w = (d.iw | d.kw | d.ow | d.pw | d.dw) == 0 && d.sw == 1;
-
+    const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
+    const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
+    const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
     if (!no_h) {
         if (!d.ih || !d.kh) return FAIL;
-
-        if (!d.oh) d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
-        else if (!d.ph && d.oh != compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh))
+        if (!d.oh) {
+            d.ph = 0;
+            d.oh = compute_out(is_deconv, d.ih, d.kh, d.sh, d.ph, d.dh);
+        } else if (d.ph < 0)
             d.ph = compute_pad(is_deconv, d.oh, d.ih, d.kh, d.sh, d.dh);
     }
 
     if (!no_w) {
         if (!d.iw || !d.kw) return FAIL;
-
-        if (!d.ow) d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
-        else if (!d.pw && d.ow != compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw))
+        if (!d.ow) {
+            d.pw = 0;
+            d.ow = compute_out(is_deconv, d.iw, d.kw, d.sw, d.pw, d.dw);
+        } else if (d.pw < 0)
             d.pw = compute_pad(is_deconv, d.ow, d.iw, d.kw, d.sw, d.dw);
     }
 
     if (!no_d && d.id) {
         if (!d.id || !d.kd) return FAIL;
-
-        if (!d.od) d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
-        else if (!d.pd && d.od != compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd))
+        if (!d.od) {
+            d.pd = 0;
+            d.od = compute_out(is_deconv, d.id, d.kd, d.sd, d.pd, d.dd);
+        } else if (d.pd < 0)
             d.pd = compute_pad(is_deconv, d.od, d.id, d.kd, d.sd, d.dd);
     }
-
     if (no_w && no_h && d.id) {
         d.iw = d.ih = d.id;
         d.kw = d.kh = d.kd;
@@ -187,7 +185,7 @@ void desc2str(const desc_t *d, char *buffer, bool canonical) {
         buffer += l; rem_len -= l; \
     } while(0)
 
-    if (canonical || d->g != 1) DPRINT("g%d", d->g);
+    if (canonical || d->has_groups) DPRINT("g%d", d->g);
     if (canonical || d->mb != 2) DPRINT("mb%d", d->mb);
 
     const bool half_form = (d->ih == d->iw && d->kh == d->kw && d->oh == d->ow
@@ -230,19 +228,25 @@ void desc2str(const desc_t *d, char *buffer, bool canonical) {
 void prb_t::count_ops() {
     if (ops > 0) return;
 
+    int od_t = is_deconv ? this->id : this->od;
+    int oh_t = is_deconv ? this->ih : this->oh;
+    int ow_t = is_deconv ? this->iw : this->ow;
+    int id_t = is_deconv ? this->od : this->id;
+    int ih_t = is_deconv ? this->oh : this->ih;
+    int iw_t = is_deconv ? this->ow : this->iw;
     double sp_ops = 0;
-    for (int od = 0; od < this->od; ++od) {
-    for (int oh = 0; oh < this->oh; ++oh) {
-    for (int ow = 0; ow < this->ow; ++ow) {
+    for (int od = 0; od < od_t; ++od) {
+    for (int oh = 0; oh < oh_t; ++oh) {
+    for (int ow = 0; ow < ow_t; ++ow) {
         for (int kd = 0; kd < this->kd; ++kd) {
             const int id = od * this->sd - this->pd + kd * (this->dd + 1);
-            if (id < 0 || id >= this->id) continue;
+            if (id < 0 || id >= id_t) continue;
             for (int kh = 0; kh < this->kh; ++kh) {
                 const int ih = oh * this->sh - this->ph + kh * (this->dh + 1);
-                if (ih < 0 || ih >= this->ih) continue;
+                if (ih < 0 || ih >= ih_t) continue;
                 for (int kw = 0; kw < this->kw; ++kw) {
                     const int iw = ow * this->sw - this->pw + kw * (this->dw + 1);
-                    if (iw < 0 || iw >= this->iw) continue;
+                    if (iw < 0 || iw >= iw_t) continue;
                     sp_ops += 1;
                 }
             }
@@ -278,13 +282,11 @@ void prb_t::generate_oscales() {
 
 void prb2str(const prb_t *p, char *buffer, bool canonical) {
     char desc_buf[max_desc_len], attr_buf[max_attr_len];
-    char dir_str[32] = {0}, cfg_str[32] = {0}, alg_str[32] = {0},
-         merge_str[32] = {0};
+    char dir_str[32] = {0}, cfg_str[32] = {0}, alg_str[32] = {0};
     desc2str(p, desc_buf, canonical);
     snprintf(dir_str, sizeof(dir_str), "--dir=%s ", dir2str(p->dir));
     snprintf(cfg_str, sizeof(cfg_str), "--cfg=%s ", cfg2str(p->cfg));
     snprintf(alg_str, sizeof(alg_str), "--alg=%s ", alg2str(p->alg));
-    snprintf(merge_str, sizeof(merge_str), "--merge=%s ", merge2str(p->merge));
     bool is_attr_def = p->attr.is_def();
     if (!is_attr_def) {
         int len = snprintf(attr_buf, max_attr_len, "--attr=\"");
@@ -293,11 +295,10 @@ void prb2str(const prb_t *p, char *buffer, bool canonical) {
         len = (int)strnlen(attr_buf, max_attr_len);
         snprintf(attr_buf + len, max_attr_len - len, "\" ");
     }
-    snprintf(buffer, max_prb_len, "%s%s%s%s%s%s",
+    snprintf(buffer, max_prb_len, "%s%s%s%s%s",
             p->dir == FWD_B ? "" : dir_str,
             p->cfg == conf_f32 ? "" : cfg_str,
             p->alg == DIRECT ? "" : alg_str,
-            p->merge == NONE ? "" : merge_str,
             is_attr_def ? "" : attr_buf,
             desc_buf);
 }