Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / program.cpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #include "error_handler.h"
20 #include "kernel_selector_helper.h"
21 #include "internal_primitive.h"
22 #include "internal_primitive_type_base.h"
23 #include "layout_optimizer.h"
24 #include "pass_manager.h"
25 #include "primitive_type.h"
26 #include "program_dump_graph.h"
27 #include "program_helpers.h"
28 #include "program_impl.h"
29 #include "sliding_window_utils.h"
30
31 #include "convolution_inst.h"
32 #include "concatenation_inst.h"
33 #include "crop_inst.h"
34 #include "data_inst.h"
35 #include "deconvolution_inst.h"
36 #include "detection_output_inst.h"
37 #include "input_layout_inst.h"
38 #include "lstm_inst.h"
39 #include "lstm_elt_inst.h"
40 #include "lstm_gemm_inst.h"
41 #include "mutable_data_inst.h"
42 #include "pooling_inst.h"
43 #include "primitive_inst.h"
44 #include "prior_box_inst.h"
45 #include "proposal_inst.h"
46 #include "reorder_inst.h"
47 #include "reshape_inst.h"
48 #include "split_inst.h"
49
50 #include "gpu/ocl_toolkit.h"
51
52 #include <fstream>
53 #include <algorithm>
54 #include <stdio.h>
55 #include <iostream>
56 #include <sstream>
57 #include <iomanip>
58 #include <memory>
59
60 program_impl::program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal, bool no_optimizations)
61     : engine(&engine_ref), options(options), processing_order(* new nodes_ordering), pm(std::unique_ptr<pass_manager>(new pass_manager()))
62 {
63     set_options();
64     prepare_nodes(topology);
65     if (no_optimizations)
66         init_graph();
67     else
68         build_program(is_internal);
69 }
70
71 program_impl::program_impl(engine_impl& engine_ref, std::set<std::shared_ptr<program_node>> const& nodes, build_options const& options, bool is_internal)
72     : engine(&engine_ref), options(options), processing_order(*new nodes_ordering), pm(std::unique_ptr<pass_manager>(new pass_manager()))
73 {
74     set_options();
75     prepare_nodes(nodes);
76     build_program(is_internal);
77 }
78
79 program_impl::~program_impl() = default;
80
81 program_node& program_impl::get_node(primitive_id const& id)
82 {
83     try
84     {
85         return *nodes_map.at(id);
86     }
87     catch (...)
88     {
89         throw std::runtime_error("Program doesn't contain primtive node: " + id);
90     }
91 }
92
93 program_node const& program_impl::get_node(primitive_id const& id) const
94 {
95     try
96     {
97         return *nodes_map.at(id);
98     }
99     catch (...)
100     {
101         throw std::runtime_error("Program doesn't contain primtive node: " + id);
102     }
103 }
104
105 // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
106 bool program_impl::analyze_output_size_handling_need()
107 {
108     bool handling_needed = false;
109
110     // Calculate output size and compare with specified.
111     for (const auto& node : processing_order)
112     {
113         if (node->is_type<convolution>())
114         {
115             auto& prim_node = node->as<convolution>();
116             const auto& prim = prim_node.get_primitive();
117
118             if (!prim->with_output_size)
119                 continue;
120
121             tensor specified_output_range({ 0, 0, prim->output_size.spatial[0], prim->output_size.spatial[1] }, 1);
122
123             auto filter_size = prim_node.weights(0).get_output_layout().size;
124
125             auto calc_output_range = calc_sliding_window_output_range<swor_mode::all>(
126                 prim_node.input().get_output_layout().size,
127                 filter_size, prim->input_offset, prim->stride, prim->dilation, true, 1);
128
129             if (specified_output_range != calc_output_range)
130                 handling_needed = true;
131         }
132         else if (node->is_type<deconvolution>())
133         {
134             auto& prim_node = node->as<deconvolution>();
135             const auto& prim = prim_node.get_primitive();
136
137             if (!prim->with_output_size)
138                 continue;
139
140             tensor specified_output_range({ 0, 0, prim->output_size.spatial[0], prim->output_size.spatial[1] }, 1);
141
142             auto filter_size = prim_node.weights(0).get_output_layout().size;
143
144             auto calc_output_range = calc_sliding_window_needed_input_range(
145                 prim_node.input().get_output_layout().size,
146                 filter_size, prim->input_offset, prim->stride, { 1, 1, 1, 1 }, true, 1);
147
148             if (specified_output_range != calc_output_range)
149                 handling_needed = true;
150         }
151         else if (node->is_type<pooling>())
152         {
153             auto& prim_node = node->as<pooling>();
154             const auto& prim = prim_node.get_primitive();
155
156             if (!prim->with_output_size)
157                 continue;
158
159             tensor specified_output_range({ 0, 0, prim->output_size.spatial[0], prim->output_size.spatial[1] }, 1);
160
161             // TODO: Check compatibility of output size calculation (with caffe).
162             auto calc_output_range = calc_sliding_window_output_range<swor_mode::exceed_once_data>(
163                 prim_node.input().get_output_layout().size,
164                 prim->size, prim->input_offset, prim->stride, { 1, 1, 1, 1 }, true, 1);
165
166             if (specified_output_range != calc_output_range)
167                 handling_needed = true;
168         }
169     }
170
171     return handling_needed;
172 }
173
174 // create new nodes for a program based on the set of nodes
175 // method created to be used by propagate_constants to build sub program from constant nodes 
176 void program_impl::prepare_nodes(std::set<std::shared_ptr<program_node>>const &nodes)
177 {
178     for (const auto& itr : nodes)
179     {
180         if (itr.get()->is_type<data>())
181         {
182             get_or_create(
183                 std::make_shared<input_layout>(itr.get()->id(), itr.get()->as<data>().get_primitive()->mem.get_layout())
184             );
185         }
186         else
187         {
188             get_or_create(itr->desc);
189         }
190     }
191     for (const auto& node : nodes_map)
192     {
193         auto node_ptr = node.second;
194         if (node_ptr == nullptr)
195             throw error("NULL pointer in nodes_map.", CLDNN_ERROR);
196         //ToDo: avoid O(n^2) run time here (pass map instead of set?)
197         bool found = false;
198         for (const auto& src_node : nodes)
199         {
200             if (src_node == nullptr)
201                 throw error("NULL pointer in nodes_map.", CLDNN_ERROR);
202             if (node.first == src_node->get_primitive()->id)
203             {
204                 copy_node_dependencies(node_ptr.get(), src_node.get());
205                 found = true;
206                 break;
207             }
208         }
209         if (!found)
210         {
211             add_node_dependencies(node_ptr.get());
212         }
213         if (node_ptr->dependencies.size() == 0)
214             inputs.push_back(node_ptr.get());
215     }
216 }
217
218 // create all nodes from topology primitives, add dependencies among them and create inputs list
219 void program_impl::prepare_nodes(topology_impl const &topology)
220 {
221     auto const& topo_map = topology.get_primitives();
222     for (const auto& prim : topo_map)
223     {
224         get_or_create(prim.second);
225     }
226     add_split_outputs();
227     for (const auto& node : nodes_map)
228     {
229         auto node_ptr = node.second.get();
230         if (node_ptr == nullptr)
231             throw error("NULL pointer in nodes_map.", CLDNN_ERROR);
232         add_node_dependencies(node_ptr);
233         if (node_ptr->dependencies.size()==0)
234         {
235             inputs.push_back(node_ptr);
236         }
237     }
238 }
239
240 // add node's dependecies from its primitive dependencies
241 void program_impl::add_node_dependencies(program_node* node)
242 {
243     auto deps = node->get_primitive()->dependencies();
244     //add pointers to node's dependencies
245     for (auto& dep : deps)
246     {
247         try {
248             auto dep_node = nodes_map.at(dep);
249             node->dependencies.push_back(dep_node.get());
250             dep_node->users.push_back(node);
251         }
252         catch (...) {
253             throw std::runtime_error("Program doesn't contain primitive: " + dep +
254                 " that is input to: " + node->get_primitive()->id);
255         }
256     }
257 }
258
259 /* helper method for program_impl constructor from list of nodes which
260    copies src_node dependecies to the destination node dest_node dependencies.
261    But only to those which appaer in this program implementation nodes_map */
262 void program_impl::copy_node_dependencies(program_node* dest_node, program_node* src_node)
263 {
264     if (dest_node->get_primitive()->id != src_node->get_primitive()->id)
265     {
266         throw std::runtime_error("Node " + src_node->get_primitive()->id +  " and its copy " + dest_node->get_primitive()->id + " do not match.");
267     }
268     auto src_deps = src_node->get_dependencies();
269     //add pointers to node's dependencies
270     for (auto& src_dep : src_deps)
271     {
272         // do not copy dependencies to nodes which does not belong to the new (subgraph) topology
273         if (nodes_map.find(src_dep->get_primitive()->id) == nodes_map.end()) continue;
274
275         try {
276             auto dest_dep = nodes_map.at(src_dep->get_primitive()->id);
277             dest_node->dependencies.push_back(dest_dep.get());
278             dest_dep->users.push_back(dest_node);
279         }
280         catch (...) {
281             throw std::runtime_error("Program doesn't contain primitive: " + src_dep->get_primitive()->id +
282                 " that is input to: " + src_node->get_primitive()->id);
283         }
284     }
285 }
286
287 void program_impl::set_options()
288 {
289     static std::atomic<uint32_t> id_gen{ 0 };
290     prog_id = ++id_gen;
291     assert(prog_id != 0);
292
293     if ((options.get<build_option_type::tuning_config>()->config.mode == tuning_mode::tuning_tune_and_cache) &&
294         !engine->configuration().enable_profiling)
295     {
296         throw std::invalid_argument("Engine must be created with profiling enabled in tune_and_cache mode!");
297     }
298 }
299
300 void program_impl::build_program(bool is_internal)
301 {
302     init_graph();
303     {
304         pre_optimize_graph(is_internal);
305     }
306     run_graph_compilation();
307     {
308         post_optimize_graph(is_internal);
309     }
310     engine->compile_program(*this);
311     this->dump_program("finished", true);
312     cleanup();
313 }
314
315 void program_impl::init_graph()
316 {
317     graph_initializations graph_initializations_pass;
318     pm->run(*this, graph_initializations_pass);
319
320     calculate_prior_boxes calculate_prior_boxes_pass;
321     pm->run(*this, calculate_prior_boxes_pass);
322     
323     mark_nodes mark_nodes_pass;
324     pm->run(*this, mark_nodes_pass);
325 }
326
327 void program_impl::run_graph_compilation() {
328     compile_graph compile_graph_pass;
329     pm->run(*this, compile_graph_pass);
330 }
331
332 void program_impl::pre_optimize_graph(bool is_internal)
333 {
334     trim_to_outputs trim_pass; //trim to outputs
335     pm->run(*this, trim_pass); // ToDo remove hidden dependencies from trimm pass
336
337     handle_input_padding handle_input_padding; // handle symmetric and asymmetric padding for input
338     pm->run(*this, handle_input_padding);
339
340     add_reshape_to_primitives add_reshape_to_primitives_pass; // add reshape to input/parameters for some primitives
341     pm->run(*this, add_reshape_to_primitives_pass);
342
343     processing_order.calculate_BFS_processing_order(); // this method makes sense only for OOOQ (out of order execution queue)
344
345     bool output_size_handling_enabled = analyze_output_size_handling_need();
346     for (auto& node : processing_order)
347     {
348         if (!node->is_type<internal_primitive>() && !node->is_type<data>())
349             node->get_output_layout();
350     }
351
352     if (options.get<build_option_type::optimize_data>()->enabled())
353     {
354         prepare_primitive_fusing prepare_primitive_fusing_pass;
355         pm->run(*this, prepare_primitive_fusing_pass);
356
357         layout_optimizer lo(output_size_handling_enabled);
358         reorder_inputs reorder_inputs_pass(lo);
359         pm->run(*this, reorder_inputs_pass);
360
361         // this code should be moved to post compilation after kernel selector will support handling reorder bias
362         pre_optimize_bias pre_optimize_bias_pass(lo);
363         pm->run(*this, pre_optimize_bias_pass);
364
365         // passes regarding conv + eltwise optimizations
366
367         // shrinking eltwise if users are conv 1x1 with stride > 1 optimization
368         eltwise_shrinking eltwise_shrinking_pass;
369         pm->run(*this, eltwise_shrinking_pass);
370
371         // trying to set stride to 1x1 by shrinking convolutions before eltwise if doable
372         eltwise_remove_stride eltwise_remove_stride_pass;
373         pm->run(*this, eltwise_remove_stride_pass);
374
375         prepare_conv_eltw_fusing prepare_conv_eltw_fusing_pass;
376         pm->run(*this, prepare_conv_eltw_fusing_pass);
377
378         prepare_conv_eltw_read_write_opt prepare_conv_eltw_read_write_opt_pass;
379         pm->run(*this, prepare_conv_eltw_read_write_opt_pass);
380     }
381
382     handle_reshape();
383
384     remove_redundant_reorders remove_redundant_reorders_pass;
385     pm->run(*this, remove_redundant_reorders_pass);
386
387     prepare_padding prepare_padding_pass(output_size_handling_enabled);
388     pm->run(*this, prepare_padding_pass);
389
390     prepare_depthwise_sep_opt prepare_depthwise_sep_opt_pass;
391     pm->run(*this, prepare_depthwise_sep_opt_pass);
392
393     if (!is_internal)
394     {
395         propagate_constants propagate_constants_pass;  // ToDo remove hidden dependencies from propagate_constants pass
396         pm->run(*this, propagate_constants_pass);
397     }
398
399     //try to fuse buffers (i.e. depth_concat in bfyx format) after padding calculations
400     if (options.get<build_option_type::optimize_data>()->enabled())
401     {
402         prepare_buffer_fusing prepare_buffer_fusing_pass;
403         pm->run(*this, prepare_buffer_fusing_pass);
404     }
405
406     //check if there exists some layout incompatibilities and add an reorder node if required
407     add_required_reorders add_required_reorders_pass;
408     pm->run(*this, add_required_reorders_pass);
409 }
410
411 void program_impl::post_optimize_graph(bool is_internal)
412 {
413     layout_optimizer lo;
414     post_optimize_weights post_optimize_weights_pass(lo);
415     pm->run(*this, post_optimize_weights_pass);
416
417     remove_redundant_reorders remove_redundant_reorders_pass;
418     pm->run(*this, remove_redundant_reorders_pass); //TODO: do we need it at this place also?
419
420     if (!is_internal)
421     {
422         propagate_constants propagate_constants_pass;  // ToDo remove hidden dependencies from propagate_constants pass
423         pm->run(*this, propagate_constants_pass);
424     }
425
426     prep_opt_depthwise_sep_post prep_opt_depthwise_sep_post_pass;
427     pm->run(*this, prep_opt_depthwise_sep_post_pass);
428    
429     prepare_memory_dependencies();
430 }
431
432 // mark if the node is constant assuming that all dependencies are marked properly
433 void program_impl::mark_if_constant(program_node& node) 
434 {
435     if (node.get_dependencies().empty())
436         return;
437     if (node.is_type<prior_box>())
438         return;
439     node.constant = true;
440     for (auto& dep : node.get_dependencies())
441     {
442         if (!dep->constant)
443         {
444             node.constant = false;
445             break;
446         }
447     }
448 }
449
450 // mark if the node is in data flow assuming that all dependencies are marked properly
451 void program_impl::mark_if_data_flow(program_node& node) 
452 {
453     if (node.is_type<mutable_data>() || node.is_type<input_layout>())
454     {
455         node.data_flow = true;
456     }
457     else
458     {
459         node.data_flow = false;
460         size_t inputs_count = node.get_dependencies().size();
461         if (node.is_type<detection_output>() || node.is_type<proposal>())
462             inputs_count = 2; //ignore third input as it is related to prior boxes (i.e. concat of prior-boxes)
463         for (size_t idx = 0; idx < inputs_count; idx++)
464         {
465             if (node.get_dependency(idx).is_in_data_flow())
466             {
467                 node.data_flow = true;
468                 return;
469             }
470         }
471     }
472 }
473 void program_impl::cleanup()
474 {
475     for (auto& node : processing_order)
476         if (!node->is_type<internal_primitive>())
477             node->get_output_layout();
478
479     //in debug build, at the end, mark all nodes as outputs so user can query for buffers of all not-optimized nodes, including internal ones etc.
480     if (is_debug_build())
481     {
482         for (auto& node : processing_order)
483         {
484             if (!node->is_output())
485             {
486                 node->set_output(true);
487                 outputs.push_back(node);
488             }
489         }
490     }
491 }
492
493 void program_impl::add_split_outputs() {
494     auto itr = nodes_map.begin();
495     while (itr != nodes_map.end())
496     {
497         auto node_itr = itr++;
498         auto& node = (*node_itr).second;
499
500         if (node->is_type<split>())
501         {
502             auto split_prim = node->as<split>().typed_desc();
503             primitive_id input_id = split_prim->input[0];
504             auto split_num = split_prim->output_offsets.size();
505
506             //create crop for each split ouptut provided
507             for (decltype(split_num) i = 0; i < split_num; i++)
508             {
509                 primitive_id output_id = node->id() + ":" + split_prim->output_ids[i];
510
511                 //create dummy crop primitive and add it to nodes map
512                 auto crop_prim = std::make_shared<crop>(output_id, input_id, tensor{ 1,1,1,1 }, split_prim->output_offsets[i]);
513                 get_or_create(crop_prim);
514             }
515         }
516     }
517 }
518
519 program_impl::nodes_ordering& program_impl::get_processing_order()
520 {
521     return processing_order;
522 }
523
524 const program_impl::nodes_ordering& program_impl::get_processing_order() const
525 {
526     return processing_order;
527 }
528
529 void add_memory_dependency(program_node* node, program_node* dep)
530 {
531     if (node->can_be_optimized() ||
532         !dep->can_be_optimized())
533     {
534         node->add_memory_dependency(dep->id());
535     }
536     else
537     {
538         if (node->id() == dep->id())
539         {
540             return;
541         }
542         for (auto subdep : dep->get_dependencies())
543         {
544             add_memory_dependency(node, subdep);
545             add_memory_dependency(subdep, node);
546         }
547     }
548 }
549
550 void program_impl::basic_memory_dependencies()
551 {
552     auto itr = processing_order.begin();
553     std::vector<primitive_id> past_outputs;
554     while (itr != processing_order.end())
555     {
556         auto& node = *itr;
557         itr++;
558
559         //data primitive can't be reused
560         if (node->is_type<data>())
561             continue;
562
563         // add my dependencies to restriction list (can't share input.output buffers)
564         for (auto it : node->get_dependencies())
565         {
566             add_memory_dependency(node, it);
567             add_memory_dependency(it, node);
568         }
569
570         // Note we iterate over processing order, it means if primitve has processing num greater than any of outputs, this output
571         // has to land on the primitve restriction list. Otherwise memory reuse can corrupt final results.
572         node->add_memory_dependency(past_outputs);
573         // if current node is an output add it to the outputs list after restriction.
574         if (node->is_output())
575             past_outputs.push_back(node->id());
576     }
577 }
578
579
580 void program_impl::skipped_branch_memory_dependencies()
581 {
582     // Primitive A can't use primitive B buffer if processing_num(B) < processing_num(A) and for any usr - the user of B processing_num(usr) > processing_num(A)
583     // Otherwise it could override data that has to be used in the future.
584     auto itrB = processing_order.begin();
585     while (itrB != processing_order.end())
586     {
587         auto& nodeB = *itrB;
588         auto itrA = ++itrB;
589         if (nodeB->get_users().size()==0)
590             continue;
591
592         // find the last user of B in processing order
593         auto itrUsr = nodeB->get_users().begin();
594         auto lastUsr = itrUsr++;
595         while (itrUsr != nodeB->get_users().end())
596         {
597             if (processing_order.get_processing_number(*lastUsr) < processing_order.get_processing_number(*itrUsr))
598                 lastUsr = itrUsr;
599             itrUsr++;
600         }
601
602         //mark all nodes in between B and lastUsr of B as forbidden to share buffer with B
603         while (itrA != processing_order.get_processing_iterator(**lastUsr))
604         {
605             auto& nodeA = *itrA;
606             itrA++;
607             add_memory_dependency(nodeA, nodeB);
608             add_memory_dependency(nodeB, nodeA);
609         }
610     }
611 }
612
613 void program_impl::oooq_memory_dependencies()
614 {
615     auto itr = processing_order.begin();
616     // This order let us build dependencies based on syncing points.
617     // Set of nodes between two syncing points will be called sync_region.
618     // Major rules is: can't share resource with nodes in my sync_region
619
620     int32_t last_barrier = 0;
621     bool needs_barrier = false;
622     std::vector<cldnn::program_node*> sync_region;
623     while (itr != processing_order.end())
624     {
625         auto& node = *itr;
626         itr++;
627
628         // if any of dep has proccess num after barrier -> needs barrier
629         for (auto dep : node->get_dependencies())
630         {
631             if (processing_order.get_processing_number(dep) >= last_barrier)
632             {
633                 needs_barrier = true;
634                 break;
635             }
636         }
637
638         if (needs_barrier)
639         {
640             last_barrier = processing_order.get_processing_number(node);
641             needs_barrier = false;
642             // add each pair bi-direction dependency
643             for (auto nd1 = sync_region.begin(); nd1 + 1 != sync_region.end(); nd1++)
644             {
645                 for (auto nd2 = nd1 + 1; nd2 != sync_region.end(); nd2++)
646                 {
647                     add_memory_dependency(*nd1, *nd2);
648                     add_memory_dependency(*nd2, *nd1);
649                 }
650             }
651
652             // collect dependencies of every node in sync region
653             std::vector<cldnn::program_node*> deps;
654             for (auto& nd_in_region : sync_region)
655                 for (auto& dep : nd_in_region->get_dependencies())
656                     deps.emplace_back(dep);
657
658
659             for (auto& nd_in_region : sync_region)
660                 for (auto& dep : deps)
661                 {
662                     add_memory_dependency(nd_in_region, dep);
663                     add_memory_dependency(dep, nd_in_region);
664                 }
665
666             sync_region.clear();
667         }
668         sync_region.push_back(node);
669     }
670 }
671
672 void program_impl::prepare_memory_dependencies()
673 {
674     if (!get_engine().configuration().enable_memory_pool)
675         return;
676
677     basic_memory_dependencies();
678     skipped_branch_memory_dependencies();
679     oooq_memory_dependencies();
680 }
681
682 std::string program_impl::get_memory_dependencies_string() const
683 {
684     std::string mem_dep = "Memory dependencies/restrictions:\n";
685     auto itr = processing_order.begin();
686     while (itr != processing_order.end())
687     {
688         auto& node = *itr;
689         itr++;
690         mem_dep = mem_dep.append("primitive: ").append(node->id()).append(" restricted list: ");
691         for (auto it : node->get_memory_dependencies())
692             mem_dep == mem_dep.append(it).append(", ");
693         mem_dep = mem_dep.append("\n");
694     }
695     return mem_dep;
696 }
697
698 void program_impl::handle_reshape()
699 {
700     //reshape primitive by definition does not change underlying data, only shape description
701     //however during graph initialization and data optimization the layouts can be changed without user's knowledge,
702     //when reshape is followed by reorder, it is likely that reorder's output will not be as expected (for example reshape with flattened shape)
703     //this pass resolved the issue by changing graph in the following way
704     //- in case reshape has multiple users with reshape->reorder sequence, it will be splitted to multiple reshape primitives with single user
705     //- in case of reshape->reorder sequence, the additional reorder before reshape will be added,
706     //  if last reorder does not contain padding or mean subtract, it will be removed later in the graph
707
708     for (const auto& node : processing_order)
709     {
710         if (node->is_type<reshape>())
711         {
712             auto& input_node = node->get_dependency(0);
713
714             if (input_node.is_type<reorder>())
715                 continue;
716
717             node->get_output_layout();
718             if (node->as<reshape>().is_in_place())
719                 node->optimized = true;
720
721             //vector for storing nodes that are reorder type, for which splitted primitives are needed (except for the first one where orginal reshape will be used)
722             std::vector<program_node*> reorder_node_to_split;
723
724             //find the users of reshape that are reorder type, if none present then skip the current node
725             for (const auto& user : node->get_users())
726             {
727                 if (user->is_type<reorder>())
728                     reorder_node_to_split.push_back(user);
729             }
730
731             if (!reorder_node_to_split.empty())
732             {
733                 auto& prim_node = node->as<reshape>();
734                 const auto& prim = prim_node.get_primitive();
735                 auto output_shape = prim->output_shape;
736
737                 //vector for storing reshape nodes to connect to new reorder nodes (if needed)
738                 std::vector<program_node*> reorder_reshape_nodes;
739
740                 bool skip_first_user = false;
741                 auto reshape_users = node->get_users();
742                 for (const auto& user : reshape_users)
743                 {
744                     //reshape node for first user will be the orginal reshape from the graph
745                     if (!skip_first_user)
746                     {
747                         if (std::find(reorder_node_to_split.begin(), reorder_node_to_split.end(), user) != reorder_node_to_split.end())
748                             reorder_reshape_nodes.push_back(node);
749                         skip_first_user = true;
750                         continue;
751                     }
752
753                     //other reshapes will be clones of the orginal one connected to reshape->reorder sequences
754                     if (std::find(reorder_node_to_split.begin(), reorder_node_to_split.end(), user) != reorder_node_to_split.end())
755                     {
756                         auto new_reshape = std::make_shared<reshape>("_reshape_split_" + user->id() + "_" + node->id(), input_node.id(), output_shape);
757                         auto& new_reshape_node = get_or_create(new_reshape);
758                         add_connection(input_node, new_reshape_node);
759                         user->replace_dependency(0, new_reshape_node);
760                         processing_order.insert_next(&input_node, &new_reshape_node);
761                         reorder_reshape_nodes.push_back(&new_reshape_node);
762                     }
763                 }
764
765                 //add new reorder nodes to proper reshape node
766                 auto reshape_reorder_id = 0;
767                 for (const auto& reorder_node : reorder_node_to_split)
768                 {
769                     /*
770                     auto new_reshape = std::make_shared<reshape>("_reshape_split_" + user->id() + "_" + node->id(), input_node.id(), output_shape);
771                     auto& new_reshape_node = get_or_create(new_reshape);
772                     add_connection(input_node, new_reshape_node);
773                     user->replace_dependency(0, new_reshape_node);
774                     processing_order.insert(std::next(processing_order.get_processing_iterator(input_node)), &new_reshape_node);
775                     reorder_reshape_nodes.push_back(&new_reshape_node);
776                     */
777                     auto& reorder_reshape_node = reorder_reshape_nodes[reshape_reorder_id];
778                     auto reshape_in_layout = reorder_node->get_output_layout();
779                     auto reshape_input = std::make_shared<reorder>("_reshape_input_" + reorder_node->id() + "_" + reorder_reshape_node->id(), input_node.id(),
780                         reshape_in_layout.format, reshape_in_layout.data_type);
781                     auto& reshape_input_node = get_or_create(reshape_input);
782                     add_intermediate(reshape_input_node, *reorder_reshape_node, 0, reshape_input_node.dependencies.empty());
783                     reshape_reorder_id++;
784                 }
785             }
786
787             auto reshape_layout = node->get_output_layout();
788             if (!(node->is_output()) && (reshape_layout.format != cldnn::format::bfyx))
789             {
790                 auto bfyx_layout = layout({ reshape_layout.data_type, cldnn::format::bfyx, reshape_layout.size });
791                 //when some primitive does an implicit reorder to some other format then we lose the info about pitches in reshape stage
792                 //we assume user provides the input vector in bfyx
793                 if (!program_helpers::are_layouts_identical(reshape_layout, bfyx_layout).second)
794                 {
795                     auto reshape_input = std::make_shared<reorder>("_reshape_input_" + node->id(), input_node.id(), cldnn::format::bfyx, reshape_layout.data_type);
796                     auto& reshape_input_node = get_or_create(reshape_input);
797                     add_intermediate(reshape_input_node, *node, 0, reshape_input_node.dependencies.empty());
798
799                     auto reshape_users = node->get_users();
800                     for (const auto& user : reshape_users)
801                     {
802                         size_t idx = 0;
803                         for (size_t i = 0; i < user->get_dependencies().size(); i++)
804                         {
805                             auto& input = user->get_dependency(i);
806                             if (input.id() == node->id()) {
807                                 idx = i;
808                                 break;
809                             }
810                         }
811                         auto reshape_output = std::make_shared<reorder>("_reshape_output_" + node->id(), user->id(), reshape_layout.format, reshape_layout.data_type);
812                         auto& reshape_output_node = get_or_create(reshape_output);
813                         add_intermediate(reshape_output_node, *user, idx, reshape_output_node.dependencies.empty());
814                     }
815                 }
816             }
817         }
818     }
819 }
820
821 void program_impl::apply_needed_padding(program_node& node, program_node& prev_node,
822     const padding& needed_padding)
823 {
824     auto target_layout = prev_node.get_output_layout();
825
826     // Short circuit if padding did not change.
827     if (target_layout.data_padding == needed_padding)
828         return;
829
830     // Special handling for input nodes.
831     if (prev_node.is_type<input_layout>() || prev_node.is_type<mutable_data>())
832     {
833         target_layout.data_padding = needed_padding;
834
835         auto r_prim = std::make_shared<reorder>("reorder_input_" + node.id(), prev_node.id(), target_layout);
836         add_intermediate(r_prim, node, 0);
837         return;
838     }
839
840     prev_node.merge_output_padding(needed_padding);
841 }
842
843 void program_impl::reverse_connection(program_node& dep_node, program_node& user_node)
844 {
845     if (std::find(dep_node.users.begin(), dep_node.users.end(), &user_node) != dep_node.users.end())
846     {
847         remove_connection(dep_node, user_node);
848         add_connection(user_node, dep_node);
849     }
850     else
851         throw std::runtime_error("Trying to reverse connection, but nodes are wrongly or not connected.");
852 }
853
854 program_node& program_impl::get_or_create(std::shared_ptr<primitive> prim)
855 {
856     auto itr = nodes_map.lower_bound(prim->id);
857     if (itr != nodes_map.end() && itr->first == prim->id)
858         return *itr->second;
859
860     auto new_node = prim->type->create_node(*this, prim);
861     nodes_map.insert(itr, { prim->id, new_node });
862     return *new_node;
863 }
864
865 void program_impl::add_intermediate(program_node& node, program_node& next, size_t prev_idx,
866     bool connect_int_node_with_old_dep, bool move_usrs_of_prev_to_node)
867 {
868     if (connect_int_node_with_old_dep && !node.dependencies.empty())
869         throw std::invalid_argument("Node which is about to be added in between two other nodes should not have any existing dependencies");
870
871     auto& prev = next.get_dependency(prev_idx);
872     //firstly add connection, later replace dependency, so 'prev' won't become dangling and therefore removed
873     if (connect_int_node_with_old_dep)
874     {
875         add_connection(prev, node);
876         if (processing_order.size() != 0)
877         {
878             processing_order.insert_next(&prev, &node);
879         }
880     }
881
882     if (move_usrs_of_prev_to_node) {
883         auto itr = prev.get_users().begin();
884         while(itr!= prev.get_users().end())
885         {
886             auto usr = *itr;
887             itr++;
888             if (usr->id() != node.id())
889                 usr->replace_dependency(prev, node);
890         }
891         mark_if_constant(prev);
892         mark_if_constant(node);
893         mark_if_data_flow(prev);
894         mark_if_data_flow(node);
895     }
896     else {
897         next.replace_dependency(prev_idx, node);
898         node.constant = prev.constant;
899         node.data_flow = prev.data_flow;
900     }
901 }
902
903 void program_impl::add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx, 
904     bool connect_int_node_with_old_dep, bool move_usrs_of_prev_to_node)
905 {
906     add_intermediate(get_or_create(prim), next, prev_idx, connect_int_node_with_old_dep, move_usrs_of_prev_to_node);
907 }
908
909 void program_impl::add_connection(program_node& prev, program_node& next)
910 {
911     prev.users.push_back(&next);
912     next.dependencies.push_back(&prev);
913 }
914
915 void program_impl::remove_connection(program_node& prev, program_node& next)
916 {
917     prev.users.remove(&next);
918     next.dependencies.erase(std::remove(next.dependencies.begin(), next.dependencies.end(), &prev), next.dependencies.end());
919 }
920
921 void program_impl::remove_all_connections(program_node& node) {
922     // since the graph is not topological sorted, we need to remove the node from both dependencies and users
923     for (auto &e : node.users)
924     {
925         e->dependencies.erase(std::remove(e->dependencies.begin(), e->dependencies.end(), &node), e->dependencies.end());
926     }
927     for (auto &e : node.dependencies) 
928     {
929         e->users.remove(&node);
930     }
931     node.dependencies.clear();
932     node.users.clear();
933 }
934
935 void program_impl::rename(program_node & node, primitive_id const & new_id)
936 {
937     if (nodes_map.count(new_id))
938         throw std::runtime_error("Trying to rename program_node but node with id " + new_id + " already exists");
939     if (node.is_output())
940         throw std::invalid_argument("Trying to rename an output node. If you intend to do that, please clear 'output' flag manually.");
941
942     auto node_ptr = nodes_map.find(node.id())->second;
943     nodes_map.emplace(new_id, node_ptr);
944     nodes_map.erase(node.id());
945
946     if (!node.is_type<internal_primitive>())
947         const_cast<primitive_id&>(node.desc->id) = new_id;
948     else
949         reinterpret_cast<details::internal_program_node_base&>(node).internal_id = new_id;
950 }
951
952 void program_impl::swap_names(program_node& node1, program_node& node2)
953 {
954     const auto _extract_id = [](program_node& node) -> primitive_id&
955     {
956         if (!node.is_type<internal_primitive>())
957             return const_cast<primitive_id&>(node.desc->id);
958         else
959             return reinterpret_cast<details::internal_program_node_base&>(node).internal_id;
960     };
961
962     nodes_map.at(node1.id()).swap(nodes_map.at(node2.id()));
963     std::swap(_extract_id(node1), _extract_id(node2));
964 }
965
966 void program_impl::replace_all_usages(program_node & old_node, program_node & new_node)
967 {
968     auto itr = old_node.users.begin();
969     bool end = (itr == old_node.users.end());
970     while (!end)
971     {
972         auto& usage = (*itr++);
973         end = (itr == old_node.users.end());
974         usage->replace_dependency(old_node, new_node);
975     }
976 }
977
978 void program_impl::replace(program_node& old_node, program_node& new_node)
979 {
980     if (!new_node.dependencies.empty() || !new_node.users.empty())
981         throw std::invalid_argument("Node which is about to replace other node should be detached");
982
983     if (new_node.is_output())
984         throw std::invalid_argument("Replacement node shouldn't be marked as an output since it's impossible to rename such node.");
985
986     auto id = old_node.id();
987     new_node.output_layout = old_node.get_output_layout();
988     new_node.valid_output_layout = old_node.valid_output_layout;
989
990     
991     //copy old's dependencies
992     while (!old_node.dependencies.empty())
993     {
994         auto& dep = old_node.dependencies.front();
995         add_connection(*dep, new_node);
996         remove_connection(*dep, old_node);
997     }
998
999     //append users
1000     for (auto& user : old_node.users)
1001     {
1002         new_node.users.push_back(user);
1003         for (auto& users_dep : user->dependencies)
1004         {
1005             if (users_dep == &old_node)
1006             {
1007                 users_dep = &new_node;
1008                 break;
1009             }
1010         }
1011     }
1012
1013     old_node.users.clear();
1014
1015     bool old_was_output = false;
1016     //copy node's state
1017     if (old_node.is_output())
1018     {
1019         old_was_output = true;
1020         old_node.set_output(false);
1021         outputs.erase(std::remove(outputs.begin(), outputs.end(), &old_node), outputs.end());
1022     }
1023     if (new_node.is_input())
1024         inputs.push_back(&new_node);
1025     if (old_node.is_input())
1026         inputs.remove(&old_node);
1027
1028     new_node.constant = old_node.constant;
1029     new_node.user_mark = old_node.user_mark;
1030
1031     processing_order.insert(&old_node, &new_node);
1032     if (processing_order.get_processing_iterator(old_node) != processing_order.end())
1033         processing_order.erase(&old_node);
1034     nodes_map.erase(id);
1035     rename(new_node, id);
1036
1037     //mark new node as an output after renaming
1038     if (old_was_output)
1039     {
1040         new_node.set_output(true);
1041         outputs.push_back(&new_node);
1042     }
1043 }
1044
1045 bool program_impl::remove_if_dangling(program_node& node)
1046 {
1047     if (!node.users.empty())
1048         return false;
1049     if (!node.dependencies.empty())
1050         return false;
1051
1052     if (!node.is_output() || is_debug_build())
1053     {
1054         if (node.is_input())
1055             inputs.remove(&node);
1056
1057         if (std::find(processing_order.begin(), processing_order.end(), &node) != processing_order.end())
1058             processing_order.erase(&node);
1059         optimized_out.push_back(node.id());
1060         nodes_map.erase(node.id());
1061     }
1062     return true;
1063 }
1064
1065 bool program_impl::extract_and_remove(program_node& node)
1066 {
1067     if (node.get_dependencies().size() != 1)
1068         return false;
1069
1070     if (node.is_output() && node.get_dependency(0).is_output() && !is_debug_build()) //TODO: add a mechanism to support removal of nodes which are marked as outputs
1071         return false;
1072
1073     if (node.is_output() && !is_debug_build())
1074     {
1075         auto& prev = node.get_dependency(0);
1076         auto node_id = node.id();
1077
1078         node.set_output(false);
1079         outputs.erase(std::remove(outputs.begin(), outputs.end(), &node), outputs.end());
1080
1081         rename(node, "_cldnn_tmp_" + node_id);
1082         rename(prev, node_id);
1083
1084         prev.set_output(true);
1085         outputs.push_back(&prev);
1086     }
1087
1088     auto& input = node.get_dependency(0);
1089     node.dependencies.clear();
1090     input.users.remove(&node);
1091
1092     if (!node.is_endpoint())
1093         replace_all_usages(node, input);
1094     else
1095         remove_if_dangling(node);
1096
1097     return true;
1098 }
1099
1100 void program_impl::remove_nodes(std::list<program_node*>& to_remove)
1101 {
1102     for (auto const& node : to_remove)
1103     {
1104         if (node->is_input())
1105             get_inputs().remove(node);
1106         else
1107         {
1108             for (auto& dep : node->dependencies)
1109                 dep->users.remove(node);
1110         }
1111         for (auto& user : node->users)
1112         {
1113             user->dependencies.erase(std::remove(user->dependencies.begin(),
1114                 user->dependencies.end(), node),
1115                 user->dependencies.end());
1116         }
1117         get_processing_order().erase(node);
1118         optimized_out.push_back(node->id());
1119         nodes_map.erase(node->id());
1120     }
1121 }
1122
1123 void program_impl::dump_memory_pool() const
1124 {
1125     if (!get_engine().configuration().enable_memory_pool)
1126         return;
1127     auto path = get_dir_path(options);
1128     if (path.empty())
1129     {
1130         return;
1131     }
1132     path += "cldnn_memory_pool.log";
1133     auto dep = get_memory_dependencies_string();
1134     get_engine().dump_memory_pool(*this, path, dep);
1135     std::string dump_file_name = std::to_string(pm->get_pass_count()+1) + "_memory_pool";
1136     dump_program(dump_file_name.c_str(), true);
1137 }
1138
1139 //TODO: break this function into number of smaller ones + add per-primitive fields (possibly use primitive_inst::to_string?)
1140 void program_impl::dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter) const
1141 {
1142     std::string path = get_dir_path(options);
1143     if (path.empty())
1144     {
1145         return;
1146     }
1147
1148     std::ofstream graph(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".graph");
1149     dump_graph_init(graph, *this, filter);
1150
1151     if (!with_full_info)
1152     {
1153         return;
1154     }
1155
1156     graph.open(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".info");
1157     dump_graph_info(graph, *this, filter);
1158
1159     graph.open(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".order");
1160     dump_graph_processing_order(graph, *this);
1161
1162     graph.open(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".optimized");
1163     dump_graph_optimized(graph, *this);
1164 }
1165
1166