Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / dnn_types.cpp
index 2bb3429..ca200ae 100644 (file)
@@ -102,6 +102,7 @@ data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt) {
     case mkldnn_gOIw8i16o2i:
     case mkldnn_goihw:
     case mkldnn_hwigo:
+    case mkldnn_giohw:
     case mkldnn_hwigo_s8s8:
     case mkldnn_gOIhw8i8o:
     case mkldnn_gOIhw16i16o:
@@ -119,6 +120,7 @@ data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt) {
     case mkldnn_gOhwi16o:
     case mkldnn_Goihw8g:
     case mkldnn_Goihw16g:
+    case mkldnn_Goihw16g_s8s8:
     case mkldnn_gOhIw16o4i:
     case mkldnn_goidhw:
     case mkldnn_gOIdhw16i16o:
@@ -192,18 +194,56 @@ attr_t::post_ops_t::kind_t attr_t::post_ops_t::str2kind(const char *str) {
 #define CASE(_knd) if (!strcasecmp(STRINGIFY(_knd), str)) return _knd
     CASE(SUM);
     CASE(RELU);
+    CASE(TANH);
+    CASE(ELU);
+    CASE(SQUARE);
+    CASE(ABS);
+    CASE(SQRT);
+    CASE(LINEAR);
+    CASE(BRELU);
+    CASE(SRELU);
+    CASE(LOGISTIC);
 #undef CASE
     assert(!"unknown attr::post_ops::kind");
     return KIND_TOTAL;
 }
 
 const char *attr_t::post_ops_t::kind2str(attr_t::post_ops_t::kind_t kind) {
-    if (kind == SUM) return "sum";
-    if (kind == RELU) return "relu";
+#define CASE(_knd, str) if (kind == _knd) return str
+    CASE(SUM, "sum");
+    CASE(RELU, "relu");
+    CASE(TANH, "tanh");
+    CASE(ELU, "elu");
+    CASE(SQUARE, "square");
+    CASE(ABS, "abs");
+    CASE(SQRT, "sqrt");
+    CASE(LINEAR, "linear");
+    CASE(BRELU, "brelu");
+    CASE(SRELU, "srelu");
+    CASE(LOGISTIC, "logistic");
+#undef CASE
     assert(!"unknown attr::post_ops::kind");
     return "unknown attr::post_ops::kind";
 }
 
+mkldnn_alg_kind_t attr_t::post_ops_t::kind2mkldnn_kind(
+        attr_t::post_ops_t::kind_t kind) {
+#define CASE(_knd, _mknd) if (kind == _knd) return _mknd
+    CASE(RELU, mkldnn_eltwise_relu);
+    CASE(TANH, mkldnn_eltwise_tanh);
+    CASE(ELU, mkldnn_eltwise_elu);
+    CASE(SQUARE, mkldnn_eltwise_square);
+    CASE(ABS, mkldnn_eltwise_abs);
+    CASE(SQRT, mkldnn_eltwise_sqrt);
+    CASE(LINEAR, mkldnn_eltwise_linear);
+    CASE(BRELU, mkldnn_eltwise_bounded_relu);
+    CASE(SRELU, mkldnn_eltwise_soft_relu);
+    CASE(LOGISTIC, mkldnn_eltwise_logistic);
+#undef CASE
+    assert(!"unknown attr::post_ops::kind");
+    return mkldnn_alg_kind_undef;
+}
+
 int attr_t::post_ops_t::from_str(const char *str, const char **end_s) {
     *this = post_ops_t();
 
@@ -236,9 +276,26 @@ int attr_t::post_ops_t::from_str(const char *str, const char **end_s) {
                     } else {
                         e.sum.scale = 1.f;
                     }
-                } else if (k == RELU) {
+                } else {
+                    e.eltwise.alg = kind2mkldnn_kind(k);
                     e.eltwise.scale = 1.f;
                     e.eltwise.alpha = e.eltwise.beta = 0.f;
+
+                    for (int i = 0; i < 3; ++i) {
+                        // :alpha:beta:scale
+                        float &val = i == 0 ? e.eltwise.alpha
+                            : i == 1 ? e.eltwise.beta : e.eltwise.scale;
+                        if (*s == ':') {
+                            char *end;
+                            val = strtof(++s, &end);
+                            if (end == s) return FAIL;
+                            s = end;
+                        } else {
+                            break;
+                        }
+                    }
+
+                    if (e.eltwise.scale <= 0) return FAIL;
                 }
 
                 break;
@@ -265,7 +322,18 @@ void attr_t::post_ops_t::to_str(char *buffer, char **end_b) const {
             buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.sum.scale);
             break;
         case RELU:
-            buffer += sprintf(buffer, "%s", kind2str(e.kind));
+        case TANH:
+        case ELU:
+        case SQUARE:
+        case ABS:
+        case SQRT:
+        case LINEAR:
+        case BRELU:
+        case SRELU:
+        case LOGISTIC:
+            buffer += sprintf(buffer, "%s:%g", kind2str(e.kind), e.eltwise.alpha);
+            if (e.eltwise.beta != 0.f || e.eltwise.scale != 1.f)
+                buffer += sprintf(buffer, ":%g:%g", e.eltwise.beta, e.eltwise.scale);
             break;
         default:
             assert(!"unknown kind");
@@ -372,9 +440,17 @@ mkldnn_primitive_attr_t create_mkldnn_attr(const attr_t &attr, int scale_cnt,
                 DNN_SAFE_V(mkldnn_post_ops_append_sum(ops, e.sum.scale));
                 break;
             case attr_t::post_ops_t::RELU:
+            case attr_t::post_ops_t::TANH:
+            case attr_t::post_ops_t::ELU:
+            case attr_t::post_ops_t::SQUARE:
+            case attr_t::post_ops_t::ABS:
+            case attr_t::post_ops_t::SQRT:
+            case attr_t::post_ops_t::LINEAR:
+            case attr_t::post_ops_t::BRELU:
+            case attr_t::post_ops_t::SRELU:
+            case attr_t::post_ops_t::LOGISTIC:
                 DNN_SAFE_V(mkldnn_post_ops_append_eltwise(ops, e.eltwise.scale,
-                            mkldnn_eltwise_relu, e.eltwise.alpha,
-                            e.eltwise.beta));
+                            e.eltwise.alg, e.eltwise.alpha, e.eltwise.beta));
                 break;
             default:
                 assert(!"unknown attr::post_ops::kind");