Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / graph_optimizer / graph_initializations.cpp
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #include "pass_manager.h"
20 #include "program_node.h"
21
22 #include "split_inst.h"
23 #include "convolution_inst.h"
24 #include "crop_inst.h"
25 #include "lstm_inst.h"
26 #include "reshape_inst.h"
27 #include "upsampling_inst.h"
28
29 #include <iomanip>
30
31 using namespace cldnn;
32
33 namespace cldnn
34 {
35     std::string get_id_string(size_t i) {
36         std::stringstream ss;
37         ss << std::setw(5) << std::setfill('0') << i;
38         return ss.str();
39     }
40
41     // ToDo: rewrite methods in this class the same style (maybe: handle_<primitive_name>() ), 
42     //       is it possible to avoid iterating over all nodes several times?
43     //       do we have any repeated code here, can we make it more readable?
44     void graph_initializations::replace_nodes(program_impl& p)
45     {
46         auto itr = p.nodes_map.begin();
47         while (itr != p.nodes_map.end())
48         {
49             auto node_itr = itr++;
50             auto& node = (*node_itr).second;
51
52             if (node->is_type<split>())
53             {
54                 //check if split is not used by any primitive, as it will be optimized
55                 if (node->get_users().size() != 0)
56                     throw std::logic_error("Split layer cannot be used directly! Please use split output \"" + node->id() + ":<split_output_id>\"!");
57
58                 //get_output size and validate split primitive inputs
59                 auto output_layout = node->get_output_layout();
60                 auto output_layout_size = output_layout.size;
61
62                 auto split_prim = node->as<split>().typed_desc();
63                 primitive_id input_id = split_prim->input[0];
64                 auto split_num = split_prim->output_offsets.size();
65
66                 //create crop for each split ouptut provided
67                 for (decltype(split_num) i = 0; i < split_num; i++)
68                 {
69                     primitive_id output_id = node->id() + ":" + split_prim->output_ids[i];
70
71                     auto node_ptr = p.nodes_map.find(output_id)->second;
72
73                     //calculate crop reference input size
74                     tensor reference_input_size;
75
76                     // For all the split offsets before the last split offset, the size can be calculated
77                     // size_of_offset[n] = offset[n + 1] - offset[n];
78                     if (i != (split_num - 1))
79                     {
80                         reference_input_size += split_prim->output_offsets[i + 1] - split_prim->output_offsets[i];
81                     }
82                     // For the last split i.e. size[split_num - 1] = split_input.size - offsets[n];
83                     else
84                     {
85                         reference_input_size += output_layout_size - split_prim->output_offsets[i];
86                     }
87
88                     // For all the other dimensions, copy from the split_input
89                     for (int dimension = 0; dimension < CLDNN_TENSOR_DIM_MAX; dimension++)
90                     {
91                         reference_input_size.raw[dimension]
92                             = (reference_input_size.raw[dimension] == 0) ? output_layout_size.raw[dimension] : reference_input_size.raw[dimension];
93                     }
94
95                     //update crop primitive
96                     node_ptr->set_output_padding(output_layout.data_padding);
97                     auto crop_prim = node_ptr->as<crop>().typed_desc();
98                     crop_prim->reference_input = reference_input_size;
99                 }
100
101                 //remove input->split connection and remove original split node
102                 p.remove_connection(node->get_dependency(0), *node);
103                 p.optimized_out.push_back(node->id());
104                 p.nodes_map.erase(node->id());
105                 continue;
106             }
107
108             //find upsampling primitives with bilinear filtering and create deconvolution with proper weights instead
109             if (node->is_type<upsampling>())
110             {
111                 auto upsampling_prim = node->as<upsampling>().typed_desc();
112
113                 if (upsampling_prim->sample_type != upsampling_sample_type::bilinear)
114                     continue;
115
116                 //check if num_filter is not 0 (required for bilinear upsampling)
117                 if (upsampling_prim->num_filter == 0)
118                     throw std::logic_error("num_filter in upsampling cannot be 0 in bilinear filtering mode in \"" + node->id() + "\"!");
119
120                 primitive_id upsampling_id = node->id();
121                 auto& input_node = node->get_dependency(0);
122
123                 primitive_id input_id = upsampling_prim->input[0];
124                 auto num_filter = upsampling_prim->num_filter;
125
126                 //setting deconvolution parameters based on upsampling input
127                 auto scale = static_cast<tensor::value_type>(upsampling_prim->scale);
128                 tensor stride(1, 1, scale, scale);
129                 auto offset = static_cast<tensor::value_type>(std::ceil((scale - 1) / 2.f));
130                 tensor input_offset(0, 0, -offset, -offset);
131
132                 //setting weights for deconvolution
133                 auto kernel_size = static_cast<tensor::value_type>((2 * scale) - (scale % 2));
134                 layout weights_layout(data_types::f32, format::bfyx, tensor(1, 1, kernel_size, kernel_size));
135
136                 std::vector<primitive_id> weights_vec;
137                 for (uint32_t weights_idx = 0; weights_idx < num_filter; weights_idx++)
138                 {
139                     memory_impl::ptr data_to_allocate = p.get_engine().allocate_memory(weights_layout);
140                     mem_lock<float> dst{ data_to_allocate };
141                     float *dst_data = dst.data();
142                     //initialize with bilinear weights data
143                     auto f = static_cast<uint32_t>(std::ceil(kernel_size / 2.0f));
144                     float c = (2 * f - 1 - f % 2) / (2.f * f);
145                     float x = 0.f;
146                     float y = 0.f;
147                     for (size_t i = 0; i < weights_layout.count(); ++i) {
148                         x = static_cast<float>(i % kernel_size);
149                         y = static_cast<float>((i / kernel_size) % kernel_size);
150                         dst_data[i] = (1 - std::abs(x / f - c)) * (1 - std::abs(y / f - c));
151                     }
152
153                     //create weights primitive, with dummy memory which will be replaced in firther step
154                     primitive_id weights_id = upsampling_id + "_deconvolution_weights" + std::to_string(weights_idx);
155                     layout dummy_layout(data_types::f32, format::bfyx, tensor(1, 1, 1, 1));
156                     float zero = 0.f;
157                     auto weights_prim = std::make_shared<data>(weights_id, memory::attach(dummy_layout, &zero, 1));
158                     p.get_or_create(weights_prim);
159
160                     weights_vec.push_back(weights_id);
161
162                     auto weights_node_ptr = p.nodes_map.find(weights_id)->second;
163
164                     //attach weights buffer
165                     auto& data_node = weights_node_ptr->as<data>();
166                     data_node.attach_memory(*data_to_allocate, false);
167                 }
168
169                 //remove upsampling node, rename it and move to the optimized list
170                 p.remove_connection(node->get_dependency(0), *node);
171                 auto rename_id = upsampling_id + "_tmp";
172                 p.rename(*node, rename_id);
173
174                 //create deconvolution primitive
175                 auto deconv_prim = std::make_shared<deconvolution>(upsampling_id, input_id, weights_vec, stride, input_offset);
176                 p.get_or_create(deconv_prim);
177
178                 auto deconv_node_ptr = p.nodes_map.find(upsampling_id)->second;
179
180                 auto upsampling_node_ptr = p.nodes_map.find(rename_id)->second;
181                 p.replace_all_usages(*upsampling_node_ptr, *deconv_node_ptr);
182                 p.optimized_out.push_back(rename_id);
183                 p.nodes_map.erase(rename_id);
184
185                 //add connections input->deconvolution and weights->deconvolution
186                 p.add_connection(input_node, *deconv_node_ptr);
187
188                 for (uint32_t weights_idx = 0; weights_idx < num_filter; weights_idx++)
189                 {
190                     auto weights_node_ptr = p.nodes_map.find(weights_vec[weights_idx])->second;
191                     p.add_connection(*weights_node_ptr, *deconv_node_ptr);
192                 }
193                 continue;
194             }
195
196             //find deconvolution primitives with stride 1 and change them to convolution with trasposed weights
197             if (node->is_type<deconvolution>())
198             {
199                 if (!p.get_options().get<build_option_type::optimize_data>()->enabled())
200                     continue;
201
202                 auto deconv_prim = node->as<deconvolution>().typed_desc();
203
204                 //limit optimization to stride = 1
205                 if (deconv_prim->stride.spatial[0] != 1 || deconv_prim->stride.spatial[1] != 1 || deconv_prim->gradient())
206                     continue;
207
208                 primitive_id deconv_id = node->id();
209                 auto& input_node = node->get_dependency(0);
210
211                 primitive_id input_id = deconv_prim->input[0];
212
213                 //setting convolution parameters based on deconvolution params
214                 auto stride = deconv_prim->stride;
215                 auto weights = deconv_prim->weights;
216                 std::vector<primitive_id> weights_vec;
217                 for (auto& weights_id : weights)
218                     weights_vec.push_back(weights_id);
219                 auto biases = deconv_prim->bias;
220                 std::vector<primitive_id> bias_vec;
221                 for (auto& bias_id : biases)
222                     bias_vec.push_back(bias_id);
223                 auto input_offset = deconv_prim->input_offset;
224                 auto with_activation = deconv_prim->with_activation;
225                 auto activation_negative_slope = deconv_prim->activation_negative_slope;
226                 auto output_padding = deconv_prim->output_padding;
227
228                 //remove deconvolution node and its connections to weights and biases, rename it and move to the optimized list
229                 tensor filter_size = { 1, 1, 1, 1 };
230                 p.remove_connection(node->get_dependency(0), *node);
231                 for (auto& weights_id : weights_vec)
232                 {
233                     auto weights_node_ptr = p.nodes_map.find(weights_id)->second;
234                     p.remove_connection(*weights_node_ptr, *node);
235                     //get filter spatial sizes for input offset adjustment, perform this only once as all filters shouls have same size
236                     if (weights_id == weights_vec[0])
237                         filter_size = weights_node_ptr->get_output_layout().size;
238                 }
239
240                 input_offset.spatial[0] = std::abs(input_offset.spatial[0]) - (filter_size.spatial[0] - 1);
241                 input_offset.spatial[1] = std::abs(input_offset.spatial[1]) - (filter_size.spatial[1] - 1);
242
243                 if (!bias_vec.empty())
244                 {
245                     for (auto& bias_id : bias_vec)
246                     {
247                         auto bias_id_node_ptr = p.nodes_map.find(bias_id)->second;
248                         p.remove_connection(*bias_id_node_ptr, *node);
249                     }
250                 }
251                 auto rename_id = deconv_id + "_tmp";
252                 p.rename(*node, rename_id);
253
254                 //create convolution primitive
255                 if (biases.size() != 0)
256                 {
257                     auto conv_prim = std::make_shared<convolution>(deconv_id, input_id, weights_vec, bias_vec,
258                         stride, input_offset, tensor{ 1, 1, 1, 1 }, with_activation, activation_negative_slope, output_padding);
259                     p.get_or_create(conv_prim);
260                 }
261                 else
262                 {
263                     auto conv_prim = std::make_shared<convolution>(deconv_id, input_id, weights_vec,
264                         stride, input_offset, tensor{ 1, 1, 1, 1 }, with_activation, activation_negative_slope, output_padding);
265                     p.get_or_create(conv_prim);
266                 }
267
268                 auto conv_node_ptr = p.nodes_map.find(deconv_id)->second;
269                 auto conv_node = &conv_node_ptr->as<convolution>();
270                 conv_node->set_transposed(true);
271
272                 //add connections input->convolution, weights->convolution and bias->convolution
273                 p.add_connection(input_node, *conv_node_ptr);
274
275                 for (auto& weights_id : weights_vec)
276                 {
277                     auto weights_node_ptr = p.nodes_map.find(weights_id)->second;
278                     p.add_connection(*weights_node_ptr, *conv_node_ptr);
279                 }
280
281                 if (!bias_vec.empty())
282                 {
283                     for (auto& bias_id : bias_vec)
284                     {
285                         auto bias_id_node_ptr = p.nodes_map.find(bias_id)->second;
286                         p.add_connection(*bias_id_node_ptr, *conv_node_ptr);
287                     }
288                 }
289
290                 auto deconv_node_ptr = p.nodes_map.find(rename_id)->second;
291                 p.replace_all_usages(*deconv_node_ptr, *conv_node_ptr);
292                 p.optimized_out.push_back(rename_id);
293                 p.nodes_map.erase(rename_id);
294
295                 continue;
296             }
297         }
298     }
299
300     void graph_initializations::handle_detection_output(program_impl& p)
301     {
302         auto itr = p.nodes_map.begin(); //note we need to use iterators since currently processed element can be removed
303         while (itr != p.nodes_map.end())
304         {
305             auto node_itr = itr++;
306             auto& node = *(*node_itr).second;
307             // Create second part detection output primitive and replace nodes names - do it only once
308             if ((p.get_options().get<build_option_type::detection_output_gpu>()->enabled()) &&
309                 (node.is_type<detection_output>()) &&
310                 (node.id().find("_pre") == std::string::npos))    //ToDo: this will fail if user will name the primitive with using _pre like do_pre
311                                                                   //      we need to use node mark() or some other idea to prevent it   
312             {
313                 // rename detection output
314                 const primitive_id detect_out_node_name = node.id();
315                 const primitive_id new_primitive_id = detect_out_node_name + "_pre";
316                 p.rename(node, new_primitive_id);
317
318                 auto detect_out_prim = node.as<detection_output>().typed_desc();
319                 // Create new primitive, "keep top k" part of detection output
320                 // ToDo: add a default parameters to the detection_output_sort class constructor to get rid off this initialization from here
321                 auto detect_out_sort_prim = std::make_shared<detection_output_sort>(
322                     detect_out_node_name,
323                     node.id(),
324                     // not important params here - it will be set during "primitive_impl* create" func in "detection_output_sort_gpu"
325                     0,      // num_images
326                     0,      // num_classes
327                     0,      // keep_top_k
328                     false,  // share_location
329                     0,      // top_k
330                     -1,     // background_label_id
331                     detect_out_prim->output_padding);
332
333                 p.get_or_create(detect_out_sort_prim);
334
335                 auto sort_node = p.nodes_map.find(detect_out_node_name)->second;
336
337                 // Add connection to second part of detection output
338                 if (node.get_users().size())
339                 {
340                     p.add_intermediate(*sort_node, *(node.get_users().front()), 0, false);
341                 }
342                 else
343                 {
344                     p.add_connection(node, *sort_node);
345                 }
346             }
347         }
348     }
349
350     void graph_initializations::handle_lstm(program_impl& p)
351     {
352         bool has_lstm_children;
353         auto itr = p.nodes_map.begin(); //note we need to use iterators since currently processed element can be removed
354         while (itr != p.nodes_map.end())
355         {
356             auto node_itr = itr++;
357             auto& node = (*node_itr).second;
358             has_lstm_children = false;
359             // replace lstm node with lstm_gemm and lstm_elt nodes
360             if (node->is_type<lstm>()) {
361                 bool initial_hidden_term = node->as<lstm>().initial_hidden_term();
362                 bool initial_cell_term = node->as<lstm>().initial_cell_term();
363                 bool bias_term = node->as<lstm>().bias_term();
364                 auto lstm_prim = node->as<lstm>().typed_desc();
365                 primitive_id weights_id = lstm_prim->weights;
366                 primitive_id recurrent_id = lstm_prim->recurrent;
367                 primitive_id bias_id = bias_term ? lstm_prim->bias : "";
368                 primitive_id initial_hidden_id = initial_hidden_term ? lstm_prim->initial_hidden : "";
369                 primitive_id initial_cell_id = initial_cell_term ? lstm_prim->initial_cell : "";
370
371                 //removing connection with weights to get proper dependency order for next operations
372                 p.remove_connection(*p.nodes_map.at(weights_id), *node);
373                 p.remove_connection(*p.nodes_map.at(recurrent_id), *node);
374                 if (bias_term)
375                     p.remove_connection(*p.nodes_map.at(bias_id), *node);
376                 if (initial_hidden_term)
377                     p.remove_connection(*p.nodes_map.at(initial_hidden_id), *node);
378                 if (initial_cell_term)
379                     p.remove_connection(*p.nodes_map.at(initial_cell_id), *node);
380
381                 //calculating sizes
382                 auto input_size = node->get_dependency(0).get_output_layout().size;
383                 auto recurrent_size = p.nodes_map.at(recurrent_id)->get_output_layout().size;
384
385                 // hidden tensor size = [batch, seq, hidden_size, direction]
386                 // the output of the element wise operation is cropped and used in the next time step
387                 // sequence_len = 1 and direction = 1. The backward pass is separated from the forward pass
388                 auto hidden_size = tensor(input_size.batch[0], 1, recurrent_size.spatial[0], 1);
389
390                 size_t directions = recurrent_size.feature[0];
391                 size_t input_directions = input_size.spatial[1];
392                 size_t num_input_dependencies = node->get_dependencies().size();
393                 size_t input_vector_size = node->as<lstm>().sequence_len();
394                 size_t sequence_len = input_vector_size;
395
396                 // Calculate the input sequence length for the lstm node
397                 // Case 1: If the input comes in as a concatenated input i.e. the
398                 // input is not divided into sequence elements
399                 if (input_vector_size == 1 && num_input_dependencies == 1)
400                 {
401                     // Either the input actually has 1 sequence element
402                     auto& input = node->get_dependency(0);
403                     auto input_layout = input.get_output_layout();
404                     tensor input_tensor = input_layout.size;
405
406                     // Get the sequence length from the input to LSTM
407                     sequence_len = input_layout.size.feature[0];
408
409                     // If the input's feature/sequence length field is > 1, i.e. If
410                     // the sequence elements are concatenated into one single input
411                     // then it has to be split into individual sequence elements
412                     if (sequence_len > 1)
413                     {
414                         for (size_t sequence_element = 0; sequence_element < sequence_len; sequence_element++)
415                         {
416                             primitive_id crop_id = input.id() + ":crop:" + get_id_string(sequence_element);
417                             tensor crop_tensor{ input_tensor.batch[0], 1, input_tensor.spatial[0], input_tensor.spatial[1] };
418                             tensor offset_tensor{ 0, static_cast<tensor::value_type>(sequence_element), 0, 0 };
419                             auto input_crop = std::make_shared<crop>(crop_id, input.id(), crop_tensor, offset_tensor);
420                             auto& input_crop_node = p.get_or_create(input_crop);
421
422                             // Add the crop nodes as user for input
423                             p.add_connection(node->get_dependency(0), input_crop_node);
424
425                             // Connect crop with lstm
426                             p.add_connection(input_crop_node, *node);
427                         }
428
429                         // We have the sequence elements (cropped inputs) as input to LSTM. 
430                         // The original input is no longer a dependency to LSTM. 
431                         // Remove the input node as a dependency to LSTM
432                         p.remove_connection(node->get_dependency(0), *node);
433
434                         // Update the total no. of input dependecies
435                         num_input_dependencies = node->get_dependencies().size();
436                     }
437                 }
438
439                 //if the sequence has a single element but it has multiple inputs then
440                 //the parent of this lstm is an lstm node. If this is a bidirectional lstm
441                 //then the sequence length is the number of dependencies divided by 2.
442                 else if (input_vector_size == 1 && num_input_dependencies > 1)
443                 {
444                     sequence_len = (directions == 1) ? num_input_dependencies : num_input_dependencies / 2;
445                 }
446
447                 //check if this lstm node has an lstm child
448                 for (auto& user : node->get_users())
449                 {
450                     if (user->is_type<lstm>())
451                     {
452                         has_lstm_children = true;
453                     }
454                 }
455
456                 bool emit_last_cell = lstm_prim->output_selection == cldnn_lstm_output_hidden_cell ||
457                     lstm_prim->output_selection == cldnn_lstm_output_sequence_cell;
458                 bool emit_sequence = lstm_prim->output_selection == cldnn_lstm_output_sequence_cell ||
459                     lstm_prim->output_selection == cldnn_lstm_output_sequence;
460
461                 std::vector<program_node*> cell_list(directions * sequence_len);
462                 std::vector<program_node*> hidden_list(directions * sequence_len);
463                 std::map<size_t, std::pair<primitive_id, program_node*>> output_map;
464                 auto dependencies = node->get_dependencies();
465
466                 //lstm expanding
467                 for (size_t dir = 0; dir < directions; ++dir) {
468                     auto hidden_id = initial_hidden_id;
469                     auto cell_id = initial_cell_id;
470                     for (size_t i = 0; i < sequence_len; ++i) {
471                         size_t idx = i + dir * sequence_len;
472                         primitive_id lstm_gemm_id = node->id() + ":lstm_gemm" + get_id_string(idx);
473                         primitive_id lstm_elt_id = node->id() + ":lstm_elt" + get_id_string(idx);
474                         primitive_id crop_id = node->id() + ":crop" + get_id_string(idx);
475
476                         size_t input_idx = i;
477                         //for bidirectional lstms, if first LSTM layer then reverse input
478                         //for subsequent stacked layers the input is strided on the dir dimension
479                         if (directions > 0) {
480                             if (num_input_dependencies > sequence_len) { // stacked layer
481                                 input_idx = dir * sequence_len + i;
482                             }
483                             else
484                             {
485                                 if ((input_directions < 2) && dir > 0) { // first layer
486                                     input_idx = sequence_len - i - 1;
487                                 }
488                             }
489                         }
490
491                         //primitive_id lstm_gemm_input_id = node->get_dependency(input_idx).get_primitive()->id;
492                         //the line below requires an attention: get_org_primitive_id() might not be an actual id of a node (see rename method)
493                         //ToDO: ensure that get_org_primitive_id() is suitable here
494                         primitive_id lstm_gemm_input_id = node->get_dependency(input_idx).get_org_primitive_id();
495
496                         auto lstm_gemm_node = std::make_shared<lstm_gemm>(lstm_gemm_id, lstm_gemm_input_id, weights_id, recurrent_id, bias_id, hidden_id, (uint32_t)dir);
497                         auto &n1 = p.get_or_create(lstm_gemm_node);
498
499                         auto lstm_elt_node = std::make_shared<lstm_elt>(lstm_elt_id, lstm_gemm_id, cell_id, lstm_prim->clip, lstm_prim->input_forget,
500                             lstm_prim->activations, lstm_prim->activation_params, lstm_prim->offset_order, (uint32_t)dir);
501                         auto &n2 = p.get_or_create(lstm_elt_node);
502                         //adding lstm_elt as user
503                         p.add_connection(n1, n2);
504                         //adding dependecy to lstm_gemm node
505                         //input
506                         p.add_connection(node->get_dependency(input_idx), n1);
507                         //adding weights and initial values to lstm_gemm
508                         p.add_connection(*p.nodes_map.at(weights_id), n1);
509                         p.add_connection(*p.nodes_map.at(recurrent_id), n1);
510                         if (bias_term)
511                             p.add_connection(*p.nodes_map.at(bias_id), n1);
512
513                         //adding cell and hiddens as dependencies
514                         if (i > 0)
515                         {
516                             p.add_connection(*cell_list[size_t(i - 1) * directions + dir], n2);
517                             p.add_connection(*hidden_list[size_t(i - 1) * directions + dir], n1);
518                         }
519                         //if initial values are present
520                         else
521                         {
522                             if (initial_hidden_term)
523                                 p.add_connection(*p.nodes_map.at(hidden_id), n1);
524                             if (initial_cell_term)
525                                 p.add_connection(*p.nodes_map.at(cell_id), n2);
526                         }
527
528                         //lstm_hidden
529                         {
530                             hidden_id = crop_id + ":hidden";
531                             auto crop_hidden = std::make_shared<crop>(hidden_id, lstm_elt_id, hidden_size, tensor{ 0,0,0,0 });
532                             auto &n3 = p.get_or_create(crop_hidden);
533                             //adding eltwise as dependency to hidden
534                             p.add_connection(n2, n3);
535
536                             //if parent is lstm adding hiddens as dependency
537                             if (has_lstm_children)
538                             {
539                                 for (auto& user : node->get_users())
540                                 {
541                                     p.add_connection(n3, *user);
542                                 }
543                             }
544                             hidden_list[i * directions + dir] = &n3;
545                             if (i == sequence_len - 1 || emit_sequence)
546                             {
547                                 output_map[i * directions + dir] = { hidden_id, &n3 };
548                             }
549                         }
550
551                         //lstm_cell
552                         if (i < sequence_len - 1 || emit_last_cell)
553                         {
554                             cell_id = crop_id + ":cell";
555                             auto crop_cell = std::make_shared<crop>(cell_id, lstm_elt_id, hidden_size, tensor{ 0,1,0,0 });
556                             auto &n4 = p.get_or_create(crop_cell);
557                             p.add_connection(n2, n4);
558                             cell_list[i * directions + dir] = &n4;
559                             if (i == sequence_len - 1)
560                             {
561                                 output_map[sequence_len * directions + dir] = { cell_id, &n4 };
562                             }
563                         }
564                     }
565                 }
566                 //if there is no next lstm, concatenation is created
567                 if (!has_lstm_children)
568                 {
569                     std::vector<primitive_id> output_ids_offsets;
570                     for (auto& e : output_map)
571                     {
572                         output_ids_offsets.push_back(e.second.first);
573                     }
574                     primitive_id original_id = node->id();
575                     primitive_id concatenation_id = original_id + ":concat";
576                     auto concatenation_primitive = std::make_shared<concatenation>(concatenation_id, output_ids_offsets, concatenation::along_f);
577                     auto &concatenation_node = p.get_or_create(concatenation_primitive);
578                     for (auto& e : output_map)
579                     {
580                         p.add_connection(*e.second.second, concatenation_node);
581                     }
582                     if (directions == 2) {
583                         // bidirectional support requires concatenations along the direction and sequence axis
584                         // instead we can concatenate along the sequence axis and reshape the tensor to the account
585                         // for the direction
586                         size_t concatenate_len = emit_sequence ? sequence_len : 1;
587                         if (emit_last_cell) concatenate_len++;
588
589                         tensor output_size{ input_size.batch[0], static_cast<int32_t>(concatenate_len), hidden_size.spatial[0], (int32_t)directions };
590                         primitive_id reshape_id = original_id + ":reshape";
591                         auto reshape_primitive = std::make_shared<reshape>(reshape_id, concatenation_id, output_size);
592                         auto &reshape_node = p.get_or_create(reshape_primitive);
593                         p.add_connection(concatenation_node, reshape_node);
594                         p.replace_all_usages(*node, reshape_node);
595                     }
596                     else
597                     {
598                         p.replace_all_usages(*node, concatenation_node);
599                     }
600                 }
601                 //removing expanded node
602                 p.remove_all_connections(*node);
603                 p.nodes_map.erase(node->id());
604                 continue;
605             }
606         }
607
608     }
609
610     void graph_initializations::set_outputs(program_impl& p)
611     {
612         auto outputs_option = p.get_options().get<build_option_type::outputs>();
613         if (!outputs_option->outputs.empty())
614         {
615             for (auto const& output : outputs_option->outputs)
616             {
617                 auto o_node = p.nodes_map.at(output);
618                 o_node->set_output(true);
619                 p.outputs.push_back(o_node.get());
620             }
621         }
622         else
623         {
624             for (auto& node : p.nodes_map)
625                 if (node.second->is_endpoint())
626                 {
627                     node.second->set_output(true);
628                     p.outputs.push_back(node.second.get());
629                 }
630         }
631     }
632
633     void graph_initializations::run(program_impl& p)
634     {
635         replace_nodes(p);
636         handle_detection_output(p);
637         handle_lstm(p);
638         set_outputs(p);
639         p.get_processing_order().calc_processing_order(p);
640     }
641 }