2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
19 #include "error_handler.h"
20 #include "kernel_selector_helper.h"
21 #include "internal_primitive.h"
22 #include "internal_primitive_type_base.h"
23 #include "layout_optimizer.h"
24 #include "pass_manager.h"
25 #include "primitive_type.h"
26 #include "program_dump_graph.h"
27 #include "program_helpers.h"
28 #include "program_impl.h"
29 #include "sliding_window_utils.h"
31 #include "convolution_inst.h"
32 #include "concatenation_inst.h"
33 #include "crop_inst.h"
34 #include "data_inst.h"
35 #include "deconvolution_inst.h"
36 #include "detection_output_inst.h"
37 #include "input_layout_inst.h"
38 #include "lstm_inst.h"
39 #include "lstm_elt_inst.h"
40 #include "lstm_gemm_inst.h"
41 #include "mutable_data_inst.h"
42 #include "pooling_inst.h"
43 #include "primitive_inst.h"
44 #include "prior_box_inst.h"
45 #include "proposal_inst.h"
46 #include "reorder_inst.h"
47 #include "reshape_inst.h"
48 #include "split_inst.h"
50 #include "gpu/ocl_toolkit.h"
60 program_impl::program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal, bool no_optimizations)
61 : engine(&engine_ref), options(options), processing_order(* new nodes_ordering), pm(std::unique_ptr<pass_manager>(new pass_manager()))
64 prepare_nodes(topology);
68 build_program(is_internal);
71 program_impl::program_impl(engine_impl& engine_ref, std::set<std::shared_ptr<program_node>> const& nodes, build_options const& options, bool is_internal)
72 : engine(&engine_ref), options(options), processing_order(*new nodes_ordering), pm(std::unique_ptr<pass_manager>(new pass_manager()))
76 build_program(is_internal);
79 program_impl::~program_impl() = default;
81 program_node& program_impl::get_node(primitive_id const& id)
85 return *nodes_map.at(id);
89 throw std::runtime_error("Program doesn't contain primtive node: " + id);
93 program_node const& program_impl::get_node(primitive_id const& id) const
97 return *nodes_map.at(id);
101 throw std::runtime_error("Program doesn't contain primtive node: " + id);
105 // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
106 bool program_impl::analyze_output_size_handling_need()
108 bool handling_needed = false;
110 // Calculate output size and compare with specified.
111 for (const auto& node : processing_order)
113 if (node->is_type<convolution>())
115 auto& prim_node = node->as<convolution>();
116 const auto& prim = prim_node.get_primitive();
118 if (!prim->with_output_size)
121 tensor specified_output_range({ 0, 0, prim->output_size.spatial[0], prim->output_size.spatial[1] }, 1);
123 auto filter_size = prim_node.weights(0).get_output_layout().size;
125 auto calc_output_range = calc_sliding_window_output_range<swor_mode::all>(
126 prim_node.input().get_output_layout().size,
127 filter_size, prim->input_offset, prim->stride, prim->dilation, true, 1);
129 if (specified_output_range != calc_output_range)
130 handling_needed = true;
132 else if (node->is_type<deconvolution>())
134 auto& prim_node = node->as<deconvolution>();
135 const auto& prim = prim_node.get_primitive();
137 if (!prim->with_output_size)
140 tensor specified_output_range({ 0, 0, prim->output_size.spatial[0], prim->output_size.spatial[1] }, 1);
142 auto filter_size = prim_node.weights(0).get_output_layout().size;
144 auto calc_output_range = calc_sliding_window_needed_input_range(
145 prim_node.input().get_output_layout().size,
146 filter_size, prim->input_offset, prim->stride, { 1, 1, 1, 1 }, true, 1);
148 if (specified_output_range != calc_output_range)
149 handling_needed = true;
151 else if (node->is_type<pooling>())
153 auto& prim_node = node->as<pooling>();
154 const auto& prim = prim_node.get_primitive();
156 if (!prim->with_output_size)
159 tensor specified_output_range({ 0, 0, prim->output_size.spatial[0], prim->output_size.spatial[1] }, 1);
161 // TODO: Check compatibility of output size calculation (with caffe).
162 auto calc_output_range = calc_sliding_window_output_range<swor_mode::exceed_once_data>(
163 prim_node.input().get_output_layout().size,
164 prim->size, prim->input_offset, prim->stride, { 1, 1, 1, 1 }, true, 1);
166 if (specified_output_range != calc_output_range)
167 handling_needed = true;
171 return handling_needed;
174 // create new nodes for a program based on the set of nodes
175 // method created to be used by propagate_constants to build sub program from constant nodes
176 void program_impl::prepare_nodes(std::set<std::shared_ptr<program_node>>const &nodes)
178 for (const auto& itr : nodes)
180 if (itr.get()->is_type<data>())
183 std::make_shared<input_layout>(itr.get()->id(), itr.get()->as<data>().get_primitive()->mem.get_layout())
188 get_or_create(itr->desc);
191 for (const auto& node : nodes_map)
193 auto node_ptr = node.second;
194 if (node_ptr == nullptr)
195 throw error("NULL pointer in nodes_map.", CLDNN_ERROR);
196 //ToDo: avoid O(n^2) run time here (pass map instead of set?)
198 for (const auto& src_node : nodes)
200 if (src_node == nullptr)
201 throw error("NULL pointer in nodes_map.", CLDNN_ERROR);
202 if (node.first == src_node->get_primitive()->id)
204 copy_node_dependencies(node_ptr.get(), src_node.get());
211 add_node_dependencies(node_ptr.get());
213 if (node_ptr->dependencies.size() == 0)
214 inputs.push_back(node_ptr.get());
218 // create all nodes from topology primitives, add dependencies among them and create inputs list
219 void program_impl::prepare_nodes(topology_impl const &topology)
221 auto const& topo_map = topology.get_primitives();
222 for (const auto& prim : topo_map)
224 get_or_create(prim.second);
227 for (const auto& node : nodes_map)
229 auto node_ptr = node.second.get();
230 if (node_ptr == nullptr)
231 throw error("NULL pointer in nodes_map.", CLDNN_ERROR);
232 add_node_dependencies(node_ptr);
233 if (node_ptr->dependencies.size()==0)
235 inputs.push_back(node_ptr);
240 // add node's dependecies from its primitive dependencies
241 void program_impl::add_node_dependencies(program_node* node)
243 auto deps = node->get_primitive()->dependencies();
244 //add pointers to node's dependencies
245 for (auto& dep : deps)
248 auto dep_node = nodes_map.at(dep);
249 node->dependencies.push_back(dep_node.get());
250 dep_node->users.push_back(node);
253 throw std::runtime_error("Program doesn't contain primitive: " + dep +
254 " that is input to: " + node->get_primitive()->id);
259 /* helper method for program_impl constructor from list of nodes which
260 copies src_node dependecies to the destination node dest_node dependencies.
261 But only to those which appaer in this program implementation nodes_map */
262 void program_impl::copy_node_dependencies(program_node* dest_node, program_node* src_node)
264 if (dest_node->get_primitive()->id != src_node->get_primitive()->id)
266 throw std::runtime_error("Node " + src_node->get_primitive()->id + " and its copy " + dest_node->get_primitive()->id + " do not match.");
268 auto src_deps = src_node->get_dependencies();
269 //add pointers to node's dependencies
270 for (auto& src_dep : src_deps)
272 // do not copy dependencies to nodes which does not belong to the new (subgraph) topology
273 if (nodes_map.find(src_dep->get_primitive()->id) == nodes_map.end()) continue;
276 auto dest_dep = nodes_map.at(src_dep->get_primitive()->id);
277 dest_node->dependencies.push_back(dest_dep.get());
278 dest_dep->users.push_back(dest_node);
281 throw std::runtime_error("Program doesn't contain primitive: " + src_dep->get_primitive()->id +
282 " that is input to: " + src_node->get_primitive()->id);
287 void program_impl::set_options()
289 static std::atomic<uint32_t> id_gen{ 0 };
291 assert(prog_id != 0);
293 if ((options.get<build_option_type::tuning_config>()->config.mode == tuning_mode::tuning_tune_and_cache) &&
294 !engine->configuration().enable_profiling)
296 throw std::invalid_argument("Engine must be created with profiling enabled in tune_and_cache mode!");
300 void program_impl::build_program(bool is_internal)
304 pre_optimize_graph(is_internal);
306 run_graph_compilation();
308 post_optimize_graph(is_internal);
310 engine->compile_program(*this);
311 this->dump_program("finished", true);
315 void program_impl::init_graph()
317 graph_initializations graph_initializations_pass;
318 pm->run(*this, graph_initializations_pass);
320 calculate_prior_boxes calculate_prior_boxes_pass;
321 pm->run(*this, calculate_prior_boxes_pass);
323 mark_nodes mark_nodes_pass;
324 pm->run(*this, mark_nodes_pass);
327 void program_impl::run_graph_compilation() {
328 compile_graph compile_graph_pass;
329 pm->run(*this, compile_graph_pass);
332 void program_impl::pre_optimize_graph(bool is_internal)
334 trim_to_outputs trim_pass; //trim to outputs
335 pm->run(*this, trim_pass); // ToDo remove hidden dependencies from trimm pass
337 handle_input_padding handle_input_padding; // handle symmetric and asymmetric padding for input
338 pm->run(*this, handle_input_padding);
340 add_reshape_to_primitives add_reshape_to_primitives_pass; // add reshape to input/parameters for some primitives
341 pm->run(*this, add_reshape_to_primitives_pass);
343 processing_order.calculate_BFS_processing_order(); // this method makes sense only for OOOQ (out of order execution queue)
345 bool output_size_handling_enabled = analyze_output_size_handling_need();
346 for (auto& node : processing_order)
348 if (!node->is_type<internal_primitive>() && !node->is_type<data>())
349 node->get_output_layout();
352 if (options.get<build_option_type::optimize_data>()->enabled())
354 prepare_primitive_fusing prepare_primitive_fusing_pass;
355 pm->run(*this, prepare_primitive_fusing_pass);
357 layout_optimizer lo(output_size_handling_enabled);
358 reorder_inputs reorder_inputs_pass(lo);
359 pm->run(*this, reorder_inputs_pass);
361 // this code should be moved to post compilation after kernel selector will support handling reorder bias
362 pre_optimize_bias pre_optimize_bias_pass(lo);
363 pm->run(*this, pre_optimize_bias_pass);
365 // passes regarding conv + eltwise optimizations
367 // shrinking eltwise if users are conv 1x1 with stride > 1 optimization
368 eltwise_shrinking eltwise_shrinking_pass;
369 pm->run(*this, eltwise_shrinking_pass);
371 // trying to set stride to 1x1 by shrinking convolutions before eltwise if doable
372 eltwise_remove_stride eltwise_remove_stride_pass;
373 pm->run(*this, eltwise_remove_stride_pass);
375 prepare_conv_eltw_fusing prepare_conv_eltw_fusing_pass;
376 pm->run(*this, prepare_conv_eltw_fusing_pass);
378 prepare_conv_eltw_read_write_opt prepare_conv_eltw_read_write_opt_pass;
379 pm->run(*this, prepare_conv_eltw_read_write_opt_pass);
384 remove_redundant_reorders remove_redundant_reorders_pass;
385 pm->run(*this, remove_redundant_reorders_pass);
387 prepare_padding prepare_padding_pass(output_size_handling_enabled);
388 pm->run(*this, prepare_padding_pass);
390 prepare_depthwise_sep_opt prepare_depthwise_sep_opt_pass;
391 pm->run(*this, prepare_depthwise_sep_opt_pass);
395 propagate_constants propagate_constants_pass; // ToDo remove hidden dependencies from propagate_constants pass
396 pm->run(*this, propagate_constants_pass);
399 //try to fuse buffers (i.e. depth_concat in bfyx format) after padding calculations
400 if (options.get<build_option_type::optimize_data>()->enabled())
402 prepare_buffer_fusing prepare_buffer_fusing_pass;
403 pm->run(*this, prepare_buffer_fusing_pass);
406 //check if there exists some layout incompatibilities and add an reorder node if required
407 add_required_reorders add_required_reorders_pass;
408 pm->run(*this, add_required_reorders_pass);
411 void program_impl::post_optimize_graph(bool is_internal)
414 post_optimize_weights post_optimize_weights_pass(lo);
415 pm->run(*this, post_optimize_weights_pass);
417 remove_redundant_reorders remove_redundant_reorders_pass;
418 pm->run(*this, remove_redundant_reorders_pass); //TODO: do we need it at this place also?
422 propagate_constants propagate_constants_pass; // ToDo remove hidden dependencies from propagate_constants pass
423 pm->run(*this, propagate_constants_pass);
426 prep_opt_depthwise_sep_post prep_opt_depthwise_sep_post_pass;
427 pm->run(*this, prep_opt_depthwise_sep_post_pass);
429 prepare_memory_dependencies();
432 // mark if the node is constant assuming that all dependencies are marked properly
433 void program_impl::mark_if_constant(program_node& node)
435 if (node.get_dependencies().empty())
437 if (node.is_type<prior_box>())
439 node.constant = true;
440 for (auto& dep : node.get_dependencies())
444 node.constant = false;
450 // mark if the node is in data flow assuming that all dependencies are marked properly
451 void program_impl::mark_if_data_flow(program_node& node)
453 if (node.is_type<mutable_data>() || node.is_type<input_layout>())
455 node.data_flow = true;
459 node.data_flow = false;
460 size_t inputs_count = node.get_dependencies().size();
461 if (node.is_type<detection_output>() || node.is_type<proposal>())
462 inputs_count = 2; //ignore third input as it is related to prior boxes (i.e. concat of prior-boxes)
463 for (size_t idx = 0; idx < inputs_count; idx++)
465 if (node.get_dependency(idx).is_in_data_flow())
467 node.data_flow = true;
473 void program_impl::cleanup()
475 for (auto& node : processing_order)
476 if (!node->is_type<internal_primitive>())
477 node->get_output_layout();
479 //in debug build, at the end, mark all nodes as outputs so user can query for buffers of all not-optimized nodes, including internal ones etc.
480 if (is_debug_build())
482 for (auto& node : processing_order)
484 if (!node->is_output())
486 node->set_output(true);
487 outputs.push_back(node);
493 void program_impl::add_split_outputs() {
494 auto itr = nodes_map.begin();
495 while (itr != nodes_map.end())
497 auto node_itr = itr++;
498 auto& node = (*node_itr).second;
500 if (node->is_type<split>())
502 auto split_prim = node->as<split>().typed_desc();
503 primitive_id input_id = split_prim->input[0];
504 auto split_num = split_prim->output_offsets.size();
506 //create crop for each split ouptut provided
507 for (decltype(split_num) i = 0; i < split_num; i++)
509 primitive_id output_id = node->id() + ":" + split_prim->output_ids[i];
511 //create dummy crop primitive and add it to nodes map
512 auto crop_prim = std::make_shared<crop>(output_id, input_id, tensor{ 1,1,1,1 }, split_prim->output_offsets[i]);
513 get_or_create(crop_prim);
519 program_impl::nodes_ordering& program_impl::get_processing_order()
521 return processing_order;
524 const program_impl::nodes_ordering& program_impl::get_processing_order() const
526 return processing_order;
529 void add_memory_dependency(program_node* node, program_node* dep)
531 if (node->can_be_optimized() ||
532 !dep->can_be_optimized())
534 node->add_memory_dependency(dep->id());
538 if (node->id() == dep->id())
542 for (auto subdep : dep->get_dependencies())
544 add_memory_dependency(node, subdep);
545 add_memory_dependency(subdep, node);
550 void program_impl::basic_memory_dependencies()
552 auto itr = processing_order.begin();
553 std::vector<primitive_id> past_outputs;
554 while (itr != processing_order.end())
559 //data primitive can't be reused
560 if (node->is_type<data>())
563 // add my dependencies to restriction list (can't share input.output buffers)
564 for (auto it : node->get_dependencies())
566 add_memory_dependency(node, it);
567 add_memory_dependency(it, node);
570 // Note we iterate over processing order, it means if primitve has processing num greater than any of outputs, this output
571 // has to land on the primitve restriction list. Otherwise memory reuse can corrupt final results.
572 node->add_memory_dependency(past_outputs);
573 // if current node is an output add it to the outputs list after restriction.
574 if (node->is_output())
575 past_outputs.push_back(node->id());
580 void program_impl::skipped_branch_memory_dependencies()
582 // Primitive A can't use primitive B buffer if processing_num(B) < processing_num(A) and for any usr - the user of B processing_num(usr) > processing_num(A)
583 // Otherwise it could override data that has to be used in the future.
584 auto itrB = processing_order.begin();
585 while (itrB != processing_order.end())
589 if (nodeB->get_users().size()==0)
592 // find the last user of B in processing order
593 auto itrUsr = nodeB->get_users().begin();
594 auto lastUsr = itrUsr++;
595 while (itrUsr != nodeB->get_users().end())
597 if (processing_order.get_processing_number(*lastUsr) < processing_order.get_processing_number(*itrUsr))
602 //mark all nodes in between B and lastUsr of B as forbidden to share buffer with B
603 while (itrA != processing_order.get_processing_iterator(**lastUsr))
607 add_memory_dependency(nodeA, nodeB);
608 add_memory_dependency(nodeB, nodeA);
613 void program_impl::oooq_memory_dependencies()
615 auto itr = processing_order.begin();
616 // This order let us build dependencies based on syncing points.
617 // Set of nodes between two syncing points will be called sync_region.
618 // Major rules is: can't share resource with nodes in my sync_region
620 int32_t last_barrier = 0;
621 bool needs_barrier = false;
622 std::vector<cldnn::program_node*> sync_region;
623 while (itr != processing_order.end())
628 // if any of dep has proccess num after barrier -> needs barrier
629 for (auto dep : node->get_dependencies())
631 if (processing_order.get_processing_number(dep) >= last_barrier)
633 needs_barrier = true;
640 last_barrier = processing_order.get_processing_number(node);
641 needs_barrier = false;
642 // add each pair bi-direction dependency
643 for (auto nd1 = sync_region.begin(); nd1 + 1 != sync_region.end(); nd1++)
645 for (auto nd2 = nd1 + 1; nd2 != sync_region.end(); nd2++)
647 add_memory_dependency(*nd1, *nd2);
648 add_memory_dependency(*nd2, *nd1);
652 // collect dependencies of every node in sync region
653 std::vector<cldnn::program_node*> deps;
654 for (auto& nd_in_region : sync_region)
655 for (auto& dep : nd_in_region->get_dependencies())
656 deps.emplace_back(dep);
659 for (auto& nd_in_region : sync_region)
660 for (auto& dep : deps)
662 add_memory_dependency(nd_in_region, dep);
663 add_memory_dependency(dep, nd_in_region);
668 sync_region.push_back(node);
672 void program_impl::prepare_memory_dependencies()
674 if (!get_engine().configuration().enable_memory_pool)
677 basic_memory_dependencies();
678 skipped_branch_memory_dependencies();
679 oooq_memory_dependencies();
682 std::string program_impl::get_memory_dependencies_string() const
684 std::string mem_dep = "Memory dependencies/restrictions:\n";
685 auto itr = processing_order.begin();
686 while (itr != processing_order.end())
690 mem_dep = mem_dep.append("primitive: ").append(node->id()).append(" restricted list: ");
691 for (auto it : node->get_memory_dependencies())
692 mem_dep == mem_dep.append(it).append(", ");
693 mem_dep = mem_dep.append("\n");
698 void program_impl::handle_reshape()
700 //reshape primitive by definition does not change underlying data, only shape description
701 //however during graph initialization and data optimization the layouts can be changed without user's knowledge,
702 //when reshape is followed by reorder, it is likely that reorder's output will not be as expected (for example reshape with flattened shape)
703 //this pass resolved the issue by changing graph in the following way
704 //- in case reshape has multiple users with reshape->reorder sequence, it will be splitted to multiple reshape primitives with single user
705 //- in case of reshape->reorder sequence, the additional reorder before reshape will be added,
706 // if last reorder does not contain padding or mean subtract, it will be removed later in the graph
708 for (const auto& node : processing_order)
710 if (node->is_type<reshape>())
712 auto& input_node = node->get_dependency(0);
714 if (input_node.is_type<reorder>())
717 node->get_output_layout();
718 if (node->as<reshape>().is_in_place())
719 node->optimized = true;
721 //vector for storing nodes that are reorder type, for which splitted primitives are needed (except for the first one where orginal reshape will be used)
722 std::vector<program_node*> reorder_node_to_split;
724 //find the users of reshape that are reorder type, if none present then skip the current node
725 for (const auto& user : node->get_users())
727 if (user->is_type<reorder>())
728 reorder_node_to_split.push_back(user);
731 if (!reorder_node_to_split.empty())
733 auto& prim_node = node->as<reshape>();
734 const auto& prim = prim_node.get_primitive();
735 auto output_shape = prim->output_shape;
737 //vector for storing reshape nodes to connect to new reorder nodes (if needed)
738 std::vector<program_node*> reorder_reshape_nodes;
740 bool skip_first_user = false;
741 auto reshape_users = node->get_users();
742 for (const auto& user : reshape_users)
744 //reshape node for first user will be the orginal reshape from the graph
745 if (!skip_first_user)
747 if (std::find(reorder_node_to_split.begin(), reorder_node_to_split.end(), user) != reorder_node_to_split.end())
748 reorder_reshape_nodes.push_back(node);
749 skip_first_user = true;
753 //other reshapes will be clones of the orginal one connected to reshape->reorder sequences
754 if (std::find(reorder_node_to_split.begin(), reorder_node_to_split.end(), user) != reorder_node_to_split.end())
756 auto new_reshape = std::make_shared<reshape>("_reshape_split_" + user->id() + "_" + node->id(), input_node.id(), output_shape);
757 auto& new_reshape_node = get_or_create(new_reshape);
758 add_connection(input_node, new_reshape_node);
759 user->replace_dependency(0, new_reshape_node);
760 processing_order.insert_next(&input_node, &new_reshape_node);
761 reorder_reshape_nodes.push_back(&new_reshape_node);
765 //add new reorder nodes to proper reshape node
766 auto reshape_reorder_id = 0;
767 for (const auto& reorder_node : reorder_node_to_split)
770 auto new_reshape = std::make_shared<reshape>("_reshape_split_" + user->id() + "_" + node->id(), input_node.id(), output_shape);
771 auto& new_reshape_node = get_or_create(new_reshape);
772 add_connection(input_node, new_reshape_node);
773 user->replace_dependency(0, new_reshape_node);
774 processing_order.insert(std::next(processing_order.get_processing_iterator(input_node)), &new_reshape_node);
775 reorder_reshape_nodes.push_back(&new_reshape_node);
777 auto& reorder_reshape_node = reorder_reshape_nodes[reshape_reorder_id];
778 auto reshape_in_layout = reorder_node->get_output_layout();
779 auto reshape_input = std::make_shared<reorder>("_reshape_input_" + reorder_node->id() + "_" + reorder_reshape_node->id(), input_node.id(),
780 reshape_in_layout.format, reshape_in_layout.data_type);
781 auto& reshape_input_node = get_or_create(reshape_input);
782 add_intermediate(reshape_input_node, *reorder_reshape_node, 0, reshape_input_node.dependencies.empty());
783 reshape_reorder_id++;
787 auto reshape_layout = node->get_output_layout();
788 if (!(node->is_output()) && (reshape_layout.format != cldnn::format::bfyx))
790 auto bfyx_layout = layout({ reshape_layout.data_type, cldnn::format::bfyx, reshape_layout.size });
791 //when some primitive does an implicit reorder to some other format then we lose the info about pitches in reshape stage
792 //we assume user provides the input vector in bfyx
793 if (!program_helpers::are_layouts_identical(reshape_layout, bfyx_layout).second)
795 auto reshape_input = std::make_shared<reorder>("_reshape_input_" + node->id(), input_node.id(), cldnn::format::bfyx, reshape_layout.data_type);
796 auto& reshape_input_node = get_or_create(reshape_input);
797 add_intermediate(reshape_input_node, *node, 0, reshape_input_node.dependencies.empty());
799 auto reshape_users = node->get_users();
800 for (const auto& user : reshape_users)
803 for (size_t i = 0; i < user->get_dependencies().size(); i++)
805 auto& input = user->get_dependency(i);
806 if (input.id() == node->id()) {
811 auto reshape_output = std::make_shared<reorder>("_reshape_output_" + node->id(), user->id(), reshape_layout.format, reshape_layout.data_type);
812 auto& reshape_output_node = get_or_create(reshape_output);
813 add_intermediate(reshape_output_node, *user, idx, reshape_output_node.dependencies.empty());
821 void program_impl::apply_needed_padding(program_node& node, program_node& prev_node,
822 const padding& needed_padding)
824 auto target_layout = prev_node.get_output_layout();
826 // Short circuit if padding did not change.
827 if (target_layout.data_padding == needed_padding)
830 // Special handling for input nodes.
831 if (prev_node.is_type<input_layout>() || prev_node.is_type<mutable_data>())
833 target_layout.data_padding = needed_padding;
835 auto r_prim = std::make_shared<reorder>("reorder_input_" + node.id(), prev_node.id(), target_layout);
836 add_intermediate(r_prim, node, 0);
840 prev_node.merge_output_padding(needed_padding);
843 void program_impl::reverse_connection(program_node& dep_node, program_node& user_node)
845 if (std::find(dep_node.users.begin(), dep_node.users.end(), &user_node) != dep_node.users.end())
847 remove_connection(dep_node, user_node);
848 add_connection(user_node, dep_node);
851 throw std::runtime_error("Trying to reverse connection, but nodes are wrongly or not connected.");
854 program_node& program_impl::get_or_create(std::shared_ptr<primitive> prim)
856 auto itr = nodes_map.lower_bound(prim->id);
857 if (itr != nodes_map.end() && itr->first == prim->id)
860 auto new_node = prim->type->create_node(*this, prim);
861 nodes_map.insert(itr, { prim->id, new_node });
865 void program_impl::add_intermediate(program_node& node, program_node& next, size_t prev_idx,
866 bool connect_int_node_with_old_dep, bool move_usrs_of_prev_to_node)
868 if (connect_int_node_with_old_dep && !node.dependencies.empty())
869 throw std::invalid_argument("Node which is about to be added in between two other nodes should not have any existing dependencies");
871 auto& prev = next.get_dependency(prev_idx);
872 //firstly add connection, later replace dependency, so 'prev' won't become dangling and therefore removed
873 if (connect_int_node_with_old_dep)
875 add_connection(prev, node);
876 if (processing_order.size() != 0)
878 processing_order.insert_next(&prev, &node);
882 if (move_usrs_of_prev_to_node) {
883 auto itr = prev.get_users().begin();
884 while(itr!= prev.get_users().end())
888 if (usr->id() != node.id())
889 usr->replace_dependency(prev, node);
891 mark_if_constant(prev);
892 mark_if_constant(node);
893 mark_if_data_flow(prev);
894 mark_if_data_flow(node);
897 next.replace_dependency(prev_idx, node);
898 node.constant = prev.constant;
899 node.data_flow = prev.data_flow;
903 void program_impl::add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx,
904 bool connect_int_node_with_old_dep, bool move_usrs_of_prev_to_node)
906 add_intermediate(get_or_create(prim), next, prev_idx, connect_int_node_with_old_dep, move_usrs_of_prev_to_node);
909 void program_impl::add_connection(program_node& prev, program_node& next)
911 prev.users.push_back(&next);
912 next.dependencies.push_back(&prev);
915 void program_impl::remove_connection(program_node& prev, program_node& next)
917 prev.users.remove(&next);
918 next.dependencies.erase(std::remove(next.dependencies.begin(), next.dependencies.end(), &prev), next.dependencies.end());
921 void program_impl::remove_all_connections(program_node& node) {
922 // since the graph is not topological sorted, we need to remove the node from both dependencies and users
923 for (auto &e : node.users)
925 e->dependencies.erase(std::remove(e->dependencies.begin(), e->dependencies.end(), &node), e->dependencies.end());
927 for (auto &e : node.dependencies)
929 e->users.remove(&node);
931 node.dependencies.clear();
935 void program_impl::rename(program_node & node, primitive_id const & new_id)
937 if (nodes_map.count(new_id))
938 throw std::runtime_error("Trying to rename program_node but node with id " + new_id + " already exists");
939 if (node.is_output())
940 throw std::invalid_argument("Trying to rename an output node. If you intend to do that, please clear 'output' flag manually.");
942 auto node_ptr = nodes_map.find(node.id())->second;
943 nodes_map.emplace(new_id, node_ptr);
944 nodes_map.erase(node.id());
946 if (!node.is_type<internal_primitive>())
947 const_cast<primitive_id&>(node.desc->id) = new_id;
949 reinterpret_cast<details::internal_program_node_base&>(node).internal_id = new_id;
952 void program_impl::swap_names(program_node& node1, program_node& node2)
954 const auto _extract_id = [](program_node& node) -> primitive_id&
956 if (!node.is_type<internal_primitive>())
957 return const_cast<primitive_id&>(node.desc->id);
959 return reinterpret_cast<details::internal_program_node_base&>(node).internal_id;
962 nodes_map.at(node1.id()).swap(nodes_map.at(node2.id()));
963 std::swap(_extract_id(node1), _extract_id(node2));
966 void program_impl::replace_all_usages(program_node & old_node, program_node & new_node)
968 auto itr = old_node.users.begin();
969 bool end = (itr == old_node.users.end());
972 auto& usage = (*itr++);
973 end = (itr == old_node.users.end());
974 usage->replace_dependency(old_node, new_node);
978 void program_impl::replace(program_node& old_node, program_node& new_node)
980 if (!new_node.dependencies.empty() || !new_node.users.empty())
981 throw std::invalid_argument("Node which is about to replace other node should be detached");
983 if (new_node.is_output())
984 throw std::invalid_argument("Replacement node shouldn't be marked as an output since it's impossible to rename such node.");
986 auto id = old_node.id();
987 new_node.output_layout = old_node.get_output_layout();
988 new_node.valid_output_layout = old_node.valid_output_layout;
991 //copy old's dependencies
992 while (!old_node.dependencies.empty())
994 auto& dep = old_node.dependencies.front();
995 add_connection(*dep, new_node);
996 remove_connection(*dep, old_node);
1000 for (auto& user : old_node.users)
1002 new_node.users.push_back(user);
1003 for (auto& users_dep : user->dependencies)
1005 if (users_dep == &old_node)
1007 users_dep = &new_node;
1013 old_node.users.clear();
1015 bool old_was_output = false;
1017 if (old_node.is_output())
1019 old_was_output = true;
1020 old_node.set_output(false);
1021 outputs.erase(std::remove(outputs.begin(), outputs.end(), &old_node), outputs.end());
1023 if (new_node.is_input())
1024 inputs.push_back(&new_node);
1025 if (old_node.is_input())
1026 inputs.remove(&old_node);
1028 new_node.constant = old_node.constant;
1029 new_node.user_mark = old_node.user_mark;
1031 processing_order.insert(&old_node, &new_node);
1032 if (processing_order.get_processing_iterator(old_node) != processing_order.end())
1033 processing_order.erase(&old_node);
1034 nodes_map.erase(id);
1035 rename(new_node, id);
1037 //mark new node as an output after renaming
1040 new_node.set_output(true);
1041 outputs.push_back(&new_node);
1045 bool program_impl::remove_if_dangling(program_node& node)
1047 if (!node.users.empty())
1049 if (!node.dependencies.empty())
1052 if (!node.is_output() || is_debug_build())
1054 if (node.is_input())
1055 inputs.remove(&node);
1057 if (std::find(processing_order.begin(), processing_order.end(), &node) != processing_order.end())
1058 processing_order.erase(&node);
1059 optimized_out.push_back(node.id());
1060 nodes_map.erase(node.id());
1065 bool program_impl::extract_and_remove(program_node& node)
1067 if (node.get_dependencies().size() != 1)
1070 if (node.is_output() && node.get_dependency(0).is_output() && !is_debug_build()) //TODO: add a mechanism to support removal of nodes which are marked as outputs
1073 if (node.is_output() && !is_debug_build())
1075 auto& prev = node.get_dependency(0);
1076 auto node_id = node.id();
1078 node.set_output(false);
1079 outputs.erase(std::remove(outputs.begin(), outputs.end(), &node), outputs.end());
1081 rename(node, "_cldnn_tmp_" + node_id);
1082 rename(prev, node_id);
1084 prev.set_output(true);
1085 outputs.push_back(&prev);
1088 auto& input = node.get_dependency(0);
1089 node.dependencies.clear();
1090 input.users.remove(&node);
1092 if (!node.is_endpoint())
1093 replace_all_usages(node, input);
1095 remove_if_dangling(node);
1100 void program_impl::remove_nodes(std::list<program_node*>& to_remove)
1102 for (auto const& node : to_remove)
1104 if (node->is_input())
1105 get_inputs().remove(node);
1108 for (auto& dep : node->dependencies)
1109 dep->users.remove(node);
1111 for (auto& user : node->users)
1113 user->dependencies.erase(std::remove(user->dependencies.begin(),
1114 user->dependencies.end(), node),
1115 user->dependencies.end());
1117 get_processing_order().erase(node);
1118 optimized_out.push_back(node->id());
1119 nodes_map.erase(node->id());
1123 void program_impl::dump_memory_pool() const
1125 if (!get_engine().configuration().enable_memory_pool)
1127 auto path = get_dir_path(options);
1132 path += "cldnn_memory_pool.log";
1133 auto dep = get_memory_dependencies_string();
1134 get_engine().dump_memory_pool(*this, path, dep);
1135 std::string dump_file_name = std::to_string(pm->get_pass_count()+1) + "_memory_pool";
1136 dump_program(dump_file_name.c_str(), true);
1139 //TODO: break this function into number of smaller ones + add per-primitive fields (possibly use primitive_inst::to_string?)
1140 void program_impl::dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter) const
1142 std::string path = get_dir_path(options);
1148 std::ofstream graph(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".graph");
1149 dump_graph_init(graph, *this, filter);
1151 if (!with_full_info)
1156 graph.open(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".info");
1157 dump_graph_info(graph, *this, filter);
1159 graph.open(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".order");
1160 dump_graph_processing_order(graph, *this);
1162 graph.open(path + "cldnn_program_" + std::to_string(prog_id) + "_" + stage + ".optimized");
1163 dump_graph_optimized(graph, *this);