1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
23 #include "mkl_cblas.h"
27 // MSVC doesn't support collapse clause in omp parallel
28 #if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
32 using namespace mkldnn;
34 const int batch = 128;
35 const int src_seq_length_max = 28;
36 const int tgt_seq_length_max = 28;
38 const int feature_size = 1024;
40 const int enc_bidir_n_layers = 1;
41 const int enc_unidir_n_layers = 7;
42 const int dec_n_layers = 8;
44 const int lstm_n_gates = 4;
45 const int lstm_n_states = 2;
46 std::vector<float> weighted_src_layer(batch *feature_size, 1.0f);
47 std::vector<float> alignment_model(
48 src_seq_length_max *batch *feature_size, 1.0f);
49 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
50 std::vector<float> exp_sums(batch, 1.0f);
52 void compute_weighted_annotations(float *weighted_annotations,
53 int src_seq_length_max, int batch, int feature_size,
54 float *weights_annot, float *annotations) {
55 // annotations(aka enc_dst_layer) is (t, n, 2c)
56 // weights_annot is (2c, c)
58 // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
59 cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, feature_size,
60 src_seq_length_max * batch, feature_size, 1.0f, weights_annot,
61 feature_size, annotations, feature_size, 0.0f, weighted_annotations,
65 void compute_attention(float *context_vectors, int src_seq_length_max,
66 int batch, int feature_size, float *weights_src_layer,
67 float *dec_src_layer, float *annotations, float *weighted_annotations,
68 float *weights_alignments) {
69 // dst_iter : (n, c) matrix
70 // src_layer: (n, c) matrix
71 // weighted_annotations (t, n, c)
73 // weights_yi is (c, c)
74 // weights_ai is (c, 1)
79 // first we precompute the weighted_dec_src_layer
80 cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, feature_size, batch,
81 feature_size, 1.0f, weights_src_layer, feature_size, dec_src_layer,
82 feature_size, 0.0f, weighted_src_layer.data(), feature_size);
84 // then we compute the alignment model
85 float *alignment_model_ptr = alignment_model.data();
86 #pragma omp parallel for collapse(2)
87 for (int i = 0; i < src_seq_length_max; i++) {
88 for (int j = 0; j < batch * feature_size; j++)
89 alignment_model_ptr[i * batch * feature_size + j] = tanhf(
90 weighted_src_layer.data()[j]
91 + weighted_annotations[i * batch * feature_size + j]);
94 // gemv with alignments weights. the resulting alignments are in alignments
95 cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 1,
96 src_seq_length_max * batch, feature_size, 1.0f, weights_alignments,
97 1, alignment_model_ptr, feature_size, 0.0f, alignments.data(), 1);
99 // softmax on alignments. the resulting context weights are in alignments
100 #pragma omp parallel for
101 for (int i = 0; i < batch; i++)
103 #pragma omp parallel for collapse(2)
104 for (int i = 0; i < src_seq_length_max; i++) {
105 for (int j = 0; j < batch; j++) {
106 alignments[i * batch + j] = expf(alignments[i * batch + j]);
107 exp_sums[j] += alignments[i * batch + j];
111 #pragma omp parallel for collapse(2)
112 for (int i = 0; i < src_seq_length_max; i++)
113 for (int j = 0; j < batch; j++)
114 alignments[i * batch + j] /= exp_sums[j];
116 // then we compute the context vectors
117 #pragma omp parallel for collapse(2)
118 for (int i = 0; i < batch; i++)
119 for (int j = 0; j < feature_size; j++)
120 context_vectors[i * (feature_size + feature_size) + feature_size
124 #pragma omp parallel for collapse(3)
125 for (int i = 0; i < batch; i++)
126 for (int k = 0; k < src_seq_length_max; k++)
127 for (int j = 0; j < feature_size; j++)
128 context_vectors[i * (feature_size + feature_size) + feature_size
130 += alignments[k * batch + i]
131 * annotations[j + feature_size * (i + batch * k)];
134 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
136 // we copy the context from the first layer to all other layers
137 #pragma omp parallel for collapse(3)
138 for (int k = 1; k < n_layers; k++)
139 for (int j = 0; j < batch; j++)
140 for (int i = 0; i < feature_size; i++)
141 src_iter[(k * n_states * batch + j)
142 * (feature_size + feature_size)
144 = src_iter[j * (feature_size + feature_size) + i];
148 auto cpu_engine = engine(engine::cpu, 0);
149 auto null_memory_ = null_memory(cpu_engine);
153 Note, we do not implement connection yet.
154 For the encoder we use:
155 - one primitive for the bidirectional layer of the encoder
156 - one primitive for all remaining unidirectional layers in the encoder
157 For the decoder we use:
158 - one primitive for the first iteration
159 - one primitive for all subsequent iterations in the decoder. Note that
160 in this example, this primitive computes the states in place.
161 - the attention mechanism is implemented separately as there is no support
162 for the context vectors in MKL-DNN yet
165 std::vector<primitive> encoder_net;
166 std::vector<primitive> decoder_net;
168 std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
169 std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);
173 memory::dims enc_bidir_src_layer_tz
174 = { src_seq_length_max, batch, feature_size };
175 memory::dims enc_bidir_weights_layer_tz = { enc_bidir_n_layers, 2,
176 feature_size, lstm_n_gates, feature_size };
177 memory::dims enc_bidir_weights_iter_tz = { enc_bidir_n_layers, 2,
178 feature_size, lstm_n_gates, feature_size };
179 memory::dims enc_bidir_bias_tz
180 = { enc_bidir_n_layers, 2, lstm_n_gates, feature_size };
181 memory::dims enc_bidir_dst_layer_tz
182 = { src_seq_length_max, batch, 2 * feature_size };
184 /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers
187 std::vector<float> user_enc_bidir_wei_layer(
188 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
190 std::vector<float> user_enc_bidir_wei_iter(
191 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
193 std::vector<float> user_enc_bidir_bias(
194 enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
196 // We create the memory descriptors used by the user
197 auto user_enc_bidir_src_layer_md = mkldnn::memory::desc(
198 { enc_bidir_src_layer_tz }, mkldnn::memory::data_type::f32,
199 mkldnn::memory::format::tnc);
201 auto user_enc_bidir_wei_layer_md = mkldnn::memory::desc(
202 { enc_bidir_weights_layer_tz }, mkldnn::memory::data_type::f32,
203 mkldnn::memory::format::ldigo);
205 auto user_enc_bidir_wei_iter_md = mkldnn::memory::desc(
206 { enc_bidir_weights_iter_tz }, mkldnn::memory::data_type::f32,
207 mkldnn::memory::format::ldigo);
209 auto user_enc_bidir_bias_md = mkldnn::memory::desc({ enc_bidir_bias_tz },
210 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
212 auto enc_bidir_dst_layer_md = mkldnn::memory::desc(
213 { enc_bidir_dst_layer_tz }, mkldnn::memory::data_type::f32,
214 mkldnn::memory::format::tnc);
216 /* We create memories */
217 auto user_enc_bidir_src_layer_memory = mkldnn::memory(
218 { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
219 auto user_enc_bidir_wei_layer_memory
220 = mkldnn::memory({ user_enc_bidir_wei_layer_md, cpu_engine },
221 user_enc_bidir_wei_layer.data());
222 auto user_enc_bidir_wei_iter_memory
223 = mkldnn::memory({ user_enc_bidir_wei_iter_md, cpu_engine },
224 user_enc_bidir_wei_iter.data());
225 auto user_enc_bidir_bias_memory = mkldnn::memory(
226 { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
229 /// These will be null memories
230 /// @todo introduce predefined null_memory() ?
231 auto enc_bidir_src_iter_memory = mkldnn::memory({enc_bidir_src_iter_md, cpu_engine});
232 auto enc_bidir_dst_iter_memory = mkldnn::memory({enc_bidir_dst_iter_md, cpu_engine});
235 /// @todo fix this once cell desc is merged with rnn_desc
236 rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
237 rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
238 rnn_direction::bidirectional_concat, user_enc_bidir_src_layer_md,
239 zero_md(), user_enc_bidir_wei_layer_md, user_enc_bidir_wei_iter_md,
240 user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
242 auto enc_bidir_prim_desc
243 = mkldnn::rnn_forward::primitive_desc(bi_layer_desc, cpu_engine);
245 // there are currently no reorders
246 /// @todo add a reorder when they will be available
248 auto enc_bidir_dst_layer_memory
249 = mkldnn::memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
251 encoder_net.push_back(
252 rnn_forward(enc_bidir_prim_desc, user_enc_bidir_src_layer_memory,
253 null_memory_, user_enc_bidir_wei_layer_memory,
254 user_enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
255 enc_bidir_dst_layer_memory, null_memory_, null_memory_));
257 /* GNMT encoder: unidirectional layers
259 // First unidirectinal layer, the scaling from 2*feature size features
260 // comming from the previous layer come
262 std::vector<float> user_enc_uni_first_wei_layer(
263 1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
264 std::vector<float> user_enc_uni_first_wei_iter(
265 1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
266 std::vector<float> user_enc_uni_first_bias(
267 1 * 1 * lstm_n_gates * feature_size, 1.0f);
268 memory::dims user_enc_uni_first_wei_layer_dims
269 = { 1, 1, 2 * feature_size, lstm_n_gates, feature_size };
270 memory::dims user_enc_uni_first_wei_iter_dims
271 = { 1, 1, feature_size, lstm_n_gates, feature_size };
272 memory::dims user_enc_uni_first_bias_dims
273 = { 1, 1, lstm_n_gates, feature_size };
274 memory::dims enc_uni_first_dst_layer_dims
275 = { src_seq_length_max, batch, feature_size };
276 auto user_enc_uni_first_wei_layer_md = mkldnn::memory::desc(
277 { user_enc_uni_first_wei_layer_dims },
278 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
279 auto user_enc_uni_first_wei_iter_md = mkldnn::memory::desc(
280 { user_enc_uni_first_wei_iter_dims },
281 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
282 auto user_enc_uni_first_bias_md = mkldnn::memory::desc(
283 { user_enc_uni_first_bias_dims }, mkldnn::memory::data_type::f32,
284 mkldnn::memory::format::ldgo);
285 auto enc_uni_first_dst_layer_md = mkldnn::memory::desc(
286 { enc_uni_first_dst_layer_dims }, mkldnn::memory::data_type::f32,
287 mkldnn::memory::format::tnc);
288 auto user_enc_uni_first_wei_layer_memory
289 = mkldnn::memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
290 user_enc_uni_first_wei_layer.data());
292 auto user_enc_uni_first_wei_iter_memory
293 = mkldnn::memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
294 user_enc_uni_first_wei_iter.data());
295 auto user_enc_uni_first_bias_memory
296 = mkldnn::memory({ user_enc_uni_first_bias_md, cpu_engine },
297 user_enc_uni_first_bias.data());
299 /// @todo add suport for residual connections
300 /// should it be a set residual in op_desc or a field to set manually?
301 /// should be an integer to specify at which layer to start
302 rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
303 rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
304 enc_uni_first_cell, rnn_direction::unidirectional_left2right,
305 enc_bidir_dst_layer_md, zero_md(), user_enc_uni_first_wei_layer_md,
306 user_enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
307 enc_uni_first_dst_layer_md, zero_md());
308 auto enc_uni_first_prim_desc = mkldnn::rnn_forward::primitive_desc(
309 enc_uni_first_layer_desc, cpu_engine);
310 auto enc_uni_first_dst_layer_memory = mkldnn::memory(
311 enc_uni_first_prim_desc.dst_layer_primitive_desc());
313 /// @todo add a reorder when they will be available
314 encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
315 enc_bidir_dst_layer_memory, null_memory_,
316 user_enc_uni_first_wei_layer_memory,
317 user_enc_uni_first_wei_iter_memory, user_enc_uni_first_bias_memory,
318 enc_uni_first_dst_layer_memory, null_memory_, null_memory_));
320 // Remainging Unidirectional layers
322 std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
323 * feature_size * lstm_n_gates * feature_size, 1.0f);
324 std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
325 * feature_size * lstm_n_gates * feature_size, 1.0f);
326 std::vector<float> user_enc_uni_bias(
327 (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
328 memory::dims user_enc_uni_wei_layer_dims = { (enc_unidir_n_layers - 1), 1,
329 feature_size, lstm_n_gates, feature_size };
330 memory::dims user_enc_uni_wei_iter_dims = { (enc_unidir_n_layers - 1), 1,
331 feature_size, lstm_n_gates, feature_size };
332 memory::dims user_enc_uni_bias_dims
333 = { (enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size };
334 memory::dims enc_dst_layer_dims
335 = { src_seq_length_max, batch, feature_size };
336 auto user_enc_uni_wei_layer_md = mkldnn::memory::desc(
337 { user_enc_uni_wei_layer_dims }, mkldnn::memory::data_type::f32,
338 mkldnn::memory::format::ldigo);
339 auto user_enc_uni_wei_iter_md = mkldnn::memory::desc(
340 { user_enc_uni_wei_iter_dims }, mkldnn::memory::data_type::f32,
341 mkldnn::memory::format::ldigo);
342 auto user_enc_uni_bias_md = mkldnn::memory::desc({ user_enc_uni_bias_dims },
343 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
344 auto enc_dst_layer_md = mkldnn::memory::desc({ enc_dst_layer_dims },
345 mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
346 auto user_enc_uni_wei_layer_memory
347 = mkldnn::memory({ user_enc_uni_wei_layer_md, cpu_engine },
348 user_enc_uni_wei_layer.data());
350 auto user_enc_uni_wei_iter_memory
351 = mkldnn::memory({ user_enc_uni_wei_iter_md, cpu_engine },
352 user_enc_uni_wei_iter.data());
353 auto user_enc_uni_bias_memory = mkldnn::memory(
354 { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
356 /// @todo add suport for residual connections
357 /// should it be a set residual in op_desc or a field to set manually?
358 /// should be an integer to specify at which layer to start
359 rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
360 rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
361 enc_uni_cell, rnn_direction::unidirectional_left2right,
362 enc_uni_first_dst_layer_md, zero_md(), user_enc_uni_wei_layer_md,
363 user_enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
365 auto enc_uni_prim_desc = mkldnn::rnn_forward::primitive_desc(
366 enc_uni_layer_desc, cpu_engine);
367 auto enc_dst_layer_memory
368 = mkldnn::memory(enc_uni_prim_desc.dst_layer_primitive_desc());
370 /// @todo add a reorder when they will be available
371 encoder_net.push_back(
372 rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
373 null_memory_, user_enc_uni_wei_layer_memory,
374 user_enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
375 enc_dst_layer_memory, null_memory_, null_memory_));
378 * GNMT: decoder with attention mechanism
380 // user provided memories
381 std::vector<float> user_dec_wei_layer(
382 dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
384 std::vector<float> user_dec_wei_iter(dec_n_layers * 1
385 * (feature_size + feature_size) * lstm_n_gates
386 * feature_size, 1.0f);
387 std::vector<float> user_dec_bias(
388 dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
389 std::vector<float> user_dec_dst(
390 tgt_seq_length_max * batch * feature_size, 1.0f);
391 std::vector<float> user_weights_attention_src_layer(
392 feature_size * feature_size, 1.0f);
393 std::vector<float> user_weights_annotation(
394 feature_size * feature_size, 1.0f);
395 std::vector<float> user_weights_alignments(feature_size, 1.0f);
397 memory::dims user_dec_wei_layer_dims
398 = { dec_n_layers, 1, feature_size, lstm_n_gates, feature_size };
399 memory::dims user_dec_wei_iter_dims = { dec_n_layers, 1,
400 feature_size + feature_size, lstm_n_gates, feature_size };
401 memory::dims user_dec_bias_dims
402 = { dec_n_layers, 1, lstm_n_gates, feature_size };
404 memory::dims dec_src_layer_dims = { 1, batch, feature_size };
405 memory::dims dec_dst_layer_dims
406 = { tgt_seq_length_max, batch, feature_size };
408 // We will use the same memory for dec_src_iter and dec_dst_iter
409 // However, dec_src_iter has a context vector but not
411 // To resolve this we will create one memory that holds the
412 // context vector as well as the both the hidden and cell states.
413 // For the dst_iter, we will use a view on this memory.
414 // Note that the cell state will be padded by
415 // feature_size values. However, we do not compute or
417 memory::dims dec_dst_iter_dims = { dec_n_layers, 1, lstm_n_states, batch,
418 feature_size + feature_size };
419 memory::dims dec_dst_iter_noctx_dims
420 = { dec_n_layers, 1, lstm_n_states, batch, feature_size };
422 auto user_dec_wei_layer_md = mkldnn::memory::desc(
423 { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
424 mkldnn::memory::format::ldigo);
425 auto user_dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
426 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
427 auto user_dec_bias_md = mkldnn::memory::desc({ user_dec_bias_dims },
428 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
429 auto dec_dst_layer_md = mkldnn::memory::desc({ dec_dst_layer_dims },
430 mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
431 auto dec_src_layer_md = mkldnn::memory::desc({ dec_src_layer_dims },
432 mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
433 auto dec_dst_iter_md = mkldnn::memory::desc({ dec_dst_iter_dims },
434 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
435 auto user_dec_wei_layer_memory = mkldnn::memory(
436 { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
438 auto user_dec_wei_iter_memory = mkldnn::memory(
439 { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
440 auto user_dec_bias_memory = mkldnn::memory(
441 { user_dec_bias_md, cpu_engine }, user_dec_bias.data());
442 auto user_dec_dst_layer_memory = mkldnn::memory(
443 { dec_dst_layer_md, cpu_engine }, user_dec_dst.data());
444 auto dec_src_layer_memory
445 = mkldnn::memory({ dec_src_layer_md, cpu_engine });
447 // As mentioned above, we create a view without context out of the
448 // memory with context.
449 auto dec_dst_iter_memory = mkldnn::memory({ dec_dst_iter_md, cpu_engine });
450 auto dec_dst_iter_noctx_md = mkldnn::view::primitive_desc(
451 dec_dst_iter_memory.get_primitive_desc(), dec_dst_iter_noctx_dims,
452 { 0, 0, 0, 0, 0 }).dst_primitive_desc().desc();
454 /// @todo add suport for residual connections
455 /// should it be a set residual in op_desc or a field to set manually?
456 /// should be an integer to specify at which layer to start
457 rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
458 rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
459 rnn_direction::unidirectional_left2right, dec_src_layer_md,
460 dec_dst_iter_md, user_dec_wei_layer_md, user_dec_wei_iter_md,
461 user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
462 auto dec_ctx_prim_desc
463 = mkldnn::rnn_forward::primitive_desc(dec_ctx_desc, cpu_engine);
465 /// @todo add a reorder when they will be available
466 decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
467 dec_dst_iter_memory, user_dec_wei_layer_memory,
468 user_dec_wei_iter_memory, user_dec_bias_memory,
469 user_dec_dst_layer_memory, dec_dst_iter_memory, null_memory_));
471 // allocating temporary buffer for attention mechanism
472 std::vector<float> weighted_annotations(
473 src_seq_length_max * batch * feature_size, 1.0f);
478 auto execute = [&]() {
479 // We save the original handle on dst_layer as we will modify it at each
481 void *dst_layer_original_handle
482 = user_dec_dst_layer_memory.get_data_handle();
484 // run encoder (1 stream)
485 stream(stream::kind::eager).submit(encoder_net).wait();
487 // we compute the weighted annotations once before the decoder
488 compute_weighted_annotations(weighted_annotations.data(),
489 src_seq_length_max, batch, feature_size,
490 user_weights_annotation.data(),
491 (float *)enc_dst_layer_memory.get_data_handle());
493 // We initialise dst_layer[0] to the embedding of </s>, which are
496 memset(dst_layer_original_handle, 0,
497 batch * feature_size * sizeof(float));
499 for (int i = 0; i < tgt_seq_length_max; i++) {
500 float *dst_layer_handle
501 = (float *)user_dec_dst_layer_memory.get_data_handle();
502 float *dst_iter_handle
503 = (float *)dec_dst_iter_memory.get_data_handle();
505 // Compute attention context vector into the first layer src_iter
506 compute_attention(dst_iter_handle, src_seq_length_max, batch,
507 feature_size, user_weights_attention_src_layer.data(),
509 (float *)enc_bidir_dst_layer_memory.get_data_handle(),
510 weighted_annotations.data(),
511 user_weights_alignments.data());
513 // copy the context vectors to all layers of src_iter
514 copy_context(dst_iter_handle, dec_n_layers, lstm_n_states, batch,
517 // We set src_layer to be the previously
518 dec_src_layer_memory.set_data_handle(dst_layer_handle);
520 // run the decoder iteration
521 stream(stream::kind::eager).submit(decoder_net).wait();
523 // Move the handle on the dst layer to the next iteration
524 user_dec_dst_layer_memory.set_data_handle(
525 dst_layer_handle + batch * feature_size);
527 // we restore the handle to the begining of the buffer
528 user_dec_dst_layer_memory.set_data_handle(dst_layer_original_handle);
529 /// @todo run the softmax after each iteration or not?
535 int main(int argc, char **argv) {
540 std::cerr << "status: " << e.status << std::endl;
541 std::cerr << "message: " << e.message << std::endl;