1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
25 using namespace mkldnn;
28 // N0 sequences of length T0
29 const int N0 = 1 + rand() % 31;
30 // N1 sequences of length T1
31 const int N1 = 1 + rand() % 31;
33 const int T0 = 31 + 1 + rand() % 31;
34 const int T1 = 1 + rand() % 31;
36 // Memory required to hold it: N0 * T0 + N1 * T1
37 // However it is possible to have these coming
38 // as padded chunks in larger memory:
39 // e.g. (N0 + N1) * T0
40 // We don't need to compact the data before processing,
41 // we can address the chunks via view primitive and
42 // process the data via two RNN primitives:
43 // of time lengths T1 and T0 - T1.
44 // The leftmost primitive will process N0 + N1 subsequences of length T1
45 // The rightmost primitive will process remaining N0 subsequences
47 const int leftmost_batch = N0 + N1;
48 const int rightmost_batch = N0;
50 const int leftmost_seq_length = T1;
51 const int rightmost_seq_length = T0 - T1;
54 const int common_feature_size = 1024;
56 // RNN primitive characteristics
57 const int common_n_layers = 1;
58 const int lstm_n_gates = 4;
59 const int lstm_n_states = 2;
62 auto cpu_engine = engine(engine::cpu, 0);
63 auto null_memory_ = null_memory(cpu_engine);
65 bool is_training = true;
66 auto fwd_inf_train = is_training
67 ? prop_kind::forward_training
68 : prop_kind::forward_inference;
70 std::vector<primitive> fwd_net;
71 std::vector<primitive> bwd_net;
73 // Input tensor holds two batches with different sequence lengths.
74 // Shorter sequences are padded
75 memory::dims net_src_dims = {
76 /* time */ T0, // maximum sequence length
77 /* n */ N0 + N1, // total batch size
78 /* c */ common_feature_size // common number of channels
82 * Two RNN primitives for different sequence lenghts,
83 * one unidirectional layer, LSTM-based
86 memory::dims leftmost_src_layer_dims = {
87 /* time */ leftmost_seq_length,
88 /* n */ leftmost_batch,
89 /* c */ common_feature_size
91 memory::dims rightmost_src_layer_dims = {
92 /* time */ rightmost_seq_length,
93 /* n */ rightmost_batch,
94 /* c */ common_feature_size
96 memory::dims common_weights_layer_dims = {
97 /* layers */ common_n_layers,
99 /* input feature size */ common_feature_size,
100 /* gates number */ lstm_n_gates,
101 /* output feature size */ common_feature_size
103 memory::dims common_weights_iter_dims = {
104 /* layers */ common_n_layers,
106 /* input feature size */ common_feature_size,
107 /* gates number */ lstm_n_gates,
108 /* output feature size */ common_feature_size
110 memory::dims common_bias_dims = {
111 /* layers */ common_n_layers,
113 /* gates number */ lstm_n_gates,
114 /* output feature size */ common_feature_size
116 memory::dims leftmost_dst_layer_dims = {
117 /* time */ leftmost_seq_length,
118 /* n */ leftmost_batch,
119 /* c */ common_feature_size
121 memory::dims rightmost_dst_layer_dims = {
122 /* time */ rightmost_seq_length,
123 /* n */ rightmost_batch,
124 /* c */ common_feature_size
127 // leftmost primitive passes its states to the next RNN iteration
128 // so it needs dst_iter parameter.
130 // rightmost primitive will consume these as src_iter and will access
131 // the memory via a view because it will have different batch dimension.
132 // We have arranged our primitives so that
133 // leftmost_batch >= rightmost_batch, and so the rightmost data will fit
134 // into the memory allocated for the leftmost.
135 memory::dims leftmost_dst_iter_dims = {
136 /* layers */ common_n_layers,
138 /* states */ lstm_n_states,
139 /* n */ leftmost_batch,
140 /* c */ common_feature_size
142 memory::dims rightmost_src_iter_dims = {
143 /* layers */ common_n_layers,
145 /* states */ lstm_n_states,
146 /* n */ rightmost_batch,
147 /* c */ common_feature_size
150 // multiplication of tensor dimensions
151 auto tz_volume = [=](memory::dims tz_dims) {
152 return std::accumulate(
153 tz_dims.begin(), tz_dims.end(),
154 (size_t)1, std::multiplies<size_t>());
157 // Create auxillary f32 memory descriptor
158 // based on user- supplied dimensions and layout.
159 auto formatted_md = [=](memory::dims dimensions, memory::format layout) {
160 return memory::desc({ dimensions }, memory::data_type::f32, layout);
162 // Create auxillary generic f32 memory descriptor
163 // based on supplied dimensions, with format::any.
164 auto generic_md = [=](memory::dims dimensions) {
165 return formatted_md( dimensions, memory::format::any);
169 // I/O memory, coming from user
173 std::vector<float> net_src(
174 tz_volume(net_src_dims),
176 // NOTE: in this example we study input sequences with variable batch
177 // dimension, which get processed by two separate RNN primitives, thus
178 // the destination memory for the two will have different shapes: batch
179 // is the second dimension currently: see format::tnc.
180 // We are not copying the output to some common user provided memory as we
181 // suggest that the user should rather keep the two output memories separate
182 // throughout the whole topology and only reorder to something else as
184 // So there's no common net_dst, but there are two destinations instead:
185 // leftmost_dst_layer_memory
186 // rightmost_dst_layer_memory
188 // Memory primitive for the user allocated memory
189 // Suppose user data is in tnc format.
191 = mkldnn::memory({ formatted_md(net_src_dims, memory::format::tnc),
192 cpu_engine }, net_src.data());
193 // src_layer memory of the leftmost and rightmost RNN primitives
194 // are accessed through the respective views in larger memory.
195 // View primitives compute the strides to accomodate for padding.
196 auto user_leftmost_src_layer_md
197 = mkldnn::view::primitive_desc(
198 net_src_memory.get_primitive_desc(),
199 leftmost_src_layer_dims,
200 { 0, 0, 0 } /* t, n, c offsets */
201 ).dst_primitive_desc().desc();
202 auto user_rightmost_src_layer_md
203 = mkldnn::view::primitive_desc(
204 net_src_memory.get_primitive_desc(),
205 rightmost_src_layer_dims,
206 { leftmost_seq_length, 0, 0 } /* t, n, c offsets */
207 ).dst_primitive_desc().desc();
208 auto leftmost_src_layer_memory = net_src_memory;
209 auto rightmost_src_layer_memory = net_src_memory;
211 // Other user provided memory arrays, desrciptors and primitives with the
212 // data layouts chosen by user. We'll have to reorder if RNN
213 // primitive prefers it in a different format.
214 std::vector<float> user_common_weights_layer(
215 tz_volume(common_weights_layer_dims),
217 auto user_common_weights_layer_memory
218 = mkldnn::memory({ formatted_md(common_weights_layer_dims,
219 memory::format::ldigo), cpu_engine },
220 user_common_weights_layer.data());
222 std::vector<float> user_common_weights_iter(
223 tz_volume(common_weights_iter_dims),
225 auto user_common_weights_iter_memory
226 = mkldnn::memory({ formatted_md(common_weights_iter_dims,
227 memory::format::ldigo), cpu_engine },
228 user_common_weights_layer.data());
230 std::vector<float> user_common_bias(
231 tz_volume(common_bias_dims),
233 auto user_common_bias_memory
234 = mkldnn::memory({ formatted_md(common_bias_dims, memory::format::ldgo),
235 cpu_engine }, user_common_bias.data());
237 std::vector<float> user_leftmost_dst_layer(
238 tz_volume(leftmost_dst_layer_dims),
240 auto user_leftmost_dst_layer_memory
242 formatted_md(leftmost_dst_layer_dims, memory::format::tnc),
243 cpu_engine }, user_leftmost_dst_layer.data());
245 std::vector<float> user_rightmost_dst_layer(
246 tz_volume(rightmost_dst_layer_dims),
248 auto user_rightmost_dst_layer_memory
250 formatted_md(rightmost_dst_layer_dims, memory::format::tnc),
251 cpu_engine }, user_rightmost_dst_layer.data());
254 rnn_cell::desc uni_cell(algorithm::vanilla_lstm);
256 // Describe layer, forward pass, leftmost primitive.
257 // There are no primitives to the left from here,
258 // so src_iter_desc needs to be zero_md()
259 rnn_forward::desc leftmost_layer_desc(
260 /* aprop_kind */ fwd_inf_train,
262 /* direction */ rnn_direction::unidirectional_left2right,
263 /* src_layer_desc */ user_leftmost_src_layer_md,
264 /* src_iter_desc */ zero_md(),
265 /* weights_layer_desc */ generic_md(common_weights_layer_dims),
266 /* weights_iter_desc */ generic_md(common_weights_iter_dims),
267 /* bias_desc */ generic_md(common_bias_dims),
268 /* dst_layer_desc */ formatted_md(leftmost_dst_layer_dims,
269 memory::format::tnc),
270 /* dst_iter_desc */ generic_md(leftmost_dst_iter_dims)
272 // Describe primitive
273 auto leftmost_prim_desc
274 = mkldnn::rnn_forward::primitive_desc(leftmost_layer_desc, cpu_engine);
277 // Need to connect leftmost and rightmost via "iter" parameters.
278 // We allocate memory here based on the shapes provided by RNN primitive.
281 auto leftmost_dst_iter_memory
282 = mkldnn::memory(leftmost_prim_desc.dst_iter_primitive_desc());
284 // rightmost src_iter will view into dst_iter of leftmost
285 auto rightmost_src_iter_md
286 = mkldnn::view::primitive_desc(
287 leftmost_dst_iter_memory.get_primitive_desc(),
288 rightmost_src_iter_dims,
289 { 0, 0, 0, 0, 0 } /* l, d, s, n, c offsets */
290 ).dst_primitive_desc().desc();
292 auto rightmost_src_iter_memory = leftmost_dst_iter_memory;
294 // Now rightmost primitive
295 // There are no primitives to the right from here,
296 // so dst_iter_desc is explicit zero_md()
297 rnn_forward::desc rightmost_layer_desc(
298 /* aprop_kind */ fwd_inf_train,
300 /* direction */ rnn_direction::unidirectional_left2right,
301 /* src_layer_desc */ user_rightmost_src_layer_md,
302 /* src_iter_desc */ rightmost_src_iter_md,
303 /* weights_layer_desc */ generic_md(common_weights_layer_dims),
304 /* weights_iter_desc */ generic_md(common_weights_iter_dims),
305 /* bias_desc */ generic_md(common_bias_dims),
306 /* dst_layer_desc */ formatted_md(rightmost_dst_layer_dims,
307 memory::format::tnc),
308 /* dst_iter_desc */ zero_md()
310 auto rightmost_prim_desc
311 = mkldnn::rnn_forward::primitive_desc(rightmost_layer_desc, cpu_engine);
314 // Weights and biases, layer memory
315 // Same layout should work across the layer, no reorders
316 // needed between leftmost and rigthmost, only reordering
317 // user memory to the RNN-friendly shapes.
320 auto common_weights_layer_memory = user_common_weights_layer_memory;
321 primitive common_weights_layer_reorder;
322 auto reorder_common_weights_layer = false;
323 if (memory::primitive_desc(
324 leftmost_prim_desc.weights_layer_primitive_desc())
325 != memory::primitive_desc(
326 common_weights_layer_memory.get_primitive_desc())
328 common_weights_layer_memory
329 = mkldnn::memory(leftmost_prim_desc.weights_layer_primitive_desc());
330 common_weights_layer_reorder
331 = reorder(user_common_weights_layer_memory,
332 common_weights_layer_memory);
333 reorder_common_weights_layer = true;
336 auto common_weights_iter_memory = user_common_weights_iter_memory;
337 primitive common_weights_iter_reorder;
338 auto reorder_common_weights_iter = false;
339 if (memory::primitive_desc(
340 leftmost_prim_desc.weights_iter_primitive_desc())
341 != memory::primitive_desc(
342 common_weights_iter_memory.get_primitive_desc())
344 common_weights_iter_memory
345 = mkldnn::memory(leftmost_prim_desc.weights_iter_primitive_desc());
346 common_weights_iter_reorder
347 = reorder(user_common_weights_iter_memory,
348 common_weights_iter_memory);
349 reorder_common_weights_iter = true;
353 auto common_bias_memory = user_common_bias_memory;
354 primitive common_bias_reorder;
355 auto reorder_common_bias = false;
356 if (memory::primitive_desc(
357 leftmost_prim_desc.bias_primitive_desc())
358 != memory::primitive_desc(
359 common_bias_memory.get_primitive_desc())
362 = mkldnn::memory(leftmost_prim_desc.bias_primitive_desc());
364 = reorder(user_common_bias_memory,
366 reorder_common_bias = true;
370 // Destination layer memory
373 auto leftmost_dst_layer_memory = user_leftmost_dst_layer_memory;
374 primitive leftmost_dst_layer_reorder;
375 auto reorder_leftmost_dst_layer = false;
376 if (memory::primitive_desc(
377 leftmost_prim_desc.dst_layer_primitive_desc())
378 != memory::primitive_desc(
379 leftmost_dst_layer_memory.get_primitive_desc())
381 leftmost_dst_layer_memory
382 = mkldnn::memory(leftmost_prim_desc.dst_layer_primitive_desc());
383 leftmost_dst_layer_reorder
384 = reorder(user_leftmost_dst_layer_memory,
385 leftmost_dst_layer_memory);
386 reorder_leftmost_dst_layer = true;
389 auto rightmost_dst_layer_memory = user_rightmost_dst_layer_memory;
390 primitive rightmost_dst_layer_reorder;
391 auto reorder_rightmost_dst_layer = false;
392 if (memory::primitive_desc(
393 rightmost_prim_desc.dst_layer_primitive_desc())
394 != memory::primitive_desc(
395 rightmost_dst_layer_memory.get_primitive_desc())
397 rightmost_dst_layer_memory
398 = mkldnn::memory(rightmost_prim_desc.dst_layer_primitive_desc());
399 rightmost_dst_layer_reorder
400 = reorder(user_rightmost_dst_layer_memory,
401 rightmost_dst_layer_memory);
402 reorder_rightmost_dst_layer = true;
405 // We also create workspace memory based on the information from
406 // the workspace_primitive_desc(). This is needed for internal
407 // communication between forward and backward primitives during
409 // Inference mode doesn't need it, so initialize with null_memory_
410 auto create_ws = [=](mkldnn::rnn_forward::primitive_desc &pd) {
411 auto workspace_memory = null_memory_;
414 workspace_memory = mkldnn::memory(pd.workspace_primitive_desc());
416 return workspace_memory;
418 auto leftmost_workspace_memory = create_ws(leftmost_prim_desc);
419 auto rightmost_workspace_memory = create_ws(rightmost_prim_desc);
421 // Construct the RNN primitive objects
422 rnn_forward leftmost_layer = rnn_forward(
423 /* aprimitive_desc */ leftmost_prim_desc,
424 /* src_layer */ leftmost_src_layer_memory,
425 /* src_iter */ null_memory_,
426 /* weights_layer */ common_weights_layer_memory,
427 /* weights_iter */ common_weights_iter_memory,
428 /* bias */ common_bias_memory,
429 /* dst_layer */ leftmost_dst_layer_memory,
430 /* dst_iter */ leftmost_dst_iter_memory,
431 /* workspace */ leftmost_workspace_memory
434 rnn_forward rightmost_layer = rnn_forward(
435 /* aprimitive_desc */ rightmost_prim_desc,
436 /* src_layer */ rightmost_src_layer_memory,
437 /* src_iter */ rightmost_src_iter_memory,
438 /* weights_layer */ common_weights_layer_memory,
439 /* weights_iter */ common_weights_iter_memory,
440 /* bias */ common_bias_memory,
441 /* dst_layer */ rightmost_dst_layer_memory,
442 /* dst_iter */ null_memory_,
443 /* workspace */ rightmost_workspace_memory
446 // Enqueue primitives for forward execution
447 if (reorder_common_weights_layer)
448 fwd_net.push_back(common_weights_layer_reorder);
449 if (reorder_common_weights_iter)
450 fwd_net.push_back(common_weights_iter_reorder);
451 if (reorder_common_bias)
452 fwd_net.push_back(common_bias_reorder);
453 if (reorder_leftmost_dst_layer)
454 fwd_net.push_back(leftmost_dst_layer_reorder);
456 fwd_net.push_back(leftmost_layer);
458 if (reorder_rightmost_dst_layer)
459 fwd_net.push_back(rightmost_dst_layer_reorder);
460 fwd_net.push_back(rightmost_layer);
462 // Submit forward for execution
463 stream(stream::kind::eager).submit(fwd_net).wait();
465 // No backward pass for inference
466 if (!is_training) return;
469 // Backward primitives will reuse memory from forward
470 // and allocate/describe specifics here. Only relevant for training.
473 // User-provided memory for backward by data output
474 std::vector<float> net_diff_src(
475 tz_volume(net_src_dims),
477 auto net_diff_src_memory
478 = mkldnn::memory({ formatted_md(net_src_dims, memory::format::tnc),
479 cpu_engine }, net_diff_src.data());
481 // diff_src follows the same layout we have for net_src
482 auto user_leftmost_diff_src_layer_md
483 = mkldnn::view::primitive_desc(
484 net_diff_src_memory.get_primitive_desc(),
485 leftmost_src_layer_dims,
486 { 0, 0, 0 } /* t, n, c offsets */
487 ).dst_primitive_desc().desc();
488 auto user_rightmost_diff_src_layer_md
489 = mkldnn::view::primitive_desc(
490 net_diff_src_memory.get_primitive_desc(),
491 rightmost_src_layer_dims,
492 { leftmost_seq_length, 0, 0 } /* t, n, c offsets */
493 ).dst_primitive_desc().desc();
495 auto leftmost_diff_src_layer_memory = net_diff_src_memory;
496 auto rightmost_diff_src_layer_memory = net_diff_src_memory;
498 // User-provided memory for backpropagation by weights
499 std::vector<float> user_common_diff_weights_layer(
500 tz_volume(common_weights_layer_dims),
502 auto user_common_diff_weights_layer_memory
503 = mkldnn::memory({ formatted_md(common_weights_layer_dims,
504 memory::format::ldigo), cpu_engine },
505 user_common_diff_weights_layer.data());
507 std::vector<float> user_common_diff_bias(
508 tz_volume(common_bias_dims),
510 auto user_common_diff_bias_memory
511 = mkldnn::memory({ formatted_md(common_bias_dims,
512 memory::format::ldgo), cpu_engine },
513 user_common_diff_bias.data());
515 // User-provided input to the backward primitive.
516 // To be updated by the user after forward pass using some cost function.
517 memory::dims net_diff_dst_dims = {
520 /* c */ common_feature_size
522 // Suppose user data is in tnc format.
523 std::vector<float> net_diff_dst(
524 tz_volume(net_diff_dst_dims),
526 auto net_diff_dst_memory
527 = mkldnn::memory({ formatted_md(net_diff_dst_dims, memory::format::tnc),
528 cpu_engine }, net_diff_dst.data());
529 // diff_dst_layer memory of the leftmost and rightmost RNN primitives
530 // are accessed through the respective views in larger memory.
531 // View primitives compute the strides to accomodate for padding.
532 auto user_leftmost_diff_dst_layer_md
533 = mkldnn::view::primitive_desc(
534 net_diff_dst_memory.get_primitive_desc(),
535 leftmost_dst_layer_dims,
536 { 0, 0, 0 } /* t, n, c offsets */
537 ).dst_primitive_desc().desc();
538 auto user_rightmost_diff_dst_layer_md
539 = mkldnn::view::primitive_desc(
540 net_diff_dst_memory.get_primitive_desc(),
541 rightmost_dst_layer_dims,
542 { leftmost_seq_length, 0, 0 } /* t, n, c offsets */
543 ).dst_primitive_desc().desc();
544 auto leftmost_diff_dst_layer_memory = net_diff_dst_memory;
545 auto rightmost_diff_dst_layer_memory = net_diff_dst_memory;
547 // Backward leftmost primitive descriptor
548 rnn_backward::desc leftmost_layer_bwd_desc(
549 /* aprop_kind */ prop_kind::backward,
551 /* direction */ rnn_direction::unidirectional_left2right,
552 /* src_layer_desc */ user_leftmost_src_layer_md,
553 /* src_iter_desc */ zero_md(),
554 /* weights_layer_desc */ generic_md(common_weights_layer_dims),
555 /* weights_iter_desc */ generic_md(common_weights_iter_dims),
556 /* bias_desc */ generic_md(common_bias_dims),
557 /* dst_layer_desc */ formatted_md(leftmost_dst_layer_dims,
558 memory::format::tnc),
559 /* dst_iter_desc */ generic_md(leftmost_dst_iter_dims),
560 /* diff_src_layer_desc */ user_leftmost_diff_src_layer_md,
561 /* diff_src_iter_desc */ zero_md(),
562 /* diff_weights_layer_desc */ generic_md(common_weights_layer_dims),
563 /* diff_weights_iter_desc */ generic_md(common_weights_iter_dims),
564 /* diff_bias_desc */ generic_md(common_bias_dims),
565 /* diff_dst_layer_desc */ user_leftmost_diff_dst_layer_md,
566 /* diff_dst_iter_desc */ generic_md(leftmost_dst_iter_dims)
568 auto leftmost_bwd_prim_desc
569 = mkldnn::rnn_backward::primitive_desc(
570 leftmost_layer_bwd_desc, cpu_engine, leftmost_prim_desc);
572 // As the batch dimensions are different between leftmost and rightmost
573 // we need to do the views. rightmost needs less memory, so it will view
574 // the memory of leftmost.
575 auto leftmost_diff_dst_iter_memory
576 = mkldnn::memory(leftmost_bwd_prim_desc.diff_dst_iter_primitive_desc());
578 auto rightmost_diff_src_iter_md
579 = mkldnn::view::primitive_desc(
580 leftmost_diff_dst_iter_memory.get_primitive_desc(),
581 rightmost_src_iter_dims,
582 { 0, 0, 0, 0, 0 } /* l, d, s, n, c offsets */
583 ).dst_primitive_desc().desc();
585 auto rightmost_diff_src_iter_memory = leftmost_diff_dst_iter_memory;
587 // Backward rightmost primitive descriptor
588 rnn_backward::desc rightmost_layer_bwd_desc(
589 /* aprop_kind */ prop_kind::backward,
591 /* direction */ rnn_direction::unidirectional_left2right,
592 /* src_layer_desc */ user_rightmost_src_layer_md,
593 /* src_iter_desc */ generic_md(rightmost_src_iter_dims),
594 /* weights_layer_desc */ generic_md(common_weights_layer_dims),
595 /* weights_iter_desc */ generic_md(common_weights_iter_dims),
596 /* bias_desc */ generic_md(common_bias_dims),
597 /* dst_layer_desc */ formatted_md(rightmost_dst_layer_dims,
598 memory::format::tnc),
599 /* dst_iter_desc */ zero_md(),
600 /* diff_src_layer_desc */ user_rightmost_diff_src_layer_md,
601 /* diff_src_iter_desc */ rightmost_diff_src_iter_md,
602 /* diff_weights_layer_desc */ generic_md(common_weights_layer_dims),
603 /* diff_weights_iter_desc */ generic_md(common_weights_iter_dims),
604 /* diff_bias_desc */ generic_md(common_bias_dims),
605 /* diff_dst_layer_desc */ user_rightmost_diff_dst_layer_md,
606 /* diff_dst_iter_desc */ zero_md()
608 auto rightmost_bwd_prim_desc
609 = mkldnn::rnn_backward::primitive_desc(
610 rightmost_layer_bwd_desc, cpu_engine, rightmost_prim_desc);
613 // Memory primitives for backward pass
616 // src layer uses the same memory as forward
617 auto leftmost_src_layer_bwd_memory = leftmost_src_layer_memory;
618 auto rightmost_src_layer_bwd_memory = rightmost_src_layer_memory;
620 // Memory for weights and biases for backward pass
621 // Try to use the same memory between forward and backward, but
622 // sometimes reorders are needed.
623 auto common_weights_layer_bwd_memory = common_weights_layer_memory;
624 primitive common_weights_layer_bwd_reorder;
625 auto reorder_common_weights_layer_bwd = false;
626 if (memory::primitive_desc(
627 leftmost_bwd_prim_desc.weights_layer_primitive_desc())
628 != memory::primitive_desc(
629 leftmost_prim_desc.weights_layer_primitive_desc())
631 common_weights_layer_bwd_memory
632 = memory(leftmost_bwd_prim_desc.weights_layer_primitive_desc());
633 common_weights_layer_bwd_reorder
634 = reorder(common_weights_layer_memory,
635 common_weights_layer_bwd_memory);
636 reorder_common_weights_layer_bwd = true;
639 auto common_weights_iter_bwd_memory = common_weights_iter_memory;
640 primitive common_weights_iter_bwd_reorder;
641 auto reorder_common_weights_iter_bwd = false;
642 if (memory::primitive_desc(
643 leftmost_bwd_prim_desc.weights_iter_primitive_desc())
644 != memory::primitive_desc(
645 leftmost_prim_desc.weights_iter_primitive_desc())
647 common_weights_iter_bwd_memory
648 = memory(leftmost_bwd_prim_desc.weights_iter_primitive_desc());
649 common_weights_iter_bwd_reorder
650 = reorder(common_weights_iter_memory,
651 common_weights_iter_bwd_memory);
652 reorder_common_weights_iter_bwd = true;
655 auto common_bias_bwd_memory = common_bias_memory;
656 primitive common_bias_bwd_reorder;
657 auto reorder_common_bias_bwd = false;
658 if (memory::primitive_desc(
659 leftmost_bwd_prim_desc.bias_primitive_desc())
660 != memory::primitive_desc(
661 common_bias_memory.get_primitive_desc())
663 common_bias_bwd_memory
664 = mkldnn::memory(leftmost_bwd_prim_desc.bias_primitive_desc());
665 common_bias_bwd_reorder
666 = reorder(common_bias_memory,
667 common_bias_bwd_memory);
668 reorder_common_bias_bwd = true;
671 // diff_weights and biases
672 auto common_diff_weights_layer_memory
673 = user_common_diff_weights_layer_memory;
674 primitive common_diff_weights_layer_reorder;
675 auto reorder_common_diff_weights_layer = false;
676 if (memory::primitive_desc(
677 leftmost_bwd_prim_desc.diff_weights_layer_primitive_desc())
678 != memory::primitive_desc(
679 common_diff_weights_layer_memory.get_primitive_desc())
681 common_diff_weights_layer_memory
683 leftmost_bwd_prim_desc.diff_weights_layer_primitive_desc());
684 common_diff_weights_layer_reorder
685 = reorder(user_common_diff_weights_layer_memory,
686 common_diff_weights_layer_memory);
687 reorder_common_diff_weights_layer = true;
690 auto common_diff_bias_memory = user_common_diff_bias_memory;
691 primitive common_diff_bias_reorder;
692 auto reorder_common_diff_bias = false;
693 if (memory::primitive_desc(
694 leftmost_bwd_prim_desc.diff_bias_primitive_desc())
695 != memory::primitive_desc(
696 common_diff_bias_memory.get_primitive_desc())
698 common_diff_bias_memory
699 = mkldnn::memory(leftmost_bwd_prim_desc.diff_bias_primitive_desc());
700 common_diff_bias_reorder
701 = reorder(user_common_diff_bias_memory,
702 common_diff_bias_memory);
703 reorder_common_diff_bias = true;
706 // dst_layer memory for backward pass
707 auto leftmost_dst_layer_bwd_memory = leftmost_dst_layer_memory;
708 primitive leftmost_dst_layer_bwd_reorder;
709 auto reorder_leftmost_dst_layer_bwd = false;
710 if (memory::primitive_desc(
711 leftmost_bwd_prim_desc.dst_layer_primitive_desc())
712 != memory::primitive_desc(
713 leftmost_dst_layer_bwd_memory.get_primitive_desc())
715 leftmost_dst_layer_bwd_memory
716 = mkldnn::memory(leftmost_bwd_prim_desc.dst_layer_primitive_desc());
717 leftmost_dst_layer_bwd_reorder
718 = reorder(leftmost_dst_layer_memory,
719 leftmost_dst_layer_bwd_memory);
720 reorder_leftmost_dst_layer_bwd = true;
723 auto rightmost_dst_layer_bwd_memory = rightmost_dst_layer_memory;
724 primitive rightmost_dst_layer_bwd_reorder;
725 auto reorder_rightmost_dst_layer_bwd = false;
726 if (memory::primitive_desc(
727 rightmost_bwd_prim_desc.dst_layer_primitive_desc())
728 != memory::primitive_desc(
729 rightmost_dst_layer_bwd_memory.get_primitive_desc())
731 rightmost_dst_layer_bwd_memory
733 rightmost_bwd_prim_desc.dst_layer_primitive_desc());
734 rightmost_dst_layer_bwd_reorder
735 = reorder(rightmost_dst_layer_memory,
736 rightmost_dst_layer_bwd_memory);
737 reorder_rightmost_dst_layer_bwd = true;
740 // Similar to forward, the backward primitives are connected
741 // via "iter" parameters.
742 auto common_diff_weights_iter_memory
744 leftmost_bwd_prim_desc.diff_weights_iter_primitive_desc());
746 auto leftmost_dst_iter_bwd_memory = leftmost_dst_iter_memory;
747 primitive leftmost_dst_iter_bwd_reorder;
748 auto reorder_leftmost_dst_iter_bwd = false;
749 if (memory::primitive_desc(
750 leftmost_bwd_prim_desc.dst_iter_primitive_desc())
751 != memory::primitive_desc(
752 leftmost_dst_iter_bwd_memory.get_primitive_desc())
754 leftmost_dst_iter_bwd_memory
755 = mkldnn::memory(leftmost_bwd_prim_desc.dst_iter_primitive_desc());
756 leftmost_dst_iter_bwd_reorder
757 = reorder(leftmost_dst_iter_memory,
758 leftmost_dst_iter_bwd_memory);
759 reorder_leftmost_dst_iter_bwd = true;
762 // Construct the RNN primitive objects for backward
763 rnn_backward leftmost_layer_bwd = rnn_backward(
764 /* aprimitive_desc */ leftmost_bwd_prim_desc,
765 /* src_layer */ leftmost_src_layer_bwd_memory,
766 /* src_iter */ null_memory_,
767 /* weights_layer */ common_weights_layer_bwd_memory,
768 /* weights_iter */ common_weights_iter_bwd_memory,
769 /* bias */ common_bias_bwd_memory,
770 /* dst_layer */ leftmost_dst_layer_bwd_memory,
771 /* dst_iter */ leftmost_dst_iter_bwd_memory,
772 /* diff_src_layer */ leftmost_diff_src_layer_memory,
773 /* diff_src_iter */ null_memory_,
774 /* diff_weights_layer */ common_diff_weights_layer_memory,
775 /* diff_weights_iter */ common_diff_weights_iter_memory,
776 /* diff_bias */ common_diff_bias_memory,
777 /* diff_dst_layer */ leftmost_diff_dst_layer_memory,
778 /* diff_dst_iter */ leftmost_diff_dst_iter_memory,
779 /* workspace */ leftmost_workspace_memory
782 rnn_backward rightmost_layer_bwd = rnn_backward(
783 /* aprimitive_desc */ rightmost_bwd_prim_desc,
784 /* src_layer */ rightmost_src_layer_bwd_memory,
785 /* src_iter */ rightmost_src_iter_memory,
786 /* weights_layer */ common_weights_layer_bwd_memory,
787 /* weights_iter */ common_weights_iter_bwd_memory,
788 /* bias */ common_bias_bwd_memory,
789 /* dst_layer */ rightmost_dst_layer_bwd_memory,
790 /* dst_iter */ null_memory_,
791 /* diff_src_layer */ rightmost_diff_src_layer_memory,
792 /* diff_src_iter */ rightmost_diff_src_iter_memory,
793 /* diff_weights_layer */ common_diff_weights_layer_memory,
794 /* diff_weights_iter */ common_diff_weights_iter_memory,
795 /* diff_bias */ common_diff_bias_memory,
796 /* diff_dst_layer */ rightmost_diff_dst_layer_memory,
797 /* diff_dst_iter */ null_memory_,
798 /* workspace */ rightmost_workspace_memory
801 // Enqueue primitives for backward execution
802 if (reorder_common_weights_layer_bwd)
803 bwd_net.push_back(common_weights_layer_bwd_reorder);
804 if (reorder_common_weights_iter_bwd)
805 bwd_net.push_back(common_weights_iter_bwd_reorder);
806 if (reorder_common_bias_bwd)
807 bwd_net.push_back(common_bias_bwd_reorder);
808 if (reorder_common_diff_weights_layer)
809 bwd_net.push_back(common_diff_weights_layer_reorder);
810 if (reorder_common_diff_bias)
811 bwd_net.push_back(common_diff_bias_reorder);
813 if (reorder_rightmost_dst_layer_bwd)
814 bwd_net.push_back(rightmost_dst_layer_bwd_reorder);
816 bwd_net.push_back(rightmost_layer_bwd);
818 if (reorder_leftmost_dst_layer_bwd)
819 bwd_net.push_back(leftmost_dst_layer_bwd_reorder);
820 if (reorder_leftmost_dst_iter_bwd)
821 bwd_net.push_back(leftmost_dst_iter_bwd_reorder);
822 bwd_net.push_back(leftmost_layer_bwd);
824 // Submit backward for execution
825 stream(stream::kind::eager).submit(bwd_net).wait();
827 // User updates weights and bias using diffs
831 int main(int argc, char **argv) {
836 std::cerr << "status: " << e.status << std::endl;
837 std::cerr << "message: " << e.message << std::endl;