1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
25 // MSVC doesn't support collapse clause in omp parallel
26 #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
30 using namespace mkldnn;
33 const int src_seq_length_max = 25;
34 const int tgt_seq_length_max = 27;
36 const int feature_size = 1024;
38 const int enc_bidir_n_layers = 1;
39 const int enc_unidir_n_layers = 7;
40 const int dec_n_layers = 8;
42 const int lstm_n_gates = 4;
43 const int lstm_n_states = 2;
44 std::vector<int32_t> weighted_src_layer(batch *feature_size, 1);
45 std::vector<float> alignment_model(
46 src_seq_length_max *batch *feature_size, 1.0f);
47 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
48 std::vector<float> exp_sums(batch, 1.0f);
50 const float onef = 1.0, zerof = 0.0;
53 void compute_weighted_annotations(float *weighted_annotations,
54 int src_seq_length_max, int batch, int feature_size,
55 float *weights_annot, float *annotations) {
56 // annotations(aka enc_dst_layer) is (t, n, 2c)
57 // weights_annot is (2c, c)
59 int num_weighted_annotations = src_seq_length_max * batch;
60 // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
61 mkldnn_sgemm("N", "N", &feature_size, &num_weighted_annotations,
62 &feature_size, &onef, weights_annot, &feature_size, annotations,
63 &feature_size, &zerof, weighted_annotations, &feature_size);
66 void compute_sum_of_rows(int8_t *a, int rows, int cols, int32_t *a_reduced) {
68 #pragma omp parallel for
70 for (int i = 0; i < cols; i++) {
72 for (int j = 0; j < rows; j++) {
73 a_reduced[i] += (int32_t)a[i * rows + j];
78 void compute_attention(float *context_vectors, int src_seq_length_max,
79 int batch, int feature_size, int8_t *weights_src_layer,
80 float weights_src_layer_scale, int32_t *compensation,
81 uint8_t *dec_src_layer, float dec_src_layer_scale,
82 float dec_src_layer_shift, uint8_t *annotations,
83 float *weighted_annotations, float *weights_alignments) {
84 // dst_iter : (n, c) matrix
85 // src_layer: (n, c) matrix
86 // weighted_annotations (t, n, c)
88 // weights_yi is (c, c)
89 // weights_ai is (c, 1)
94 // first we precompute the weighted_dec_src_layer
98 mkldnn_gemm_s8u8s32("N", "N", "F", &feature_size, &batch, &feature_size,
99 &onef, weights_src_layer, &feature_size, &ao, dec_src_layer,
100 &feature_size, &bo, &zerof, weighted_src_layer.data(),
103 // then we compute the alignment model
104 float *alignment_model_ptr = alignment_model.data();
106 #pragma omp parallel for collapse(2)
108 for (int i = 0; i < src_seq_length_max; i++) {
109 for (int j = 0; j < batch; j++) {
110 for (int k = 0; k < feature_size; k++) {
112 = i * batch * feature_size + j * feature_size + k;
113 alignment_model_ptr[tnc_offset] = tanhf(
114 (float)(weighted_src_layer.data()[j * feature_size + k]
115 - dec_src_layer_shift * compensation[k])
116 / (dec_src_layer_scale
117 * weights_src_layer_scale)
118 + weighted_annotations[tnc_offset]);
123 // gemv with alignments weights. the resulting alignments are in alignments
124 int num_weighted_annotations = src_seq_length_max * batch;
125 mkldnn_sgemm("N", "N", &onei, &num_weighted_annotations, &feature_size,
126 &onef, weights_alignments, &onei, alignment_model_ptr,
127 &feature_size, &zerof, alignments.data(), &onei);
129 // softmax on alignments. the resulting context weights are in alignments
131 #pragma omp parallel for
133 for (int i = 0; i < batch; i++)
136 #pragma omp parallel for collapse(2)
138 for (int i = 0; i < src_seq_length_max; i++) {
139 for (int j = 0; j < batch; j++) {
140 alignments[i * batch + j] = expf(alignments[i * batch + j]);
141 exp_sums[j] += alignments[i * batch + j];
146 #pragma omp parallel for collapse(2)
148 for (int i = 0; i < src_seq_length_max; i++)
149 for (int j = 0; j < batch; j++)
150 alignments[i * batch + j] /= exp_sums[j];
152 // then we compute the context vectors
154 #pragma omp parallel for collapse(2)
156 for (int i = 0; i < batch; i++)
157 for (int j = 0; j < feature_size; j++)
158 context_vectors[i * (feature_size + feature_size) + feature_size
163 #pragma omp parallel for collapse(3)
165 for (int i = 0; i < batch; i++)
166 for (int k = 0; k < src_seq_length_max; k++)
167 for (int j = 0; j < feature_size; j++)
168 context_vectors[i * (feature_size + feature_size) + feature_size
170 += alignments[k * batch + i]
171 * (((float)annotations[j
172 + feature_size * (i + batch * k)]
173 - dec_src_layer_shift)
174 / dec_src_layer_scale);
177 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
179 // we copy the context from the first layer to all other layers
181 #pragma omp parallel for collapse(3)
183 for (int k = 1; k < n_layers; k++)
184 for (int j = 0; j < batch; j++)
185 for (int i = 0; i < feature_size; i++)
186 src_iter[(k * n_states * batch + j)
187 * (feature_size + feature_size)
189 = src_iter[j * (feature_size + feature_size) + i];
193 auto cpu_engine = engine(engine::cpu, 0);
194 auto null_memory_ = null_memory(cpu_engine);
197 GNMT low precicion example.
198 Note, we do not implement connection yet.
199 For the encoder we use:
200 - one primitive for the bidirectional layer of the encoder
201 - one primitive for all remaining unidirectional layers in the encoder
202 For the decoder we use:
203 - one primitive for the first iteration
204 - one primitive for all subsequent iterations in the decoder. Note that
205 in this example, this primitive computes the states in place.
206 - the attention mechanism is implemented separately as there is no support
207 for the context vectors in MKL-DNN yet
210 std::vector<primitive> weights_reorders;
211 std::vector<primitive> encoder_net;
212 std::vector<primitive> decoder_net;
214 std::vector<float> net_src(batch * src_seq_length_max * feature_size, 0.1f);
215 std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 0.1f);
217 /* Quantization factors for fp32 data */
219 const float data_shift = 64.;
220 const float data_scale = 63.;
221 const int weights_scale_mask = 3; // 11 for last two dimensions of ldigo
222 std::vector<float> weights_scales(lstm_n_gates * feature_size);
223 /* assign halves of vector with arbitrary values */
224 const int scales_half = lstm_n_gates * feature_size / 2;
226 weights_scales.begin(), weights_scales.begin() + scales_half, 30.f);
227 std::fill(weights_scales.begin() + scales_half + 1, weights_scales.end(),
232 memory::dims enc_bidir_src_layer_tz
233 = { src_seq_length_max, batch, feature_size };
234 memory::dims enc_bidir_weights_layer_tz = { enc_bidir_n_layers, 2,
235 feature_size, lstm_n_gates, feature_size };
236 memory::dims enc_bidir_weights_iter_tz = { enc_bidir_n_layers, 2,
237 feature_size, lstm_n_gates, feature_size };
238 memory::dims enc_bidir_bias_tz
239 = { enc_bidir_n_layers, 2, lstm_n_gates, feature_size };
240 memory::dims enc_bidir_dst_layer_tz
241 = { src_seq_length_max, batch, 2 * feature_size };
243 /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers */
245 std::vector<float> user_enc_bidir_wei_layer(
246 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
248 std::vector<float> user_enc_bidir_wei_iter(
249 enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
251 std::vector<float> user_enc_bidir_bias(
252 enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
254 /* Create the memory for user data */
255 auto user_enc_bidir_src_layer_md = memory::desc({ enc_bidir_src_layer_tz },
256 memory::data_type::f32, memory::format::tnc);
258 auto user_enc_bidir_wei_layer_md
259 = memory::desc({ enc_bidir_weights_layer_tz },
260 memory::data_type::f32, memory::format::ldigo);
262 auto user_enc_bidir_wei_iter_md
263 = memory::desc({ enc_bidir_weights_iter_tz },
264 memory::data_type::f32, memory::format::ldigo);
266 auto user_enc_bidir_bias_md = memory::desc({ enc_bidir_bias_tz },
267 memory::data_type::f32, memory::format::ldgo);
269 auto user_enc_bidir_src_layer_memory = memory(
270 { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
271 auto user_enc_bidir_wei_layer_memory
272 = memory({ user_enc_bidir_wei_layer_md, cpu_engine },
273 user_enc_bidir_wei_layer.data());
274 auto user_enc_bidir_wei_iter_memory
275 = memory({ user_enc_bidir_wei_iter_md, cpu_engine },
276 user_enc_bidir_wei_iter.data());
277 auto user_enc_bidir_bias_memory = memory(
278 { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
280 /* Create memory descriptors for RNN data w/o specified layout */
281 auto enc_bidir_src_layer_md = memory::desc({ enc_bidir_src_layer_tz },
282 memory::data_type::u8, memory::format::any);
284 auto enc_bidir_wei_layer_md = memory::desc({ enc_bidir_weights_layer_tz },
285 memory::data_type::s8, memory::format::any);
287 auto enc_bidir_wei_iter_md = memory::desc({ enc_bidir_weights_iter_tz },
288 memory::data_type::s8, memory::format::any);
290 auto enc_bidir_dst_layer_md = memory::desc({ enc_bidir_dst_layer_tz },
291 memory::data_type::u8, memory::format::any);
293 /* Create bidirectional RNN */
294 rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
296 /* Check if int8 RNN is supported */
298 rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
299 rnn_direction::bidirectional_concat, enc_bidir_src_layer_md,
300 zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
301 user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
303 if (e.status == mkldnn_unimplemented) {
305 << "Dependency on Intel(R) MKL version 2019u2 or newer is "
306 "required for int8 RNN"
312 rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
313 rnn_direction::bidirectional_concat, enc_bidir_src_layer_md,
314 zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
315 user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
317 /* Define RNN attributes that store quantization parameters */
319 attr.set_int_output_round_mode(round_mode::round_nearest);
320 attr.set_rnn_data_qparams(data_scale, data_shift);
321 attr.set_rnn_weights_qparams(weights_scale_mask, weights_scales);
323 auto enc_bidir_prim_desc
324 = rnn_forward::primitive_desc(bi_layer_desc, attr, cpu_engine);
326 /* Create memory primitives for input data and use reorders to quantize
328 * NOTE: same attributes are used when creating RNN primitive and reorders
330 auto enc_bidir_src_layer_memory
331 = memory(enc_bidir_prim_desc.src_layer_primitive_desc());
332 auto enc_bidir_src_layer_reorder_pd = reorder::primitive_desc(
333 user_enc_bidir_src_layer_memory.get_primitive_desc(),
334 enc_bidir_src_layer_memory.get_primitive_desc(), attr);
335 encoder_net.push_back(reorder(enc_bidir_src_layer_reorder_pd,
336 user_enc_bidir_src_layer_memory, enc_bidir_src_layer_memory));
338 auto enc_bidir_wei_layer_memory
339 = memory(enc_bidir_prim_desc.weights_layer_primitive_desc());
340 auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
341 user_enc_bidir_wei_layer_memory.get_primitive_desc(),
342 enc_bidir_wei_layer_memory.get_primitive_desc(), attr);
343 weights_reorders.push_back(reorder(enc_bidir_wei_layer_reorder_pd,
344 user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory));
346 auto enc_bidir_wei_iter_memory
347 = memory(enc_bidir_prim_desc.weights_iter_primitive_desc());
348 auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
349 user_enc_bidir_wei_iter_memory.get_primitive_desc(),
350 enc_bidir_wei_iter_memory.get_primitive_desc(), attr);
351 weights_reorders.push_back(reorder(enc_bidir_wei_iter_reorder_pd,
352 user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory));
354 auto enc_bidir_dst_layer_memory
355 = memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
357 encoder_net.push_back(
358 rnn_forward(enc_bidir_prim_desc, enc_bidir_src_layer_memory,
359 null_memory_, enc_bidir_wei_layer_memory,
360 enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
361 enc_bidir_dst_layer_memory, null_memory_, null_memory_));
363 /* GNMT encoder: unidirectional layers */
364 // First unidirectinal layer scales 2 * feature_size output of bidirectional
365 // layer to feature_size output
366 std::vector<float> user_enc_uni_first_wei_layer(
367 1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 0.3f);
368 std::vector<float> user_enc_uni_first_wei_iter(
369 1 * 1 * feature_size * lstm_n_gates * feature_size, 0.2f);
370 std::vector<float> user_enc_uni_first_bias(
371 1 * 1 * lstm_n_gates * feature_size, 1.0f);
373 memory::dims user_enc_uni_first_wei_layer_dims
374 = { 1, 1, 2 * feature_size, lstm_n_gates, feature_size };
375 memory::dims user_enc_uni_first_wei_iter_dims
376 = { 1, 1, feature_size, lstm_n_gates, feature_size };
377 memory::dims user_enc_uni_first_bias_dims
378 = { 1, 1, lstm_n_gates, feature_size };
379 memory::dims enc_uni_first_dst_layer_dims
380 = { src_seq_length_max, batch, feature_size };
382 auto user_enc_uni_first_wei_layer_md
383 = memory::desc({ user_enc_uni_first_wei_layer_dims },
384 memory::data_type::f32, memory::format::ldigo);
385 auto user_enc_uni_first_wei_iter_md
386 = memory::desc({ user_enc_uni_first_wei_iter_dims },
387 memory::data_type::f32, memory::format::ldigo);
388 auto user_enc_uni_first_bias_md
389 = memory::desc({ user_enc_uni_first_bias_dims },
390 memory::data_type::f32, memory::format::ldgo);
391 auto user_enc_uni_first_wei_layer_memory
392 = memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
393 user_enc_uni_first_wei_layer.data());
394 auto user_enc_uni_first_wei_iter_memory
395 = memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
396 user_enc_uni_first_wei_iter.data());
397 auto user_enc_uni_first_bias_memory
398 = memory({ user_enc_uni_first_bias_md, cpu_engine },
399 user_enc_uni_first_bias.data());
401 auto enc_uni_first_wei_layer_md
402 = memory::desc({ user_enc_uni_first_wei_layer_dims },
403 memory::data_type::s8, memory::format::any);
404 auto enc_uni_first_wei_iter_md
405 = memory::desc({ user_enc_uni_first_wei_iter_dims },
406 memory::data_type::s8, memory::format::any);
407 auto enc_uni_first_dst_layer_md
408 = memory::desc({ enc_uni_first_dst_layer_dims },
409 memory::data_type::u8, memory::format::any);
411 rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
412 rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
413 enc_uni_first_cell, rnn_direction::unidirectional_left2right,
414 enc_bidir_dst_layer_md, zero_md(), enc_uni_first_wei_layer_md,
415 enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
416 enc_uni_first_dst_layer_md, zero_md());
418 auto enc_uni_first_prim_desc = rnn_forward::primitive_desc(
419 enc_uni_first_layer_desc, attr, cpu_engine);
421 auto enc_uni_first_wei_layer_memory
422 = memory(enc_uni_first_prim_desc.weights_layer_primitive_desc());
423 auto enc_uni_first_wei_layer_reorder_pd = reorder::primitive_desc(
424 user_enc_uni_first_wei_layer_memory.get_primitive_desc(),
425 enc_uni_first_wei_layer_memory.get_primitive_desc(), attr);
426 weights_reorders.push_back(reorder(enc_uni_first_wei_layer_reorder_pd,
427 user_enc_uni_first_wei_layer_memory,
428 enc_uni_first_wei_layer_memory));
430 auto enc_uni_first_wei_iter_memory
431 = memory(enc_uni_first_prim_desc.weights_iter_primitive_desc());
432 auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
433 user_enc_uni_first_wei_iter_memory.get_primitive_desc(),
434 enc_uni_first_wei_iter_memory.get_primitive_desc(), attr);
435 weights_reorders.push_back(reorder(enc_uni_first_wei_iter_reorder_pd,
436 user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory));
438 auto enc_uni_first_dst_layer_memory
439 = memory(enc_uni_first_prim_desc.dst_layer_primitive_desc());
441 encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
442 enc_bidir_dst_layer_memory, null_memory_,
443 enc_uni_first_wei_layer_memory, enc_uni_first_wei_iter_memory,
444 user_enc_uni_first_bias_memory, enc_uni_first_dst_layer_memory,
445 null_memory_, null_memory_));
447 /* Remainging unidirectional layers */
448 std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
449 * feature_size * lstm_n_gates * feature_size,
451 std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
452 * feature_size * lstm_n_gates * feature_size,
454 std::vector<float> user_enc_uni_bias(
455 (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
457 memory::dims user_enc_uni_wei_layer_dims = { (enc_unidir_n_layers - 1), 1,
458 feature_size, lstm_n_gates, feature_size };
459 memory::dims user_enc_uni_wei_iter_dims = { (enc_unidir_n_layers - 1), 1,
460 feature_size, lstm_n_gates, feature_size };
461 memory::dims user_enc_uni_bias_dims
462 = { (enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size };
463 memory::dims enc_dst_layer_dims
464 = { src_seq_length_max, batch, feature_size };
466 auto user_enc_uni_wei_layer_md
467 = memory::desc({ user_enc_uni_wei_layer_dims },
468 memory::data_type::f32, memory::format::ldigo);
469 auto user_enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
470 memory::data_type::f32, memory::format::ldigo);
471 auto user_enc_uni_bias_md = memory::desc({ user_enc_uni_bias_dims },
472 memory::data_type::f32, memory::format::ldgo);
474 auto user_enc_uni_wei_layer_memory
475 = memory({ user_enc_uni_wei_layer_md, cpu_engine },
476 user_enc_uni_wei_layer.data());
477 auto user_enc_uni_wei_iter_memory
478 = memory({ user_enc_uni_wei_iter_md, cpu_engine },
479 user_enc_uni_wei_iter.data());
480 auto user_enc_uni_bias_memory = memory(
481 { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
483 auto enc_uni_wei_layer_md = memory::desc({ user_enc_uni_wei_layer_dims },
484 memory::data_type::s8, memory::format::any);
485 auto enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
486 memory::data_type::s8, memory::format::any);
487 auto enc_dst_layer_md = memory::desc({ enc_dst_layer_dims },
488 memory::data_type::f32, memory::format::any);
490 rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
491 rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
492 enc_uni_cell, rnn_direction::unidirectional_left2right,
493 enc_uni_first_dst_layer_md, zero_md(), enc_uni_wei_layer_md,
494 enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
496 auto enc_uni_prim_desc
497 = rnn_forward::primitive_desc(enc_uni_layer_desc, attr, cpu_engine);
499 auto enc_uni_wei_layer_memory
500 = memory(enc_uni_prim_desc.weights_layer_primitive_desc());
501 auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
502 user_enc_uni_wei_layer_memory.get_primitive_desc(),
503 enc_uni_wei_layer_memory.get_primitive_desc(), attr);
504 weights_reorders.push_back(reorder(enc_uni_wei_layer_reorder_pd,
505 user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory));
507 auto enc_uni_wei_iter_memory
508 = memory(enc_uni_prim_desc.weights_iter_primitive_desc());
509 auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
510 user_enc_uni_wei_iter_memory.get_primitive_desc(),
511 enc_uni_wei_iter_memory.get_primitive_desc(), attr);
512 weights_reorders.push_back(reorder(enc_uni_wei_iter_reorder_pd,
513 user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory));
515 auto enc_dst_layer_memory
516 = memory(enc_uni_prim_desc.dst_layer_primitive_desc());
518 encoder_net.push_back(
519 rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
520 null_memory_, enc_uni_wei_layer_memory,
521 enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
522 enc_dst_layer_memory, null_memory_, null_memory_));
524 /* Decoder with attention mechanism */
525 std::vector<float> user_dec_wei_layer(
526 dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
528 std::vector<float> user_dec_wei_iter(dec_n_layers * 1
529 * (feature_size + feature_size) * lstm_n_gates
532 std::vector<float> user_dec_bias(
533 dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
534 std::vector<int8_t> user_weights_attention_src_layer(
535 feature_size * feature_size, 1);
536 float weights_attention_scale = 127.;
537 std::vector<float> user_weights_annotation(
538 feature_size * feature_size, 1.0f);
539 std::vector<float> user_weights_alignments(feature_size, 1.0f);
540 // Buffer to store decoder output for all iterations
541 std::vector<uint8_t> dec_dst(tgt_seq_length_max * batch * feature_size, 0);
543 memory::dims user_dec_wei_layer_dims
544 = { dec_n_layers, 1, feature_size, lstm_n_gates, feature_size };
545 memory::dims user_dec_wei_iter_dims = { dec_n_layers, 1,
546 feature_size + feature_size, lstm_n_gates, feature_size };
547 memory::dims user_dec_bias_dims
548 = { dec_n_layers, 1, lstm_n_gates, feature_size };
549 memory::dims dec_src_layer_dims = { 1, batch, feature_size };
550 memory::dims dec_dst_layer_dims = { 1, batch, feature_size };
552 // We will use the same memory for dec_src_iter and dec_dst_iter
553 // However, dec_src_iter has a context vector but not
555 // To resolve this we will create one memory that holds the
556 // context vector as well as the both the hidden and cell states.
557 // For the dst_iter, we will use a view on this memory.
558 // Note that the cell state will be padded by
559 // feature_size values. However, we do not compute or
561 memory::dims dec_dst_iter_dims = { dec_n_layers, 1, lstm_n_states, batch,
562 feature_size + feature_size };
563 memory::dims dec_dst_iter_noctx_dims
564 = { dec_n_layers, 1, lstm_n_states, batch, feature_size };
566 auto user_dec_wei_layer_md = memory::desc({ user_dec_wei_layer_dims },
567 memory::data_type::f32, memory::format::ldigo);
568 auto user_dec_wei_iter_md = memory::desc({ user_dec_wei_iter_dims },
569 memory::data_type::f32, memory::format::ldigo);
570 auto user_dec_bias_md = memory::desc({ user_dec_bias_dims },
571 memory::data_type::f32, memory::format::ldgo);
572 auto dec_src_layer_md = memory::desc(
573 { dec_src_layer_dims }, memory::data_type::u8, memory::format::tnc);
574 auto dec_dst_layer_md = memory::desc(
575 { dec_dst_layer_dims }, memory::data_type::u8, memory::format::tnc);
576 auto dec_dst_iter_md = memory::desc({ dec_dst_iter_dims },
577 memory::data_type::f32, memory::format::ldsnc);
579 auto user_dec_wei_layer_memory = memory(
580 { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
581 auto user_dec_wei_iter_memory = memory(
582 { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
583 auto user_dec_bias_memory
584 = memory({ user_dec_bias_md, cpu_engine }, user_dec_bias.data());
585 auto dec_src_layer_memory = memory({ dec_src_layer_md, cpu_engine });
586 auto dec_dst_layer_memory
587 = memory({ dec_dst_layer_md, cpu_engine }, dec_dst.data());
589 /* Create memory descriptors for RNN data w/o specified layout */
590 auto dec_wei_layer_md = memory::desc({ user_dec_wei_layer_dims },
591 memory::data_type::s8, memory::format::any);
592 auto dec_wei_iter_md = memory::desc({ user_dec_wei_iter_dims },
593 memory::data_type::s8, memory::format::any);
595 /* As mentioned above, we create a view without context out of the
596 memory with context. */
597 auto dec_dst_iter_memory = memory({ dec_dst_iter_md, cpu_engine });
598 auto dec_dst_iter_noctx_md
599 = view::primitive_desc(dec_dst_iter_memory.get_primitive_desc(),
600 dec_dst_iter_noctx_dims, { 0, 0, 0, 0, 0 })
601 .dst_primitive_desc()
604 rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
605 rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
606 rnn_direction::unidirectional_left2right, dec_src_layer_md,
607 dec_dst_iter_md, dec_wei_layer_md, dec_wei_iter_md,
608 user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
609 auto dec_ctx_prim_desc
610 = rnn_forward::primitive_desc(dec_ctx_desc, attr, cpu_engine);
612 /* Create memory primitives for input data and use reorders to quantize
614 auto dec_wei_layer_memory
615 = memory(dec_ctx_prim_desc.weights_layer_primitive_desc());
616 auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
617 user_dec_wei_layer_memory.get_primitive_desc(),
618 dec_wei_layer_memory.get_primitive_desc(), attr);
619 weights_reorders.push_back(reorder(dec_wei_layer_reorder_pd,
620 user_dec_wei_layer_memory, dec_wei_layer_memory));
622 auto dec_wei_iter_memory
623 = memory(dec_ctx_prim_desc.weights_iter_primitive_desc());
624 auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
625 user_dec_wei_iter_memory.get_primitive_desc(),
626 dec_wei_iter_memory.get_primitive_desc(), attr);
627 weights_reorders.push_back(reorder(dec_wei_iter_reorder_pd,
628 user_dec_wei_iter_memory, dec_wei_iter_memory));
630 decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
631 dec_dst_iter_memory, dec_wei_layer_memory, dec_wei_iter_memory,
632 user_dec_bias_memory, dec_dst_layer_memory, dec_dst_iter_memory,
635 /* Allocating temporary buffers for attention mechanism */
636 std::vector<float> weighted_annotations(
637 src_seq_length_max * batch * feature_size, 1.0f);
638 std::vector<int32_t> weights_attention_sum_rows(feature_size, 1);
643 auto execute = [&]() {
644 // reorder weights to MKLDNN internal representation
645 stream(stream::kind::eager).submit(weights_reorders).wait();
647 // run encoder (1 stream)
648 stream(stream::kind::eager).submit(encoder_net).wait();
650 // compute the weighted annotations once before the decoder
651 compute_weighted_annotations(weighted_annotations.data(),
652 src_seq_length_max, batch, feature_size,
653 user_weights_annotation.data(),
654 (float *)enc_dst_layer_memory.get_data_handle());
655 // precompute compensation for s8u8s32 gemm in compute attention
656 compute_sum_of_rows(user_weights_attention_src_layer.data(),
657 feature_size, feature_size, weights_attention_sum_rows.data());
659 // We initialise src_layer to the embedding of </s>, which
660 // are assumed to be 0 here
661 memset(dec_src_layer_memory.get_data_handle(), 0,
662 dec_src_layer_memory.get_primitive_desc().get_size());
663 // From now on, src points to the output of the last iteration
665 for (int i = 0; i < tgt_seq_length_max; i++) {
666 uint8_t *src_att_layer_handle
667 = (uint8_t *)dec_src_layer_memory.get_data_handle();
668 float *src_att_iter_handle
669 = (float *)dec_dst_iter_memory.get_data_handle();
671 // Compute attention context vector into the first layer src_iter
672 compute_attention(src_att_iter_handle, src_seq_length_max, batch,
673 feature_size, user_weights_attention_src_layer.data(),
674 weights_attention_scale, weights_attention_sum_rows.data(),
675 src_att_layer_handle, data_scale, data_shift,
676 (uint8_t *)enc_bidir_dst_layer_memory.get_data_handle(),
677 weighted_annotations.data(),
678 user_weights_alignments.data());
680 // copy the context vectors to all layers of src_iter
681 copy_context(src_att_iter_handle, dec_n_layers, lstm_n_states,
682 batch, feature_size);
684 // run the decoder iteration
685 stream(stream::kind::eager).submit(decoder_net).wait();
687 // Move the handle on the src/dst layer to the next iteration
688 auto dst_layer_handle
689 = (uint8_t *)dec_dst_layer_memory.get_data_handle();
690 dec_src_layer_memory.set_data_handle(dst_layer_handle);
691 dec_dst_layer_memory.set_data_handle(
692 dst_layer_handle + batch * feature_size);
700 int main(int argc, char **argv) {
705 std::cerr << "status: " << e.status << std::endl;
706 std::cerr << "message: " << e.message << std::endl;