Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_rnn_training.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstring>
18 #include <iostream>
19 #include <math.h>
20 #include <numeric>
21 #include <string>
22
23 #include "mkldnn.hpp"
24
25 using namespace mkldnn;
26
27 // User input is:
28 //     N0 sequences of length T0
29 const int N0 = 1 + rand() % 31;
30 //     N1 sequences of length T1
31 const int N1 = 1 + rand() % 31;
32 // Assume T0 > T1
33 const int T0 = 31 + 1 + rand() % 31;
34 const int T1 =      1 + rand() % 31;
35
36 // Memory required to hold it: N0 * T0 + N1 * T1
37 // However it is possible to have these coming
38 // as padded chunks in larger memory:
39 //      e.g. (N0 + N1) * T0
40 // We don't need to compact the data before processing,
41 // we can address the chunks via view primitive and
42 // process the data via two RNN primitives:
43 //     of time lengths T1 and T0 - T1.
44 // The leftmost primitive will process N0 + N1 subsequences of length T1
45 // The rightmost primitive will process remaining N0 subsequences
46 // of T0 - T1 length
47 const int leftmost_batch  = N0 + N1;
48 const int rightmost_batch = N0;
49
50 const int leftmost_seq_length  = T1;
51 const int rightmost_seq_length = T0 - T1;
52
53 // Number of channels
54 const int common_feature_size = 1024;
55
56 // RNN primitive characteristics
57 const int common_n_layers = 1;
58 const int lstm_n_gates = 4;
59 const int lstm_n_states = 2;
60
61 void simple_net() {
62     auto cpu_engine = engine(engine::cpu, 0);
63     auto null_memory_ = null_memory(cpu_engine);
64
65     bool is_training = true;
66     auto fwd_inf_train = is_training
67                          ? prop_kind::forward_training
68                          : prop_kind::forward_inference;
69
70     std::vector<primitive> fwd_net;
71     std::vector<primitive> bwd_net;
72
73     // Input tensor holds two batches with different sequence lengths.
74     // Shorter sequences are padded
75     memory::dims net_src_dims = {
76         /* time */ T0,                 // maximum sequence length
77         /* n    */ N0 + N1,            // total batch size
78         /* c    */ common_feature_size // common number of channels
79     };
80
81     /*
82      * Two RNN primitives for different sequence lenghts,
83      * one unidirectional layer, LSTM-based
84      */
85
86     memory::dims leftmost_src_layer_dims = {
87         /* time */ leftmost_seq_length,
88         /* n    */ leftmost_batch,
89         /* c    */ common_feature_size
90     };
91     memory::dims rightmost_src_layer_dims = {
92         /* time */ rightmost_seq_length,
93         /* n    */ rightmost_batch,
94         /* c    */ common_feature_size
95     };
96     memory::dims common_weights_layer_dims = {
97         /* layers              */ common_n_layers,
98         /* directions          */ 1,
99         /* input feature size  */ common_feature_size,
100         /* gates number        */ lstm_n_gates,
101         /* output feature size */ common_feature_size
102     };
103     memory::dims common_weights_iter_dims = {
104         /* layers              */ common_n_layers,
105         /* directions          */ 1,
106         /* input feature size  */ common_feature_size,
107         /* gates number        */ lstm_n_gates,
108         /* output feature size */ common_feature_size
109     };
110     memory::dims common_bias_dims = {
111         /* layers              */ common_n_layers,
112         /* directions          */ 1,
113         /* gates number        */ lstm_n_gates,
114         /* output feature size */ common_feature_size
115     };
116     memory::dims leftmost_dst_layer_dims = {
117         /* time */ leftmost_seq_length,
118         /* n    */ leftmost_batch,
119         /* c    */ common_feature_size
120     };
121     memory::dims rightmost_dst_layer_dims = {
122         /* time */ rightmost_seq_length,
123         /* n    */ rightmost_batch,
124         /* c    */ common_feature_size
125     };
126
127     // leftmost primitive passes its states to the next RNN iteration
128     // so it needs dst_iter parameter.
129     //
130     // rightmost primitive will consume these as src_iter and will access
131     // the memory via a view because it will have different batch dimension.
132     // We have arranged our primitives so that
133     // leftmost_batch >= rightmost_batch, and so the rightmost data will fit
134     // into the memory allocated for the leftmost.
135     memory::dims leftmost_dst_iter_dims = {
136         /* layers     */ common_n_layers,
137         /* directions */ 1,
138         /* states     */ lstm_n_states,
139         /* n          */ leftmost_batch,
140         /* c          */ common_feature_size
141     };
142     memory::dims rightmost_src_iter_dims = {
143         /* layers     */ common_n_layers,
144         /* directions */ 1,
145         /* states     */ lstm_n_states,
146         /* n          */ rightmost_batch,
147         /* c          */ common_feature_size
148     };
149
150     // multiplication of tensor dimensions
151     auto tz_volume = [=](memory::dims tz_dims) {
152         return std::accumulate(
153             tz_dims.begin(), tz_dims.end(),
154             (size_t)1, std::multiplies<size_t>());
155     };
156
157     // Create auxillary f32 memory descriptor
158     // based on user- supplied dimensions and layout.
159     auto formatted_md = [=](memory::dims dimensions, memory::format layout) {
160         return memory::desc({ dimensions }, memory::data_type::f32, layout);
161     };
162     // Create auxillary generic f32 memory descriptor
163     // based on supplied dimensions, with format::any.
164     auto generic_md = [=](memory::dims dimensions) {
165         return formatted_md( dimensions, memory::format::any);
166     };
167
168     //
169     // I/O memory, coming from user
170     //
171
172     // Net input
173     std::vector<float> net_src(
174             tz_volume(net_src_dims),
175             1.0f);
176     // NOTE: in this example we study input sequences with variable batch
177     // dimension, which get processed by two separate RNN primitives, thus
178     // the destination memory for the two will have different shapes: batch
179     // is the second dimension currently: see format::tnc.
180     // We are not copying the output to some common user provided memory as we
181     // suggest that the user should rather keep the two output memories separate
182     // throughout the whole topology and only reorder to something else as
183     // needed.
184     // So there's no common net_dst, but there are two destinations instead:
185     //    leftmost_dst_layer_memory
186     //    rightmost_dst_layer_memory
187
188     // Memory primitive for the user allocated memory
189     // Suppose user data is in tnc format.
190     auto net_src_memory
191         = mkldnn::memory({ formatted_md(net_src_dims, memory::format::tnc),
192                            cpu_engine }, net_src.data());
193     // src_layer memory of the leftmost and rightmost RNN primitives
194     // are accessed through the respective views in larger memory.
195     // View primitives compute the strides to accomodate for padding.
196     auto user_leftmost_src_layer_md
197         = mkldnn::view::primitive_desc(
198             net_src_memory.get_primitive_desc(),
199             leftmost_src_layer_dims,
200             { 0, 0, 0 } /* t, n, c offsets */
201         ).dst_primitive_desc().desc();
202     auto user_rightmost_src_layer_md
203         = mkldnn::view::primitive_desc(
204             net_src_memory.get_primitive_desc(),
205             rightmost_src_layer_dims,
206             { leftmost_seq_length, 0, 0 } /* t, n, c offsets */
207         ).dst_primitive_desc().desc();
208     auto leftmost_src_layer_memory = net_src_memory;
209     auto rightmost_src_layer_memory = net_src_memory;
210
211     // Other user provided memory arrays, desrciptors and primitives with the
212     // data layouts chosen by user. We'll have to reorder if RNN
213     // primitive prefers it in a different format.
214     std::vector<float> user_common_weights_layer(
215             tz_volume(common_weights_layer_dims),
216             1.0f);
217     auto user_common_weights_layer_memory
218         = mkldnn::memory({ formatted_md(common_weights_layer_dims,
219                            memory::format::ldigo), cpu_engine },
220                          user_common_weights_layer.data());
221
222     std::vector<float> user_common_weights_iter(
223             tz_volume(common_weights_iter_dims),
224             1.0f);
225     auto user_common_weights_iter_memory
226         = mkldnn::memory({ formatted_md(common_weights_iter_dims,
227                            memory::format::ldigo), cpu_engine },
228                          user_common_weights_layer.data());
229
230     std::vector<float> user_common_bias(
231             tz_volume(common_bias_dims),
232             1.0f);
233     auto user_common_bias_memory
234         = mkldnn::memory({ formatted_md(common_bias_dims, memory::format::ldgo),
235                            cpu_engine }, user_common_bias.data());
236
237     std::vector<float> user_leftmost_dst_layer(
238             tz_volume(leftmost_dst_layer_dims),
239             1.0f);
240     auto user_leftmost_dst_layer_memory
241         = mkldnn::memory({
242                     formatted_md(leftmost_dst_layer_dims, memory::format::tnc),
243                     cpu_engine }, user_leftmost_dst_layer.data());
244
245     std::vector<float> user_rightmost_dst_layer(
246             tz_volume(rightmost_dst_layer_dims),
247             1.0f);
248     auto user_rightmost_dst_layer_memory
249         = mkldnn::memory({
250                     formatted_md(rightmost_dst_layer_dims, memory::format::tnc),
251                     cpu_engine }, user_rightmost_dst_layer.data());
252
253     // Describe RNN cell
254     rnn_cell::desc uni_cell(algorithm::vanilla_lstm);
255
256     // Describe layer, forward pass, leftmost primitive.
257     // There are no primitives to the left from here,
258     // so src_iter_desc needs to be zero_md()
259     rnn_forward::desc leftmost_layer_desc(
260         /* aprop_kind         */ fwd_inf_train,
261         /* cell               */ uni_cell,
262         /* direction          */ rnn_direction::unidirectional_left2right,
263         /* src_layer_desc     */ user_leftmost_src_layer_md,
264         /* src_iter_desc      */ zero_md(),
265         /* weights_layer_desc */ generic_md(common_weights_layer_dims),
266         /* weights_iter_desc  */ generic_md(common_weights_iter_dims),
267         /* bias_desc          */ generic_md(common_bias_dims),
268         /* dst_layer_desc     */ formatted_md(leftmost_dst_layer_dims,
269                                                 memory::format::tnc),
270         /* dst_iter_desc      */ generic_md(leftmost_dst_iter_dims)
271     );
272     // Describe primitive
273     auto leftmost_prim_desc
274         = mkldnn::rnn_forward::primitive_desc(leftmost_layer_desc, cpu_engine);
275
276     //
277     // Need to connect leftmost and rightmost via "iter" parameters.
278     // We allocate memory here based on the shapes provided by RNN primitive.
279     //
280
281     auto leftmost_dst_iter_memory
282         = mkldnn::memory(leftmost_prim_desc.dst_iter_primitive_desc());
283
284     // rightmost src_iter will view into dst_iter of leftmost
285     auto rightmost_src_iter_md
286         = mkldnn::view::primitive_desc(
287             leftmost_dst_iter_memory.get_primitive_desc(),
288             rightmost_src_iter_dims,
289             { 0, 0, 0, 0, 0 } /* l, d, s, n, c offsets */
290         ).dst_primitive_desc().desc();
291
292     auto rightmost_src_iter_memory = leftmost_dst_iter_memory;
293
294     // Now rightmost primitive
295     // There are no primitives to the right from here,
296     // so dst_iter_desc is explicit zero_md()
297     rnn_forward::desc rightmost_layer_desc(
298         /* aprop_kind         */ fwd_inf_train,
299         /* cell               */ uni_cell,
300         /* direction          */ rnn_direction::unidirectional_left2right,
301         /* src_layer_desc     */ user_rightmost_src_layer_md,
302         /* src_iter_desc      */ rightmost_src_iter_md,
303         /* weights_layer_desc */ generic_md(common_weights_layer_dims),
304         /* weights_iter_desc  */ generic_md(common_weights_iter_dims),
305         /* bias_desc          */ generic_md(common_bias_dims),
306         /* dst_layer_desc     */ formatted_md(rightmost_dst_layer_dims,
307                                                 memory::format::tnc),
308         /* dst_iter_desc      */ zero_md()
309     );
310     auto rightmost_prim_desc
311         = mkldnn::rnn_forward::primitive_desc(rightmost_layer_desc, cpu_engine);
312
313     //
314     // Weights and biases, layer memory
315     // Same layout should work across the layer, no reorders
316     // needed between leftmost and rigthmost, only reordering
317     // user memory to the RNN-friendly shapes.
318     //
319
320     auto common_weights_layer_memory = user_common_weights_layer_memory;
321     primitive common_weights_layer_reorder;
322     auto reorder_common_weights_layer = false;
323     if (memory::primitive_desc(
324             leftmost_prim_desc.weights_layer_primitive_desc())
325         != memory::primitive_desc(
326             common_weights_layer_memory.get_primitive_desc())
327     ) {
328         common_weights_layer_memory
329             = mkldnn::memory(leftmost_prim_desc.weights_layer_primitive_desc());
330         common_weights_layer_reorder
331             = reorder(user_common_weights_layer_memory,
332                         common_weights_layer_memory);
333         reorder_common_weights_layer = true;
334     }
335
336     auto common_weights_iter_memory = user_common_weights_iter_memory;
337     primitive common_weights_iter_reorder;
338     auto reorder_common_weights_iter = false;
339     if (memory::primitive_desc(
340             leftmost_prim_desc.weights_iter_primitive_desc())
341         != memory::primitive_desc(
342             common_weights_iter_memory.get_primitive_desc())
343     ) {
344         common_weights_iter_memory
345             = mkldnn::memory(leftmost_prim_desc.weights_iter_primitive_desc());
346         common_weights_iter_reorder
347             = reorder(user_common_weights_iter_memory,
348                         common_weights_iter_memory);
349         reorder_common_weights_iter = true;
350     }
351
352
353     auto common_bias_memory = user_common_bias_memory;
354     primitive common_bias_reorder;
355     auto reorder_common_bias = false;
356     if (memory::primitive_desc(
357             leftmost_prim_desc.bias_primitive_desc())
358         != memory::primitive_desc(
359             common_bias_memory.get_primitive_desc())
360     ) {
361         common_bias_memory
362             = mkldnn::memory(leftmost_prim_desc.bias_primitive_desc());
363         common_bias_reorder
364             = reorder(user_common_bias_memory,
365                         common_bias_memory);
366         reorder_common_bias = true;
367     }
368
369     //
370     // Destination layer memory
371     //
372
373     auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
374     primitive leftmost_dst_layer_reorder;
375     auto reorder_leftmost_dst_layer = false;
376     if (memory::primitive_desc(
377             leftmost_prim_desc.dst_layer_primitive_desc())
378         != memory::primitive_desc(
379             leftmost_dst_layer_memory.get_primitive_desc())
380     ) {
381         leftmost_dst_layer_memory
382             = mkldnn::memory(leftmost_prim_desc.dst_layer_primitive_desc());
383         leftmost_dst_layer_reorder
384             = reorder(user_leftmost_dst_layer_memory,
385                         leftmost_dst_layer_memory);
386         reorder_leftmost_dst_layer = true;
387     }
388
389     auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
390     primitive rightmost_dst_layer_reorder;
391     auto reorder_rightmost_dst_layer = false;
392     if (memory::primitive_desc(
393             rightmost_prim_desc.dst_layer_primitive_desc())
394         != memory::primitive_desc(
395             rightmost_dst_layer_memory.get_primitive_desc())
396     ) {
397         rightmost_dst_layer_memory
398             = mkldnn::memory(rightmost_prim_desc.dst_layer_primitive_desc());
399         rightmost_dst_layer_reorder
400             = reorder(user_rightmost_dst_layer_memory,
401                         rightmost_dst_layer_memory);
402         reorder_rightmost_dst_layer = true;
403     }
404
405     // We also create workspace memory based on the information from
406     // the workspace_primitive_desc(). This is needed for internal
407     // communication between forward and backward primitives during
408     // training.
409     // Inference mode doesn't need it, so initialize with null_memory_
410     auto create_ws = [=](mkldnn::rnn_forward::primitive_desc &pd) {
411         auto workspace_memory = null_memory_;
412         if (is_training)
413         {
414             workspace_memory = mkldnn::memory(pd.workspace_primitive_desc());
415         }
416         return workspace_memory;
417     };
418     auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
419     auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);
420
421     // Construct the RNN primitive objects
422     rnn_forward leftmost_layer = rnn_forward(
423         /* aprimitive_desc */ leftmost_prim_desc,
424         /* src_layer       */ leftmost_src_layer_memory,
425         /* src_iter        */ null_memory_,
426         /* weights_layer   */ common_weights_layer_memory,
427         /* weights_iter    */ common_weights_iter_memory,
428         /* bias            */ common_bias_memory,
429         /* dst_layer       */ leftmost_dst_layer_memory,
430         /* dst_iter        */ leftmost_dst_iter_memory,
431         /* workspace       */ leftmost_workspace_memory
432     );
433
434     rnn_forward rightmost_layer = rnn_forward(
435         /* aprimitive_desc */ rightmost_prim_desc,
436         /* src_layer       */ rightmost_src_layer_memory,
437         /* src_iter        */ rightmost_src_iter_memory,
438         /* weights_layer   */ common_weights_layer_memory,
439         /* weights_iter    */ common_weights_iter_memory,
440         /* bias            */ common_bias_memory,
441         /* dst_layer       */ rightmost_dst_layer_memory,
442         /* dst_iter        */ null_memory_,
443         /* workspace       */ rightmost_workspace_memory
444     );
445
446     // Enqueue primitives for forward execution
447     if (reorder_common_weights_layer)
448         fwd_net.push_back(common_weights_layer_reorder);
449     if (reorder_common_weights_iter)
450         fwd_net.push_back(common_weights_iter_reorder);
451     if (reorder_common_bias)
452         fwd_net.push_back(common_bias_reorder);
453     if (reorder_leftmost_dst_layer)
454         fwd_net.push_back(leftmost_dst_layer_reorder);
455
456     fwd_net.push_back(leftmost_layer);
457
458     if (reorder_rightmost_dst_layer)
459         fwd_net.push_back(rightmost_dst_layer_reorder);
460     fwd_net.push_back(rightmost_layer);
461
462     // Submit forward for execution
463     stream(stream::kind::eager).submit(fwd_net).wait();
464
465     // No backward pass for inference
466     if (!is_training) return;
467
468     //
469     // Backward primitives will reuse memory from forward
470     // and allocate/describe specifics here. Only relevant for training.
471     //
472
473     // User-provided memory for backward by data output
474     std::vector<float> net_diff_src(
475             tz_volume(net_src_dims),
476             1.0f);
477     auto net_diff_src_memory
478         = mkldnn::memory({ formatted_md(net_src_dims, memory::format::tnc),
479                            cpu_engine }, net_diff_src.data());
480
481     // diff_src follows the same layout we have for net_src
482     auto user_leftmost_diff_src_layer_md
483         = mkldnn::view::primitive_desc(
484             net_diff_src_memory.get_primitive_desc(),
485             leftmost_src_layer_dims,
486             { 0, 0, 0 } /* t, n, c offsets */
487         ).dst_primitive_desc().desc();
488     auto user_rightmost_diff_src_layer_md
489         = mkldnn::view::primitive_desc(
490             net_diff_src_memory.get_primitive_desc(),
491             rightmost_src_layer_dims,
492             { leftmost_seq_length, 0, 0 } /* t, n, c offsets */
493         ).dst_primitive_desc().desc();
494
495     auto leftmost_diff_src_layer_memory = net_diff_src_memory;
496     auto rightmost_diff_src_layer_memory = net_diff_src_memory;
497
498     // User-provided memory for backpropagation by weights
499     std::vector<float> user_common_diff_weights_layer(
500             tz_volume(common_weights_layer_dims),
501             1.0f);
502     auto user_common_diff_weights_layer_memory
503         = mkldnn::memory({ formatted_md(common_weights_layer_dims,
504                            memory::format::ldigo), cpu_engine },
505                          user_common_diff_weights_layer.data());
506
507     std::vector<float> user_common_diff_bias(
508             tz_volume(common_bias_dims),
509             1.0f);
510     auto user_common_diff_bias_memory
511         = mkldnn::memory({ formatted_md(common_bias_dims,
512                            memory::format::ldgo), cpu_engine },
513                          user_common_diff_bias.data());
514
515     // User-provided input to the backward primitive.
516     // To be updated by the user after forward pass using some cost function.
517     memory::dims net_diff_dst_dims = {
518         /* time */ T0,
519         /* n    */ N0 + N1,
520         /* c    */ common_feature_size
521     };
522     // Suppose user data is in tnc format.
523     std::vector<float> net_diff_dst(
524         tz_volume(net_diff_dst_dims),
525         1.0f);
526     auto net_diff_dst_memory
527         = mkldnn::memory({ formatted_md(net_diff_dst_dims, memory::format::tnc),
528                            cpu_engine }, net_diff_dst.data());
529     // diff_dst_layer memory of the leftmost and rightmost RNN primitives
530     // are accessed through the respective views in larger memory.
531     // View primitives compute the strides to accomodate for padding.
532     auto user_leftmost_diff_dst_layer_md
533         = mkldnn::view::primitive_desc(
534             net_diff_dst_memory.get_primitive_desc(),
535             leftmost_dst_layer_dims,
536             { 0, 0, 0 } /* t, n, c offsets */
537         ).dst_primitive_desc().desc();
538     auto user_rightmost_diff_dst_layer_md
539         = mkldnn::view::primitive_desc(
540             net_diff_dst_memory.get_primitive_desc(),
541             rightmost_dst_layer_dims,
542             { leftmost_seq_length, 0, 0 } /* t, n, c offsets */
543         ).dst_primitive_desc().desc();
544     auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
545     auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;
546
547     // Backward leftmost primitive descriptor
548     rnn_backward::desc leftmost_layer_bwd_desc(
549         /* aprop_kind              */ prop_kind::backward,
550         /* cell                    */ uni_cell,
551         /* direction               */ rnn_direction::unidirectional_left2right,
552         /* src_layer_desc          */ user_leftmost_src_layer_md,
553         /* src_iter_desc           */ zero_md(),
554         /* weights_layer_desc      */ generic_md(common_weights_layer_dims),
555         /* weights_iter_desc       */ generic_md(common_weights_iter_dims),
556         /* bias_desc               */ generic_md(common_bias_dims),
557         /* dst_layer_desc          */ formatted_md(leftmost_dst_layer_dims,
558                                                     memory::format::tnc),
559         /* dst_iter_desc           */ generic_md(leftmost_dst_iter_dims),
560         /* diff_src_layer_desc     */ user_leftmost_diff_src_layer_md,
561         /* diff_src_iter_desc      */ zero_md(),
562         /* diff_weights_layer_desc */ generic_md(common_weights_layer_dims),
563         /* diff_weights_iter_desc  */ generic_md(common_weights_iter_dims),
564         /* diff_bias_desc          */ generic_md(common_bias_dims),
565         /* diff_dst_layer_desc     */ user_leftmost_diff_dst_layer_md,
566         /* diff_dst_iter_desc      */ generic_md(leftmost_dst_iter_dims)
567     );
568     auto leftmost_bwd_prim_desc
569         = mkldnn::rnn_backward::primitive_desc(
570             leftmost_layer_bwd_desc, cpu_engine, leftmost_prim_desc);
571
572     // As the batch dimensions are different between leftmost and rightmost
573     // we need to do the views. rightmost needs less memory, so it will view
574     // the memory of leftmost.
575     auto leftmost_diff_dst_iter_memory
576         = mkldnn::memory(leftmost_bwd_prim_desc.diff_dst_iter_primitive_desc());
577
578     auto rightmost_diff_src_iter_md
579         = mkldnn::view::primitive_desc(
580             leftmost_diff_dst_iter_memory.get_primitive_desc(),
581             rightmost_src_iter_dims,
582             { 0, 0, 0, 0, 0 } /* l, d, s, n, c offsets */
583         ).dst_primitive_desc().desc();
584
585     auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;
586
587     // Backward rightmost primitive descriptor
588     rnn_backward::desc rightmost_layer_bwd_desc(
589         /* aprop_kind              */ prop_kind::backward,
590         /* cell                    */ uni_cell,
591         /* direction               */ rnn_direction::unidirectional_left2right,
592         /* src_layer_desc          */ user_rightmost_src_layer_md,
593         /* src_iter_desc           */ generic_md(rightmost_src_iter_dims),
594         /* weights_layer_desc      */ generic_md(common_weights_layer_dims),
595         /* weights_iter_desc       */ generic_md(common_weights_iter_dims),
596         /* bias_desc               */ generic_md(common_bias_dims),
597         /* dst_layer_desc          */ formatted_md(rightmost_dst_layer_dims,
598                                                     memory::format::tnc),
599         /* dst_iter_desc           */ zero_md(),
600         /* diff_src_layer_desc     */ user_rightmost_diff_src_layer_md,
601         /* diff_src_iter_desc      */ rightmost_diff_src_iter_md,
602         /* diff_weights_layer_desc */ generic_md(common_weights_layer_dims),
603         /* diff_weights_iter_desc  */ generic_md(common_weights_iter_dims),
604         /* diff_bias_desc          */ generic_md(common_bias_dims),
605         /* diff_dst_layer_desc     */ user_rightmost_diff_dst_layer_md,
606         /* diff_dst_iter_desc      */ zero_md()
607     );
608     auto rightmost_bwd_prim_desc
609         = mkldnn::rnn_backward::primitive_desc(
610             rightmost_layer_bwd_desc, cpu_engine, rightmost_prim_desc);
611
612     //
613     // Memory primitives for backward pass
614     //
615
616     // src layer uses the same memory as forward
617     auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
618     auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;
619
620     // Memory for weights and biases for backward pass
621     // Try to use the same memory between forward and backward, but
622     // sometimes reorders are needed.
623     auto common_weights_layer_bwd_memory = common_weights_layer_memory;
624     primitive common_weights_layer_bwd_reorder;
625     auto reorder_common_weights_layer_bwd = false;
626     if (memory::primitive_desc(
627             leftmost_bwd_prim_desc.weights_layer_primitive_desc())
628         != memory::primitive_desc(
629             leftmost_prim_desc.weights_layer_primitive_desc())
630     ) {
631         common_weights_layer_bwd_memory
632             = memory(leftmost_bwd_prim_desc.weights_layer_primitive_desc());
633         common_weights_layer_bwd_reorder
634             = reorder(common_weights_layer_memory,
635                         common_weights_layer_bwd_memory);
636         reorder_common_weights_layer_bwd = true;
637     }
638
639     auto common_weights_iter_bwd_memory = common_weights_iter_memory;
640     primitive common_weights_iter_bwd_reorder;
641     auto reorder_common_weights_iter_bwd = false;
642     if (memory::primitive_desc(
643             leftmost_bwd_prim_desc.weights_iter_primitive_desc())
644         != memory::primitive_desc(
645             leftmost_prim_desc.weights_iter_primitive_desc())
646     ) {
647         common_weights_iter_bwd_memory
648             = memory(leftmost_bwd_prim_desc.weights_iter_primitive_desc());
649         common_weights_iter_bwd_reorder
650             = reorder(common_weights_iter_memory,
651                         common_weights_iter_bwd_memory);
652         reorder_common_weights_iter_bwd = true;
653     }
654
655     auto common_bias_bwd_memory = common_bias_memory;
656     primitive common_bias_bwd_reorder;
657     auto reorder_common_bias_bwd = false;
658     if (memory::primitive_desc(
659             leftmost_bwd_prim_desc.bias_primitive_desc())
660         != memory::primitive_desc(
661             common_bias_memory.get_primitive_desc())
662     ) {
663         common_bias_bwd_memory
664             = mkldnn::memory(leftmost_bwd_prim_desc.bias_primitive_desc());
665         common_bias_bwd_reorder
666             = reorder(common_bias_memory,
667                         common_bias_bwd_memory);
668         reorder_common_bias_bwd = true;
669     }
670
671     // diff_weights and biases
672     auto common_diff_weights_layer_memory
673         = user_common_diff_weights_layer_memory;
674     primitive common_diff_weights_layer_reorder;
675     auto reorder_common_diff_weights_layer = false;
676     if (memory::primitive_desc(
677             leftmost_bwd_prim_desc.diff_weights_layer_primitive_desc())
678         != memory::primitive_desc(
679             common_diff_weights_layer_memory.get_primitive_desc())
680     ) {
681         common_diff_weights_layer_memory
682             = mkldnn::memory(
683                 leftmost_bwd_prim_desc.diff_weights_layer_primitive_desc());
684         common_diff_weights_layer_reorder
685             = reorder(user_common_diff_weights_layer_memory,
686                         common_diff_weights_layer_memory);
687         reorder_common_diff_weights_layer = true;
688     }
689
690     auto common_diff_bias_memory = user_common_diff_bias_memory;
691     primitive common_diff_bias_reorder;
692     auto reorder_common_diff_bias = false;
693     if (memory::primitive_desc(
694             leftmost_bwd_prim_desc.diff_bias_primitive_desc())
695         != memory::primitive_desc(
696             common_diff_bias_memory.get_primitive_desc())
697     ) {
698         common_diff_bias_memory
699             = mkldnn::memory(leftmost_bwd_prim_desc.diff_bias_primitive_desc());
700         common_diff_bias_reorder
701             = reorder(user_common_diff_bias_memory,
702                         common_diff_bias_memory);
703         reorder_common_diff_bias = true;
704     }
705
706     // dst_layer memory for backward pass
707     auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
708     primitive leftmost_dst_layer_bwd_reorder;
709     auto reorder_leftmost_dst_layer_bwd = false;
710     if (memory::primitive_desc(
711             leftmost_bwd_prim_desc.dst_layer_primitive_desc())
712         != memory::primitive_desc(
713             leftmost_dst_layer_bwd_memory.get_primitive_desc())
714     ) {
715         leftmost_dst_layer_bwd_memory
716             = mkldnn::memory(leftmost_bwd_prim_desc.dst_layer_primitive_desc());
717         leftmost_dst_layer_bwd_reorder
718             = reorder(leftmost_dst_layer_memory,
719                         leftmost_dst_layer_bwd_memory);
720         reorder_leftmost_dst_layer_bwd = true;
721     }
722
723     auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
724     primitive rightmost_dst_layer_bwd_reorder;
725     auto reorder_rightmost_dst_layer_bwd = false;
726     if (memory::primitive_desc(
727             rightmost_bwd_prim_desc.dst_layer_primitive_desc())
728         != memory::primitive_desc(
729             rightmost_dst_layer_bwd_memory.get_primitive_desc())
730     ) {
731         rightmost_dst_layer_bwd_memory
732             = mkldnn::memory(
733                 rightmost_bwd_prim_desc.dst_layer_primitive_desc());
734         rightmost_dst_layer_bwd_reorder
735             = reorder(rightmost_dst_layer_memory,
736                         rightmost_dst_layer_bwd_memory);
737         reorder_rightmost_dst_layer_bwd = true;
738     }
739
740     // Similar to forward, the backward primitives are connected
741     // via "iter" parameters.
742     auto common_diff_weights_iter_memory
743         = mkldnn::memory(
744             leftmost_bwd_prim_desc.diff_weights_iter_primitive_desc());
745
746     auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
747     primitive leftmost_dst_iter_bwd_reorder;
748     auto reorder_leftmost_dst_iter_bwd = false;
749     if (memory::primitive_desc(
750             leftmost_bwd_prim_desc.dst_iter_primitive_desc())
751         != memory::primitive_desc(
752             leftmost_dst_iter_bwd_memory.get_primitive_desc())
753     ) {
754         leftmost_dst_iter_bwd_memory
755             = mkldnn::memory(leftmost_bwd_prim_desc.dst_iter_primitive_desc());
756         leftmost_dst_iter_bwd_reorder
757             = reorder(leftmost_dst_iter_memory,
758                         leftmost_dst_iter_bwd_memory);
759         reorder_leftmost_dst_iter_bwd = true;
760     }
761
762     // Construct the RNN primitive objects for backward
763     rnn_backward leftmost_layer_bwd = rnn_backward(
764         /* aprimitive_desc    */ leftmost_bwd_prim_desc,
765         /* src_layer          */ leftmost_src_layer_bwd_memory,
766         /* src_iter           */ null_memory_,
767         /* weights_layer      */ common_weights_layer_bwd_memory,
768         /* weights_iter       */ common_weights_iter_bwd_memory,
769         /* bias               */ common_bias_bwd_memory,
770         /* dst_layer          */ leftmost_dst_layer_bwd_memory,
771         /* dst_iter           */ leftmost_dst_iter_bwd_memory,
772         /* diff_src_layer     */ leftmost_diff_src_layer_memory,
773         /* diff_src_iter      */ null_memory_,
774         /* diff_weights_layer */ common_diff_weights_layer_memory,
775         /* diff_weights_iter  */ common_diff_weights_iter_memory,
776         /* diff_bias          */ common_diff_bias_memory,
777         /* diff_dst_layer     */ leftmost_diff_dst_layer_memory,
778         /* diff_dst_iter      */ leftmost_diff_dst_iter_memory,
779         /* workspace          */ leftmost_workspace_memory
780     );
781
782     rnn_backward rightmost_layer_bwd = rnn_backward(
783         /* aprimitive_desc    */ rightmost_bwd_prim_desc,
784         /* src_layer          */ rightmost_src_layer_bwd_memory,
785         /* src_iter           */ rightmost_src_iter_memory,
786         /* weights_layer      */ common_weights_layer_bwd_memory,
787         /* weights_iter       */ common_weights_iter_bwd_memory,
788         /* bias               */ common_bias_bwd_memory,
789         /* dst_layer          */ rightmost_dst_layer_bwd_memory,
790         /* dst_iter           */ null_memory_,
791         /* diff_src_layer     */ rightmost_diff_src_layer_memory,
792         /* diff_src_iter      */ rightmost_diff_src_iter_memory,
793         /* diff_weights_layer */ common_diff_weights_layer_memory,
794         /* diff_weights_iter  */ common_diff_weights_iter_memory,
795         /* diff_bias          */ common_diff_bias_memory,
796         /* diff_dst_layer     */ rightmost_diff_dst_layer_memory,
797         /* diff_dst_iter      */ null_memory_,
798         /* workspace          */ rightmost_workspace_memory
799     );
800
801     // Enqueue primitives for backward execution
802     if (reorder_common_weights_layer_bwd)
803         bwd_net.push_back(common_weights_layer_bwd_reorder);
804     if (reorder_common_weights_iter_bwd)
805         bwd_net.push_back(common_weights_iter_bwd_reorder);
806     if (reorder_common_bias_bwd)
807         bwd_net.push_back(common_bias_bwd_reorder);
808     if (reorder_common_diff_weights_layer)
809         bwd_net.push_back(common_diff_weights_layer_reorder);
810     if (reorder_common_diff_bias)
811         bwd_net.push_back(common_diff_bias_reorder);
812
813     if (reorder_rightmost_dst_layer_bwd)
814         bwd_net.push_back(rightmost_dst_layer_bwd_reorder);
815
816     bwd_net.push_back(rightmost_layer_bwd);
817
818     if (reorder_leftmost_dst_layer_bwd)
819         bwd_net.push_back(leftmost_dst_layer_bwd_reorder);
820     if (reorder_leftmost_dst_iter_bwd)
821         bwd_net.push_back(leftmost_dst_iter_bwd_reorder);
822     bwd_net.push_back(leftmost_layer_bwd);
823
824     // Submit backward for execution
825     stream(stream::kind::eager).submit(bwd_net).wait();
826     //
827     // User updates weights and bias using diffs
828     //
829 }
830
831 int main(int argc, char **argv) {
832     try {
833         simple_net();
834         std::cout << "ok\n";
835     } catch (error &e) {
836         std::cerr << "status: " << e.status << std::endl;
837         std::cerr << "message: " << e.message << std::endl;
838         return 1;
839     }
840     return 0;
841 }