Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_rnn_int8.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include <iostream>
19 #include <math.h>
20 #include <numeric>
21 #include <string>
22
23 #include "mkldnn.hpp"
24
25 // MSVC doesn't support collapse clause in omp parallel
26 #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
27 #define collapse(x)
28 #endif
29
30 using namespace mkldnn;
31
32 const int batch = 64;
33 const int src_seq_length_max = 25;
34 const int tgt_seq_length_max = 27;
35
36 const int feature_size = 1024;
37
38 const int enc_bidir_n_layers = 1;
39 const int enc_unidir_n_layers = 7;
40 const int dec_n_layers = 8;
41
42 const int lstm_n_gates = 4;
43 const int lstm_n_states = 2;
44 std::vector<int32_t> weighted_src_layer(batch *feature_size, 1);
45 std::vector<float> alignment_model(
46         src_seq_length_max *batch *feature_size, 1.0f);
47 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
48 std::vector<float> exp_sums(batch, 1.0f);
49
50 const float onef = 1.0, zerof = 0.0;
51 const int onei = 1;
52
53 void compute_weighted_annotations(float *weighted_annotations,
54         int src_seq_length_max, int batch, int feature_size,
55         float *weights_annot, float *annotations) {
56     // annotations(aka enc_dst_layer) is (t, n, 2c)
57     // weights_annot is (2c, c)
58
59     int num_weighted_annotations = src_seq_length_max * batch;
60     // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
61     mkldnn_sgemm("N", "N", &feature_size, &num_weighted_annotations,
62             &feature_size, &onef, weights_annot, &feature_size, annotations,
63             &feature_size, &zerof, weighted_annotations, &feature_size);
64 }
65
66 void compute_sum_of_rows(int8_t *a, int rows, int cols, int32_t *a_reduced) {
67 #ifdef _OPENMP
68 #pragma omp parallel for
69 #endif
70     for (int i = 0; i < cols; i++) {
71         a_reduced[i] = 0;
72         for (int j = 0; j < rows; j++) {
73             a_reduced[i] += (int32_t)a[i * rows + j];
74         }
75     }
76 }
77
78 void compute_attention(float *context_vectors, int src_seq_length_max,
79         int batch, int feature_size, int8_t *weights_src_layer,
80         float weights_src_layer_scale, int32_t *compensation,
81         uint8_t *dec_src_layer, float dec_src_layer_scale,
82         float dec_src_layer_shift, uint8_t *annotations,
83         float *weighted_annotations, float *weights_alignments) {
84     // dst_iter : (n, c) matrix
85     // src_layer: (n, c) matrix
86     // weighted_annotations (t, n, c)
87
88     // weights_yi is (c, c)
89     // weights_ai is (c, 1)
90     // tmp[i] is (n, c)
91     // a[i] is (n, 1)
92     // p is (n, 1)
93
94     // first we precompute the weighted_dec_src_layer
95     int8_t ao = 0;
96     int8_t bo = 0;
97     int32_t co = 0;
98     mkldnn_gemm_s8u8s32("N", "N", "F", &feature_size, &batch, &feature_size,
99             &onef, weights_src_layer, &feature_size, &ao, dec_src_layer,
100             &feature_size, &bo, &zerof, weighted_src_layer.data(),
101             &feature_size, &co);
102
103     // then we compute the alignment model
104     float *alignment_model_ptr = alignment_model.data();
105 #ifdef _OPENMP
106 #pragma omp parallel for collapse(2)
107 #endif
108     for (int i = 0; i < src_seq_length_max; i++) {
109         for (int j = 0; j < batch; j++) {
110             for (int k = 0; k < feature_size; k++) {
111                 size_t tnc_offset
112                         = i * batch * feature_size + j * feature_size + k;
113                 alignment_model_ptr[tnc_offset] = tanhf(
114                         (float)(weighted_src_layer.data()[j * feature_size + k]
115                                 - dec_src_layer_shift * compensation[k])
116                                 / (dec_src_layer_scale
117                                           * weights_src_layer_scale)
118                         + weighted_annotations[tnc_offset]);
119             }
120         }
121     }
122
123     // gemv with alignments weights. the resulting alignments are in alignments
124     int num_weighted_annotations = src_seq_length_max * batch;
125     mkldnn_sgemm("N", "N", &onei, &num_weighted_annotations, &feature_size,
126             &onef, weights_alignments, &onei, alignment_model_ptr,
127             &feature_size, &zerof, alignments.data(), &onei);
128
129 // softmax on alignments. the resulting context weights are in alignments
130 #ifdef _OPENMP
131 #pragma omp parallel for
132 #endif
133     for (int i = 0; i < batch; i++)
134         exp_sums[i] = 0.0f;
135 #ifdef _OPENMP
136 #pragma omp parallel for collapse(2)
137 #endif
138     for (int i = 0; i < src_seq_length_max; i++) {
139         for (int j = 0; j < batch; j++) {
140             alignments[i * batch + j] = expf(alignments[i * batch + j]);
141             exp_sums[j] += alignments[i * batch + j];
142         }
143     }
144
145 #ifdef _OPENMP
146 #pragma omp parallel for collapse(2)
147 #endif
148     for (int i = 0; i < src_seq_length_max; i++)
149         for (int j = 0; j < batch; j++)
150             alignments[i * batch + j] /= exp_sums[j];
151
152 // then we compute the context vectors
153 #ifdef _OPENMP
154 #pragma omp parallel for collapse(2)
155 #endif
156     for (int i = 0; i < batch; i++)
157         for (int j = 0; j < feature_size; j++)
158             context_vectors[i * (feature_size + feature_size) + feature_size
159                     + j]
160                     = 0.0f;
161
162 #ifdef _OPENMP
163 #pragma omp parallel for collapse(3)
164 #endif
165     for (int i = 0; i < batch; i++)
166         for (int k = 0; k < src_seq_length_max; k++)
167             for (int j = 0; j < feature_size; j++)
168                 context_vectors[i * (feature_size + feature_size) + feature_size
169                         + j]
170                         += alignments[k * batch + i]
171                         * (((float)annotations[j
172                                    + feature_size * (i + batch * k)]
173                                    - dec_src_layer_shift)
174                         / dec_src_layer_scale);
175 }
176
177 void copy_context(float *src_iter, int n_layers, int n_states, int batch,
178         int feature_size) {
179 // we copy the context from the first layer to all other layers
180 #ifdef _OPENMP
181 #pragma omp parallel for collapse(3)
182 #endif
183     for (int k = 1; k < n_layers; k++)
184         for (int j = 0; j < batch; j++)
185             for (int i = 0; i < feature_size; i++)
186                 src_iter[(k * n_states * batch + j)
187                                 * (feature_size + feature_size)
188                         + i]
189                         = src_iter[j * (feature_size + feature_size) + i];
190 }
191
192 void simple_net() {
193     auto cpu_engine = engine(engine::cpu, 0);
194     auto null_memory_ = null_memory(cpu_engine);
195
196     /*
197       GNMT low precicion example.
198       Note, we do not implement connection yet.
199       For the encoder we use:
200       - one primitive for the bidirectional layer of the encoder
201       - one primitive for all remaining unidirectional layers in the encoder
202       For the decoder we use:
203       - one primitive for the first iteration
204       - one primitive for all subsequent iterations in the decoder. Note that
205         in this example, this primitive computes the states in place.
206       - the attention mechanism is implemented separately as there is no support
207         for the context vectors in MKL-DNN yet
208      */
209
210     std::vector<primitive> weights_reorders;
211     std::vector<primitive> encoder_net;
212     std::vector<primitive> decoder_net;
213
214     std::vector<float> net_src(batch * src_seq_length_max * feature_size, 0.1f);
215     std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 0.1f);
216
217     /* Quantization factors for fp32 data */
218
219     const float data_shift = 64.;
220     const float data_scale = 63.;
221     const int weights_scale_mask = 3; // 11 for last two dimensions of ldigo
222     std::vector<float> weights_scales(lstm_n_gates * feature_size);
223     /* assign halves of vector with arbitrary values */
224     const int scales_half = lstm_n_gates * feature_size / 2;
225     std::fill(
226             weights_scales.begin(), weights_scales.begin() + scales_half, 30.f);
227     std::fill(weights_scales.begin() + scales_half + 1, weights_scales.end(),
228             65.5f);
229
230     /* Encoder */
231
232     memory::dims enc_bidir_src_layer_tz
233             = { src_seq_length_max, batch, feature_size };
234     memory::dims enc_bidir_weights_layer_tz = { enc_bidir_n_layers, 2,
235         feature_size, lstm_n_gates, feature_size };
236     memory::dims enc_bidir_weights_iter_tz = { enc_bidir_n_layers, 2,
237         feature_size, lstm_n_gates, feature_size };
238     memory::dims enc_bidir_bias_tz
239             = { enc_bidir_n_layers, 2, lstm_n_gates, feature_size };
240     memory::dims enc_bidir_dst_layer_tz
241             = { src_seq_length_max, batch, 2 * feature_size };
242
243     /* GNMT encoder: 1 bidirectional layer and 7 unidirectional layers */
244
245     std::vector<float> user_enc_bidir_wei_layer(
246             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
247             0.3f);
248     std::vector<float> user_enc_bidir_wei_iter(
249             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
250             0.2f);
251     std::vector<float> user_enc_bidir_bias(
252             enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
253
254     /* Create the memory for user data */
255     auto user_enc_bidir_src_layer_md = memory::desc({ enc_bidir_src_layer_tz },
256             memory::data_type::f32, memory::format::tnc);
257
258     auto user_enc_bidir_wei_layer_md
259             = memory::desc({ enc_bidir_weights_layer_tz },
260                     memory::data_type::f32, memory::format::ldigo);
261
262     auto user_enc_bidir_wei_iter_md
263             = memory::desc({ enc_bidir_weights_iter_tz },
264                     memory::data_type::f32, memory::format::ldigo);
265
266     auto user_enc_bidir_bias_md = memory::desc({ enc_bidir_bias_tz },
267             memory::data_type::f32, memory::format::ldgo);
268
269     auto user_enc_bidir_src_layer_memory = memory(
270             { user_enc_bidir_src_layer_md, cpu_engine }, net_src.data());
271     auto user_enc_bidir_wei_layer_memory
272             = memory({ user_enc_bidir_wei_layer_md, cpu_engine },
273                     user_enc_bidir_wei_layer.data());
274     auto user_enc_bidir_wei_iter_memory
275             = memory({ user_enc_bidir_wei_iter_md, cpu_engine },
276                     user_enc_bidir_wei_iter.data());
277     auto user_enc_bidir_bias_memory = memory(
278             { user_enc_bidir_bias_md, cpu_engine }, user_enc_bidir_bias.data());
279
280     /* Create memory descriptors for RNN data w/o specified layout */
281     auto enc_bidir_src_layer_md = memory::desc({ enc_bidir_src_layer_tz },
282             memory::data_type::u8, memory::format::any);
283
284     auto enc_bidir_wei_layer_md = memory::desc({ enc_bidir_weights_layer_tz },
285             memory::data_type::s8, memory::format::any);
286
287     auto enc_bidir_wei_iter_md = memory::desc({ enc_bidir_weights_iter_tz },
288             memory::data_type::s8, memory::format::any);
289
290     auto enc_bidir_dst_layer_md = memory::desc({ enc_bidir_dst_layer_tz },
291             memory::data_type::u8, memory::format::any);
292
293     /* Create bidirectional RNN */
294     rnn_cell::desc bi_cell(algorithm::vanilla_lstm);
295
296     /* Check if int8 RNN is supported */
297     try {
298         rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
299                 rnn_direction::bidirectional_concat, enc_bidir_src_layer_md,
300                 zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
301                 user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
302     } catch (error &e) {
303         if (e.status == mkldnn_unimplemented) {
304             std::cerr
305                     << "Dependency on Intel(R) MKL version 2019u2 or newer is "
306                        "required for int8 RNN"
307                     << std::endl;
308         }
309         throw;
310     }
311
312     rnn_forward::desc bi_layer_desc(prop_kind::forward_inference, bi_cell,
313             rnn_direction::bidirectional_concat, enc_bidir_src_layer_md,
314             zero_md(), enc_bidir_wei_layer_md, enc_bidir_wei_iter_md,
315             user_enc_bidir_bias_md, enc_bidir_dst_layer_md, zero_md());
316
317     /* Define RNN attributes that store quantization parameters */
318     primitive_attr attr;
319     attr.set_int_output_round_mode(round_mode::round_nearest);
320     attr.set_rnn_data_qparams(data_scale, data_shift);
321     attr.set_rnn_weights_qparams(weights_scale_mask, weights_scales);
322
323     auto enc_bidir_prim_desc
324             = rnn_forward::primitive_desc(bi_layer_desc, attr, cpu_engine);
325
326     /* Create memory primitives for input data and use reorders to quantize
327      * values to int8
328      * NOTE: same attributes are used when creating RNN primitive and reorders
329      */
330     auto enc_bidir_src_layer_memory
331             = memory(enc_bidir_prim_desc.src_layer_primitive_desc());
332     auto enc_bidir_src_layer_reorder_pd = reorder::primitive_desc(
333             user_enc_bidir_src_layer_memory.get_primitive_desc(),
334             enc_bidir_src_layer_memory.get_primitive_desc(), attr);
335     encoder_net.push_back(reorder(enc_bidir_src_layer_reorder_pd,
336             user_enc_bidir_src_layer_memory, enc_bidir_src_layer_memory));
337
338     auto enc_bidir_wei_layer_memory
339             = memory(enc_bidir_prim_desc.weights_layer_primitive_desc());
340     auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
341             user_enc_bidir_wei_layer_memory.get_primitive_desc(),
342             enc_bidir_wei_layer_memory.get_primitive_desc(), attr);
343     weights_reorders.push_back(reorder(enc_bidir_wei_layer_reorder_pd,
344             user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory));
345
346     auto enc_bidir_wei_iter_memory
347             = memory(enc_bidir_prim_desc.weights_iter_primitive_desc());
348     auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
349             user_enc_bidir_wei_iter_memory.get_primitive_desc(),
350             enc_bidir_wei_iter_memory.get_primitive_desc(), attr);
351     weights_reorders.push_back(reorder(enc_bidir_wei_iter_reorder_pd,
352             user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory));
353
354     auto enc_bidir_dst_layer_memory
355             = memory(enc_bidir_prim_desc.dst_layer_primitive_desc());
356
357     encoder_net.push_back(
358             rnn_forward(enc_bidir_prim_desc, enc_bidir_src_layer_memory,
359                     null_memory_, enc_bidir_wei_layer_memory,
360                     enc_bidir_wei_iter_memory, user_enc_bidir_bias_memory,
361                     enc_bidir_dst_layer_memory, null_memory_, null_memory_));
362
363     /* GNMT encoder: unidirectional layers */
364     // First unidirectinal layer scales 2 * feature_size output of bidirectional
365     // layer to feature_size output
366     std::vector<float> user_enc_uni_first_wei_layer(
367             1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 0.3f);
368     std::vector<float> user_enc_uni_first_wei_iter(
369             1 * 1 * feature_size * lstm_n_gates * feature_size, 0.2f);
370     std::vector<float> user_enc_uni_first_bias(
371             1 * 1 * lstm_n_gates * feature_size, 1.0f);
372
373     memory::dims user_enc_uni_first_wei_layer_dims
374             = { 1, 1, 2 * feature_size, lstm_n_gates, feature_size };
375     memory::dims user_enc_uni_first_wei_iter_dims
376             = { 1, 1, feature_size, lstm_n_gates, feature_size };
377     memory::dims user_enc_uni_first_bias_dims
378             = { 1, 1, lstm_n_gates, feature_size };
379     memory::dims enc_uni_first_dst_layer_dims
380             = { src_seq_length_max, batch, feature_size };
381
382     auto user_enc_uni_first_wei_layer_md
383             = memory::desc({ user_enc_uni_first_wei_layer_dims },
384                     memory::data_type::f32, memory::format::ldigo);
385     auto user_enc_uni_first_wei_iter_md
386             = memory::desc({ user_enc_uni_first_wei_iter_dims },
387                     memory::data_type::f32, memory::format::ldigo);
388     auto user_enc_uni_first_bias_md
389             = memory::desc({ user_enc_uni_first_bias_dims },
390                     memory::data_type::f32, memory::format::ldgo);
391     auto user_enc_uni_first_wei_layer_memory
392             = memory({ user_enc_uni_first_wei_layer_md, cpu_engine },
393                     user_enc_uni_first_wei_layer.data());
394     auto user_enc_uni_first_wei_iter_memory
395             = memory({ user_enc_uni_first_wei_iter_md, cpu_engine },
396                     user_enc_uni_first_wei_iter.data());
397     auto user_enc_uni_first_bias_memory
398             = memory({ user_enc_uni_first_bias_md, cpu_engine },
399                     user_enc_uni_first_bias.data());
400
401     auto enc_uni_first_wei_layer_md
402             = memory::desc({ user_enc_uni_first_wei_layer_dims },
403                     memory::data_type::s8, memory::format::any);
404     auto enc_uni_first_wei_iter_md
405             = memory::desc({ user_enc_uni_first_wei_iter_dims },
406                     memory::data_type::s8, memory::format::any);
407     auto enc_uni_first_dst_layer_md
408             = memory::desc({ enc_uni_first_dst_layer_dims },
409                     memory::data_type::u8, memory::format::any);
410
411     rnn_cell::desc enc_uni_first_cell(algorithm::vanilla_lstm);
412     rnn_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
413             enc_uni_first_cell, rnn_direction::unidirectional_left2right,
414             enc_bidir_dst_layer_md, zero_md(), enc_uni_first_wei_layer_md,
415             enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
416             enc_uni_first_dst_layer_md, zero_md());
417
418     auto enc_uni_first_prim_desc = rnn_forward::primitive_desc(
419             enc_uni_first_layer_desc, attr, cpu_engine);
420
421     auto enc_uni_first_wei_layer_memory
422             = memory(enc_uni_first_prim_desc.weights_layer_primitive_desc());
423     auto enc_uni_first_wei_layer_reorder_pd = reorder::primitive_desc(
424             user_enc_uni_first_wei_layer_memory.get_primitive_desc(),
425             enc_uni_first_wei_layer_memory.get_primitive_desc(), attr);
426     weights_reorders.push_back(reorder(enc_uni_first_wei_layer_reorder_pd,
427             user_enc_uni_first_wei_layer_memory,
428             enc_uni_first_wei_layer_memory));
429
430     auto enc_uni_first_wei_iter_memory
431             = memory(enc_uni_first_prim_desc.weights_iter_primitive_desc());
432     auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
433             user_enc_uni_first_wei_iter_memory.get_primitive_desc(),
434             enc_uni_first_wei_iter_memory.get_primitive_desc(), attr);
435     weights_reorders.push_back(reorder(enc_uni_first_wei_iter_reorder_pd,
436             user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory));
437
438     auto enc_uni_first_dst_layer_memory
439             = memory(enc_uni_first_prim_desc.dst_layer_primitive_desc());
440
441     encoder_net.push_back(rnn_forward(enc_uni_first_prim_desc,
442             enc_bidir_dst_layer_memory, null_memory_,
443             enc_uni_first_wei_layer_memory, enc_uni_first_wei_iter_memory,
444             user_enc_uni_first_bias_memory, enc_uni_first_dst_layer_memory,
445             null_memory_, null_memory_));
446
447     /* Remainging unidirectional layers */
448     std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
449                     * feature_size * lstm_n_gates * feature_size,
450             0.3f);
451     std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
452                     * feature_size * lstm_n_gates * feature_size,
453             0.2f);
454     std::vector<float> user_enc_uni_bias(
455             (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
456
457     memory::dims user_enc_uni_wei_layer_dims = { (enc_unidir_n_layers - 1), 1,
458         feature_size, lstm_n_gates, feature_size };
459     memory::dims user_enc_uni_wei_iter_dims = { (enc_unidir_n_layers - 1), 1,
460         feature_size, lstm_n_gates, feature_size };
461     memory::dims user_enc_uni_bias_dims
462             = { (enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size };
463     memory::dims enc_dst_layer_dims
464             = { src_seq_length_max, batch, feature_size };
465
466     auto user_enc_uni_wei_layer_md
467             = memory::desc({ user_enc_uni_wei_layer_dims },
468                     memory::data_type::f32, memory::format::ldigo);
469     auto user_enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
470             memory::data_type::f32, memory::format::ldigo);
471     auto user_enc_uni_bias_md = memory::desc({ user_enc_uni_bias_dims },
472             memory::data_type::f32, memory::format::ldgo);
473
474     auto user_enc_uni_wei_layer_memory
475             = memory({ user_enc_uni_wei_layer_md, cpu_engine },
476                     user_enc_uni_wei_layer.data());
477     auto user_enc_uni_wei_iter_memory
478             = memory({ user_enc_uni_wei_iter_md, cpu_engine },
479                     user_enc_uni_wei_iter.data());
480     auto user_enc_uni_bias_memory = memory(
481             { user_enc_uni_bias_md, cpu_engine }, user_enc_uni_bias.data());
482
483     auto enc_uni_wei_layer_md = memory::desc({ user_enc_uni_wei_layer_dims },
484             memory::data_type::s8, memory::format::any);
485     auto enc_uni_wei_iter_md = memory::desc({ user_enc_uni_wei_iter_dims },
486             memory::data_type::s8, memory::format::any);
487     auto enc_dst_layer_md = memory::desc({ enc_dst_layer_dims },
488             memory::data_type::f32, memory::format::any);
489
490     rnn_cell::desc enc_uni_cell(algorithm::vanilla_lstm);
491     rnn_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
492             enc_uni_cell, rnn_direction::unidirectional_left2right,
493             enc_uni_first_dst_layer_md, zero_md(), enc_uni_wei_layer_md,
494             enc_uni_wei_iter_md, user_enc_uni_bias_md, enc_dst_layer_md,
495             zero_md());
496     auto enc_uni_prim_desc
497             = rnn_forward::primitive_desc(enc_uni_layer_desc, attr, cpu_engine);
498
499     auto enc_uni_wei_layer_memory
500             = memory(enc_uni_prim_desc.weights_layer_primitive_desc());
501     auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
502             user_enc_uni_wei_layer_memory.get_primitive_desc(),
503             enc_uni_wei_layer_memory.get_primitive_desc(), attr);
504     weights_reorders.push_back(reorder(enc_uni_wei_layer_reorder_pd,
505             user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory));
506
507     auto enc_uni_wei_iter_memory
508             = memory(enc_uni_prim_desc.weights_iter_primitive_desc());
509     auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
510             user_enc_uni_wei_iter_memory.get_primitive_desc(),
511             enc_uni_wei_iter_memory.get_primitive_desc(), attr);
512     weights_reorders.push_back(reorder(enc_uni_wei_iter_reorder_pd,
513             user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory));
514
515     auto enc_dst_layer_memory
516             = memory(enc_uni_prim_desc.dst_layer_primitive_desc());
517
518     encoder_net.push_back(
519             rnn_forward(enc_uni_prim_desc, enc_uni_first_dst_layer_memory,
520                     null_memory_, enc_uni_wei_layer_memory,
521                     enc_uni_wei_iter_memory, user_enc_uni_bias_memory,
522                     enc_dst_layer_memory, null_memory_, null_memory_));
523
524     /* Decoder with attention mechanism */
525     std::vector<float> user_dec_wei_layer(
526             dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
527             0.2f);
528     std::vector<float> user_dec_wei_iter(dec_n_layers * 1
529                     * (feature_size + feature_size) * lstm_n_gates
530                     * feature_size,
531             0.3f);
532     std::vector<float> user_dec_bias(
533             dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
534     std::vector<int8_t> user_weights_attention_src_layer(
535             feature_size * feature_size, 1);
536     float weights_attention_scale = 127.;
537     std::vector<float> user_weights_annotation(
538             feature_size * feature_size, 1.0f);
539     std::vector<float> user_weights_alignments(feature_size, 1.0f);
540     // Buffer to store decoder output for all iterations
541     std::vector<uint8_t> dec_dst(tgt_seq_length_max * batch * feature_size, 0);
542
543     memory::dims user_dec_wei_layer_dims
544             = { dec_n_layers, 1, feature_size, lstm_n_gates, feature_size };
545     memory::dims user_dec_wei_iter_dims = { dec_n_layers, 1,
546         feature_size + feature_size, lstm_n_gates, feature_size };
547     memory::dims user_dec_bias_dims
548             = { dec_n_layers, 1, lstm_n_gates, feature_size };
549     memory::dims dec_src_layer_dims = { 1, batch, feature_size };
550     memory::dims dec_dst_layer_dims = { 1, batch, feature_size };
551
552     // We will use the same memory for dec_src_iter and dec_dst_iter
553     // However, dec_src_iter has a context vector but not
554     // dec_dst_iter.
555     // To resolve this we will create one memory that holds the
556     // context vector as well as the both the hidden and cell states.
557     // For the dst_iter, we will use a view on this memory.
558     // Note that the cell state will be padded by
559     // feature_size values. However, we do not compute or
560     // access those.
561     memory::dims dec_dst_iter_dims = { dec_n_layers, 1, lstm_n_states, batch,
562         feature_size + feature_size };
563     memory::dims dec_dst_iter_noctx_dims
564             = { dec_n_layers, 1, lstm_n_states, batch, feature_size };
565
566     auto user_dec_wei_layer_md = memory::desc({ user_dec_wei_layer_dims },
567             memory::data_type::f32, memory::format::ldigo);
568     auto user_dec_wei_iter_md = memory::desc({ user_dec_wei_iter_dims },
569             memory::data_type::f32, memory::format::ldigo);
570     auto user_dec_bias_md = memory::desc({ user_dec_bias_dims },
571             memory::data_type::f32, memory::format::ldgo);
572     auto dec_src_layer_md = memory::desc(
573             { dec_src_layer_dims }, memory::data_type::u8, memory::format::tnc);
574     auto dec_dst_layer_md = memory::desc(
575             { dec_dst_layer_dims }, memory::data_type::u8, memory::format::tnc);
576     auto dec_dst_iter_md = memory::desc({ dec_dst_iter_dims },
577             memory::data_type::f32, memory::format::ldsnc);
578
579     auto user_dec_wei_layer_memory = memory(
580             { user_dec_wei_layer_md, cpu_engine }, user_dec_wei_layer.data());
581     auto user_dec_wei_iter_memory = memory(
582             { user_dec_wei_iter_md, cpu_engine }, user_dec_wei_iter.data());
583     auto user_dec_bias_memory
584             = memory({ user_dec_bias_md, cpu_engine }, user_dec_bias.data());
585     auto dec_src_layer_memory = memory({ dec_src_layer_md, cpu_engine });
586     auto dec_dst_layer_memory
587             = memory({ dec_dst_layer_md, cpu_engine }, dec_dst.data());
588
589     /* Create memory descriptors for RNN data w/o specified layout */
590     auto dec_wei_layer_md = memory::desc({ user_dec_wei_layer_dims },
591             memory::data_type::s8, memory::format::any);
592     auto dec_wei_iter_md = memory::desc({ user_dec_wei_iter_dims },
593             memory::data_type::s8, memory::format::any);
594
595     /* As mentioned above, we create a view without context out of the
596      memory with context. */
597     auto dec_dst_iter_memory = memory({ dec_dst_iter_md, cpu_engine });
598     auto dec_dst_iter_noctx_md
599             = view::primitive_desc(dec_dst_iter_memory.get_primitive_desc(),
600                       dec_dst_iter_noctx_dims, { 0, 0, 0, 0, 0 })
601                       .dst_primitive_desc()
602                       .desc();
603
604     rnn_cell::desc dec_cell(algorithm::vanilla_lstm);
605     rnn_forward::desc dec_ctx_desc(prop_kind::forward_inference, dec_cell,
606             rnn_direction::unidirectional_left2right, dec_src_layer_md,
607             dec_dst_iter_md, dec_wei_layer_md, dec_wei_iter_md,
608             user_dec_bias_md, dec_dst_layer_md, dec_dst_iter_noctx_md);
609     auto dec_ctx_prim_desc
610             = rnn_forward::primitive_desc(dec_ctx_desc, attr, cpu_engine);
611
612     /* Create memory primitives for input data and use reorders to quantize
613      * values to int8 */
614     auto dec_wei_layer_memory
615             = memory(dec_ctx_prim_desc.weights_layer_primitive_desc());
616     auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
617             user_dec_wei_layer_memory.get_primitive_desc(),
618             dec_wei_layer_memory.get_primitive_desc(), attr);
619     weights_reorders.push_back(reorder(dec_wei_layer_reorder_pd,
620             user_dec_wei_layer_memory, dec_wei_layer_memory));
621
622     auto dec_wei_iter_memory
623             = memory(dec_ctx_prim_desc.weights_iter_primitive_desc());
624     auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
625             user_dec_wei_iter_memory.get_primitive_desc(),
626             dec_wei_iter_memory.get_primitive_desc(), attr);
627     weights_reorders.push_back(reorder(dec_wei_iter_reorder_pd,
628             user_dec_wei_iter_memory, dec_wei_iter_memory));
629
630     decoder_net.push_back(rnn_forward(dec_ctx_prim_desc, dec_src_layer_memory,
631             dec_dst_iter_memory, dec_wei_layer_memory, dec_wei_iter_memory,
632             user_dec_bias_memory, dec_dst_layer_memory, dec_dst_iter_memory,
633             null_memory_));
634
635     /* Allocating temporary buffers for attention mechanism */
636     std::vector<float> weighted_annotations(
637             src_seq_length_max * batch * feature_size, 1.0f);
638     std::vector<int32_t> weights_attention_sum_rows(feature_size, 1);
639
640     /*
641        Execution
642      */
643     auto execute = [&]() {
644         // reorder weights to MKLDNN internal representation
645         stream(stream::kind::eager).submit(weights_reorders).wait();
646
647         // run encoder (1 stream)
648         stream(stream::kind::eager).submit(encoder_net).wait();
649
650         // compute the weighted annotations once before the decoder
651         compute_weighted_annotations(weighted_annotations.data(),
652                 src_seq_length_max, batch, feature_size,
653                 user_weights_annotation.data(),
654                 (float *)enc_dst_layer_memory.get_data_handle());
655         // precompute compensation for s8u8s32 gemm in compute attention
656         compute_sum_of_rows(user_weights_attention_src_layer.data(),
657                 feature_size, feature_size, weights_attention_sum_rows.data());
658
659         // We initialise src_layer to the embedding of </s>, which
660         // are assumed to be 0 here
661         memset(dec_src_layer_memory.get_data_handle(), 0,
662                 dec_src_layer_memory.get_primitive_desc().get_size());
663         // From now on, src points to the output of the last iteration
664
665         for (int i = 0; i < tgt_seq_length_max; i++) {
666             uint8_t *src_att_layer_handle
667                     = (uint8_t *)dec_src_layer_memory.get_data_handle();
668             float *src_att_iter_handle
669                     = (float *)dec_dst_iter_memory.get_data_handle();
670
671             // Compute attention context vector into the first layer src_iter
672             compute_attention(src_att_iter_handle, src_seq_length_max, batch,
673                     feature_size, user_weights_attention_src_layer.data(),
674                     weights_attention_scale, weights_attention_sum_rows.data(),
675                     src_att_layer_handle, data_scale, data_shift,
676                     (uint8_t *)enc_bidir_dst_layer_memory.get_data_handle(),
677                     weighted_annotations.data(),
678                     user_weights_alignments.data());
679
680             // copy the context vectors to all layers of src_iter
681             copy_context(src_att_iter_handle, dec_n_layers, lstm_n_states,
682                     batch, feature_size);
683
684             // run the decoder iteration
685             stream(stream::kind::eager).submit(decoder_net).wait();
686
687             // Move the handle on the src/dst layer to the next iteration
688             auto dst_layer_handle
689                     = (uint8_t *)dec_dst_layer_memory.get_data_handle();
690             dec_src_layer_memory.set_data_handle(dst_layer_handle);
691             dec_dst_layer_memory.set_data_handle(
692                     dst_layer_handle + batch * feature_size);
693         }
694
695     };
696
697     execute();
698 }
699
700 int main(int argc, char **argv) {
701     try {
702         simple_net();
703         std::cout << "ok\n";
704     } catch (error &e) {
705         std::cerr << "status: " << e.status << std::endl;
706         std::cerr << "message: " << e.message << std::endl;
707     }
708     return 0;
709 }