Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / rnn_aux.cpp
index 124cbec..c6068da 100644 (file)
@@ -39,6 +39,24 @@ alg_t str2alg(const char *str) {
     return VANILLA_RNN;
 }
 
+policy_t str2policy(const char *str) {
+#define CASE(_plc) if (!strcasecmp(STRINGIFY(_plc), str)) return _plc
+    CASE(NONE);
+    CASE(COMMON);
+    CASE(PER_OC);
+#undef CASE
+    assert(!"unknown policy");
+    return NONE;
+}
+
+const char * policy2str(policy_t policy) {
+    if (policy == NONE) return "none";
+    if (policy == COMMON) return "common";
+    if (policy == PER_OC) return "per_oc";
+    assert(!"unknown policy");
+    return "unknown policy";
+}
+
 const char *alg2str(alg_t alg) {
     if (alg == VANILLA_RNN)
         return "VANILLA_RNN";
@@ -99,6 +117,25 @@ mkldnn_alg_kind_t activation2kind(activation_t act) {
     return alg_kind;
 }
 
+mkldnn_prop_kind_t str2prop(const char *str) {
+    if (!strcasecmp("FWD_D", str))
+        return mkldnn_forward;
+    if (!strcasecmp("BWD_D", str))
+        return mkldnn_backward;
+    assert(!"unknown propagation");
+    return mkldnn_forward;
+}
+
+const char *prop2str(mkldnn_prop_kind_t prop) {
+    if (prop == mkldnn_forward)
+        return "FWD_D";
+    if (prop == mkldnn_backward)
+        return "BWD_DW";
+    assert(!"unknown propagation");
+    return "unknown propagation";
+
+}
+
 mkldnn_rnn_direction_t str2direction(const char *str) {
     if (!strcasecmp("left2right", str))
         return mkldnn_unidirectional_left2right;
@@ -185,8 +222,11 @@ int str2desc(rnn_desc_t *desc, const char *str) {
 void prb2str(const rnn_prb_t *p, const res_t *res, char *buffer) {
     int rem_len = max_prb_len;
 
-    DPRINT("%s,%s,%s,", alg2str(p->alg), activation2str(p->activation),
-            direction2str(p->direction));
+    DPRINT("--prop=%s --alg=%s --activation=%s --direction=%s --cfg=%s "
+           "--scaling=%s ",
+            prop2str(p->prop), alg2str(p->alg), activation2str(p->activation),
+            direction2str(p->direction), cfg2str(p->cfg),
+            policy2str(p->scale_policy));
     DPRINT("l%d", p->n_layer);
     DPRINT("t%d", p->n_iter);
     DPRINT("mb%d", p->mb);
@@ -203,10 +243,20 @@ void init_buffer(float *buf, int size, float value) {
 }
 
 float logistic(float x) {
-    return 1.0f / (1.0f + expf(-x));
+    if (x < 0)
+        return (expf(x) / (1 + expf(x)));
+    else
+        return 1.0f - (expf(-x) / (1 + expf(-x)));
 }
 float dlogistic(float x) {
-    return x * (1 - x);
+    float tmp = logistic(x);
+    return tmp * (1 - tmp);
+}
+float dtanhf(float x) {
+    return (1 - tanhf(x)) * (1 + tanhf(x));
+}
+float x_m_square(float x) {
+    return x - x * x;
 }
 float relu(float x) {
     return x > 0 ? x : 0;
@@ -214,8 +264,8 @@ float relu(float x) {
 float drelu(float x) {
     return float(x > 0);
 }
-float dtanhf(float x) {
-    return (1 - x) * (1 + x);
+float one_m_square(float x) {
+    return 1 - x * x;
 }
 
 int compare_dat(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
@@ -414,4 +464,32 @@ int compare_dst_last_iteration(const rnn_prb_t *p, dnn_mem_t &mem_dt,
     return compare_dat(p, dst_last_iteration, mem_dt, mem_fp, r, final_compare);
 }
 
+void rnn_prb_t::set_qparams(float fp_min, float fp_max) {
+    if (cfg == conf_f32) {
+        data_shift = 0.;
+        data_scale = 1.;
+        wei_scale = 1.;
+        return;
+    }
+
+    /* Set parameters for quantization of src and weights from fp32 data
+     * in [-1, 1] to int8 data in a range specified in cfg */
+    float fp_range = fp_max - fp_min;
+    float int8_src_range = cfg[input].f_max - cfg[input].f_min,
+          int8_wei_range = cfg[weights_input].f_max - cfg[weights_input].f_min;
+
+    data_shift = cfg[input].f_mean;
+    data_scale = int8_src_range / fp_range;
+
+    if (scale_policy == COMMON) {
+        wei_scale = int8_wei_range / fp_range;
+    } else if (scale_policy == PER_OC) {
+        float K = int8_wei_range / fp_range;
+        int nelems = dic * n_gates();
+        for (int i = 0; i < nelems; i++) {
+            wei_oc_scales[i] = K * (1. + (float)i / nelems);
+        }
+    }
+}
+
 } // namespace rnn