Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_rnn.cpp
index 105979a..029e3c4 100644 (file)
@@ -20,8 +20,6 @@
 #include <numeric>
 #include <string>
 
-#include "mkl_cblas.h"
-
 #include "mkldnn.hpp"
 
 // MSVC doesn't support collapse clause in omp parallel
@@ -49,6 +47,9 @@ std::vector<float> alignment_model(
 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
 std::vector<float> exp_sums(batch, 1.0f);
 
+const float onef = 1.0, zerof = 0.0;
+const int onei = 1;
+
 void compute_weighted_annotations(float *weighted_annotations,
         int src_seq_length_max, int batch, int feature_size,
         float *weights_annot, float *annotations) {
@@ -56,10 +57,11 @@ void compute_weighted_annotations(float *weighted_annotations,
     // weights_annot is (2c, c)
 
     // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
-    cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, feature_size,
-            src_seq_length_max * batch, feature_size, 1.0f, weights_annot,
-            feature_size, annotations, feature_size, 0.0f, weighted_annotations,
-            feature_size);
+    int num_weighted_annotations = src_seq_length_max * batch;
+    mkldnn_sgemm("N", "N",
+            &feature_size, &num_weighted_annotations, &feature_size,
+            &onef, weights_annot, &feature_size, annotations, &feature_size,
+            &zerof, weighted_annotations, &feature_size);
 }
 
 void compute_attention(float *context_vectors, int src_seq_length_max,
@@ -77,13 +79,16 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
     // p is (n, 1)
 
     // first we precompute the weighted_dec_src_layer
-    cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, feature_size, batch,
-            feature_size, 1.0f, weights_src_layer, feature_size, dec_src_layer,
-            feature_size, 0.0f, weighted_src_layer.data(), feature_size);
+    mkldnn_sgemm("N", "N",
+            &feature_size, &batch, &feature_size, &onef,
+            weights_src_layer, &feature_size, dec_src_layer, &feature_size,
+            &zerof, weighted_src_layer.data(), &feature_size);
 
     // then we compute the alignment model
     float *alignment_model_ptr = alignment_model.data();
+#ifdef _OPENMP
 #pragma omp parallel for collapse(2)
+#endif
     for (int i = 0; i < src_seq_length_max; i++) {
         for (int j = 0; j < batch * feature_size; j++)
             alignment_model_ptr[i * batch * feature_size + j] = tanhf(
@@ -92,15 +97,21 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
     }
 
     // gemv with alignments weights. the resulting alignments are in alignments
-    cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 1,
-            src_seq_length_max * batch, feature_size, 1.0f, weights_alignments,
-            1, alignment_model_ptr, feature_size, 0.0f, alignments.data(), 1);
-
-// softmax on alignments. the resulting context weights are in alignments
+    int num_weighted_annotations = src_seq_length_max * batch;
+    mkldnn_sgemm("N", "N",
+            &onei, &num_weighted_annotations, &feature_size, &onef,
+            weights_alignments, &onei, alignment_model_ptr, &feature_size,
+            &zerof, alignments.data(), &onei);
+
+    // softmax on alignments. the resulting context weights are in alignments
+#ifdef _OPENMP
 #pragma omp parallel for
+#endif
     for (int i = 0; i < batch; i++)
         exp_sums[i] = 0.0f;
+#ifdef _OPENMP
 #pragma omp parallel for collapse(2)
+#endif
     for (int i = 0; i < src_seq_length_max; i++) {
         for (int j = 0; j < batch; j++) {
             alignments[i * batch + j] = expf(alignments[i * batch + j]);
@@ -108,20 +119,26 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
         }
     }
 
+#ifdef _OPENMP
 #pragma omp parallel for collapse(2)
+#endif
     for (int i = 0; i < src_seq_length_max; i++)
         for (int j = 0; j < batch; j++)
             alignments[i * batch + j] /= exp_sums[j];
 
-// then we compute the context vectors
+    // then we compute the context vectors
+#ifdef _OPENMP
 #pragma omp parallel for collapse(2)
+#endif
     for (int i = 0; i < batch; i++)
         for (int j = 0; j < feature_size; j++)
             context_vectors[i * (feature_size + feature_size) + feature_size
                     + j]
                     = 0.0f;
 
+#ifdef _OPENMP
 #pragma omp parallel for collapse(3)
+#endif
     for (int i = 0; i < batch; i++)
         for (int k = 0; k < src_seq_length_max; k++)
             for (int j = 0; j < feature_size; j++)
@@ -133,8 +150,10 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
 
 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
         int feature_size) {
-// we copy the context from the first layer to all other layers
+    // we copy the context from the first layer to all other layers
+#ifdef _OPENMP
 #pragma omp parallel for collapse(3)
+#endif
     for (int k = 1; k < n_layers; k++)
         for (int j = 0; j < batch; j++)
             for (int i = 0; i < feature_size; i++)
@@ -162,6 +181,7 @@ void simple_net() {
         for the context vectors in MKL-DNN yet
      */
 
+    std::vector<primitive> weights_reorders;
     std::vector<primitive> encoder_net;
     std::vector<primitive> decoder_net;
 
@@ -181,8 +201,7 @@ void simple_net() {
     memory::dims enc_bidir_dst_layer_tz
             = { src_seq_length_max, batch, 2 * feature_size };
 
-    /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers
-     */
+    /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers */
 
     std::vector<float> user_enc_bidir_wei_layer(
             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
@@ -193,7 +212,7 @@ void simple_net() {
     std::vector<float> user_enc_bidir_bias(
             enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
 
-    // We create the memory descriptors used by the user
+    /* Create the memory for user data */
     auto user_enc_bidir_src_layer_md = mkldnn::memory::desc(
             { enc_bidir_src_layer_tz }, mkldnn::memory::data_type::f32,
             mkldnn::memory::format::tnc);
@@ -209,11 +228,6 @@ void simple_net() {
     auto user_enc_bidir_bias_md = mkldnn::memory::desc({ enc_bidir_bias_tz },
             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
 
-    auto enc_bidir_dst_layer_md = mkldnn::memory::desc(
-            { enc_bidir_dst_layer_tz }, mkldnn::memory::data_type::f32,
-            mkldnn::memory::format::tnc);
-
-    /* We create memories */
     auto user_enc_bidir_src_layer_memory = mkldnn::memory(
             { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
     auto user_enc_bidir_wei_layer_memory
@@ -225,40 +239,57 @@ void simple_net() {
     auto user_enc_bidir_bias_memory = mkldnn::memory(
             { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
 
-#if 0
-    /// These will be null memories
-    /// @todo introduce predefined null_memory() ?
-    auto enc_bidir_src_iter_memory = mkldnn::memory({enc_bidir_src_iter_md, cpu_engine});
-    auto enc_bidir_dst_iter_memory = mkldnn::memory({enc_bidir_dst_iter_md, cpu_engine});
-#endif
+    /* Create memory descriptors for RNN data w/o specified layout */
+    auto enc_bidir_wei_layer_md = memory::desc({ enc_bidir_weights_layer_tz },
+            memory::data_type::f32, memory::format::any);
+
+    auto enc_bidir_wei_iter_md = memory::desc({ enc_bidir_weights_iter_tz },
+            memory::data_type::f32, memory::format::any);
 
-    /// @todo fix this once cell desc is merged with rnn_desc
+    auto enc_bidir_dst_layer_md = memory::desc({ enc_bidir_dst_layer_tz },
+            memory::data_type::f32, memory::format::any);
+
+    /* Create bidirectional RNN */
     rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
     rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
             rnn_direction::bidirectional_concat, user_enc_bidir_src_layer_md,
-            zero_md(), user_enc_bidir_wei_layer_md, user_enc_bidir_wei_iter_md,
+            zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
             user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
 
     auto enc_bidir_prim_desc
             = mkldnn::rnn_forward::primitive_desc(bi_layer_desc, cpu_engine);
 
-    // there are currently no reorders
-    /// @todo add a reorder when they will be available
+    /* Create memory primitives for input data and use reorders to reorder
+     * user data to internal representation
+     */
+    auto enc_bidir_wei_layer_memory
+            = memory(enc_bidir_prim_desc.weights_layer_primitive_desc());
+    auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
+            user_enc_bidir_wei_layer_memory.get_primitive_desc(),
+            enc_bidir_wei_layer_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(enc_bidir_wei_layer_reorder_pd,
+            user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory));
+
+    auto enc_bidir_wei_iter_memory
+            = memory(enc_bidir_prim_desc.weights_iter_primitive_desc());
+    auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
+            user_enc_bidir_wei_iter_memory.get_primitive_desc(),
+            enc_bidir_wei_iter_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(enc_bidir_wei_iter_reorder_pd,
+            user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory));
 
     auto enc_bidir_dst_layer_memory
             = mkldnn::memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
 
     encoder_net.push_back(
             rnn_forward(enc_bidir_prim_desc, user_enc_bidir_src_layer_memory,
-                    null_memory_, user_enc_bidir_wei_layer_memory,
-                    user_enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
+                    null_memory_, enc_bidir_wei_layer_memory,
+                    enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
                     enc_bidir_dst_layer_memory, null_memory_, null_memory_));
 
-    /* GNMT encoder: unidirectional layers
-     */
-    // First unidirectinal layer, the scaling from 2*feature size features
-    // comming from the previous layer come
-    /// memories
+    /* GNMT encoder: unidirectional layers */
+    // First unidirectinal layer scales 2 * feature_size output of bidirectional
+    // layer to feature_size output
     std::vector<float> user_enc_uni_first_wei_layer(
             1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
     std::vector<float> user_enc_uni_first_wei_iter(
@@ -282,13 +313,9 @@ void simple_net() {
     auto user_enc_uni_first_bias_md = mkldnn::memory::desc(
             { user_enc_uni_first_bias_dims }, mkldnn::memory::data_type::f32,
             mkldnn::memory::format::ldgo);
-    auto enc_uni_first_dst_layer_md = mkldnn::memory::desc(
-            { enc_uni_first_dst_layer_dims }, mkldnn::memory::data_type::f32,
-            mkldnn::memory::format::tnc);
     auto user_enc_uni_first_wei_layer_memory
             = mkldnn::memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
                     user_enc_uni_first_wei_layer.data());
-    ;
     auto user_enc_uni_first_wei_iter_memory
             = mkldnn::memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
                     user_enc_uni_first_wei_iter.data());
@@ -296,29 +323,55 @@ void simple_net() {
             = mkldnn::memory({ user_enc_uni_first_bias_md, cpu_engine },
                     user_enc_uni_first_bias.data());
 
+    auto enc_uni_first_wei_layer_md
+            = memory::desc({ user_enc_uni_first_wei_layer_dims },
+                    memory::data_type::f32, memory::format::any);
+    auto enc_uni_first_wei_iter_md
+            = memory::desc({ user_enc_uni_first_wei_iter_dims },
+                    memory::data_type::f32, memory::format::any);
+    auto enc_uni_first_dst_layer_md
+            = memory::desc({ enc_uni_first_dst_layer_dims },
+                    memory::data_type::f32, memory::format::any);
+
     /// @todo add suport for residual connections
     /// should it be a set residual in op_desc or a field to set manually?
     /// should be an integer to specify at which layer to start
     rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
     rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
             enc_uni_first_cell, rnn_direction::unidirectional_left2right,
-            enc_bidir_dst_layer_md, zero_md(), user_enc_uni_first_wei_layer_md,
-            user_enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
+            enc_bidir_dst_layer_md, zero_md(), enc_uni_first_wei_layer_md,
+            enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
             enc_uni_first_dst_layer_md, zero_md());
     auto enc_uni_first_prim_desc = mkldnn::rnn_forward::primitive_desc(
             enc_uni_first_layer_desc, cpu_engine);
+
+    auto enc_uni_first_wei_layer_memory
+            = memory(enc_uni_first_prim_desc.weights_layer_primitive_desc());
+    auto enc_uni_first_wei_layer_reorder_pd = reorder::primitive_desc(
+            user_enc_uni_first_wei_layer_memory.get_primitive_desc(),
+            enc_uni_first_wei_layer_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(enc_uni_first_wei_layer_reorder_pd,
+            user_enc_uni_first_wei_layer_memory,
+            enc_uni_first_wei_layer_memory));
+
+    auto enc_uni_first_wei_iter_memory
+            = memory(enc_uni_first_prim_desc.weights_iter_primitive_desc());
+    auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
+            user_enc_uni_first_wei_iter_memory.get_primitive_desc(),
+            enc_uni_first_wei_iter_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(enc_uni_first_wei_iter_reorder_pd,
+            user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory));
+
     auto enc_uni_first_dst_layer_memory = mkldnn::memory(
             enc_uni_first_prim_desc.dst_layer_primitive_desc());
 
-    /// @todo add a reorder when they will be available
     encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
             enc_bidir_dst_layer_memory, null_memory_,
-            user_enc_uni_first_wei_layer_memory,
-            user_enc_uni_first_wei_iter_memory, user_enc_uni_first_bias_memory,
+            enc_uni_first_wei_layer_memory,
+            enc_uni_first_wei_iter_memory, user_enc_uni_first_bias_memory,
             enc_uni_first_dst_layer_memory, null_memory_, null_memory_));
 
-    // Remainging Unidirectional layers
-    /// memories
+    /* Remainging unidirectional layers */
     std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
                     * feature_size * lstm_n_gates * feature_size, 1.0f);
     std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
@@ -341,43 +394,60 @@ void simple_net() {
             mkldnn::memory::format::ldigo);
     auto user_enc_uni_bias_md = mkldnn::memory::desc({ user_enc_uni_bias_dims },
             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
-    auto enc_dst_layer_md = mkldnn::memory::desc({ enc_dst_layer_dims },
-            mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
     auto user_enc_uni_wei_layer_memory
             = mkldnn::memory({ user_enc_uni_wei_layer_md, cpu_engine },
                     user_enc_uni_wei_layer.data());
-    ;
     auto user_enc_uni_wei_iter_memory
             = mkldnn::memory({ user_enc_uni_wei_iter_md, cpu_engine },
                     user_enc_uni_wei_iter.data());
     auto user_enc_uni_bias_memory = mkldnn::memory(
             { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
 
+    auto enc_uni_wei_layer_md = memory::desc({ user_enc_uni_wei_layer_dims },
+            memory::data_type::f32, memory::format::any);
+    auto enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
+            memory::data_type::f32, memory::format::any);
+    auto enc_dst_layer_md = memory::desc({ enc_dst_layer_dims },
+            memory::data_type::f32, memory::format::any);
+
     /// @todo add suport for residual connections
     /// should it be a set residual in op_desc or a field to set manually?
     /// should be an integer to specify at which layer to start
     rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
     rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
             enc_uni_cell, rnn_direction::unidirectional_left2right,
-            enc_uni_first_dst_layer_md, zero_md(), user_enc_uni_wei_layer_md,
-            user_enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
+            enc_uni_first_dst_layer_md, zero_md(), enc_uni_wei_layer_md,
+            enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
             zero_md());
     auto enc_uni_prim_desc = mkldnn::rnn_forward::primitive_desc(
             enc_uni_layer_desc, cpu_engine);
+
+    auto enc_uni_wei_layer_memory
+            = memory(enc_uni_prim_desc.weights_layer_primitive_desc());
+    auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
+            user_enc_uni_wei_layer_memory.get_primitive_desc(),
+            enc_uni_wei_layer_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(enc_uni_wei_layer_reorder_pd,
+            user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory));
+
+    auto enc_uni_wei_iter_memory
+            = memory(enc_uni_prim_desc.weights_iter_primitive_desc());
+    auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
+            user_enc_uni_wei_iter_memory.get_primitive_desc(),
+            enc_uni_wei_iter_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(enc_uni_wei_iter_reorder_pd,
+            user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory));
+
     auto enc_dst_layer_memory
             = mkldnn::memory(enc_uni_prim_desc.dst_layer_primitive_desc());
 
-    /// @todo add a reorder when they will be available
     encoder_net.push_back(
             rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
-                    null_memory_, user_enc_uni_wei_layer_memory,
-                    user_enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
+                    null_memory_, enc_uni_wei_layer_memory,
+                    enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
                     enc_dst_layer_memory, null_memory_, null_memory_));
 
-    /*
-     * GNMT: decoder with attention mechanism
-     */
-    // user provided memories
+    /* GNMT: decoder with attention mechanism */
     std::vector<float> user_dec_wei_layer(
             dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
             1.0f);
@@ -402,8 +472,7 @@ void simple_net() {
             = { dec_n_layers, 1, lstm_n_gates, feature_size };
 
     memory::dims dec_src_layer_dims = { 1, batch, feature_size };
-    memory::dims dec_dst_layer_dims
-            = { tgt_seq_length_max, batch, feature_size };
+    memory::dims dec_dst_layer_dims = { 1, batch, feature_size };
 
     // We will use the same memory for dec_src_iter and dec_dst_iter
     // However, dec_src_iter has a context vector but not
@@ -434,7 +503,6 @@ void simple_net() {
             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
     auto user_dec_wei_layer_memory = mkldnn::memory(
             { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
-    ;
     auto user_dec_wei_iter_memory = mkldnn::memory(
             { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
     auto user_dec_bias_memory = mkldnn::memory(
@@ -444,6 +512,12 @@ void simple_net() {
     auto dec_src_layer_memory
             = mkldnn::memory({ dec_src_layer_md, cpu_engine });
 
+    auto dec_wei_layer_md = mkldnn::memory::desc(
+            { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
+            mkldnn::memory::format::any);
+    auto dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
+            mkldnn::memory::data_type::f32, mkldnn::memory::format::any);
+
     // As mentioned above, we create a view without context out of the
     // memory with context.
     auto dec_dst_iter_memory = mkldnn::memory({ dec_dst_iter_md, cpu_engine });
@@ -457,15 +531,30 @@ void simple_net() {
     rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
     rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
             rnn_direction::unidirectional_left2right, dec_src_layer_md,
-            dec_dst_iter_md, user_dec_wei_layer_md, user_dec_wei_iter_md,
+            dec_dst_iter_md, dec_wei_layer_md, dec_wei_iter_md,
             user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
     auto dec_ctx_prim_desc
             = mkldnn::rnn_forward::primitive_desc(dec_ctx_desc, cpu_engine);
 
-    /// @todo add a reorder when they will be available
+    auto dec_wei_layer_memory
+            = memory(dec_ctx_prim_desc.weights_layer_primitive_desc());
+    auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
+            user_dec_wei_layer_memory.get_primitive_desc(),
+            dec_wei_layer_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(dec_wei_layer_reorder_pd,
+            user_dec_wei_layer_memory, dec_wei_layer_memory));
+
+    auto dec_wei_iter_memory
+            = memory(dec_ctx_prim_desc.weights_iter_primitive_desc());
+    auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
+            user_dec_wei_iter_memory.get_primitive_desc(),
+            dec_wei_iter_memory.get_primitive_desc());
+    weights_reorders.push_back(reorder(dec_wei_iter_reorder_pd,
+            user_dec_wei_iter_memory, dec_wei_iter_memory));
+
     decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
-            dec_dst_iter_memory, user_dec_wei_layer_memory,
-            user_dec_wei_iter_memory, user_dec_bias_memory,
+            dec_dst_iter_memory, dec_wei_layer_memory,
+            dec_wei_iter_memory, user_dec_bias_memory,
             user_dec_dst_layer_memory, dec_dst_iter_memory, null_memory_));
 
     // allocating temporary buffer for attention mechanism
@@ -476,10 +565,8 @@ void simple_net() {
        Execution
      */
     auto execute = [&]() {
-        // We save the original handle on dst_layer as we will modify it at each
-        // iteration
-        void *dst_layer_original_handle
-                = user_dec_dst_layer_memory.get_data_handle();
+        // reorder weights to MKLDNN internal representation
+        stream(stream::kind::eager).submit(weights_reorders).wait();
 
         // run encoder (1 stream)
         stream(stream::kind::eager).submit(encoder_net).wait();
@@ -490,43 +577,40 @@ void simple_net() {
                 user_weights_annotation.data(),
                 (float *)enc_dst_layer_memory.get_data_handle());
 
-        // We initialise dst_layer[0] to the embedding of </s>, which are
-        // assumed to
-        // be 0 here
-        memset(dst_layer_original_handle, 0,
-                batch * feature_size * sizeof(float));
+        // We initialise src_layer to the embedding of </s>, which
+        // are assumed to be 0 here
+        memset(dec_src_layer_memory.get_data_handle(), 0,
+               dec_src_layer_memory.get_primitive_desc().get_size());
+        // From now on, src points to the output of the last iteration
 
         for (int i = 0; i < tgt_seq_length_max; i++) {
-            float *dst_layer_handle
-                    = (float *)user_dec_dst_layer_memory.get_data_handle();
-            float *dst_iter_handle
-                    = (float *)dec_dst_iter_memory.get_data_handle();
+            float *src_att_layer_handle
+                    = (float *) dec_src_layer_memory.get_data_handle();
+            float *src_att_iter_handle
+                    = (float *) dec_dst_iter_memory.get_data_handle();
 
             // Compute attention context vector into the first layer src_iter
-            compute_attention(dst_iter_handle, src_seq_length_max, batch,
+            compute_attention(src_att_iter_handle, src_seq_length_max, batch,
                     feature_size, user_weights_attention_src_layer.data(),
-                    dst_layer_handle,
+                    src_att_layer_handle,
                     (float *)enc_bidir_dst_layer_memory.get_data_handle(),
                     weighted_annotations.data(),
                     user_weights_alignments.data());
 
             // copy the context vectors to all layers of src_iter
-            copy_context(dst_iter_handle, dec_n_layers, lstm_n_states, batch,
+            copy_context(src_att_iter_handle, dec_n_layers, lstm_n_states, batch,
                     feature_size);
 
-            // We set src_layer to be the previously
-            dec_src_layer_memory.set_data_handle(dst_layer_handle);
-
             // run the decoder iteration
             stream(stream::kind::eager).submit(decoder_net).wait();
 
-            // Move the handle on the dst layer to the next iteration
+            // Move the handle on the src/dst layer to the next iteration
+            auto dst_layer_handle = (float *) user_dec_dst_layer_memory.get_data_handle();
+            dec_src_layer_memory.set_data_handle(dst_layer_handle);
             user_dec_dst_layer_memory.set_data_handle(
                     dst_layer_handle + batch * feature_size);
         }
-        // we restore the handle to the begining of the buffer
-        user_dec_dst_layer_memory.set_data_handle(dst_layer_original_handle);
-        /// @todo run the softmax after each iteration or not?
+
     };
 
     execute();