Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / program_impl.h
1 /*
2 // Copyright (c) 2019 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #pragma once
20
21 #include "api/CPP/program.hpp"
22
23 #include "refcounted_obj.h"
24 #include "engine_impl.h"
25
26 #include <list>
27
28 namespace cldnn
29 {
30
31 struct topology_impl;
32 struct primitive_impl;
33 struct program_node;
34 class layout_optimizer;
35 class pass_manager;
36 class program_impl_wrapper;
37 /*
38     cldnn_program implementation
39 */
40 struct program_impl : public refcounted_obj<program_impl>
41 {
42     friend class calculate_prior_boxes;             // to be removed when possible
43     friend class graph_initializations;             // to be removed when possible
44     friend class prepare_padding;                   // to be removed when possible
45     friend class propagate_constants;               // to be removed when possible
46     friend class prepare_primitive_fusing;          // to be removed when possible
47     friend class prepare_conv_eltw_fusing;          // to be removed when possible
48     friend class prepare_conv_eltw_read_write_opt;  // to be removed when possible
49     friend class reorder_inputs;                    // to be removed when possible
50     friend class program_impl_wrapper;              // this class is intended to extend the interface of program_impl for 
51                                                     // the usage within tests_core_internal project only
52 public:
53     struct nodes_ordering
54     {
55     public:
56         typedef std::list<program_node*> list_of_nodes;
57         typedef list_of_nodes::const_iterator const_iterator;
58         typedef list_of_nodes::iterator node_iterator;
59         const_iterator begin() const { return _processing_order.begin(); }
60         const_iterator end() const { return _processing_order.end(); }
61
62         void calc_processing_order_visit(program_node* node);
63         void calc_processing_order(program_impl& p);
64         int32_t get_processing_number(program_node* node) const { return get_processing_number(get_processing_iterator(*node)); }
65         // int32_t get_processing_number(const_iterator iter) const { return 1+(int32_t)std::distance(begin(), iter); }
66         int32_t get_processing_number(node_iterator iter) const { return 1 + (int32_t)std::distance(_processing_order.begin(), const_iterator(iter)); }
67         void calculate_BFS_processing_order();
68         size_t size() { return _processing_order.size(); }
69         bool is_correct(program_node* node);
70         
71         node_iterator get_processing_iterator(program_node& node) const
72         {
73             return processing_order_iterators.at(&node);
74         }
75         void clear()
76         {
77             processing_order_iterators.clear();
78             _processing_order.clear();
79         }
80
81         void insert(program_node* key_node, program_node* node)
82         {
83             node_iterator _where = processing_order_iterators.at(key_node);
84             processing_order_iterators[node] = _processing_order.insert(_where, node);
85         }
86
87         void insert_next(program_node* key_node, program_node* node)
88         {
89             node_iterator _where = std::next(processing_order_iterators.at(key_node));
90             processing_order_iterators[node] = _processing_order.insert(_where, node);
91         }
92
93         void erase(program_node* key_node)
94         {
95             node_iterator i = processing_order_iterators.at(key_node);
96             processing_order_iterators.erase(key_node);
97             _processing_order.erase(i);
98         }
99
100     private:
101         list_of_nodes _processing_order;
102         std::map<program_node*, node_iterator> processing_order_iterators;
103     };
104
105     template <class T>
106     struct single_element_container
107     {
108         single_element_container(T& t) : elem(&t)
109         {}
110         constexpr size_t size() const { return 1; }
111         single_element_container begin() const { return single_element_container(elem); }
112         single_element_container end() const { return single_element_container(nullptr); }
113         single_element_container& operator ++() { elem = nullptr; return *this; }
114         bool operator !=(single_element_container const& sec) { return elem != sec.elem; }
115
116        T operator *() { return *elem; }
117
118     private:
119         single_element_container(T* t) : elem(t)
120         {}
121
122         T* elem;
123     };
124     program_impl(engine_impl& engine_ref, topology_impl const& topology, build_options const& options, bool is_internal, bool no_optimizations=false);
125     /* constructor used to build a program from subset of nodes of other program (used in propagate_constants) */
126     program_impl(engine_impl& engine_ref, std::set<std::shared_ptr<program_node>> const &nodes, build_options const& options, bool is_internal);
127     ~program_impl();
128     engine_impl& get_engine() const { return *engine; }
129     const build_options& get_options() const { return options; }
130     std::list<program_node*>& get_inputs() { return inputs; }     // ToDo: redesign trim to ouptut pass to make it const as_well as get_engine and get options 
131     std::vector<program_node*>& get_outputs() { return outputs; }  // ToDo: redesign reorder-inputs pass to make it const as_well as get_engine and get options 
132     bool is_debug_build() const { return options.get<build_option_type::debug>()->enabled(); }
133     const nodes_ordering& get_processing_order() const;
134     nodes_ordering& get_processing_order();
135     const std::list<primitive_id>& get_optimized_out() const { return optimized_out; }
136     bool has_node(const primitive_id& prim) const { return nodes_map.count(prim) > 0; }
137     program_node& get_node(primitive_id const& id);
138     program_node const& get_node(primitive_id const& id) const;
139     std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) { return nodes_map.at(prim);  }
140     std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) const { return nodes_map.at(prim); }
141     void dump_memory_pool() const;
142
143     //returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
144     //if it was previously created, otherwise creates and then returns program_node
145     program_node& get_or_create(std::shared_ptr<primitive> prim);
146
147     // Inserts given program_node 'node' as an intermediate node between 'next' and it's
148     //  dependency at 'prev_idx' index.
149     void add_intermediate(program_node& node, program_node& next, size_t prev_idx,
150                           bool connect_int_node_with_old_dep = true,
151                           bool move_usrs_of_prev_to_node = false);
152
153     // Gets or creates program_node for given primitive 'prim' and inserts it as an intermediate
154     // node between 'next' and it's dependency at 'prev_idx' index.
155     void add_intermediate(std::shared_ptr<primitive> prim, program_node& next, size_t prev_idx, 
156         bool connect_int_node_with_old_dep = true,
157         bool move_usrs_of_prev_to_node = false);
158
159     //removes a node from the graph and deletes it afterwards,
160     //prereq: node cannot be marked as output and has to have exactly one dependency
161     //returns if 'node' has been extracted and removed successfully
162     bool extract_and_remove(program_node& node);
163
164     //returns if 'node' has been removed
165     bool remove_if_dangling(program_node& node);
166
167     void mark_if_constant(program_node& node);
168     // mark if the node is in data flow assuming that all dependencies are marked properly
169     void mark_if_data_flow(program_node& node);
170     //Reverses connection - user becomes dependency.
171
172     void remove_nodes(std::list<program_node*>& to_remove);
173     void dump_program(const char* stage, bool with_full_info, std::function<bool(program_node const&)> const& filter = nullptr) const;
174
175 private:
176     uint32_t prog_id = 0;
177     engine_impl::ptr engine;
178     build_options options;
179     std::list<program_node*> inputs;
180     std::vector<program_node*> outputs;
181     nodes_ordering processing_order;
182     std::unique_ptr<pass_manager> pm;
183
184     std::map<primitive_id, std::shared_ptr<program_node>> nodes_map;
185     std::list<primitive_id> optimized_out;
186
187     /*
188     ** High-level functions, in order of usage
189     */
190     /* build nodes internal structure based on topology */
191     void prepare_nodes(topology_impl const& topology);
192     /* build nodes internal structure based on the subset of nodes of other program  (used in propagate_constants) */
193     void prepare_nodes(std::set<std::shared_ptr<program_node>> const& nodes);
194     void add_node_dependencies(program_node* node_ptr);
195     void copy_node_dependencies(program_node* dest, program_node* src);
196     void build_program(bool is_internal);
197     void init_graph();
198     void set_options();
199
200     void run_graph_compilation();
201     void pre_optimize_graph(bool is_internal);
202     void post_optimize_graph(bool is_internal);
203     void cleanup();
204
205     /*
206     ** Analysis functions
207     */
208     // TODO: Remove once we will get full support for input/output padding in all primitive implementations.
209     bool analyze_output_size_handling_need();
210
211     // handle split, deconvolution and upsampling
212     void handle_reshape();
213
214     /*
215     ** Optimization functions
216     */
217     void apply_needed_padding(program_node& node, program_node& prev_node, const padding& needed_padding);
218
219     /*
220     ** Memory pool functions
221     */
222     void prepare_memory_dependencies();
223     void basic_memory_dependencies();
224     void skipped_branch_memory_dependencies();
225     void oooq_memory_dependencies();
226     std::string get_memory_dependencies_string() const;
227
228     /*
229     ** Utilities
230     */
231     void add_split_outputs();
232     // mark if the node is constant assuming that all dependencies are marked properly
233     void reverse_connection(program_node& dep_node, program_node& user_node);
234
235     void add_connection(program_node& prev, program_node& next);
236
237     void remove_connection(program_node& prev, program_node& next);
238
239     void remove_all_connections(program_node& node);
240
241     void rename(program_node & node, primitive_id const & new_id);
242     void swap_names(program_node& node1, program_node& node2);
243     void replace_all_usages(program_node& old_node, program_node& new_node);
244
245     //old_node - node which will be replaced
246     //new_node - node which will replace the old one
247     void replace(program_node& old_node, program_node& new_node);
248 };
249
250 }
251
252 API_CAST(::cldnn_program, cldnn::program_impl)