Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_rnn.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include <iostream>
19 #include <math.h>
20 #include <numeric>
21 #include <string>
22
23 #include "mkldnn.hpp"
24
25 // MSVC doesn't support collapse clause in omp parallel
26 #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
27 #define collapse(x)
28 #endif
29
30 using namespace mkldnn;
31
32 const int batch = 128;
33 const int src_seq_length_max = 28;
34 const int tgt_seq_length_max = 28;
35
36 const int feature_size = 1024;
37
38 const int enc_bidir_n_layers = 1;
39 const int enc_unidir_n_layers = 7;
40 const int dec_n_layers = 8;
41
42 const int lstm_n_gates = 4;
43 const int lstm_n_states = 2;
44 std::vector<float> weighted_src_layer(batch *feature_size, 1.0f);
45 std::vector<float> alignment_model(
46         src_seq_length_max *batch *feature_size, 1.0f);
47 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
48 std::vector<float> exp_sums(batch, 1.0f);
49
50 const float onef = 1.0, zerof = 0.0;
51 const int onei = 1;
52
53 void compute_weighted_annotations(float *weighted_annotations,
54         int src_seq_length_max, int batch, int feature_size,
55         float *weights_annot, float *annotations) {
56     // annotations(aka enc_dst_layer) is (t, n, 2c)
57     // weights_annot is (2c, c)
58
59     // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
60     int num_weighted_annotations = src_seq_length_max * batch;
61     mkldnn_sgemm("N", "N",
62             &feature_size, &num_weighted_annotations, &feature_size,
63             &onef, weights_annot, &feature_size, annotations, &feature_size,
64             &zerof, weighted_annotations, &feature_size);
65 }
66
67 void compute_attention(float *context_vectors, int src_seq_length_max,
68         int batch, int feature_size, float *weights_src_layer,
69         float *dec_src_layer, float *annotations, float *weighted_annotations,
70         float *weights_alignments) {
71     // dst_iter : (n, c) matrix
72     // src_layer: (n, c) matrix
73     // weighted_annotations (t, n, c)
74
75     // weights_yi is (c, c)
76     // weights_ai is (c, 1)
77     // tmp[i] is (n, c)
78     // a[i] is (n, 1)
79     // p is (n, 1)
80
81     // first we precompute the weighted_dec_src_layer
82     mkldnn_sgemm("N", "N",
83             &feature_size, &batch, &feature_size, &onef,
84             weights_src_layer, &feature_size, dec_src_layer, &feature_size,
85             &zerof, weighted_src_layer.data(), &feature_size);
86
87     // then we compute the alignment model
88     float *alignment_model_ptr = alignment_model.data();
89 #ifdef _OPENMP
90 #pragma omp parallel for collapse(2)
91 #endif
92     for (int i = 0; i < src_seq_length_max; i++) {
93         for (int j = 0; j < batch * feature_size; j++)
94             alignment_model_ptr[i * batch * feature_size + j] = tanhf(
95                     weighted_src_layer.data()[j]
96                     + weighted_annotations[i * batch * feature_size + j]);
97     }
98
99     // gemv with alignments weights. the resulting alignments are in alignments
100     int num_weighted_annotations = src_seq_length_max * batch;
101     mkldnn_sgemm("N", "N",
102             &onei, &num_weighted_annotations, &feature_size, &onef,
103             weights_alignments, &onei, alignment_model_ptr, &feature_size,
104             &zerof, alignments.data(), &onei);
105
106     // softmax on alignments. the resulting context weights are in alignments
107 #ifdef _OPENMP
108 #pragma omp parallel for
109 #endif
110     for (int i = 0; i < batch; i++)
111         exp_sums[i] = 0.0f;
112 #ifdef _OPENMP
113 #pragma omp parallel for collapse(2)
114 #endif
115     for (int i = 0; i < src_seq_length_max; i++) {
116         for (int j = 0; j < batch; j++) {
117             alignments[i * batch + j] = expf(alignments[i * batch + j]);
118             exp_sums[j] += alignments[i * batch + j];
119         }
120     }
121
122 #ifdef _OPENMP
123 #pragma omp parallel for collapse(2)
124 #endif
125     for (int i = 0; i < src_seq_length_max; i++)
126         for (int j = 0; j < batch; j++)
127             alignments[i * batch + j] /= exp_sums[j];
128
129     // then we compute the context vectors
130 #ifdef _OPENMP
131 #pragma omp parallel for collapse(2)
132 #endif
133     for (int i = 0; i < batch; i++)
134         for (int j = 0; j < feature_size; j++)
135             context_vectors[i * (feature_size + feature_size) + feature_size
136                     + j]
137                     = 0.0f;
138
139 #ifdef _OPENMP
140 #pragma omp parallel for collapse(3)
141 #endif
142     for (int i = 0; i < batch; i++)
143         for (int k = 0; k < src_seq_length_max; k++)
144             for (int j = 0; j < feature_size; j++)
145                 context_vectors[i * (feature_size + feature_size) + feature_size
146                         + j]
147                         += alignments[k * batch + i]
148                         * annotations[j + feature_size * (i + batch * k)];
149 }
150
151 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
152         int feature_size) {
153     // we copy the context from the first layer to all other layers
154 #ifdef _OPENMP
155 #pragma omp parallel for collapse(3)
156 #endif
157     for (int k = 1; k < n_layers; k++)
158         for (int j = 0; j < batch; j++)
159             for (int i = 0; i < feature_size; i++)
160                 src_iter[(k * n_states * batch + j)
161                                 * (feature_size + feature_size)
162                         + i]
163                         = src_iter[j * (feature_size + feature_size) + i];
164 }
165
166 void simple_net() {
167     auto cpu_engine = engine(engine::cpu, 0);
168     auto null_memory_ = null_memory(cpu_engine);
169
170     /*
171       GNMT Example.
172       Note, we do not implement connection yet.
173       For the encoder we use:
174       - one primitive for the bidirectional layer of the encoder
175       - one primitive for all remaining unidirectional layers in the encoder
176       For the decoder we use:
177       - one primitive for the first iteration
178       - one primitive for all subsequent iterations in the decoder. Note that
179         in this example, this primitive computes the states in place.
180       - the attention mechanism is implemented separately as there is no support
181         for the context vectors in MKL-DNN yet
182      */
183
184     std::vector<primitive> weights_reorders;
185     std::vector<primitive> encoder_net;
186     std::vector<primitive> decoder_net;
187
188     std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
189     std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);
190
191     /* Encoder */
192
193     memory::dims enc_bidir_src_layer_tz
194             = { src_seq_length_max, batch, feature_size };
195     memory::dims enc_bidir_weights_layer_tz = { enc_bidir_n_layers, 2,
196         feature_size, lstm_n_gates, feature_size };
197     memory::dims enc_bidir_weights_iter_tz = { enc_bidir_n_layers, 2,
198         feature_size, lstm_n_gates, feature_size };
199     memory::dims enc_bidir_bias_tz
200             = { enc_bidir_n_layers, 2, lstm_n_gates, feature_size };
201     memory::dims enc_bidir_dst_layer_tz
202             = { src_seq_length_max, batch, 2 * feature_size };
203
204     /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers */
205
206     std::vector<float> user_enc_bidir_wei_layer(
207             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
208             1.0f);
209     std::vector<float> user_enc_bidir_wei_iter(
210             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
211             1.0f);
212     std::vector<float> user_enc_bidir_bias(
213             enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
214
215     /* Create the memory for user data */
216     auto user_enc_bidir_src_layer_md = mkldnn::memory::desc(
217             { enc_bidir_src_layer_tz }, mkldnn::memory::data_type::f32,
218             mkldnn::memory::format::tnc);
219
220     auto user_enc_bidir_wei_layer_md = mkldnn::memory::desc(
221             { enc_bidir_weights_layer_tz }, mkldnn::memory::data_type::f32,
222             mkldnn::memory::format::ldigo);
223
224     auto user_enc_bidir_wei_iter_md = mkldnn::memory::desc(
225             { enc_bidir_weights_iter_tz }, mkldnn::memory::data_type::f32,
226             mkldnn::memory::format::ldigo);
227
228     auto user_enc_bidir_bias_md = mkldnn::memory::desc({ enc_bidir_bias_tz },
229             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
230
231     auto user_enc_bidir_src_layer_memory = mkldnn::memory(
232             { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
233     auto user_enc_bidir_wei_layer_memory
234             = mkldnn::memory({ user_enc_bidir_wei_layer_md, cpu_engine },
235                     user_enc_bidir_wei_layer.data());
236     auto user_enc_bidir_wei_iter_memory
237             = mkldnn::memory({ user_enc_bidir_wei_iter_md, cpu_engine },
238                     user_enc_bidir_wei_iter.data());
239     auto user_enc_bidir_bias_memory = mkldnn::memory(
240             { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
241
242     /* Create memory descriptors for RNN data w/o specified layout */
243     auto enc_bidir_wei_layer_md = memory::desc({ enc_bidir_weights_layer_tz },
244             memory::data_type::f32, memory::format::any);
245
246     auto enc_bidir_wei_iter_md = memory::desc({ enc_bidir_weights_iter_tz },
247             memory::data_type::f32, memory::format::any);
248
249     auto enc_bidir_dst_layer_md = memory::desc({ enc_bidir_dst_layer_tz },
250             memory::data_type::f32, memory::format::any);
251
252     /* Create bidirectional RNN */
253     rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
254     rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
255             rnn_direction::bidirectional_concat, user_enc_bidir_src_layer_md,
256             zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
257             user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
258
259     auto enc_bidir_prim_desc
260             = mkldnn::rnn_forward::primitive_desc(bi_layer_desc, cpu_engine);
261
262     /* Create memory primitives for input data and use reorders to reorder
263      * user data to internal representation
264      */
265     auto enc_bidir_wei_layer_memory
266             = memory(enc_bidir_prim_desc.weights_layer_primitive_desc());
267     auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
268             user_enc_bidir_wei_layer_memory.get_primitive_desc(),
269             enc_bidir_wei_layer_memory.get_primitive_desc());
270     weights_reorders.push_back(reorder(enc_bidir_wei_layer_reorder_pd,
271             user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory));
272
273     auto enc_bidir_wei_iter_memory
274             = memory(enc_bidir_prim_desc.weights_iter_primitive_desc());
275     auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
276             user_enc_bidir_wei_iter_memory.get_primitive_desc(),
277             enc_bidir_wei_iter_memory.get_primitive_desc());
278     weights_reorders.push_back(reorder(enc_bidir_wei_iter_reorder_pd,
279             user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory));
280
281     auto enc_bidir_dst_layer_memory
282             = mkldnn::memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
283
284     encoder_net.push_back(
285             rnn_forward(enc_bidir_prim_desc, user_enc_bidir_src_layer_memory,
286                     null_memory_, enc_bidir_wei_layer_memory,
287                     enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
288                     enc_bidir_dst_layer_memory, null_memory_, null_memory_));
289
290     /* GNMT encoder: unidirectional layers */
291     // First unidirectinal layer scales 2 * feature_size output of bidirectional
292     // layer to feature_size output
293     std::vector<float> user_enc_uni_first_wei_layer(
294             1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
295     std::vector<float> user_enc_uni_first_wei_iter(
296             1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
297     std::vector<float> user_enc_uni_first_bias(
298             1 * 1 * lstm_n_gates * feature_size, 1.0f);
299     memory::dims user_enc_uni_first_wei_layer_dims
300             = { 1, 1, 2 * feature_size, lstm_n_gates, feature_size };
301     memory::dims user_enc_uni_first_wei_iter_dims
302             = { 1, 1, feature_size, lstm_n_gates, feature_size };
303     memory::dims user_enc_uni_first_bias_dims
304             = { 1, 1, lstm_n_gates, feature_size };
305     memory::dims enc_uni_first_dst_layer_dims
306             = { src_seq_length_max, batch, feature_size };
307     auto user_enc_uni_first_wei_layer_md = mkldnn::memory::desc(
308             { user_enc_uni_first_wei_layer_dims },
309             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
310     auto user_enc_uni_first_wei_iter_md = mkldnn::memory::desc(
311             { user_enc_uni_first_wei_iter_dims },
312             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
313     auto user_enc_uni_first_bias_md = mkldnn::memory::desc(
314             { user_enc_uni_first_bias_dims }, mkldnn::memory::data_type::f32,
315             mkldnn::memory::format::ldgo);
316     auto user_enc_uni_first_wei_layer_memory
317             = mkldnn::memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
318                     user_enc_uni_first_wei_layer.data());
319     auto user_enc_uni_first_wei_iter_memory
320             = mkldnn::memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
321                     user_enc_uni_first_wei_iter.data());
322     auto user_enc_uni_first_bias_memory
323             = mkldnn::memory({ user_enc_uni_first_bias_md, cpu_engine },
324                     user_enc_uni_first_bias.data());
325
326     auto enc_uni_first_wei_layer_md
327             = memory::desc({ user_enc_uni_first_wei_layer_dims },
328                     memory::data_type::f32, memory::format::any);
329     auto enc_uni_first_wei_iter_md
330             = memory::desc({ user_enc_uni_first_wei_iter_dims },
331                     memory::data_type::f32, memory::format::any);
332     auto enc_uni_first_dst_layer_md
333             = memory::desc({ enc_uni_first_dst_layer_dims },
334                     memory::data_type::f32, memory::format::any);
335
336     /// @todo add suport for residual connections
337     /// should it be a set residual in op_desc or a field to set manually?
338     /// should be an integer to specify at which layer to start
339     rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
340     rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
341             enc_uni_first_cell, rnn_direction::unidirectional_left2right,
342             enc_bidir_dst_layer_md, zero_md(), enc_uni_first_wei_layer_md,
343             enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
344             enc_uni_first_dst_layer_md, zero_md());
345     auto enc_uni_first_prim_desc = mkldnn::rnn_forward::primitive_desc(
346             enc_uni_first_layer_desc, cpu_engine);
347
348     auto enc_uni_first_wei_layer_memory
349             = memory(enc_uni_first_prim_desc.weights_layer_primitive_desc());
350     auto enc_uni_first_wei_layer_reorder_pd = reorder::primitive_desc(
351             user_enc_uni_first_wei_layer_memory.get_primitive_desc(),
352             enc_uni_first_wei_layer_memory.get_primitive_desc());
353     weights_reorders.push_back(reorder(enc_uni_first_wei_layer_reorder_pd,
354             user_enc_uni_first_wei_layer_memory,
355             enc_uni_first_wei_layer_memory));
356
357     auto enc_uni_first_wei_iter_memory
358             = memory(enc_uni_first_prim_desc.weights_iter_primitive_desc());
359     auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
360             user_enc_uni_first_wei_iter_memory.get_primitive_desc(),
361             enc_uni_first_wei_iter_memory.get_primitive_desc());
362     weights_reorders.push_back(reorder(enc_uni_first_wei_iter_reorder_pd,
363             user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory));
364
365     auto enc_uni_first_dst_layer_memory = mkldnn::memory(
366             enc_uni_first_prim_desc.dst_layer_primitive_desc());
367
368     encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
369             enc_bidir_dst_layer_memory, null_memory_,
370             enc_uni_first_wei_layer_memory,
371             enc_uni_first_wei_iter_memory, user_enc_uni_first_bias_memory,
372             enc_uni_first_dst_layer_memory, null_memory_, null_memory_));
373
374     /* Remainging unidirectional layers */
375     std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
376                     * feature_size * lstm_n_gates * feature_size, 1.0f);
377     std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
378                     * feature_size * lstm_n_gates * feature_size, 1.0f);
379     std::vector<float> user_enc_uni_bias(
380             (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
381     memory::dims user_enc_uni_wei_layer_dims = { (enc_unidir_n_layers - 1), 1,
382         feature_size, lstm_n_gates, feature_size };
383     memory::dims user_enc_uni_wei_iter_dims = { (enc_unidir_n_layers - 1), 1,
384         feature_size, lstm_n_gates, feature_size };
385     memory::dims user_enc_uni_bias_dims
386             = { (enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size };
387     memory::dims enc_dst_layer_dims
388             = { src_seq_length_max, batch, feature_size };
389     auto user_enc_uni_wei_layer_md = mkldnn::memory::desc(
390             { user_enc_uni_wei_layer_dims }, mkldnn::memory::data_type::f32,
391             mkldnn::memory::format::ldigo);
392     auto user_enc_uni_wei_iter_md = mkldnn::memory::desc(
393             { user_enc_uni_wei_iter_dims }, mkldnn::memory::data_type::f32,
394             mkldnn::memory::format::ldigo);
395     auto user_enc_uni_bias_md = mkldnn::memory::desc({ user_enc_uni_bias_dims },
396             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
397     auto user_enc_uni_wei_layer_memory
398             = mkldnn::memory({ user_enc_uni_wei_layer_md, cpu_engine },
399                     user_enc_uni_wei_layer.data());
400     auto user_enc_uni_wei_iter_memory
401             = mkldnn::memory({ user_enc_uni_wei_iter_md, cpu_engine },
402                     user_enc_uni_wei_iter.data());
403     auto user_enc_uni_bias_memory = mkldnn::memory(
404             { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
405
406     auto enc_uni_wei_layer_md = memory::desc({ user_enc_uni_wei_layer_dims },
407             memory::data_type::f32, memory::format::any);
408     auto enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
409             memory::data_type::f32, memory::format::any);
410     auto enc_dst_layer_md = memory::desc({ enc_dst_layer_dims },
411             memory::data_type::f32, memory::format::any);
412
413     /// @todo add suport for residual connections
414     /// should it be a set residual in op_desc or a field to set manually?
415     /// should be an integer to specify at which layer to start
416     rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
417     rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
418             enc_uni_cell, rnn_direction::unidirectional_left2right,
419             enc_uni_first_dst_layer_md, zero_md(), enc_uni_wei_layer_md,
420             enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
421             zero_md());
422     auto enc_uni_prim_desc = mkldnn::rnn_forward::primitive_desc(
423             enc_uni_layer_desc, cpu_engine);
424
425     auto enc_uni_wei_layer_memory
426             = memory(enc_uni_prim_desc.weights_layer_primitive_desc());
427     auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
428             user_enc_uni_wei_layer_memory.get_primitive_desc(),
429             enc_uni_wei_layer_memory.get_primitive_desc());
430     weights_reorders.push_back(reorder(enc_uni_wei_layer_reorder_pd,
431             user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory));
432
433     auto enc_uni_wei_iter_memory
434             = memory(enc_uni_prim_desc.weights_iter_primitive_desc());
435     auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
436             user_enc_uni_wei_iter_memory.get_primitive_desc(),
437             enc_uni_wei_iter_memory.get_primitive_desc());
438     weights_reorders.push_back(reorder(enc_uni_wei_iter_reorder_pd,
439             user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory));
440
441     auto enc_dst_layer_memory
442             = mkldnn::memory(enc_uni_prim_desc.dst_layer_primitive_desc());
443
444     encoder_net.push_back(
445             rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
446                     null_memory_, enc_uni_wei_layer_memory,
447                     enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
448                     enc_dst_layer_memory, null_memory_, null_memory_));
449
450     /* GNMT: decoder with attention mechanism */
451     std::vector<float> user_dec_wei_layer(
452             dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
453             1.0f);
454     std::vector<float> user_dec_wei_iter(dec_n_layers * 1
455                     * (feature_size + feature_size) * lstm_n_gates
456                     * feature_size, 1.0f);
457     std::vector<float> user_dec_bias(
458             dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
459     std::vector<float> user_dec_dst(
460             tgt_seq_length_max * batch * feature_size, 1.0f);
461     std::vector<float> user_weights_attention_src_layer(
462             feature_size * feature_size, 1.0f);
463     std::vector<float> user_weights_annotation(
464             feature_size * feature_size, 1.0f);
465     std::vector<float> user_weights_alignments(feature_size, 1.0f);
466
467     memory::dims user_dec_wei_layer_dims
468             = { dec_n_layers, 1, feature_size, lstm_n_gates, feature_size };
469     memory::dims user_dec_wei_iter_dims = { dec_n_layers, 1,
470         feature_size + feature_size, lstm_n_gates, feature_size };
471     memory::dims user_dec_bias_dims
472             = { dec_n_layers, 1, lstm_n_gates, feature_size };
473
474     memory::dims dec_src_layer_dims = { 1, batch, feature_size };
475     memory::dims dec_dst_layer_dims = { 1, batch, feature_size };
476
477     // We will use the same memory for dec_src_iter and dec_dst_iter
478     // However, dec_src_iter has a context vector but not
479     // dec_dst_iter.
480     // To resolve this we will create one memory that holds the
481     // context vector as well as the both the hidden and cell states.
482     // For the dst_iter, we will use a view on this memory.
483     // Note that the cell state will be padded by
484     // feature_size values. However, we do not compute or
485     // access those.
486     memory::dims dec_dst_iter_dims = { dec_n_layers, 1, lstm_n_states, batch,
487         feature_size + feature_size };
488     memory::dims dec_dst_iter_noctx_dims
489             = { dec_n_layers, 1, lstm_n_states, batch, feature_size };
490
491     auto user_dec_wei_layer_md = mkldnn::memory::desc(
492             { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
493             mkldnn::memory::format::ldigo);
494     auto user_dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
495             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
496     auto user_dec_bias_md = mkldnn::memory::desc({ user_dec_bias_dims },
497             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
498     auto dec_dst_layer_md = mkldnn::memory::desc({ dec_dst_layer_dims },
499             mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
500     auto dec_src_layer_md = mkldnn::memory::desc({ dec_src_layer_dims },
501             mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
502     auto dec_dst_iter_md = mkldnn::memory::desc({ dec_dst_iter_dims },
503             mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
504     auto user_dec_wei_layer_memory = mkldnn::memory(
505             { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
506     auto user_dec_wei_iter_memory = mkldnn::memory(
507             { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
508     auto user_dec_bias_memory = mkldnn::memory(
509             { user_dec_bias_md, cpu_engine }, user_dec_bias.data());
510     auto user_dec_dst_layer_memory = mkldnn::memory(
511             { dec_dst_layer_md, cpu_engine }, user_dec_dst.data());
512     auto dec_src_layer_memory
513             = mkldnn::memory({ dec_src_layer_md, cpu_engine });
514
515     auto dec_wei_layer_md = mkldnn::memory::desc(
516             { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
517             mkldnn::memory::format::any);
518     auto dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
519             mkldnn::memory::data_type::f32, mkldnn::memory::format::any);
520
521     // As mentioned above, we create a view without context out of the
522     // memory with context.
523     auto dec_dst_iter_memory = mkldnn::memory({ dec_dst_iter_md, cpu_engine });
524     auto dec_dst_iter_noctx_md = mkldnn::view::primitive_desc(
525             dec_dst_iter_memory.get_primitive_desc(), dec_dst_iter_noctx_dims,
526             { 0, 0, 0, 0, 0 }).dst_primitive_desc().desc();
527
528     /// @todo add suport for residual connections
529     /// should it be a set residual in op_desc or a field to set manually?
530     /// should be an integer to specify at which layer to start
531     rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
532     rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
533             rnn_direction::unidirectional_left2right, dec_src_layer_md,
534             dec_dst_iter_md, dec_wei_layer_md, dec_wei_iter_md,
535             user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
536     auto dec_ctx_prim_desc
537             = mkldnn::rnn_forward::primitive_desc(dec_ctx_desc, cpu_engine);
538
539     auto dec_wei_layer_memory
540             = memory(dec_ctx_prim_desc.weights_layer_primitive_desc());
541     auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
542             user_dec_wei_layer_memory.get_primitive_desc(),
543             dec_wei_layer_memory.get_primitive_desc());
544     weights_reorders.push_back(reorder(dec_wei_layer_reorder_pd,
545             user_dec_wei_layer_memory, dec_wei_layer_memory));
546
547     auto dec_wei_iter_memory
548             = memory(dec_ctx_prim_desc.weights_iter_primitive_desc());
549     auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
550             user_dec_wei_iter_memory.get_primitive_desc(),
551             dec_wei_iter_memory.get_primitive_desc());
552     weights_reorders.push_back(reorder(dec_wei_iter_reorder_pd,
553             user_dec_wei_iter_memory, dec_wei_iter_memory));
554
555     decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
556             dec_dst_iter_memory, dec_wei_layer_memory,
557             dec_wei_iter_memory, user_dec_bias_memory,
558             user_dec_dst_layer_memory, dec_dst_iter_memory, null_memory_));
559
560     // allocating temporary buffer for attention mechanism
561     std::vector<float> weighted_annotations(
562             src_seq_length_max * batch * feature_size, 1.0f);
563
564     /*
565        Execution
566      */
567     auto execute = [&]() {
568         // reorder weights to MKLDNN internal representation
569         stream(stream::kind::eager).submit(weights_reorders).wait();
570
571         // run encoder (1 stream)
572         stream(stream::kind::eager).submit(encoder_net).wait();
573
574         // we compute the weighted annotations once before the decoder
575         compute_weighted_annotations(weighted_annotations.data(),
576                 src_seq_length_max, batch, feature_size,
577                 user_weights_annotation.data(),
578                 (float *)enc_dst_layer_memory.get_data_handle());
579
580         // We initialise src_layer to the embedding of </s>, which
581         // are assumed to be 0 here
582         memset(dec_src_layer_memory.get_data_handle(), 0,
583                dec_src_layer_memory.get_primitive_desc().get_size());
584         // From now on, src points to the output of the last iteration
585
586         for (int i = 0; i < tgt_seq_length_max; i++) {
587             float *src_att_layer_handle
588                     = (float *) dec_src_layer_memory.get_data_handle();
589             float *src_att_iter_handle
590                     = (float *) dec_dst_iter_memory.get_data_handle();
591
592             // Compute attention context vector into the first layer src_iter
593             compute_attention(src_att_iter_handle, src_seq_length_max, batch,
594                     feature_size, user_weights_attention_src_layer.data(),
595                     src_att_layer_handle,
596                     (float *)enc_bidir_dst_layer_memory.get_data_handle(),
597                     weighted_annotations.data(),
598                     user_weights_alignments.data());
599
600             // copy the context vectors to all layers of src_iter
601             copy_context(src_att_iter_handle, dec_n_layers, lstm_n_states, batch,
602                     feature_size);
603
604             // run the decoder iteration
605             stream(stream::kind::eager).submit(decoder_net).wait();
606
607             // Move the handle on the src/dst layer to the next iteration
608             auto dst_layer_handle = (float *) user_dec_dst_layer_memory.get_data_handle();
609             dec_src_layer_memory.set_data_handle(dst_layer_handle);
610             user_dec_dst_layer_memory.set_data_handle(
611                     dst_layer_handle + batch * feature_size);
612         }
613
614     };
615
616     execute();
617 }
618
619 int main(int argc, char **argv) {
620     try {
621         simple_net();
622         std::cout << "ok\n";
623     } catch (error &e) {
624         std::cerr << "status: " << e.status << std::endl;
625         std::cerr << "message: " << e.message << std::endl;
626         return 1;
627     }
628     return 0;
629 }