Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_rnn.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include <iostream>
19 #include <math.h>
20 #include <numeric>
21 #include <string>
22
23 #include "mkl_cblas.h"
24
25 #include "mkldnn.hpp"
26
27 // MSVC doesn't support collapse clause in omp parallel
28 #if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
29 #define collapse(x)
30 #endif
31
32 using namespace mkldnn;
33
34 const int batch = 128;
35 const int src_seq_length_max = 28;
36 const int tgt_seq_length_max = 28;
37
38 const int feature_size = 1024;
39
40 const int enc_bidir_n_layers = 1;
41 const int enc_unidir_n_layers = 7;
42 const int dec_n_layers = 8;
43
44 const int lstm_n_gates = 4;
45 const int lstm_n_states = 2;
46 std::vector<float> weighted_src_layer(batch *feature_size, 1.0f);
47 std::vector<float> alignment_model(
48         src_seq_length_max *batch *feature_size, 1.0f);
49 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
50 std::vector<float> exp_sums(batch, 1.0f);
51
52 void compute_weighted_annotations(float *weighted_annotations,
53         int src_seq_length_max, int batch, int feature_size,
54         float *weights_annot, float *annotations) {
55     // annotations(aka enc_dst_layer) is (t, n, 2c)
56     // weights_annot is (2c, c)
57
58     // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
59     cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, feature_size,
60             src_seq_length_max * batch, feature_size, 1.0f, weights_annot,
61             feature_size, annotations, feature_size, 0.0f, weighted_annotations,
62             feature_size);
63 }
64
65 void compute_attention(float *context_vectors, int src_seq_length_max,
66         int batch, int feature_size, float *weights_src_layer,
67         float *dec_src_layer, float *annotations, float *weighted_annotations,
68         float *weights_alignments) {
69     // dst_iter : (n, c) matrix
70     // src_layer: (n, c) matrix
71     // weighted_annotations (t, n, c)
72
73     // weights_yi is (c, c)
74     // weights_ai is (c, 1)
75     // tmp[i] is (n, c)
76     // a[i] is (n, 1)
77     // p is (n, 1)
78
79     // first we precompute the weighted_dec_src_layer
80     cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, feature_size, batch,
81             feature_size, 1.0f, weights_src_layer, feature_size, dec_src_layer,
82             feature_size, 0.0f, weighted_src_layer.data(), feature_size);
83
84     // then we compute the alignment model
85     float *alignment_model_ptr = alignment_model.data();
86 #pragma omp parallel for collapse(2)
87     for (int i = 0; i < src_seq_length_max; i++) {
88         for (int j = 0; j < batch * feature_size; j++)
89             alignment_model_ptr[i * batch * feature_size + j] = tanhf(
90                     weighted_src_layer.data()[j]
91                     + weighted_annotations[i * batch * feature_size + j]);
92     }
93
94     // gemv with alignments weights. the resulting alignments are in alignments
95     cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 1,
96             src_seq_length_max * batch, feature_size, 1.0f, weights_alignments,
97             1, alignment_model_ptr, feature_size, 0.0f, alignments.data(), 1);
98
99 // softmax on alignments. the resulting context weights are in alignments
100 #pragma omp parallel for
101     for (int i = 0; i < batch; i++)
102         exp_sums[i] = 0.0f;
103 #pragma omp parallel for collapse(2)
104     for (int i = 0; i < src_seq_length_max; i++) {
105         for (int j = 0; j < batch; j++) {
106             alignments[i * batch + j] = expf(alignments[i * batch + j]);
107             exp_sums[j] += alignments[i * batch + j];
108         }
109     }
110
111 #pragma omp parallel for collapse(2)
112     for (int i = 0; i < src_seq_length_max; i++)
113         for (int j = 0; j < batch; j++)
114             alignments[i * batch + j] /= exp_sums[j];
115
116 // then we compute the context vectors
117 #pragma omp parallel for collapse(2)
118     for (int i = 0; i < batch; i++)
119         for (int j = 0; j < feature_size; j++)
120             context_vectors[i * (feature_size + feature_size) + feature_size
121                     + j]
122                     = 0.0f;
123
124 #pragma omp parallel for collapse(3)
125     for (int i = 0; i < batch; i++)
126         for (int k = 0; k < src_seq_length_max; k++)
127             for (int j = 0; j < feature_size; j++)
128                 context_vectors[i * (feature_size + feature_size) + feature_size
129                         + j]
130                         += alignments[k * batch + i]
131                         * annotations[j + feature_size * (i + batch * k)];
132 }
133
134 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
135         int feature_size) {
136 // we copy the context from the first layer to all other layers
137 #pragma omp parallel for collapse(3)
138     for (int k = 1; k < n_layers; k++)
139         for (int j = 0; j < batch; j++)
140             for (int i = 0; i < feature_size; i++)
141                 src_iter[(k * n_states * batch + j)
142                                 * (feature_size + feature_size)
143                         + i]
144                         = src_iter[j * (feature_size + feature_size) + i];
145 }
146
147 void simple_net() {
148     auto cpu_engine = engine(engine::cpu, 0);
149     auto null_memory_ = null_memory(cpu_engine);
150
151     /*
152       GNMT Example.
153       Note, we do not implement connection yet.
154       For the encoder we use:
155       - one primitive for the bidirectional layer of the encoder
156       - one primitive for all remaining unidirectional layers in the encoder
157       For the decoder we use:
158       - one primitive for the first iteration
159       - one primitive for all subsequent iterations in the decoder. Note that
160         in this example, this primitive computes the states in place.
161       - the attention mechanism is implemented separately as there is no support
162         for the context vectors in MKL-DNN yet
163      */
164
165     std::vector<primitive> encoder_net;
166     std::vector<primitive> decoder_net;
167
168     std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
169     std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);
170
171     /* Encoder */
172
173     memory::dims enc_bidir_src_layer_tz
174             = { src_seq_length_max, batch, feature_size };
175     memory::dims enc_bidir_weights_layer_tz = { enc_bidir_n_layers, 2,
176         feature_size, lstm_n_gates, feature_size };
177     memory::dims enc_bidir_weights_iter_tz = { enc_bidir_n_layers, 2,
178         feature_size, lstm_n_gates, feature_size };
179     memory::dims enc_bidir_bias_tz
180             = { enc_bidir_n_layers, 2, lstm_n_gates, feature_size };
181     memory::dims enc_bidir_dst_layer_tz
182             = { src_seq_length_max, batch, 2 * feature_size };
183
184     /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers
185      */
186
187     std::vector<float> user_enc_bidir_wei_layer(
188             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
189             1.0f);
190     std::vector<float> user_enc_bidir_wei_iter(
191             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
192             1.0f);
193     std::vector<float> user_enc_bidir_bias(
194             enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
195
196     // We create the memory descriptors used by the user
197     auto user_enc_bidir_src_layer_md = mkldnn::memory::desc(
198             { enc_bidir_src_layer_tz }, mkldnn::memory::data_type::f32,
199             mkldnn::memory::format::tnc);
200
201     auto user_enc_bidir_wei_layer_md = mkldnn::memory::desc(
202             { enc_bidir_weights_layer_tz }, mkldnn::memory::data_type::f32,
203             mkldnn::memory::format::ldigo);
204
205     auto user_enc_bidir_wei_iter_md = mkldnn::memory::desc(
206             { enc_bidir_weights_iter_tz }, mkldnn::memory::data_type::f32,
207             mkldnn::memory::format::ldigo);
208
209     auto user_enc_bidir_bias_md = mkldnn::memory::desc({ enc_bidir_bias_tz },
210             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
211
212     auto enc_bidir_dst_layer_md = mkldnn::memory::desc(
213             { enc_bidir_dst_layer_tz }, mkldnn::memory::data_type::f32,
214             mkldnn::memory::format::tnc);
215
216     /* We create memories */
217     auto user_enc_bidir_src_layer_memory = mkldnn::memory(
218             { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
219     auto user_enc_bidir_wei_layer_memory
220             = mkldnn::memory({ user_enc_bidir_wei_layer_md, cpu_engine },
221                     user_enc_bidir_wei_layer.data());
222     auto user_enc_bidir_wei_iter_memory
223             = mkldnn::memory({ user_enc_bidir_wei_iter_md, cpu_engine },
224                     user_enc_bidir_wei_iter.data());
225     auto user_enc_bidir_bias_memory = mkldnn::memory(
226             { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
227
228 #if 0
229     /// These will be null memories
230     /// @todo introduce predefined null_memory() ?
231     auto enc_bidir_src_iter_memory = mkldnn::memory({enc_bidir_src_iter_md, cpu_engine});
232     auto enc_bidir_dst_iter_memory = mkldnn::memory({enc_bidir_dst_iter_md, cpu_engine});
233 #endif
234
235     /// @todo fix this once cell desc is merged with rnn_desc
236     rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
237     rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
238             rnn_direction::bidirectional_concat, user_enc_bidir_src_layer_md,
239             zero_md(), user_enc_bidir_wei_layer_md, user_enc_bidir_wei_iter_md,
240             user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
241
242     auto enc_bidir_prim_desc
243             = mkldnn::rnn_forward::primitive_desc(bi_layer_desc, cpu_engine);
244
245     // there are currently no reorders
246     /// @todo add a reorder when they will be available
247
248     auto enc_bidir_dst_layer_memory
249             = mkldnn::memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
250
251     encoder_net.push_back(
252             rnn_forward(enc_bidir_prim_desc, user_enc_bidir_src_layer_memory,
253                     null_memory_, user_enc_bidir_wei_layer_memory,
254                     user_enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
255                     enc_bidir_dst_layer_memory, null_memory_, null_memory_));
256
257     /* GNMT encoder: unidirectional layers
258      */
259     // First unidirectinal layer, the scaling from 2*feature size features
260     // comming from the previous layer come
261     /// memories
262     std::vector<float> user_enc_uni_first_wei_layer(
263             1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
264     std::vector<float> user_enc_uni_first_wei_iter(
265             1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
266     std::vector<float> user_enc_uni_first_bias(
267             1 * 1 * lstm_n_gates * feature_size, 1.0f);
268     memory::dims user_enc_uni_first_wei_layer_dims
269             = { 1, 1, 2 * feature_size, lstm_n_gates, feature_size };
270     memory::dims user_enc_uni_first_wei_iter_dims
271             = { 1, 1, feature_size, lstm_n_gates, feature_size };
272     memory::dims user_enc_uni_first_bias_dims
273             = { 1, 1, lstm_n_gates, feature_size };
274     memory::dims enc_uni_first_dst_layer_dims
275             = { src_seq_length_max, batch, feature_size };
276     auto user_enc_uni_first_wei_layer_md = mkldnn::memory::desc(
277             { user_enc_uni_first_wei_layer_dims },
278             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
279     auto user_enc_uni_first_wei_iter_md = mkldnn::memory::desc(
280             { user_enc_uni_first_wei_iter_dims },
281             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
282     auto user_enc_uni_first_bias_md = mkldnn::memory::desc(
283             { user_enc_uni_first_bias_dims }, mkldnn::memory::data_type::f32,
284             mkldnn::memory::format::ldgo);
285     auto enc_uni_first_dst_layer_md = mkldnn::memory::desc(
286             { enc_uni_first_dst_layer_dims }, mkldnn::memory::data_type::f32,
287             mkldnn::memory::format::tnc);
288     auto user_enc_uni_first_wei_layer_memory
289             = mkldnn::memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
290                     user_enc_uni_first_wei_layer.data());
291     ;
292     auto user_enc_uni_first_wei_iter_memory
293             = mkldnn::memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
294                     user_enc_uni_first_wei_iter.data());
295     auto user_enc_uni_first_bias_memory
296             = mkldnn::memory({ user_enc_uni_first_bias_md, cpu_engine },
297                     user_enc_uni_first_bias.data());
298
299     /// @todo add suport for residual connections
300     /// should it be a set residual in op_desc or a field to set manually?
301     /// should be an integer to specify at which layer to start
302     rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
303     rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
304             enc_uni_first_cell, rnn_direction::unidirectional_left2right,
305             enc_bidir_dst_layer_md, zero_md(), user_enc_uni_first_wei_layer_md,
306             user_enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
307             enc_uni_first_dst_layer_md, zero_md());
308     auto enc_uni_first_prim_desc = mkldnn::rnn_forward::primitive_desc(
309             enc_uni_first_layer_desc, cpu_engine);
310     auto enc_uni_first_dst_layer_memory = mkldnn::memory(
311             enc_uni_first_prim_desc.dst_layer_primitive_desc());
312
313     /// @todo add a reorder when they will be available
314     encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
315             enc_bidir_dst_layer_memory, null_memory_,
316             user_enc_uni_first_wei_layer_memory,
317             user_enc_uni_first_wei_iter_memory, user_enc_uni_first_bias_memory,
318             enc_uni_first_dst_layer_memory, null_memory_, null_memory_));
319
320     // Remainging Unidirectional layers
321     /// memories
322     std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
323                     * feature_size * lstm_n_gates * feature_size, 1.0f);
324     std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
325                     * feature_size * lstm_n_gates * feature_size, 1.0f);
326     std::vector<float> user_enc_uni_bias(
327             (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
328     memory::dims user_enc_uni_wei_layer_dims = { (enc_unidir_n_layers - 1), 1,
329         feature_size, lstm_n_gates, feature_size };
330     memory::dims user_enc_uni_wei_iter_dims = { (enc_unidir_n_layers - 1), 1,
331         feature_size, lstm_n_gates, feature_size };
332     memory::dims user_enc_uni_bias_dims
333             = { (enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size };
334     memory::dims enc_dst_layer_dims
335             = { src_seq_length_max, batch, feature_size };
336     auto user_enc_uni_wei_layer_md = mkldnn::memory::desc(
337             { user_enc_uni_wei_layer_dims }, mkldnn::memory::data_type::f32,
338             mkldnn::memory::format::ldigo);
339     auto user_enc_uni_wei_iter_md = mkldnn::memory::desc(
340             { user_enc_uni_wei_iter_dims }, mkldnn::memory::data_type::f32,
341             mkldnn::memory::format::ldigo);
342     auto user_enc_uni_bias_md = mkldnn::memory::desc({ user_enc_uni_bias_dims },
343             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
344     auto enc_dst_layer_md = mkldnn::memory::desc({ enc_dst_layer_dims },
345             mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
346     auto user_enc_uni_wei_layer_memory
347             = mkldnn::memory({ user_enc_uni_wei_layer_md, cpu_engine },
348                     user_enc_uni_wei_layer.data());
349     ;
350     auto user_enc_uni_wei_iter_memory
351             = mkldnn::memory({ user_enc_uni_wei_iter_md, cpu_engine },
352                     user_enc_uni_wei_iter.data());
353     auto user_enc_uni_bias_memory = mkldnn::memory(
354             { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
355
356     /// @todo add suport for residual connections
357     /// should it be a set residual in op_desc or a field to set manually?
358     /// should be an integer to specify at which layer to start
359     rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
360     rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
361             enc_uni_cell, rnn_direction::unidirectional_left2right,
362             enc_uni_first_dst_layer_md, zero_md(), user_enc_uni_wei_layer_md,
363             user_enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
364             zero_md());
365     auto enc_uni_prim_desc = mkldnn::rnn_forward::primitive_desc(
366             enc_uni_layer_desc, cpu_engine);
367     auto enc_dst_layer_memory
368             = mkldnn::memory(enc_uni_prim_desc.dst_layer_primitive_desc());
369
370     /// @todo add a reorder when they will be available
371     encoder_net.push_back(
372             rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
373                     null_memory_, user_enc_uni_wei_layer_memory,
374                     user_enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
375                     enc_dst_layer_memory, null_memory_, null_memory_));
376
377     /*
378      * GNMT: decoder with attention mechanism
379      */
380     // user provided memories
381     std::vector<float> user_dec_wei_layer(
382             dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
383             1.0f);
384     std::vector<float> user_dec_wei_iter(dec_n_layers * 1
385                     * (feature_size + feature_size) * lstm_n_gates
386                     * feature_size, 1.0f);
387     std::vector<float> user_dec_bias(
388             dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
389     std::vector<float> user_dec_dst(
390             tgt_seq_length_max * batch * feature_size, 1.0f);
391     std::vector<float> user_weights_attention_src_layer(
392             feature_size * feature_size, 1.0f);
393     std::vector<float> user_weights_annotation(
394             feature_size * feature_size, 1.0f);
395     std::vector<float> user_weights_alignments(feature_size, 1.0f);
396
397     memory::dims user_dec_wei_layer_dims
398             = { dec_n_layers, 1, feature_size, lstm_n_gates, feature_size };
399     memory::dims user_dec_wei_iter_dims = { dec_n_layers, 1,
400         feature_size + feature_size, lstm_n_gates, feature_size };
401     memory::dims user_dec_bias_dims
402             = { dec_n_layers, 1, lstm_n_gates, feature_size };
403
404     memory::dims dec_src_layer_dims = { 1, batch, feature_size };
405     memory::dims dec_dst_layer_dims
406             = { tgt_seq_length_max, batch, feature_size };
407
408     // We will use the same memory for dec_src_iter and dec_dst_iter
409     // However, dec_src_iter has a context vector but not
410     // dec_dst_iter.
411     // To resolve this we will create one memory that holds the
412     // context vector as well as the both the hidden and cell states.
413     // For the dst_iter, we will use a view on this memory.
414     // Note that the cell state will be padded by
415     // feature_size values. However, we do not compute or
416     // access those.
417     memory::dims dec_dst_iter_dims = { dec_n_layers, 1, lstm_n_states, batch,
418         feature_size + feature_size };
419     memory::dims dec_dst_iter_noctx_dims
420             = { dec_n_layers, 1, lstm_n_states, batch, feature_size };
421
422     auto user_dec_wei_layer_md = mkldnn::memory::desc(
423             { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
424             mkldnn::memory::format::ldigo);
425     auto user_dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
426             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
427     auto user_dec_bias_md = mkldnn::memory::desc({ user_dec_bias_dims },
428             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
429     auto dec_dst_layer_md = mkldnn::memory::desc({ dec_dst_layer_dims },
430             mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
431     auto dec_src_layer_md = mkldnn::memory::desc({ dec_src_layer_dims },
432             mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
433     auto dec_dst_iter_md = mkldnn::memory::desc({ dec_dst_iter_dims },
434             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
435     auto user_dec_wei_layer_memory = mkldnn::memory(
436             { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
437     ;
438     auto user_dec_wei_iter_memory = mkldnn::memory(
439             { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
440     auto user_dec_bias_memory = mkldnn::memory(
441             { user_dec_bias_md, cpu_engine }, user_dec_bias.data());
442     auto user_dec_dst_layer_memory = mkldnn::memory(
443             { dec_dst_layer_md, cpu_engine }, user_dec_dst.data());
444     auto dec_src_layer_memory
445             = mkldnn::memory({ dec_src_layer_md, cpu_engine });
446
447     // As mentioned above, we create a view without context out of the
448     // memory with context.
449     auto dec_dst_iter_memory = mkldnn::memory({ dec_dst_iter_md, cpu_engine });
450     auto dec_dst_iter_noctx_md = mkldnn::view::primitive_desc(
451             dec_dst_iter_memory.get_primitive_desc(), dec_dst_iter_noctx_dims,
452             { 0, 0, 0, 0, 0 }).dst_primitive_desc().desc();
453
454     /// @todo add suport for residual connections
455     /// should it be a set residual in op_desc or a field to set manually?
456     /// should be an integer to specify at which layer to start
457     rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
458     rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
459             rnn_direction::unidirectional_left2right, dec_src_layer_md,
460             dec_dst_iter_md, user_dec_wei_layer_md, user_dec_wei_iter_md,
461             user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
462     auto dec_ctx_prim_desc
463             = mkldnn::rnn_forward::primitive_desc(dec_ctx_desc, cpu_engine);
464
465     /// @todo add a reorder when they will be available
466     decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
467             dec_dst_iter_memory, user_dec_wei_layer_memory,
468             user_dec_wei_iter_memory, user_dec_bias_memory,
469             user_dec_dst_layer_memory, dec_dst_iter_memory, null_memory_));
470
471     // allocating temporary buffer for attention mechanism
472     std::vector<float> weighted_annotations(
473             src_seq_length_max * batch * feature_size, 1.0f);
474
475     /*
476        Execution
477      */
478     auto execute = [&]() {
479         // We save the original handle on dst_layer as we will modify it at each
480         // iteration
481         void *dst_layer_original_handle
482                 = user_dec_dst_layer_memory.get_data_handle();
483
484         // run encoder (1 stream)
485         stream(stream::kind::eager).submit(encoder_net).wait();
486
487         // we compute the weighted annotations once before the decoder
488         compute_weighted_annotations(weighted_annotations.data(),
489                 src_seq_length_max, batch, feature_size,
490                 user_weights_annotation.data(),
491                 (float *)enc_dst_layer_memory.get_data_handle());
492
493         // We initialise dst_layer[0] to the embedding of </s>, which are
494         // assumed to
495         // be 0 here
496         memset(dst_layer_original_handle, 0,
497                 batch * feature_size * sizeof(float));
498
499         for (int i = 0; i < tgt_seq_length_max; i++) {
500             float *dst_layer_handle
501                     = (float *)user_dec_dst_layer_memory.get_data_handle();
502             float *dst_iter_handle
503                     = (float *)dec_dst_iter_memory.get_data_handle();
504
505             // Compute attention context vector into the first layer src_iter
506             compute_attention(dst_iter_handle, src_seq_length_max, batch,
507                     feature_size, user_weights_attention_src_layer.data(),
508                     dst_layer_handle,
509                     (float *)enc_bidir_dst_layer_memory.get_data_handle(),
510                     weighted_annotations.data(),
511                     user_weights_alignments.data());
512
513             // copy the context vectors to all layers of src_iter
514             copy_context(dst_iter_handle, dec_n_layers, lstm_n_states, batch,
515                     feature_size);
516
517             // We set src_layer to be the previously
518             dec_src_layer_memory.set_data_handle(dst_layer_handle);
519
520             // run the decoder iteration
521             stream(stream::kind::eager).submit(decoder_net).wait();
522
523             // Move the handle on the dst layer to the next iteration
524             user_dec_dst_layer_memory.set_data_handle(
525                     dst_layer_handle + batch * feature_size);
526         }
527         // we restore the handle to the begining of the buffer
528         user_dec_dst_layer_memory.set_data_handle(dst_layer_original_handle);
529         /// @todo run the softmax after each iteration or not?
530     };
531
532     execute();
533 }
534
535 int main(int argc, char **argv) {
536     try {
537         simple_net();
538         std::cout << "ok\n";
539     } catch (error &e) {
540         std::cerr << "status: " << e.status << std::endl;
541         std::cerr << "message: " << e.message << std::endl;
542         return 1;
543     }
544     return 0;
545 }