1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
25 // MSVC doesn't support collapse clause in omp parallel
26 #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
30 using namespace mkldnn;
32 const int batch = 128;
33 const int src_seq_length_max = 28;
34 const int tgt_seq_length_max = 28;
36 const int feature_size = 1024;
38 const int enc_bidir_n_layers = 1;
39 const int enc_unidir_n_layers = 7;
40 const int dec_n_layers = 8;
42 const int lstm_n_gates = 4;
43 const int lstm_n_states = 2;
44 std::vector<float> weighted_src_layer(batch *feature_size, 1.0f);
45 std::vector<float> alignment_model(
46 src_seq_length_max *batch *feature_size, 1.0f);
47 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
48 std::vector<float> exp_sums(batch, 1.0f);
50 const float onef = 1.0, zerof = 0.0;
53 void compute_weighted_annotations(float *weighted_annotations,
54 int src_seq_length_max, int batch, int feature_size,
55 float *weights_annot, float *annotations) {
56 // annotations(aka enc_dst_layer) is (t, n, 2c)
57 // weights_annot is (2c, c)
59 // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
60 int num_weighted_annotations = src_seq_length_max * batch;
61 mkldnn_sgemm("N", "N",
62 &feature_size, &num_weighted_annotations, &feature_size,
63 &onef, weights_annot, &feature_size, annotations, &feature_size,
64 &zerof, weighted_annotations, &feature_size);
67 void compute_attention(float *context_vectors, int src_seq_length_max,
68 int batch, int feature_size, float *weights_src_layer,
69 float *dec_src_layer, float *annotations, float *weighted_annotations,
70 float *weights_alignments) {
71 // dst_iter : (n, c) matrix
72 // src_layer: (n, c) matrix
73 // weighted_annotations (t, n, c)
75 // weights_yi is (c, c)
76 // weights_ai is (c, 1)
81 // first we precompute the weighted_dec_src_layer
82 mkldnn_sgemm("N", "N",
83 &feature_size, &batch, &feature_size, &onef,
84 weights_src_layer, &feature_size, dec_src_layer, &feature_size,
85 &zerof, weighted_src_layer.data(), &feature_size);
87 // then we compute the alignment model
88 float *alignment_model_ptr = alignment_model.data();
90 #pragma omp parallel for collapse(2)
92 for (int i = 0; i < src_seq_length_max; i++) {
93 for (int j = 0; j < batch * feature_size; j++)
94 alignment_model_ptr[i * batch * feature_size + j] = tanhf(
95 weighted_src_layer.data()[j]
96 + weighted_annotations[i * batch * feature_size + j]);
99 // gemv with alignments weights. the resulting alignments are in alignments
100 int num_weighted_annotations = src_seq_length_max * batch;
101 mkldnn_sgemm("N", "N",
102 &onei, &num_weighted_annotations, &feature_size, &onef,
103 weights_alignments, &onei, alignment_model_ptr, &feature_size,
104 &zerof, alignments.data(), &onei);
106 // softmax on alignments. the resulting context weights are in alignments
108 #pragma omp parallel for
110 for (int i = 0; i < batch; i++)
113 #pragma omp parallel for collapse(2)
115 for (int i = 0; i < src_seq_length_max; i++) {
116 for (int j = 0; j < batch; j++) {
117 alignments[i * batch + j] = expf(alignments[i * batch + j]);
118 exp_sums[j] += alignments[i * batch + j];
123 #pragma omp parallel for collapse(2)
125 for (int i = 0; i < src_seq_length_max; i++)
126 for (int j = 0; j < batch; j++)
127 alignments[i * batch + j] /= exp_sums[j];
129 // then we compute the context vectors
131 #pragma omp parallel for collapse(2)
133 for (int i = 0; i < batch; i++)
134 for (int j = 0; j < feature_size; j++)
135 context_vectors[i * (feature_size + feature_size) + feature_size
140 #pragma omp parallel for collapse(3)
142 for (int i = 0; i < batch; i++)
143 for (int k = 0; k < src_seq_length_max; k++)
144 for (int j = 0; j < feature_size; j++)
145 context_vectors[i * (feature_size + feature_size) + feature_size
147 += alignments[k * batch + i]
148 * annotations[j + feature_size * (i + batch * k)];
151 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
153 // we copy the context from the first layer to all other layers
155 #pragma omp parallel for collapse(3)
157 for (int k = 1; k < n_layers; k++)
158 for (int j = 0; j < batch; j++)
159 for (int i = 0; i < feature_size; i++)
160 src_iter[(k * n_states * batch + j)
161 * (feature_size + feature_size)
163 = src_iter[j * (feature_size + feature_size) + i];
167 auto cpu_engine = engine(engine::cpu, 0);
168 auto null_memory_ = null_memory(cpu_engine);
172 Note, we do not implement connection yet.
173 For the encoder we use:
174 - one primitive for the bidirectional layer of the encoder
175 - one primitive for all remaining unidirectional layers in the encoder
176 For the decoder we use:
177 - one primitive for the first iteration
178 - one primitive for all subsequent iterations in the decoder. Note that
179 in this example, this primitive computes the states in place.
180 - the attention mechanism is implemented separately as there is no support
181 for the context vectors in MKL-DNN yet
184 std::vector<primitive> weights_reorders;
185 std::vector<primitive> encoder_net;
186 std::vector<primitive> decoder_net;
188 std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
189 std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);
193 memory::dims enc_bidir_src_layer_tz
194 = { src_seq_length_max, batch, feature_size };
195 memory::dims enc_bidir_weights_layer_tz = { enc_bidir_n_layers, 2,
196 feature_size, lstm_n_gates, feature_size };
197 memory::dims enc_bidir_weights_iter_tz = { enc_bidir_n_layers, 2,
198 feature_size, lstm_n_gates, feature_size };
199 memory::dims enc_bidir_bias_tz
200 = { enc_bidir_n_layers, 2, lstm_n_gates, feature_size };
201 memory::dims enc_bidir_dst_layer_tz
202 = { src_seq_length_max, batch, 2 * feature_size };
204 /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers */
206 std::vector<float> user_enc_bidir_wei_layer(
207 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
209 std::vector<float> user_enc_bidir_wei_iter(
210 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
212 std::vector<float> user_enc_bidir_bias(
213 enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
215 /* Create the memory for user data */
216 auto user_enc_bidir_src_layer_md = mkldnn::memory::desc(
217 { enc_bidir_src_layer_tz }, mkldnn::memory::data_type::f32,
218 mkldnn::memory::format::tnc);
220 auto user_enc_bidir_wei_layer_md = mkldnn::memory::desc(
221 { enc_bidir_weights_layer_tz }, mkldnn::memory::data_type::f32,
222 mkldnn::memory::format::ldigo);
224 auto user_enc_bidir_wei_iter_md = mkldnn::memory::desc(
225 { enc_bidir_weights_iter_tz }, mkldnn::memory::data_type::f32,
226 mkldnn::memory::format::ldigo);
228 auto user_enc_bidir_bias_md = mkldnn::memory::desc({ enc_bidir_bias_tz },
229 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
231 auto user_enc_bidir_src_layer_memory = mkldnn::memory(
232 { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
233 auto user_enc_bidir_wei_layer_memory
234 = mkldnn::memory({ user_enc_bidir_wei_layer_md, cpu_engine },
235 user_enc_bidir_wei_layer.data());
236 auto user_enc_bidir_wei_iter_memory
237 = mkldnn::memory({ user_enc_bidir_wei_iter_md, cpu_engine },
238 user_enc_bidir_wei_iter.data());
239 auto user_enc_bidir_bias_memory = mkldnn::memory(
240 { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
242 /* Create memory descriptors for RNN data w/o specified layout */
243 auto enc_bidir_wei_layer_md = memory::desc({ enc_bidir_weights_layer_tz },
244 memory::data_type::f32, memory::format::any);
246 auto enc_bidir_wei_iter_md = memory::desc({ enc_bidir_weights_iter_tz },
247 memory::data_type::f32, memory::format::any);
249 auto enc_bidir_dst_layer_md = memory::desc({ enc_bidir_dst_layer_tz },
250 memory::data_type::f32, memory::format::any);
252 /* Create bidirectional RNN */
253 rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
254 rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
255 rnn_direction::bidirectional_concat, user_enc_bidir_src_layer_md,
256 zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
257 user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
259 auto enc_bidir_prim_desc
260 = mkldnn::rnn_forward::primitive_desc(bi_layer_desc, cpu_engine);
262 /* Create memory primitives for input data and use reorders to reorder
263 * user data to internal representation
265 auto enc_bidir_wei_layer_memory
266 = memory(enc_bidir_prim_desc.weights_layer_primitive_desc());
267 auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
268 user_enc_bidir_wei_layer_memory.get_primitive_desc(),
269 enc_bidir_wei_layer_memory.get_primitive_desc());
270 weights_reorders.push_back(reorder(enc_bidir_wei_layer_reorder_pd,
271 user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory));
273 auto enc_bidir_wei_iter_memory
274 = memory(enc_bidir_prim_desc.weights_iter_primitive_desc());
275 auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
276 user_enc_bidir_wei_iter_memory.get_primitive_desc(),
277 enc_bidir_wei_iter_memory.get_primitive_desc());
278 weights_reorders.push_back(reorder(enc_bidir_wei_iter_reorder_pd,
279 user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory));
281 auto enc_bidir_dst_layer_memory
282 = mkldnn::memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
284 encoder_net.push_back(
285 rnn_forward(enc_bidir_prim_desc, user_enc_bidir_src_layer_memory,
286 null_memory_, enc_bidir_wei_layer_memory,
287 enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
288 enc_bidir_dst_layer_memory, null_memory_, null_memory_));
290 /* GNMT encoder: unidirectional layers */
291 // First unidirectinal layer scales 2 * feature_size output of bidirectional
292 // layer to feature_size output
293 std::vector<float> user_enc_uni_first_wei_layer(
294 1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
295 std::vector<float> user_enc_uni_first_wei_iter(
296 1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
297 std::vector<float> user_enc_uni_first_bias(
298 1 * 1 * lstm_n_gates * feature_size, 1.0f);
299 memory::dims user_enc_uni_first_wei_layer_dims
300 = { 1, 1, 2 * feature_size, lstm_n_gates, feature_size };
301 memory::dims user_enc_uni_first_wei_iter_dims
302 = { 1, 1, feature_size, lstm_n_gates, feature_size };
303 memory::dims user_enc_uni_first_bias_dims
304 = { 1, 1, lstm_n_gates, feature_size };
305 memory::dims enc_uni_first_dst_layer_dims
306 = { src_seq_length_max, batch, feature_size };
307 auto user_enc_uni_first_wei_layer_md = mkldnn::memory::desc(
308 { user_enc_uni_first_wei_layer_dims },
309 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
310 auto user_enc_uni_first_wei_iter_md = mkldnn::memory::desc(
311 { user_enc_uni_first_wei_iter_dims },
312 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
313 auto user_enc_uni_first_bias_md = mkldnn::memory::desc(
314 { user_enc_uni_first_bias_dims }, mkldnn::memory::data_type::f32,
315 mkldnn::memory::format::ldgo);
316 auto user_enc_uni_first_wei_layer_memory
317 = mkldnn::memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
318 user_enc_uni_first_wei_layer.data());
319 auto user_enc_uni_first_wei_iter_memory
320 = mkldnn::memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
321 user_enc_uni_first_wei_iter.data());
322 auto user_enc_uni_first_bias_memory
323 = mkldnn::memory({ user_enc_uni_first_bias_md, cpu_engine },
324 user_enc_uni_first_bias.data());
326 auto enc_uni_first_wei_layer_md
327 = memory::desc({ user_enc_uni_first_wei_layer_dims },
328 memory::data_type::f32, memory::format::any);
329 auto enc_uni_first_wei_iter_md
330 = memory::desc({ user_enc_uni_first_wei_iter_dims },
331 memory::data_type::f32, memory::format::any);
332 auto enc_uni_first_dst_layer_md
333 = memory::desc({ enc_uni_first_dst_layer_dims },
334 memory::data_type::f32, memory::format::any);
336 /// @todo add suport for residual connections
337 /// should it be a set residual in op_desc or a field to set manually?
338 /// should be an integer to specify at which layer to start
339 rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
340 rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
341 enc_uni_first_cell, rnn_direction::unidirectional_left2right,
342 enc_bidir_dst_layer_md, zero_md(), enc_uni_first_wei_layer_md,
343 enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
344 enc_uni_first_dst_layer_md, zero_md());
345 auto enc_uni_first_prim_desc = mkldnn::rnn_forward::primitive_desc(
346 enc_uni_first_layer_desc, cpu_engine);
348 auto enc_uni_first_wei_layer_memory
349 = memory(enc_uni_first_prim_desc.weights_layer_primitive_desc());
350 auto enc_uni_first_wei_layer_reorder_pd = reorder::primitive_desc(
351 user_enc_uni_first_wei_layer_memory.get_primitive_desc(),
352 enc_uni_first_wei_layer_memory.get_primitive_desc());
353 weights_reorders.push_back(reorder(enc_uni_first_wei_layer_reorder_pd,
354 user_enc_uni_first_wei_layer_memory,
355 enc_uni_first_wei_layer_memory));
357 auto enc_uni_first_wei_iter_memory
358 = memory(enc_uni_first_prim_desc.weights_iter_primitive_desc());
359 auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
360 user_enc_uni_first_wei_iter_memory.get_primitive_desc(),
361 enc_uni_first_wei_iter_memory.get_primitive_desc());
362 weights_reorders.push_back(reorder(enc_uni_first_wei_iter_reorder_pd,
363 user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory));
365 auto enc_uni_first_dst_layer_memory = mkldnn::memory(
366 enc_uni_first_prim_desc.dst_layer_primitive_desc());
368 encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
369 enc_bidir_dst_layer_memory, null_memory_,
370 enc_uni_first_wei_layer_memory,
371 enc_uni_first_wei_iter_memory, user_enc_uni_first_bias_memory,
372 enc_uni_first_dst_layer_memory, null_memory_, null_memory_));
374 /* Remainging unidirectional layers */
375 std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
376 * feature_size * lstm_n_gates * feature_size, 1.0f);
377 std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
378 * feature_size * lstm_n_gates * feature_size, 1.0f);
379 std::vector<float> user_enc_uni_bias(
380 (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
381 memory::dims user_enc_uni_wei_layer_dims = { (enc_unidir_n_layers - 1), 1,
382 feature_size, lstm_n_gates, feature_size };
383 memory::dims user_enc_uni_wei_iter_dims = { (enc_unidir_n_layers - 1), 1,
384 feature_size, lstm_n_gates, feature_size };
385 memory::dims user_enc_uni_bias_dims
386 = { (enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size };
387 memory::dims enc_dst_layer_dims
388 = { src_seq_length_max, batch, feature_size };
389 auto user_enc_uni_wei_layer_md = mkldnn::memory::desc(
390 { user_enc_uni_wei_layer_dims }, mkldnn::memory::data_type::f32,
391 mkldnn::memory::format::ldigo);
392 auto user_enc_uni_wei_iter_md = mkldnn::memory::desc(
393 { user_enc_uni_wei_iter_dims }, mkldnn::memory::data_type::f32,
394 mkldnn::memory::format::ldigo);
395 auto user_enc_uni_bias_md = mkldnn::memory::desc({ user_enc_uni_bias_dims },
396 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
397 auto user_enc_uni_wei_layer_memory
398 = mkldnn::memory({ user_enc_uni_wei_layer_md, cpu_engine },
399 user_enc_uni_wei_layer.data());
400 auto user_enc_uni_wei_iter_memory
401 = mkldnn::memory({ user_enc_uni_wei_iter_md, cpu_engine },
402 user_enc_uni_wei_iter.data());
403 auto user_enc_uni_bias_memory = mkldnn::memory(
404 { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
406 auto enc_uni_wei_layer_md = memory::desc({ user_enc_uni_wei_layer_dims },
407 memory::data_type::f32, memory::format::any);
408 auto enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
409 memory::data_type::f32, memory::format::any);
410 auto enc_dst_layer_md = memory::desc({ enc_dst_layer_dims },
411 memory::data_type::f32, memory::format::any);
413 /// @todo add suport for residual connections
414 /// should it be a set residual in op_desc or a field to set manually?
415 /// should be an integer to specify at which layer to start
416 rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
417 rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
418 enc_uni_cell, rnn_direction::unidirectional_left2right,
419 enc_uni_first_dst_layer_md, zero_md(), enc_uni_wei_layer_md,
420 enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
422 auto enc_uni_prim_desc = mkldnn::rnn_forward::primitive_desc(
423 enc_uni_layer_desc, cpu_engine);
425 auto enc_uni_wei_layer_memory
426 = memory(enc_uni_prim_desc.weights_layer_primitive_desc());
427 auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
428 user_enc_uni_wei_layer_memory.get_primitive_desc(),
429 enc_uni_wei_layer_memory.get_primitive_desc());
430 weights_reorders.push_back(reorder(enc_uni_wei_layer_reorder_pd,
431 user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory));
433 auto enc_uni_wei_iter_memory
434 = memory(enc_uni_prim_desc.weights_iter_primitive_desc());
435 auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
436 user_enc_uni_wei_iter_memory.get_primitive_desc(),
437 enc_uni_wei_iter_memory.get_primitive_desc());
438 weights_reorders.push_back(reorder(enc_uni_wei_iter_reorder_pd,
439 user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory));
441 auto enc_dst_layer_memory
442 = mkldnn::memory(enc_uni_prim_desc.dst_layer_primitive_desc());
444 encoder_net.push_back(
445 rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
446 null_memory_, enc_uni_wei_layer_memory,
447 enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
448 enc_dst_layer_memory, null_memory_, null_memory_));
450 /* GNMT: decoder with attention mechanism */
451 std::vector<float> user_dec_wei_layer(
452 dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
454 std::vector<float> user_dec_wei_iter(dec_n_layers * 1
455 * (feature_size + feature_size) * lstm_n_gates
456 * feature_size, 1.0f);
457 std::vector<float> user_dec_bias(
458 dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
459 std::vector<float> user_dec_dst(
460 tgt_seq_length_max * batch * feature_size, 1.0f);
461 std::vector<float> user_weights_attention_src_layer(
462 feature_size * feature_size, 1.0f);
463 std::vector<float> user_weights_annotation(
464 feature_size * feature_size, 1.0f);
465 std::vector<float> user_weights_alignments(feature_size, 1.0f);
467 memory::dims user_dec_wei_layer_dims
468 = { dec_n_layers, 1, feature_size, lstm_n_gates, feature_size };
469 memory::dims user_dec_wei_iter_dims = { dec_n_layers, 1,
470 feature_size + feature_size, lstm_n_gates, feature_size };
471 memory::dims user_dec_bias_dims
472 = { dec_n_layers, 1, lstm_n_gates, feature_size };
474 memory::dims dec_src_layer_dims = { 1, batch, feature_size };
475 memory::dims dec_dst_layer_dims = { 1, batch, feature_size };
477 // We will use the same memory for dec_src_iter and dec_dst_iter
478 // However, dec_src_iter has a context vector but not
480 // To resolve this we will create one memory that holds the
481 // context vector as well as the both the hidden and cell states.
482 // For the dst_iter, we will use a view on this memory.
483 // Note that the cell state will be padded by
484 // feature_size values. However, we do not compute or
486 memory::dims dec_dst_iter_dims = { dec_n_layers, 1, lstm_n_states, batch,
487 feature_size + feature_size };
488 memory::dims dec_dst_iter_noctx_dims
489 = { dec_n_layers, 1, lstm_n_states, batch, feature_size };
491 auto user_dec_wei_layer_md = mkldnn::memory::desc(
492 { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
493 mkldnn::memory::format::ldigo);
494 auto user_dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
495 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldigo);
496 auto user_dec_bias_md = mkldnn::memory::desc({ user_dec_bias_dims },
497 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldgo);
498 auto dec_dst_layer_md = mkldnn::memory::desc({ dec_dst_layer_dims },
499 mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
500 auto dec_src_layer_md = mkldnn::memory::desc({ dec_src_layer_dims },
501 mkldnn::memory::data_type::f32, mkldnn::memory::format::tnc);
502 auto dec_dst_iter_md = mkldnn::memory::desc({ dec_dst_iter_dims },
503 mkldnn::memory::data_type::f32, mkldnn::memory::format::ldsnc);
504 auto user_dec_wei_layer_memory = mkldnn::memory(
505 { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
506 auto user_dec_wei_iter_memory = mkldnn::memory(
507 { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
508 auto user_dec_bias_memory = mkldnn::memory(
509 { user_dec_bias_md, cpu_engine }, user_dec_bias.data());
510 auto user_dec_dst_layer_memory = mkldnn::memory(
511 { dec_dst_layer_md, cpu_engine }, user_dec_dst.data());
512 auto dec_src_layer_memory
513 = mkldnn::memory({ dec_src_layer_md, cpu_engine });
515 auto dec_wei_layer_md = mkldnn::memory::desc(
516 { user_dec_wei_layer_dims }, mkldnn::memory::data_type::f32,
517 mkldnn::memory::format::any);
518 auto dec_wei_iter_md = mkldnn::memory::desc({ user_dec_wei_iter_dims },
519 mkldnn::memory::data_type::f32, mkldnn::memory::format::any);
521 // As mentioned above, we create a view without context out of the
522 // memory with context.
523 auto dec_dst_iter_memory = mkldnn::memory({ dec_dst_iter_md, cpu_engine });
524 auto dec_dst_iter_noctx_md = mkldnn::view::primitive_desc(
525 dec_dst_iter_memory.get_primitive_desc(), dec_dst_iter_noctx_dims,
526 { 0, 0, 0, 0, 0 }).dst_primitive_desc().desc();
528 /// @todo add suport for residual connections
529 /// should it be a set residual in op_desc or a field to set manually?
530 /// should be an integer to specify at which layer to start
531 rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
532 rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
533 rnn_direction::unidirectional_left2right, dec_src_layer_md,
534 dec_dst_iter_md, dec_wei_layer_md, dec_wei_iter_md,
535 user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
536 auto dec_ctx_prim_desc
537 = mkldnn::rnn_forward::primitive_desc(dec_ctx_desc, cpu_engine);
539 auto dec_wei_layer_memory
540 = memory(dec_ctx_prim_desc.weights_layer_primitive_desc());
541 auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
542 user_dec_wei_layer_memory.get_primitive_desc(),
543 dec_wei_layer_memory.get_primitive_desc());
544 weights_reorders.push_back(reorder(dec_wei_layer_reorder_pd,
545 user_dec_wei_layer_memory, dec_wei_layer_memory));
547 auto dec_wei_iter_memory
548 = memory(dec_ctx_prim_desc.weights_iter_primitive_desc());
549 auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
550 user_dec_wei_iter_memory.get_primitive_desc(),
551 dec_wei_iter_memory.get_primitive_desc());
552 weights_reorders.push_back(reorder(dec_wei_iter_reorder_pd,
553 user_dec_wei_iter_memory, dec_wei_iter_memory));
555 decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
556 dec_dst_iter_memory, dec_wei_layer_memory,
557 dec_wei_iter_memory, user_dec_bias_memory,
558 user_dec_dst_layer_memory, dec_dst_iter_memory, null_memory_));
560 // allocating temporary buffer for attention mechanism
561 std::vector<float> weighted_annotations(
562 src_seq_length_max * batch * feature_size, 1.0f);
567 auto execute = [&]() {
568 // reorder weights to MKLDNN internal representation
569 stream(stream::kind::eager).submit(weights_reorders).wait();
571 // run encoder (1 stream)
572 stream(stream::kind::eager).submit(encoder_net).wait();
574 // we compute the weighted annotations once before the decoder
575 compute_weighted_annotations(weighted_annotations.data(),
576 src_seq_length_max, batch, feature_size,
577 user_weights_annotation.data(),
578 (float *)enc_dst_layer_memory.get_data_handle());
580 // We initialise src_layer to the embedding of </s>, which
581 // are assumed to be 0 here
582 memset(dec_src_layer_memory.get_data_handle(), 0,
583 dec_src_layer_memory.get_primitive_desc().get_size());
584 // From now on, src points to the output of the last iteration
586 for (int i = 0; i < tgt_seq_length_max; i++) {
587 float *src_att_layer_handle
588 = (float *) dec_src_layer_memory.get_data_handle();
589 float *src_att_iter_handle
590 = (float *) dec_dst_iter_memory.get_data_handle();
592 // Compute attention context vector into the first layer src_iter
593 compute_attention(src_att_iter_handle, src_seq_length_max, batch,
594 feature_size, user_weights_attention_src_layer.data(),
595 src_att_layer_handle,
596 (float *)enc_bidir_dst_layer_memory.get_data_handle(),
597 weighted_annotations.data(),
598 user_weights_alignments.data());
600 // copy the context vectors to all layers of src_iter
601 copy_context(src_att_iter_handle, dec_n_layers, lstm_n_states, batch,
604 // run the decoder iteration
605 stream(stream::kind::eager).submit(decoder_net).wait();
607 // Move the handle on the src/dst layer to the next iteration
608 auto dst_layer_handle = (float *) user_dec_dst_layer_memory.get_data_handle();
609 dec_src_layer_memory.set_data_handle(dst_layer_handle);
610 user_dec_dst_layer_memory.set_data_handle(
611 dst_layer_handle + batch * feature_size);
619 int main(int argc, char **argv) {
624 std::cerr << "status: " << e.status << std::endl;
625 std::cerr << "message: " << e.message << std::endl;