Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / pass_manager.h
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #pragma once
18
19 #include "program_impl.h"
20 #include "layout_optimizer.h"
21
22 namespace cldnn
23 {
24     class base_pass
25     {
26         friend class pass_manager;
27     public:
28         base_pass(const std::string& pass_name) : name(pass_name) {}
29         virtual void run(program_impl& p) = 0;
30         std::string get_name() { return name; }
31         void clean_marks(program_impl& p) {
32             for (auto& node : p.get_processing_order())
33             {
34                 node->unmark();
35             }
36         }
37     private:
38         const std::string name;
39     };
40
41     class pass_manager
42     {
43     public:
44         pass_manager()
45         {
46             pass_count = 0;
47         }
48         void run(program_impl& p, base_pass& pass)
49         {
50             pass.run(p);
51             std::string dump_file_name;
52             if (pass_count < 10)
53                 dump_file_name += "0";
54             dump_file_name += std::to_string(pass_count) + "_" + pass.get_name();
55             p.dump_program(dump_file_name.c_str(), true);
56             pass.clean_marks(p);
57             pass_count++;
58         }
59         uint32_t get_pass_count() { return pass_count; }
60         uint32_t inc_pass_count() { return ++pass_count; }
61         ~pass_manager() {}
62     private:
63         uint32_t pass_count;
64     };
65
66     class add_required_reorders : public base_pass
67     {
68     public:
69         add_required_reorders() : base_pass("add_required_reorders") {}
70     private:
71         virtual void run(program_impl& p) override;
72         void add_reorder(program_impl& p, program_node* node, program_node* usr, layout reorder_layout);
73     };
74
75     class add_reshape_to_primitives : public base_pass
76     {
77     public:
78         add_reshape_to_primitives() : base_pass("add_reshape_to_primitives_pass") {}
79     private:
80         virtual void run(program_impl& p) override;
81     };
82
83     class calculate_prior_boxes : public base_pass
84     {
85     public: 
86         calculate_prior_boxes() : base_pass("calculated_prior_boxes") {}
87     private:
88         virtual void run(program_impl& p) override;
89     };
90
91     class compile_graph: public base_pass
92     {
93     public:
94         compile_graph() : base_pass("compile_graph") {}
95     private:
96         virtual void run(program_impl& p) override;
97     };
98
99     class eltwise_shrinking : public base_pass
100     {
101     public:
102         eltwise_shrinking() : base_pass("eltwise_shrinking") {}
103     private:
104         virtual void run(program_impl& p) override;
105     };
106
107     class eltwise_remove_stride : public base_pass
108     {
109     public:
110         eltwise_remove_stride() : base_pass("eltwise_remove_stride") {}
111     private:
112         virtual void run(program_impl& p) override;
113         void conv_stride_extend(program_impl & p, program_node & node, cldnn::tensor & tensor);
114     };
115
116     class graph_initializations : public base_pass 
117     {
118     public:
119         graph_initializations() : base_pass("init") {}
120     private:
121         virtual void run(program_impl& p) override;
122         void replace_nodes(program_impl& p);
123         void handle_detection_output(program_impl& p);
124         void handle_lstm(program_impl& p);
125         void set_outputs(program_impl& p);  
126     };
127
128     class handle_input_padding : public base_pass
129     {
130     public:
131         handle_input_padding() : base_pass("handle_input_padding") {}
132     private:
133         virtual void run(program_impl& p) override;
134     };
135
136     class mark_nodes : public base_pass
137     {
138     public:
139         mark_nodes() : base_pass("analyzed_graph") {}
140     private:
141         virtual void run(program_impl& p) override;
142         void mark_constants(program_impl& p);
143         void mark_data_flow(program_impl& p);
144     };
145
146     class prepare_buffer_fusing : public base_pass
147     {
148     public:
149         prepare_buffer_fusing() : base_pass("prepare_buffer_fusing") {}
150     private:
151         virtual void run(program_impl& p) override;
152     };
153
154     class prepare_conv_eltw_fusing : public base_pass
155     {
156     public:
157         prepare_conv_eltw_fusing() : base_pass("prepare_conv_eltw_fusing") {}
158     private:
159         virtual void run(program_impl& p) override;
160         void fuse_conv_eltwise(program_impl& p, program_node* node);
161     };
162
163     class prepare_conv_eltw_read_write_opt : public base_pass
164     {
165     public:
166         prepare_conv_eltw_read_write_opt() : base_pass("prepare_conv_eltw_read_write_opt") {}
167     private:
168         virtual void run(program_impl& p) override;
169         void conv_eltwise_read_write_opt(program_impl& p, program_node* node);
170     };
171
172     class prepare_depthwise_sep_opt : public base_pass
173     {
174     public:
175         prepare_depthwise_sep_opt() : base_pass("prepare_depthwise_sep_opt") {}
176     private:
177         virtual void run(program_impl& p) override;
178         template <typename T> void optimize_depthwise_sep_pre(T& node);
179     };
180
181     class prep_opt_depthwise_sep_post : public base_pass
182     {
183     public:
184         prep_opt_depthwise_sep_post() : base_pass("prep_opt_depthwise_sep_post") {}
185     private:
186         virtual void run(program_impl& p) override;
187         template <typename T> void optimize_depthwise_sep_pre(program_impl& p, T& node);
188     };
189
190     class prepare_primitive_fusing : public base_pass
191     {
192     public:
193         prepare_primitive_fusing() : base_pass("prepare_primitive_fusing") {}
194     private:
195         virtual void run(program_impl& p) override;
196         void fuse_skip_layers(program_impl& p, program_node* node);
197         void fuse_conv_bn_scale(program_impl& p, program_node* node);
198     };
199
200     class pre_optimize_bias : public base_pass
201     {
202     public:
203         pre_optimize_bias(layout_optimizer& lo_ref);
204     private:
205         virtual void run(program_impl& p) override;
206         virtual void run(program_impl& p, layout_optimizer& lo);
207         template <typename T>
208         void optimize_bias(T& node, layout_optimizer& lo, program_impl& p);
209         layout_optimizer& _lo;
210     };
211
212     class prepare_padding : public base_pass
213     {
214     public:
215         prepare_padding(bool output_size_handling_enabled_switch) : base_pass("prepare_padding"),
216             output_size_handling_enabled(output_size_handling_enabled_switch) {}
217     private:
218         virtual void run(program_impl& p) override;
219         bool output_size_handling_enabled;
220     };
221
222     class post_optimize_weights : public base_pass
223     {
224     public:
225         post_optimize_weights(layout_optimizer& lo_ref);
226     private:
227         virtual void run(program_impl& p) override;
228         virtual void run(program_impl& p, layout_optimizer& lo);
229         template <typename T>
230         void optimize_weights(T& node, layout_optimizer& lo, program_impl& p);
231         layout_optimizer& _lo;
232     };
233
234     class propagate_constants : public base_pass
235     {
236     public:
237         propagate_constants() : base_pass("propagate_constants") {}
238     private:
239         virtual void run(program_impl& p) override;
240         std::list<std::pair<primitive_id, memory_impl::ptr>> calculate(engine_impl &engine);
241         bool has_non_const_user(program_node& node) const;
242         void handle_constant(program_impl& prog, program_node& node);
243         void add_constant(program_impl& prog, program_node& node);
244         void add_deps_to_tpl(program_impl& prog, const std::vector<program_node*>& node);
245
246         bool has_non_trivial_constants = false;
247         std::list<typed_program_node<data>*> const_inputs;
248         std::vector<primitive_id> const_outputs;
249         std::set<std::shared_ptr<program_node>> nodes;
250     };
251
252     class remove_redundant_reorders : public base_pass
253     {
254     public:
255         remove_redundant_reorders() : base_pass("remove_redundant_reorders") {}
256         virtual void run(program_impl& p) override;
257     };
258
259     class reorder_inputs : public base_pass
260     {
261     public:
262         reorder_inputs(layout_optimizer& lo_ref);
263     private:
264         virtual void run(program_impl& p) override;
265         virtual void run(program_impl& p, layout_optimizer& lo);
266         layout_optimizer& _lo;
267     };
268
269     class trim_to_outputs : public base_pass
270     {
271     public:
272         trim_to_outputs() : base_pass("trimmed") {}
273     private:
274         virtual void run(program_impl& p) override;
275     };
276 }