Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / rnn / ref_rnn.cpp
index ed668c1..9bb9a1f 100644 (file)
@@ -52,8 +52,8 @@ float activation(activation_t f, float x, bool is_fwd = true) {
     float result = 0;
     switch (f) {
     case RELU: result = is_fwd ? relu(x) : drelu(x); break;
-    case LOGISTIC: result = is_fwd ? logistic(x) : dlogistic(x); break;
-    case TANH: result = is_fwd ? tanhf(x) : dtanhf(x); break;
+    case LOGISTIC: result = is_fwd ? logistic(x) : x_m_square(x); break;
+    case TANH: result = is_fwd ? tanhf(x) : one_m_square(x); break;
     default: assert(!"unknown activation");
     }
     return result;
@@ -164,8 +164,8 @@ void gru_lbr_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
 }
 
 // w = [weights_layer | weights_iter] : with order f, i , o, \bar(c)
-void lstm_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
-        float *dst_iter_h_, float *c_dst_, float *gates_,
+void lstm_fwd(const rnn_prb_t *p, int sic, int slc, int dic, int wc, int batch,
+        int n_gates, float *dst_iter_h_, float *c_dst_, float *gates_,
         const float *weights_layer_, const float *weights_iter_h_,
         const float *bias_, const float *src_layer_, const float *src_iter_h_,
         const float *src_iter_c_) {
@@ -182,34 +182,64 @@ void lstm_fwd(int sic, int slc, int dic, int wc, int batch, int n_gates,
 
     gemm("C", "N", "N", batch, n_gates * dic, slc, 1.0, src_layer_, wc,
             weights_layer_, n_gates * dic, 0.0, gates_, n_gates * dic);
-    gemm("C", "N", "N", batch, n_gates * dic, sic,1.0, src_iter_h_, wc,
+    gemm("C", "N", "N", batch, n_gates * dic, sic, 1.0, src_iter_h_, wc,
             weights_iter_h_, n_gates * dic, 1.0, gates_, n_gates * dic);
 
+    auto maybe_deq_w = [&](float g, int oc) {
+        if (p->cfg == conf_f32)
+            return g;
+        float scale = 1.;
+        if (p->scale_policy == PER_OC)
+            scale = p->wei_oc_scales[oc];
+        else if (p->scale_policy == COMMON)
+            scale = p->wei_scale;
+        scale *= p->data_scale;
+        return g / scale;
+    };
+
     // add bias
     for (int i = 0; i < batch; i++)
         for (int j = 0; j < n_gates; j++)
             for (int k = 0; k < dic; k++) {
-                gates(i, j, k) += bias(j, k);
+                gates(i, j, k)
+                        = maybe_deq_w(gates(i, j, k), j * dic + k) + bias(j, k);
             }
 
     // run the eltwise
     lstm_activation(dic, n_gates, batch, gates_);
 
+    auto maybe_q_d = [&](float h) {
+        if (p->cfg == conf_f32)
+            return h;
+        float fp = p->data_scale * h;
+        using R = attr_t::round_mode_t;
+        switch (p->attr.irmode) {
+        case R::DOWN: fp = floorf(fp); break;
+        case R::NEAREST: fp = nearbyintf(fp); break;
+        default: assert(!"unkown round mode");
+        }
+        if (fp + p->data_shift > p->cfg[input].max)
+            fp = p->cfg[input].max - p->data_shift;
+        if (fp + p->data_shift < p->cfg[input].min)
+            fp = p->cfg[input].min - p->data_shift;
+        return fp;
+    };
+
     // compute C_t_l and H_t_l
     for (int i = 0; i < batch; i++)
         for (int j = 0; j < dic; j++) {
             float tmp = gates(i, ohf, j) * src_iter_c(i, j)
                     + gates(i, ohi, j) * gates(i, ohc, j);
             c_dst(i, j) = tmp;
-            h_dst(i, j) = gates(i, oho, j) * tanhf(tmp);
+            h_dst(i, j) = maybe_q_d(gates(i, oho, j) * tanhf(tmp));
         }
 }
 
-void rnn_cell_fwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
-        int batch, int n_gates, float *dst_iter_h, float *dst_iter_c,
-        float *gates, const float *weights_layer, const float *weights_iter,
-        const float *bias, const float *src_layer, const float *src_iter_h,
-        const float *src_iter_c, float *ws_local_) {
+void rnn_cell_fwd(const rnn_prb_t *p, alg_t alg, activation_t f, int sic,
+        int slc, int dic, int wc, int batch, int n_gates, float *dst_iter_h,
+        float *dst_iter_c, float *gates, const float *weights_layer,
+        const float *weights_iter, const float *bias, const float *src_layer,
+        const float *src_iter_h, const float *src_iter_c, float *ws_local_) {
     switch (alg) {
     case VANILLA_GRU:
         gru_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, gates,
@@ -221,7 +251,7 @@ void rnn_cell_fwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
                 ws_local_);
         break;
     case VANILLA_LSTM:
-        lstm_fwd(sic, slc, dic, wc, batch, n_gates, dst_iter_h, dst_iter_c,
+        lstm_fwd(p, sic, slc, dic, wc, batch, n_gates, dst_iter_h, dst_iter_c,
                 gates, weights_layer, weights_iter, bias, src_layer, src_iter_h,
                 src_iter_c);
         break;
@@ -232,6 +262,7 @@ void rnn_cell_fwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
     default: break;
     }
 }
+
 void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_,
         float *dst_, rnn_action_t action = action_copy) {
     AOC<const float> src(src_, dimc, ld_src);
@@ -245,86 +276,212 @@ void copy(int dimc, int dimr, int ld_src, int ld_dst, const float *src_,
     });
 }
 
-/* FIXME: separate copy_init ???
- * fwd: ws_states = n_states
- * bwd: ws_states = n_states + 1
- *
- * lstm example:
+void shift(int dimc, int dimr, int ld_src, float *src_, float shift,
+        bool round = false, const rnn_prb_t *p = nullptr) {
+    AOC<float> src(src_, dimc, ld_src);
+    mkldnn::impl::parallel_nd(dimc, [&](int i) {
+        for (int j = 0; j < dimr; j++) {
+            float fp = src(i, j) + shift;
+            if (round) {
+                using R = attr_t::round_mode_t;
+                switch (p->attr.irmode) {
+                case R::DOWN: fp = floorf(fp); break;
+                case R::NEAREST: fp = nearbyintf(fp); break;
+                default: assert(!"unkown round mode");
+                }
+                if (fp > UINT8_MAX)
+                    fp = UINT8_MAX;
+                if (fp < 0)
+                    fp = 0;
+            }
+            src(i, j) = fp;
+        }
+    });
+}
+
+void scale(int dimc, int dimr, int ld_src, float *src_, float scale,
+        bool round = false, const rnn_prb_t *p = nullptr) {
+    AOC<float> src(src_, dimc, ld_src);
+    mkldnn::impl::parallel_nd(dimc, [&](int i) {
+        for (int j = 0; j < dimr; j++) {
+            float fp = src(i, j) * scale;
+            if (round) {
+                using R = attr_t::round_mode_t;
+                switch (p->attr.irmode) {
+                case R::DOWN: fp = floorf(fp); break;
+                case R::NEAREST: fp = nearbyintf(fp); break;
+                default: assert(!"unkown round mode");
+                }
+            }
+            src(i, j) = fp;
+        }
+    });
+}
+
+/* lstm example:
  * fwd: ws keeps {h, c} for every cell
- * bwd: wsb keeps {dh, dc, dx} for every cell
  */
-void copy_init(alg_t alg, int sic, int slc, int dic, int dlc, int wc, int batch,
-        int n_layer, int n_iter, int n_states, float *ws_,
+void copy_init_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
+        int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
+        int n_states, float *ws_, const float *src_layer_,
+        const float *firstit_states_, rnn_iter_direction_t iter_dir,
+        rnn_layer_direction_t lay_dir, int dir_val) {
+    AOC<float> ws(ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch * wc);
+    AOC<const float> src_layer(src_layer_, n_iter, batch * slc);
+    AOC<const float> firstit_states(
+            firstit_states_, n_layer, n_dir, n_states, batch * sic);
+
+    int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
+    int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
+    bool is_int8 = p->cfg[input].dt == mkldnn_u8;
+
+    // Copy input
+    for (int it = 0; it < n_iter; it++) {
+        copy(batch, slc, slc, wc, &src_layer(it, 0),
+                &ws(lay_dest, dir_val, it + 1, H, 0));
+        if (p->cfg[input].dt == mkldnn_u8)
+            // shift u8 input to s8 to avoid compensation in gemm
+            shift(batch, slc, wc, &ws(lay_dest, dir_val, it + 1, H, 0),
+                    -1. * p->data_shift);
+    }
+
+    // Copy states
+    for (int lay = 0; lay < n_layer; lay++) {
+        copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, H, 0),
+                &ws(lay + 1, dir_val, it_dest, H, 0));
+        if (p->cfg[states].dt == mkldnn_u8)
+            shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
+                    -1. * p->data_shift);
+        else if (p->cfg[states].dt == mkldnn_f32 && is_int8) {
+            // quantize to s8
+            scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, H, 0),
+                    p->data_scale, true, p);
+        }
+
+        if (alg == VANILLA_LSTM) {
+            copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, C, 0),
+                    &ws(lay + 1, dir_val, it_dest, C, 0));
+            if (p->cfg[states].dt == mkldnn_u8) {
+                // dequantize to f32
+                shift(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
+                        -1. * p->data_shift);
+                scale(batch, sic, wc, &ws(lay + 1, dir_val, it_dest, C, 0),
+                        1. / p->data_scale);
+            }
+        }
+    }
+}
+
+/* lstm example:
+ * bwd: wsb keeps {dh, dc, dx} for every cell
+*/
+void copy_init_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
+        int batch, int n_layer, int n_iter, int n_dir, int n_states, float *ws_,
         const float *src_layer_, const float *firstit_states_,
         rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
-        int dir_val, int n_dir, bool is_bwd = false, bool is_concat = false) {
+        int dir_val, bool is_concat = false) {
     AOC<float> ws(
-            ws_, n_layer + 2, n_dir, n_iter + 2, n_states + is_bwd, batch, wc);
-    auto c_stride = is_bwd ? (is_concat ? 2 * dlc : dlc) : slc;
+            ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch * wc);
+    auto c_stride = is_concat ? 2 * dlc : dlc;
     AOC<const float> src_layer(src_layer_, n_iter, batch * c_stride);
-    AOC<const float> firstit_states(firstit_states_, n_layer, n_dir, n_states,
-            batch, is_bwd ? dic : sic);
+    AOC<const float> firstit_states(
+            firstit_states_, n_layer, n_dir, n_states, batch * dic);
 
     int lay_dest = (lay_dir == bottom2top) ? 0 : n_layer + 1;
     int it_dest = (iter_dir == left2right) ? 0 : n_iter + 1;
 
-    if (!is_bwd) {
-        for (int it = 0; it < n_iter; it++)
-            copy(batch, slc, slc, wc, &src_layer(it, 0),
-                    &ws(lay_dest, dir_val, it + 1, H, 0, 0));
-
-        for (int lay = 0; lay < n_layer; lay++) {
-            copy(batch, sic, sic, wc, &firstit_states(lay, dir_val, H, 0, 0),
-                    &ws(lay + 1, dir_val, it_dest, H, 0, 0));
-            if (alg == VANILLA_LSTM) {
-                copy(batch, sic, sic, wc,
-                        &firstit_states(lay, dir_val, C, 0, 0),
-                        &ws(lay + 1, dir_val, it_dest, C, 0, 0));
+    for (int it = 0; it < n_iter; it++)
+        copy(batch, dic, c_stride, wc,
+                &src_layer(it, dir_val * is_concat * dlc),
+                &ws(lay_dest, dir_val, it + 1, n_states, 0));
+
+    for (int lay = 0; lay < n_layer; lay++) {
+        copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, H, 0),
+                &ws(lay + 1, dir_val, it_dest, H, 0));
+        if (alg == VANILLA_LSTM) {
+            copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, C, 0),
+                    &ws(lay + 1, dir_val, it_dest, C, 0));
+        }
+    }
+}
+
+void copy_res_fwd(const rnn_prb_t *p, alg_t alg, int sic, int slc, int dic,
+        int dlc, int wc, int batch, int n_layer, int n_iter, int n_dir,
+        int n_states, float *lastit_states_, float *lastlay_states_,
+        const float *ws_, rnn_iter_direction_t iter_dir,
+        rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action,
+        bool is_concat = false) {
+    int lastlay_c = is_concat ? 2 * dlc : dlc;
+    AOC<float> lastit_states(
+            lastit_states_, n_layer, n_dir, n_states, batch, dic);
+    AOC<float> lastlay_states(lastlay_states_, n_iter, batch, lastlay_c);
+    AOC<const float> ws(
+            ws_, n_layer + 2, n_dir, n_iter + 2, n_states, batch, wc);
+
+    // Copy states layer
+    for (int it = 0; it < n_iter; it++) {
+        for (int nb = 0; nb < batch; nb++) {
+            auto from = &ws(n_layer, dir_val, it + 1, H, nb, 0);
+            auto to = &lastlay_states(
+                    it, nb, action == action_concat ? dlc : 0);
+            copy(1, dlc, wc, lastlay_c, from, to, action);
+
+            if (p->cfg[dst_last_layer].dt == mkldnn_u8) {
+                // shift s8 internal ws to u8
+                shift(1, dlc, lastlay_c, to, p->data_shift);
+            } else {
+                // dequantize to f32
+                scale(1, dlc, lastlay_c, to, 1. / p->data_scale);
             }
         }
-    } else {
-        for (int it = 0; it < n_iter; it++)
-            copy(batch, dic, c_stride, wc,
-                    &src_layer(it, dir_val * is_concat * dlc),
-                    &ws(lay_dest, dir_val, it + 1, n_states, 0, 0));
-
-        for (int lay = 0; lay < n_layer; lay++) {
-            copy(batch, dic, dic, wc, &firstit_states(lay, dir_val, H, 0, 0),
-                    &ws(lay + 1, dir_val, it_dest, H, 0, 0));
-            if (alg == VANILLA_LSTM) {
-                copy(batch, dic, dic, wc,
-                        &firstit_states(lay, dir_val, C, 0, 0),
-                        &ws(lay + 1, dir_val, it_dest, C, 0, 0));
+    }
+
+    int it_source = (iter_dir == left2right) ? n_iter : 1;
+
+    // Copy states iteration
+    for (int lay = 0; lay < n_layer; lay++) {
+        if (alg == VANILLA_LSTM) {
+            copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
+                    &lastit_states(lay, dir_val, C, 0, 0));
+            if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
+                // quantize internal f32 ws to u8
+                scale(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
+                        p->data_scale);
+                shift(batch, dic, dic, &lastit_states(lay, dir_val, C, 0, 0),
+                        p->data_shift, true, p);
             }
         }
+        copy(batch, dic, wc, dic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
+                &lastit_states(lay, dir_val, H, 0, 0));
+        if (p->cfg[dst_last_iteration].dt == mkldnn_u8) {
+            // shift s8 internal ws to u8
+            shift(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
+                    p->data_shift);
+        } else {
+            // dequantize to f32
+            scale(batch, dic, dic, &lastit_states(lay, dir_val, H, 0, 0),
+                    1. / p->data_scale);
+        }
     }
 }
 
-void copy_res(alg_t alg, int sic, int slc, int dic, int dlc, int wc, int batch,
-        int n_layer, int n_iter, int n_states, float *lastit_states_,
-        float *lastlay_states_, const float *ws_,
-        mkldnn_rnn_direction_t direction, rnn_iter_direction_t iter_dir,
-        rnn_layer_direction_t lay_dir, int dir_val, int n_dir,
-        rnn_action_t action, bool is_bwd = false) {
-    int lastlay_c = is_bwd ?
-            slc :
-            (direction == mkldnn_bidirectional_concat) * dlc + dlc;
-    int lastiter_c = is_bwd ? sic : dic;
+void copy_res_bwd(alg_t alg, int sic, int slc, int dic, int dlc, int wc,
+        int batch, int n_layer, int n_iter, int n_dir, int n_states,
+        float *lastit_states_, float *lastlay_states_, const float *ws_,
+        rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
+        int dir_val, rnn_action_t action) {
     AOC<float> lastit_states(
-            lastit_states_, n_layer, n_dir, n_states, batch, lastiter_c);
-    AOC<float> lastlay_states(lastlay_states_, n_iter, batch, lastlay_c);
+            lastit_states_, n_layer, n_dir, n_states, batch, sic);
+    AOC<float> lastlay_states(lastlay_states_, n_iter, batch, slc);
     AOC<const float> ws(
-            ws_, n_layer + 2, n_dir, n_iter + 2, n_states + is_bwd, batch, wc);
+            ws_, n_layer + 2, n_dir, n_iter + 2, n_states + 1, batch, wc);
     for (int it = 0; it < n_iter; it++) {
         for (int nb = 0; nb < batch; nb++) {
             // copy H to last layer states
-            int lay = is_bwd ? 1 : n_layer;
-            int state = is_bwd ? n_states : H;
-            auto from = &ws(lay, dir_val, it + 1, state, nb, 0);
-            auto to = &lastlay_states(
-                    it, nb, (action == action_concat) && (!is_bwd) ? dlc : 0);
+            auto from = &ws(1, dir_val, it + 1, n_states, nb, 0);
+            auto to = &lastlay_states(it, nb, 0);
 
-            copy(1, is_bwd ?  slc : dlc, wc, lastlay_c, from, to, action);
+            copy(1, slc, wc, slc, from, to, action);
         }
     }
 
@@ -332,12 +489,10 @@ void copy_res(alg_t alg, int sic, int slc, int dic, int dlc, int wc, int batch,
 
     for (int lay = 0; lay < n_layer; lay++) {
         if (alg == VANILLA_LSTM) {
-            copy(batch, lastiter_c, wc, lastiter_c,
-                    &ws(lay + 1, dir_val, it_source, C, 0, 0),
+            copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, C, 0, 0),
                     &lastit_states(lay, dir_val, C, 0, 0));
         }
-        copy(batch, lastiter_c, wc, lastiter_c,
-                &ws(lay + 1, dir_val, it_source, H, 0, 0),
+        copy(batch, sic, wc, sic, &ws(lay + 1, dir_val, it_source, H, 0, 0),
                 &lastit_states(lay, dir_val, H, 0, 0));
     }
 }
@@ -355,6 +510,7 @@ void rnn_linear_fwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
     const int dlc = p->dlc;
     const int wc = max(sic, max(slc, dic));
     bool is_lbr = p->alg == LBR_GRU;
+    bool is_concat = direction == mkldnn_bidirectional_concat;
 
     const int batch = p->mb;
     const int n_gates = p->n_gates();
@@ -380,8 +536,9 @@ void rnn_linear_fwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
         // we first need to copy the initial states and input into ws
         // it simplifies the logic in the following code
         print(80, "rnn_linear_fwd: call copy_init dir_val = %d\n", dir_val);
-        copy_init(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states, ws_,
-                src_layer_, src_iter_, iter_dir, lay_dir, dir_val, n_dir);
+        copy_init_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
+                n_dir, n_states, ws_, src_layer_, src_iter_, iter_dir, lay_dir,
+                dir_val);
 
         // We run the grid of computation
         for (int il = 0; il < n_layer; il++) {
@@ -390,7 +547,7 @@ void rnn_linear_fwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
                 int iter = (iter_dir == left2right) ? it + 1 : n_iter - it;
                 int prev_iter = (iter_dir == left2right) ? iter - 1 : iter + 1;
                 int lay = il + 1;
-                rnn_cell_fwd(alg, f, sic, slc, dic, wc, batch, n_gates,
+                rnn_cell_fwd(p, alg, f, sic, slc, dic, wc, batch, n_gates,
                         &ws(lay, dir_val, iter, H, 0, 0),
                         &ws(lay, dir_val, iter, C, 0, 0),
                         &gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
@@ -399,15 +556,14 @@ void rnn_linear_fwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
                         &bias(lay - 1, dir_val, 0),
                         &ws(lay - 1, dir_val, iter, H, 0, 0),
                         &ws(lay, dir_val, prev_iter, H, 0, 0),
-                        &ws(lay, dir_val, prev_iter, C, 0, 0),
-                        ws_local_);
+                        &ws(lay, dir_val, prev_iter, C, 0, 0), ws_local_);
             }
         }
 
         // Finally we copy the results to the result buffers
-        copy_res(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states,
-                dst_iter_, dst_layer_, ws_, direction, iter_dir, lay_dir,
-                dir_val, n_dir, action);
+        copy_res_fwd(p, alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
+                n_dir, n_states, dst_iter_, dst_layer_, ws_, iter_dir, lay_dir,
+                dir_val, action, is_concat);
     };
 
     switch (direction) {
@@ -533,7 +689,7 @@ void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch,
             float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
             float c = dst_iter_c(ib, ih);
             float dho = tanhf(c) * dh;
-            b_gates(ib, oho, ih) = dlogistic(ho) * dho;
+            b_gates(ib, oho, ih) = x_m_square(ho) * dho;
 
             float dc_next = diff_dst_iter_c(ib, ih);
             float dc = ho * dh * dtanhf(c) + dc_next;
@@ -541,13 +697,13 @@ void lstm_bwd(alg_t alg, int sic, int slc, int dic, int wc, int batch,
 
             float c_old = src_iter_c(ib, ih);
             float dhf = c_old * dc;
-            b_gates(ib, ohf, ih) = dlogistic(hf) * dhf;
+            b_gates(ib, ohf, ih) = x_m_square(hf) * dhf;
 
             float dhi = hc * dc;
-            b_gates(ib, ohi, ih) = dlogistic(hi) * dhi;
+            b_gates(ib, ohi, ih) = x_m_square(hi) * dhi;
 
             float dhc = hi * dc;
-            b_gates(ib, ohc, ih) = dtanhf(hc) * dhc;
+            b_gates(ib, ohc, ih) = one_m_square(hc) * dhc;
         }
 
     gemm("C", "T", "N", sic, n_gates * dic, batch, 1.0, src_iter_h_, wc, b_gates_,
@@ -592,10 +748,10 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
     AOC<float> dhr(dhr_, batch, wc);
     AOC<float> hr(hr_, batch, wc);
 
-// dc = (1 - u) * dh; dc^ = dtanhf(c) * dc;
-// du = (h - u) * dh; du^ = dlogistic(u) * du;
+// dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
+// du = (h - u) * dh; du^ = x_m_square(u) * du;
 // dhr = Wc dc^;
-// dr = h * dhr; dr^ = dlogistic(r) * dr;
+// dr = h * dhr; dr^ = x_m_square(r) * dr;
     const int ohu = 0;
     const int ohr = 1;
     const int ohc = 2;
@@ -607,12 +763,12 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
             float dh = diff_dst_layer(ib, ih) + diff_dst_iter_h(ib, ih);
             float du = (h - c) * dh;
             float dc = (1.0f - u) * dh;
-            b_gates(ib, ohu, ih) = dlogistic(u) * du;
-            b_gates(ib, ohc, ih) = dtanhf(c) * dc;
+            b_gates(ib, ohu, ih) = x_m_square(u) * du;
+            b_gates(ib, ohc, ih) = one_m_square(c) * dc;
             diff_src_iter(ib, ih) = dh * u;
         }
-    gemm("C", "N", "T", batch, slc, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
-            &(weights_layer(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
+    gemm("C", "N", "T", batch, sic, dic, 1.0, &(b_gates(0, 2, 0)), n_gates * dic,
+            &(weights_iter_h(0, 2, 0)), n_gates * dic, 0.0, dhr_, wc);
 
     for (int ib = 0; ib < batch; ib++)
         for (int ih = 0; ih < dic; ih++) {
@@ -621,7 +777,7 @@ void gru_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
             float dr = h * dhr(ib, ih);
             hr(ib, ih) = h * r;
             diff_src_iter(ib, ih) += dhr(ib, ih) * r;
-            b_gates(ib, ohr, ih) = dlogistic(r) * dr;
+            b_gates(ib, ohr, ih) = x_m_square(r) * dr;
         }
 
 // dWx += xdu^ | xdr^ | xdc^
@@ -682,9 +838,9 @@ void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
             &weights_iter_h(0, 2, 0), n_gates * dic, 1.0, Wh_b_, dic);
 
 
-// dc = (1 - u) * dh; dc^ = dtanhf(c) * dc;
-// du = (h - u) * dh; du^ = dlogistic(u) * du;
-// dr = (Wh + b) * dc^; dr^ = dlogistic(r) * dr;
+// dc = (1 - u) * dh; dc^ = one_m_square(c) * dc;
+// du = (h - c) * dh; du^ = x_m_square(u) * du;
+// dr = (Wh + b) * dc^; dr^ = x_m_square(r) * dr;
     const int ohu = 0;
     const int ohr = 1;
     const int ohc = 2;
@@ -698,11 +854,11 @@ void gru_lbr_bwd(alg_t alg, activation_t f, int sic, int slc, int dic, int wc,
             float du = (h - c) * dh;
             float dc = (1.0f - u) * dh;
 
-            b_gates(ib, ohu, ih) = dlogistic(u) * du;
-            b_gates(ib, ohc, ih) = dtanhf(c) * dc;
+            b_gates(ib, ohu, ih) = x_m_square(u) * du;
+            b_gates(ib, ohc, ih) = one_m_square(c) * dc;
 
             float dr = Wh_b(ib, ih) * b_gates(ib, ohc, ih);
-            b_gates(ib, ohr, ih) = dlogistic(r) * dr;
+            b_gates(ib, ohr, ih) = x_m_square(r) * dr;
 
             b_gates_r(ib, ohu, ih) = b_gates(ib, ohu, ih);
             b_gates_r(ib, ohr, ih) = b_gates(ib, ohr, ih);
@@ -841,9 +997,10 @@ void rnn_linear_bwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
             rnn_layer_direction_t lay_dir, int dir_val, rnn_action_t action) {
         // we first need to copy the initial states and input into ws
         // it simplifies the logic in the following code
-        copy_init(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states,
-                wsb_, diff_dst_layer_, diff_dst_iter_, iter_dir, lay_dir,
-                dir_val, n_dir, true, direction == mkldnn_bidirectional_concat);
+        copy_init_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter,
+                n_dir, n_states, wsb_, diff_dst_layer_, diff_dst_iter_,
+                iter_dir, lay_dir, dir_val,
+                direction == mkldnn_bidirectional_concat);
 
         // We run the grid of computation
         for (int j = n_layer - 1; j >= 0; j--) {
@@ -881,9 +1038,9 @@ void rnn_linear_bwd(const rnn_prb_t *p, mkldnn_rnn_direction_t direction,
         }
 
         // Finally we copy the results to the result buffers
-        copy_res(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_states,
-                diff_src_iter_, diff_src_layer_, wsb_, direction, iter_dir,
-                lay_dir, dir_val, n_dir, action, true);
+        copy_res_bwd(alg, sic, slc, dic, dlc, wc, batch, n_layer, n_iter, n_dir,
+                n_states, diff_src_iter_, diff_src_layer_, wsb_, iter_dir,
+                lay_dir, dir_val, action);
     };
 
     switch (direction) {