Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / rnn.cpp
index d940831..526b2da 100644 (file)
@@ -35,6 +35,30 @@ namespace rnn {
 
 #define CALL_MKLDNN_RNN 1
 
+mkldnn_primitive_attr_t create_mkldnn_rnn_attr(const rnn_prb_t *p) {
+    mkldnn_primitive_attr_t mkldnn_attr = NULL;
+
+    DNN_SAFE_V(mkldnn_primitive_attr_create(&mkldnn_attr));
+    if (p->attr.irmode != attr_t::round_mode_t::NEAREST)
+        DNN_SAFE_V(mkldnn_primitive_attr_set_int_output_round_mode(
+                mkldnn_attr, (mkldnn_round_mode_t)p->attr.irmode));
+
+    if (p->scale_policy == PER_OC) {
+        DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
+                mkldnn_attr, p->dic * p->n_gates(), 0x3, p->wei_oc_scales));
+    } else if (p->scale_policy == COMMON && p->wei_scale != 1.) {
+        DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_weights_qparams(
+                mkldnn_attr, 1, 0, &p->wei_scale));
+    }
+
+    if (p->data_scale != 1.0 || p->data_shift != 0.0) {
+        DNN_SAFE_V(mkldnn_primitive_attr_set_rnn_data_qparams(
+                mkldnn_attr, p->data_scale, p->data_shift));
+    }
+
+    return mkldnn_attr;
+}
+
 int fill_memory(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem1,
         dnn_mem_t &mem2) {
 #ifdef CALL_MKLDNN_RNN
@@ -43,20 +67,20 @@ int fill_memory(const rnn_prb_t *p, rnn_data_kind_t kind, dnn_mem_t &mem1,
 #else
     const size_t nelems = mem2.nelems();
 #endif
-    size_t nchunks = mkldnn_get_max_threads();
-    size_t chunk_size = (nelems + nchunks - 1) / nchunks;
 
+    dt_conf_t c = p->cfg[kind];
+    float mean = c.f_mean, var = c.f_var, min = c.f_min, max = c.f_max;
     mkldnn::impl::parallel(0, [&](int ithr, int nthr) {
+        size_t chunk_size = (nelems + nthr - 1) / nthr;
         size_t idx_start = ithr * chunk_size;
         size_t idx_end = MIN2(idx_start + chunk_size, nelems);
-
         std::minstd_rand msr;
-        std::normal_distribution<float> gen(.0f, .001f);
+        msr.seed((unsigned long int)kind);
+        std::normal_distribution<float> gen(mean, var);
         msr.discard(idx_start);
-
-        for (size_t idx = idx_start; idx < idx_end; ++idx){
-            auto val = gen(msr);
-            mem2.set_elem(idx, MAX2(MIN2(val, 1.0f), -1.0f));
+        for (size_t idx = idx_start; idx < idx_end; ++idx) {
+            auto val = (c.dt == mkldnn_f32) ? gen(msr) : round(gen(msr));
+            mem2.set_elem(idx, MAX2(MIN2(val, max), min));
         }
     });
 
@@ -88,23 +112,20 @@ inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
     mkldnn_dims_t bias_dims
             = { p->n_layer, p->n_directions(), p->n_gates() + is_gru_lbr, p->dic };
     // mkldnn_tnc
-    int lastlay_dlc = (p->direction == mkldnn_bidirectional_concat) ?
-            2 * p->dlc :
-            p->dlc;
+    int lastlay_dlc = (p->direction == mkldnn_bidirectional_concat)
+            ? 2 * p->dlc
+            p->dlc;
     mkldnn_dims_t dst_last_layer_dims = { p->n_iter, p->mb, lastlay_dlc };
 
     DNN_SAFE(mkldnn_memory_desc_init(
-                     &input_d, 3, input_dims, p->cfg[SRC].dt, mkldnn_tnc),
+                     &input_d, 3, input_dims, p->cfg[input].dt, mkldnn_tnc),
             WARN);
     input_d.layout_desc.blocking.strides[0][0] += the_stride;
-    DNN_SAFE(mkldnn_memory_desc_init(
-                     &diff_input_d, 3, input_dims, p->cfg[SRC].dt, mkldnn_any),
-            WARN);
 
     mkldnn_dims_t states_dims
             = { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->sic };
-    DNN_SAFE(mkldnn_memory_desc_init(
-                     &states_d, 5, states_dims, p->cfg[SRC].dt, mkldnn_ldsnc),
+    DNN_SAFE(mkldnn_memory_desc_init(&states_d, 5, states_dims,
+                     p->cfg[states].dt, mkldnn_ldsnc),
             WARN);
 
     states_d.layout_desc.blocking.strides[0][3] = p->sic + the_stride;
@@ -116,43 +137,28 @@ inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
                 = states_d.layout_desc.blocking.strides[0][d + 1]
                 * states_d.dims[d + 1];
 
-    DNN_SAFE(mkldnn_memory_desc_init(&diff_states_d, 5, states_dims,
-                     p->cfg[SRC].dt, mkldnn_any),
-            WARN);
-
     DNN_SAFE(mkldnn_memory_desc_init(&weights_input_d, 5, weights_input_dims,
-                     p->cfg[SRC].dt, mkldnn_any),
-            WARN);
-    DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_input_d, 5,
-                     weights_input_dims, p->cfg[SRC].dt, mkldnn_any),
+                     p->cfg[weights_input].dt, mkldnn_any),
             WARN);
 
     DNN_SAFE(mkldnn_memory_desc_init(&weights_states_d, 5, weights_states_dims,
-                     p->cfg[SRC].dt, mkldnn_any),
-            WARN);
-    DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_states_d, 5,
-                     weights_states_dims, p->cfg[SRC].dt, mkldnn_any),
+                     p->cfg[weights_states].dt, mkldnn_any),
             WARN);
 
     DNN_SAFE(mkldnn_memory_desc_init(
-                     &bias_d, 4, bias_dims, p->cfg[SRC].dt, mkldnn_any),
-            WARN);
-    DNN_SAFE(mkldnn_memory_desc_init(
-                     &diff_bias_d, 4, bias_dims, p->cfg[SRC].dt, mkldnn_any),
+                     &bias_d, 4, bias_dims, p->cfg[bias].dt, mkldnn_any),
             WARN);
 
     DNN_SAFE(mkldnn_memory_desc_init(&dst_last_layer_d, 3, dst_last_layer_dims,
-                     p->cfg[SRC].dt, mkldnn_tnc),
+                     p->cfg[dst_last_layer].dt, mkldnn_tnc),
             WARN);
     dst_last_layer_d.layout_desc.blocking.strides[0][0] += the_stride;
-    DNN_SAFE(mkldnn_memory_desc_init(&diff_last_layer_d, 3, dst_last_layer_dims,
-                     p->cfg[SRC].dt, mkldnn_any),
-            WARN);
 
     mkldnn_dims_t dst_last_iteration_dims
             = { p->n_layer, p->n_directions(), p->n_states(), p->mb, p->dic };
     DNN_SAFE(mkldnn_memory_desc_init(&dst_last_iteration_d, 5,
-                     dst_last_iteration_dims, p->cfg[SRC].dt, mkldnn_ldsnc),
+                     dst_last_iteration_dims, p->cfg[dst_last_iteration].dt,
+                     mkldnn_ldsnc),
             WARN);
 
     dst_last_iteration_d.layout_desc.blocking.strides[0][3]
@@ -166,10 +172,6 @@ inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
                 = dst_last_iteration_d.layout_desc.blocking.strides[0][d + 1]
                 * dst_last_iteration_d.dims[d + 1];
 
-    DNN_SAFE(mkldnn_memory_desc_init(&diff_last_iteration_d, 5,
-                     dst_last_iteration_dims, p->cfg[SRC].dt, mkldnn_any),
-            WARN);
-
     mkldnn_alg_kind_t kind = alg2kind(p->alg);
     mkldnn_alg_kind_t f = activation2kind(p->activation);
 
@@ -179,14 +181,43 @@ inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
     // When inference, we use forward_inference
     // When training, we use forward_training
     {
-        DNN_SAFE(mkldnn_rnn_forward_desc_init(&rd[0], fwd_prop, &rcd,
+        mkldnn_status_t init_status = mkldnn_success;
+        init_status = mkldnn_rnn_forward_desc_init(&rd[0], fwd_prop, &rcd,
                          p->direction, &input_d, &states_d, &weights_input_d,
                          &weights_states_d, &bias_d, &dst_last_layer_d,
-                         &dst_last_iteration_d),
-                WARN);
+                         &dst_last_iteration_d);
+        if (init_status == mkldnn_unimplemented)
+            return r->state = UNIMPLEMENTED, OK;
+        else
+            SAFE(init_status, WARN);
     }
 
     if (is_bwd) {
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_input_d, 3, input_dims,
+                         p->cfg[dst_diff_input].dt, mkldnn_any),
+                WARN);
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_states_d, 5, states_dims,
+                         p->cfg[dst_diff_states].dt, mkldnn_any),
+                WARN);
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_input_d, 5,
+                         weights_input_dims, p->cfg[dst_diff_weights_input].dt,
+                         mkldnn_any),
+                WARN);
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_weights_states_d, 5,
+                         weights_states_dims,
+                         p->cfg[dst_diff_weights_states].dt, mkldnn_any),
+                WARN);
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_bias_d, 4, bias_dims,
+                         p->cfg[dst_diff_bias].dt, mkldnn_any),
+                WARN);
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_last_layer_d, 3,
+                         dst_last_layer_dims, p->cfg[diff_last_layer].dt,
+                         mkldnn_any),
+                WARN);
+        DNN_SAFE(mkldnn_memory_desc_init(&diff_last_iteration_d, 5,
+                         dst_last_iteration_dims,
+                         p->cfg[diff_last_iteration].dt, mkldnn_any),
+                WARN);
         DNN_SAFE(mkldnn_rnn_backward_desc_init(&rd[1], p->prop, &rcd,
                          p->direction, &input_d, &states_d, &weights_input_d,
                          &weights_states_d, &bias_d, &dst_last_layer_d,
@@ -196,17 +227,17 @@ inline int init_pd(const rnn_prb_t *p, mkldnn_rnn_desc_t rd[2],
                          &diff_last_iteration_d),
                 WARN);
     }
+    auto mkldnn_attr = create_mkldnn_rnn_attr(p);
     mkldnn_status_t init_status = mkldnn_success;
     for (int i = 0; i < 1 + (int)is_bwd; i++) {
-        init_status = mkldnn_primitive_desc_create(
-                &(rpd[i]), &(rd[i]), engine, NULL);
+        init_status = mkldnn_primitive_desc_create_v2(
+                &(rpd[i]), &(rd[i]), mkldnn_attr, engine, NULL);
         if (init_status == mkldnn_unimplemented)
             return r->state = UNIMPLEMENTED, OK;
         else
             SAFE(init_status, WARN);
     }
-
-    // const char *impl_str = query_impl_info(rpd);
+    mkldnn_primitive_attr_destroy(mkldnn_attr);
 
     auto q = [=](mkldnn_query_t query, int rpd_idx, int index = 0) {
         return *mkldnn_primitive_desc_query_memory_d(
@@ -311,13 +342,17 @@ int doit(const rnn_prb_t *p, res_t *r) {
     auto &diff_dst_layer_dt_d = rd[1].diff_dst_layer_desc;
     auto &diff_dst_iter_dt_d = rd[1].diff_dst_iter_desc;
 
-    input_dt = new dnn_mem_t(input_dt_d, fp);
-    states_dt = new dnn_mem_t(states_dt_d, fp);
-    weights_input_dt = new dnn_mem_t(weights_input_dt_d, fp);
-    weights_states_dt = new dnn_mem_t(weights_states_dt_d, fp);
-    bias_dt = new dnn_mem_t(bias_dt_d, fp);
-    dst_last_layer_dt = new dnn_mem_t(dst_last_layer_dt_d, fp);
-    dst_last_iteration_dt = new dnn_mem_t(dst_last_iteration_dt_d, fp);
+    input_dt = new dnn_mem_t(input_dt_d, p->cfg[input].dt);
+    states_dt = new dnn_mem_t(states_dt_d, p->cfg[states].dt);
+    weights_input_dt
+            = new dnn_mem_t(weights_input_dt_d, p->cfg[weights_input].dt);
+    weights_states_dt
+            = new dnn_mem_t(weights_states_dt_d, p->cfg[weights_states].dt);
+    bias_dt = new dnn_mem_t(bias_dt_d, p->cfg[bias].dt);
+    dst_last_layer_dt
+            = new dnn_mem_t(dst_last_layer_dt_d, p->cfg[dst_last_layer].dt);
+    dst_last_iteration_dt = new dnn_mem_t(
+            dst_last_iteration_dt_d, p->cfg[dst_last_iteration].dt);
 
     if (is_bwd) {
         bwd_weights_input_dt = new dnn_mem_t(bwd_weights_input_dt_d, fp);
@@ -417,8 +452,6 @@ int doit(const rnn_prb_t *p, res_t *r) {
             dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
             dnn_mem_t dst_last_iteration(
                     *dst_last_iteration_dt, fp, mkldnn_ldsnc);
-            SAFE(dst_last_layer.reorder(*dst_last_layer_dt), WARN);
-            SAFE(dst_last_iteration.reorder(*dst_last_iteration_dt), WARN);
             SAFE(compare_dst_last_layer(
                          p, dst_last_layer, *dst_last_layer_fp, r, true),
                     WARN);
@@ -457,8 +490,6 @@ int doit(const rnn_prb_t *p, res_t *r) {
             dnn_mem_t dst_last_layer(*dst_last_layer_dt, fp, mkldnn_tnc);
             dnn_mem_t dst_last_iteration(
                     *dst_last_iteration_dt, fp, mkldnn_ldsnc);
-            SAFE(dst_last_layer.reorder(*dst_last_layer_dt), WARN);
-            SAFE(dst_last_iteration.reorder(*dst_last_iteration_dt), WARN);
             SAFE(compare_dst_last_layer(
                          p, dst_last_layer, *dst_last_layer_fp, r, true),
                     WARN);
@@ -468,8 +499,6 @@ int doit(const rnn_prb_t *p, res_t *r) {
 
             dnn_mem_t diff_input(*dst_diff_input_dt, fp, mkldnn_tnc);
             dnn_mem_t diff_states(*dst_diff_states_dt, fp, mkldnn_ldsnc);
-            SAFE(diff_input.reorder(*dst_diff_input_dt), WARN);
-            SAFE(diff_states.reorder(*dst_diff_states_dt), WARN);
             SAFE(compare_input(p, diff_input, *dst_diff_input_fp, r, true),
                     WARN);
             SAFE(compare_states(p, diff_states, *dst_diff_states_fp, r, true),
@@ -479,9 +508,6 @@ int doit(const rnn_prb_t *p, res_t *r) {
                     *dst_diff_weights_input_dt, fp, mkldnn_ldigo);
             dnn_mem_t diff_weights_states(
                     *dst_diff_weights_states_dt, fp, mkldnn_ldigo);
-            SAFE(diff_weights_input.reorder(*dst_diff_weights_input_dt), WARN);
-            SAFE(diff_weights_states.reorder(*dst_diff_weights_states_dt),
-                    WARN);
             SAFE(compare_weights_input(p, diff_weights_input,
                          *dst_diff_weights_input_fp, r, true),
                     WARN);
@@ -490,7 +516,6 @@ int doit(const rnn_prb_t *p, res_t *r) {
                     WARN);
 
             dnn_mem_t diff_bias(*dst_diff_bias_dt, fp, mkldnn_ldgo);
-            SAFE(diff_bias.reorder(*dst_diff_bias_dt), WARN);
             SAFE(compare_bias(p, diff_bias, *dst_diff_bias_fp, r, true), WARN);
         }
     }