2 // Copyright (c) 2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
21 #include "api/CPP/program.hpp"
23 #include "refcounted_obj.h"
24 #include "engine_impl.h"
32 struct primitive_impl;
34 class layout_optimizer;
36 class program_impl_wrapper;
38 cldnn_program implementation
40 struct program_impl : public refcounted_obj<program_impl>
42 friend class calculate_prior_boxes; // to be removed when possible
43 friend class graph_initializations; // to be removed when possible
44 friend class prepare_padding; // to be removed when possible
45 friend class propagate_constants; // to be removed when possible
46 friend class prepare_primitive_fusing; // to be removed when possible
47 friend class prepare_conv_eltw_fusing; // to be removed when possible
48 friend class prepare_conv_eltw_read_write_opt; // to be removed when possible
49 friend class reorder_inputs; // to be removed when possible
50 friend class program_impl_wrapper; // this class is intended to extend the interface of program_impl for
51 // the usage within tests_core_internal project only
56 typedef std::list<program_node*> list_of_nodes;
57 typedef list_of_nodes::const_iterator const_iterator;
58 typedef list_of_nodes::iterator node_iterator;
59 const_iterator begin() const { return _processing_order.begin(); }
60 const_iterator end() const { return _processing_order.end(); }
62 void calc_processing_order_visit(program_node* node);
63 void calc_processing_order(program_impl& p);
64 int32_t get_processing_number(program_node* node) const { return get_processing_number(get_processing_iterator(*node)); }
65 // int32_t get_processing_number(const_iterator iter) const { return 1+(int32_t)std::distance(begin(), iter); }
66 int32_t get_processing_number(node_iterator iter) const { return 1 + (int32_t)std::distance(_processing_order.begin(), const_iterator(iter)); }
67 void calculate_BFS_processing_order();
68 size_t size() { return _processing_order.size(); }
69 bool is_correct(program_node* node);
71 node_iterator get_processing_iterator(program_node& node) const
73 return processing_order_iterators.at(&node);
77 processing_order_iterators.clear();
78 _processing_order.clear();
81 void insert(program_node* key_node, program_node* node)
83 node_iterator _where = processing_order_iterators.at(key_node);
84 processing_order_iterators[node] = _processing_order.insert(_where, node);
87 void insert_next(program_node* key_node, program_node* node)
89 node_iterator _where = std::next(processing_order_iterators.at(key_node));
90 processing_order_iterators[node] = _processing_order.insert(_where, node);
93 void erase(program_node* key_node)
95 node_iterator i = processing_order_iterators.at(key_node);
96 processing_order_iterators.erase(key_node);
97 _processing_order.erase(i);
101 list_of_nodes _processing_order;
102 std::map<program_node*, node_iterator> processing_order_iterators;
106 struct single_element_container
108 single_element_container(T& t) : elem(&t)
110 constexpr size_t size() const { return 1; }
111 single_element_container begin() const { return single_element_container(elem); }
112 single_element_container end() const { return single_element_container(nullptr); }
113 single_element_container& operator ++() { elem = nullptr; return *this; }
114 bool operator !=(single_element_container const& sec) { return elem != sec.elem; }
116 T operator *() { return *elem; }
119 single_element_container(T* t) : elem(t)
124 program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal, bool no_optimizations=false);
125 /* constructor used to build a program from subset of nodes of other program (used in propagate_constants) */
126 program_impl(engine_impl& engine_ref, std::set<std::shared_ptr<program_node>> const &nodes, build_options const& options, bool is_internal);
128 engine_impl& get_engine() const { return *engine; }
129 const build_options& get_options() const { return options; }
130 std::list<program_node*>& get_inputs() { return inputs; } // ToDo: redesign trim to ouptut pass to make it const as_well as get_engine and get options
131 std::vector<program_node*>& get_outputs() { return outputs; } // ToDo: redesign reorder-inputs pass to make it const as_well as get_engine and get options
132 bool is_debug_build() const { return options.get<build_option_type::debug>()->enabled(); }
133 const nodes_ordering& get_processing_order() const;
134 nodes_ordering& get_processing_order();
135 const std::list<primitive_id>& get_optimized_out() const { return optimized_out; }
136 bool has_node(const primitive_id& prim) const { return nodes_map.count(prim) > 0; }
137 program_node& get_node(primitive_id const& id);
138 program_node const& get_node(primitive_id const& id) const;
139 std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) { return nodes_map.at(prim); }
140 std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) const { return nodes_map.at(prim); }
141 void dump_memory_pool() const;
143 //returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
144 //if it was previously created, otherwise creates and then returns program_node
145 program_node& get_or_create(std::shared_ptr<primitive> prim);
147 // Inserts given program_node 'node' as an intermediate node between 'next' and it's
148 // dependency at 'prev_idx' index.
149 void add_intermediate(program_node& node, program_node& next, size_t prev_idx,
150 bool connect_int_node_with_old_dep = true,
151 bool move_usrs_of_prev_to_node = false);
153 // Gets or creates program_node for given primitive 'prim' and inserts it as an intermediate
154 // node between 'next' and it's dependency at 'prev_idx' index.
155 void add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx,
156 bool connect_int_node_with_old_dep = true,
157 bool move_usrs_of_prev_to_node = false);
159 //removes a node from the graph and deletes it afterwards,
160 //prereq: node cannot be marked as output and has to have exactly one dependency
161 //returns if 'node' has been extracted and removed successfully
162 bool extract_and_remove(program_node& node);
164 //returns if 'node' has been removed
165 bool remove_if_dangling(program_node& node);
167 void mark_if_constant(program_node& node);
168 // mark if the node is in data flow assuming that all dependencies are marked properly
169 void mark_if_data_flow(program_node& node);
170 //Reverses connection - user becomes dependency.
172 void remove_nodes(std::list<program_node*>& to_remove);
173 void dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter = nullptr) const;
176 uint32_t prog_id = 0;
177 engine_impl::ptr engine;
178 build_options options;
179 std::list<program_node*> inputs;
180 std::vector<program_node*> outputs;
181 nodes_ordering processing_order;
182 std::unique_ptr<pass_manager> pm;
184 std::map<primitive_id, std::shared_ptr<program_node>> nodes_map;
185 std::list<primitive_id> optimized_out;
188 ** High-level functions, in order of usage
190 /* build nodes internal structure based on topology */
191 void prepare_nodes(topology_impl const& topology);
192 /* build nodes internal structure based on the subset of nodes of other program (used in propagate_constants) */
193 void prepare_nodes(std::set<std::shared_ptr<program_node>> const& nodes);
194 void add_node_dependencies(program_node* node_ptr);
195 void copy_node_dependencies(program_node* dest, program_node* src);
196 void build_program(bool is_internal);
200 void run_graph_compilation();
201 void pre_optimize_graph(bool is_internal);
202 void post_optimize_graph(bool is_internal);
206 ** Analysis functions
208 // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
209 bool analyze_output_size_handling_need();
211 // handle split, deconvolution and upsampling
212 void handle_reshape();
215 ** Optimization functions
217 void apply_needed_padding(program_node& node, program_node& prev_node, const padding& needed_padding);
220 ** Memory pool functions
222 void prepare_memory_dependencies();
223 void basic_memory_dependencies();
224 void skipped_branch_memory_dependencies();
225 void oooq_memory_dependencies();
226 std::string get_memory_dependencies_string() const;
231 void add_split_outputs();
232 // mark if the node is constant assuming that all dependencies are marked properly
233 void reverse_connection(program_node& dep_node, program_node& user_node);
235 void add_connection(program_node& prev, program_node& next);
237 void remove_connection(program_node& prev, program_node& next);
239 void remove_all_connections(program_node& node);
241 void rename(program_node & node, primitive_id const & new_id);
242 void swap_names(program_node& node1, program_node& node2);
243 void replace_all_usages(program_node& old_node, program_node& new_node);
245 //old_node - node which will be replaced
246 //new_node - node which will replace the old one
247 void replace(program_node& old_node, program_node& new_node);
252 API_CAST(::cldnn_program, cldnn::program_impl)