2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
19 #include "program_impl.h"
20 #include "layout_optimizer.h"
26 friend class pass_manager;
28 base_pass(const std::string& pass_name) : name(pass_name) {}
29 virtual void run(program_impl& p) = 0;
30 std::string get_name() { return name; }
31 void clean_marks(program_impl& p) {
32 for (auto& node : p.get_processing_order())
38 const std::string name;
48 void run(program_impl& p, base_pass& pass)
51 std::string dump_file_name;
53 dump_file_name += "0";
54 dump_file_name += std::to_string(pass_count) + "_" + pass.get_name();
55 p.dump_program(dump_file_name.c_str(), true);
59 uint32_t get_pass_count() { return pass_count; }
60 uint32_t inc_pass_count() { return ++pass_count; }
66 class add_required_reorders : public base_pass
69 add_required_reorders() : base_pass("add_required_reorders") {}
71 virtual void run(program_impl& p) override;
72 void add_reorder(program_impl& p, program_node* node, program_node* usr, layout reorder_layout);
75 class add_reshape_to_primitives : public base_pass
78 add_reshape_to_primitives() : base_pass("add_reshape_to_primitives_pass") {}
80 virtual void run(program_impl& p) override;
83 class calculate_prior_boxes : public base_pass
86 calculate_prior_boxes() : base_pass("calculated_prior_boxes") {}
88 virtual void run(program_impl& p) override;
91 class compile_graph: public base_pass
94 compile_graph() : base_pass("compile_graph") {}
96 virtual void run(program_impl& p) override;
99 class eltwise_shrinking : public base_pass
102 eltwise_shrinking() : base_pass("eltwise_shrinking") {}
104 virtual void run(program_impl& p) override;
107 class eltwise_remove_stride : public base_pass
110 eltwise_remove_stride() : base_pass("eltwise_remove_stride") {}
112 virtual void run(program_impl& p) override;
113 void conv_stride_extend(program_impl & p, program_node & node, cldnn::tensor & tensor);
116 class graph_initializations : public base_pass
119 graph_initializations() : base_pass("init") {}
121 virtual void run(program_impl& p) override;
122 void replace_nodes(program_impl& p);
123 void handle_detection_output(program_impl& p);
124 void handle_lstm(program_impl& p);
125 void set_outputs(program_impl& p);
128 class handle_input_padding : public base_pass
131 handle_input_padding() : base_pass("handle_input_padding") {}
133 virtual void run(program_impl& p) override;
136 class mark_nodes : public base_pass
139 mark_nodes() : base_pass("analyzed_graph") {}
141 virtual void run(program_impl& p) override;
142 void mark_constants(program_impl& p);
143 void mark_data_flow(program_impl& p);
146 class prepare_buffer_fusing : public base_pass
149 prepare_buffer_fusing() : base_pass("prepare_buffer_fusing") {}
151 virtual void run(program_impl& p) override;
154 class prepare_conv_eltw_fusing : public base_pass
157 prepare_conv_eltw_fusing() : base_pass("prepare_conv_eltw_fusing") {}
159 virtual void run(program_impl& p) override;
160 void fuse_conv_eltwise(program_impl& p, program_node* node);
163 class prepare_conv_eltw_read_write_opt : public base_pass
166 prepare_conv_eltw_read_write_opt() : base_pass("prepare_conv_eltw_read_write_opt") {}
168 virtual void run(program_impl& p) override;
169 void conv_eltwise_read_write_opt(program_impl& p, program_node* node);
172 class prepare_depthwise_sep_opt : public base_pass
175 prepare_depthwise_sep_opt() : base_pass("prepare_depthwise_sep_opt") {}
177 virtual void run(program_impl& p) override;
178 template <typename T> void optimize_depthwise_sep_pre(T& node);
181 class prep_opt_depthwise_sep_post : public base_pass
184 prep_opt_depthwise_sep_post() : base_pass("prep_opt_depthwise_sep_post") {}
186 virtual void run(program_impl& p) override;
187 template <typename T> void optimize_depthwise_sep_pre(program_impl& p, T& node);
190 class prepare_primitive_fusing : public base_pass
193 prepare_primitive_fusing() : base_pass("prepare_primitive_fusing") {}
195 virtual void run(program_impl& p) override;
196 void fuse_skip_layers(program_impl& p, program_node* node);
197 void fuse_conv_bn_scale(program_impl& p, program_node* node);
200 class pre_optimize_bias : public base_pass
203 pre_optimize_bias(layout_optimizer& lo_ref);
205 virtual void run(program_impl& p) override;
206 virtual void run(program_impl& p, layout_optimizer& lo);
207 template <typename T>
208 void optimize_bias(T& node, layout_optimizer& lo, program_impl& p);
209 layout_optimizer& _lo;
212 class prepare_padding : public base_pass
215 prepare_padding(bool output_size_handling_enabled_switch) : base_pass("prepare_padding"),
216 output_size_handling_enabled(output_size_handling_enabled_switch) {}
218 virtual void run(program_impl& p) override;
219 bool output_size_handling_enabled;
222 class post_optimize_weights : public base_pass
225 post_optimize_weights(layout_optimizer& lo_ref);
227 virtual void run(program_impl& p) override;
228 virtual void run(program_impl& p, layout_optimizer& lo);
229 template <typename T>
230 void optimize_weights(T& node, layout_optimizer& lo, program_impl& p);
231 layout_optimizer& _lo;
234 class propagate_constants : public base_pass
237 propagate_constants() : base_pass("propagate_constants") {}
239 virtual void run(program_impl& p) override;
240 std::list<std::pair<primitive_id, memory_impl::ptr>> calculate(engine_impl &engine);
241 bool has_non_const_user(program_node& node) const;
242 void handle_constant(program_impl& prog, program_node& node);
243 void add_constant(program_impl& prog, program_node& node);
244 void add_deps_to_tpl(program_impl& prog, const std::vector<program_node*>& node);
246 bool has_non_trivial_constants = false;
247 std::list<typed_program_node<data>*> const_inputs;
248 std::vector<primitive_id> const_outputs;
249 std::set<std::shared_ptr<program_node>> nodes;
252 class remove_redundant_reorders : public base_pass
255 remove_redundant_reorders() : base_pass("remove_redundant_reorders") {}
256 virtual void run(program_impl& p) override;
259 class reorder_inputs : public base_pass
262 reorder_inputs(layout_optimizer& lo_ref);
264 virtual void run(program_impl& p) override;
265 virtual void run(program_impl& p, layout_optimizer& lo);
266 layout_optimizer& _lo;
269 class trim_to_outputs : public base_pass
272 trim_to_outputs() : base_pass("trimmed") {}
274 virtual void run(program_impl& p) override;