Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / graph_optimizer / prepare_buffer_fusing.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
18
19 #include "api/CPP/eltwise.hpp"
20 #include "api/CPP/pooling.hpp"
21 #include "api/CPP/upsampling.hpp"
22 #include "primitive_inst.h"
23 #include "activation_inst.h"
24 #include "concatenation_inst.h"
25 #include "crop_inst.h"
26 #include "eltwise_inst.h"
27 #include "reshape_inst.h"
28 #include "scale_inst.h"
29
30 #include "pass_manager.h"
31 #include "program_helpers.h"
32
33
34 using namespace cldnn;
35
36 //ToDo remove friendship relation from  program_node 
37
38 void prepare_buffer_fusing::run(program_impl& p)
39 {
40     bool is_debug = p.get_options().get<build_option_type::debug>()->enabled();
41     /*
42     We need to take care of proper ordering by types.
43     1. Concats
44     2. Crops
45     3. Others
46     Concat before crops is needed because of the crop fusing padding requirments. 
47     If crop is before concat there can be padding mismtach, since concat changes padding.
48     */
49     auto can_optimize = [](const program_node* node)
50     {
51         if (node->is_output() ||
52             (node->get_fused_activation_func() != cldnn_activation_func_t::activation_none))
53         {
54             return false;
55         }
56         return true;
57     };
58
59     //[1] First try to optimize all concats
60     auto node_itr = p.get_processing_order().begin();
61     while (node_itr != p.get_processing_order().end())
62     {
63         auto& node = (*node_itr++);
64         if (!can_optimize(node))
65             continue;
66         program_helpers::do_for_types<concatenation>(*node, [&p, is_debug](concatenation_node& node)
67         {
68             // we need to avoid mixing padded and unpadded buffer 
69             bool all_dependencies_padded = true;
70             bool all_dependencies_unpadded = true;
71             for (auto& input : node.get_dependencies()) {
72                 layout l = input->get_output_layout();
73                 if (static_cast<bool>(l.data_padding))
74                     all_dependencies_unpadded = false;
75                 else
76                     all_dependencies_padded = false;
77             }
78             auto concat_axis = node.get_primitive()->axis;
79             auto padd = node.get_output_layout().data_padding;
80
81             tensor lower_padd = padd.lower_size();
82             tensor upper_padd = padd.upper_size();
83
84             auto upper_padd_val = node.get_output_layout().get_buffer_size().raw[concat_axis] - lower_padd.raw[concat_axis];
85             tensor lower_padd_offset = lower_padd;
86
87             std::list<std::pair<const std::vector<program_node*>, tensor>> stack = { std::make_pair(node.get_dependencies(), tensor{ 0, 0, 0, 0 }) };
88             while (!stack.empty())
89             {
90                 auto nodes_list = stack.front();
91                 stack.pop_front();
92
93                 auto cascade_adjustment = nodes_list.second;
94                 upper_padd.raw[concat_axis] = upper_padd_val;
95                 lower_padd = lower_padd_offset;
96
97                 //check if concatenation in place can be applied for inputs set
98                 for (auto input : nodes_list.first)
99                 {
100                     //if any of this node's inputs is used by more than one primitive and is not optimized concatenation then do not fuse buffers,
101                     //also, if an input is marked as network output, prevent optimizations which would affect a form of its output (unless debug flag is set)
102                     // todo: in future, if this case is problem, it can be optimized further to enable buffer fusing
103                     //       per single input rather than all/none
104                     // + restrict input types to those which support padding on x,y,b and f
105                     if (!input->support_padding() ||
106                         (input->is_output() && !is_debug) ||
107                         input->get_users().size() > 2)
108                         return;
109
110                     if (input->get_users().size() > 1)
111                     {
112                         auto user_count = input->get_users().size();
113                         for (auto& user : input->get_users())
114                             if (user->is_type<concatenation>())
115                                 user_count--;
116                         if (user_count != 1) // user_cout == 0 means that input will be used only by concatenations, so we cannot apply concat in place for it
117                             return;
118                     }
119                 }
120
121                 //apply concatenation in place optimization
122                 for (auto input : nodes_list.first)
123                 {
124                     auto input_lenght = input->get_output_layout().size.raw[concat_axis];
125
126                     bool optimized_concat_input = false;
127                     if (input->type() == concatenation::type_id() && input->can_be_optimized())
128                     {
129                         if (input->as<concatenation>().get_primitive()->axis != node.get_primitive()->axis)
130                             return;
131                         optimized_concat_input = true;
132                     }
133
134                     // shrink upper pad so it points at the end of the input's buffer
135                     //
136                     //   |--- lower padd ---|                    |---------- upper padd -----------|
137                     //   |-- output padd ---| ----- input1 ------|----- input2 -----|-- out padd --|
138                     upper_padd.raw[concat_axis] -= input_lenght;
139
140                     //adjust padding sizes for cascade concatenations
141                     auto lower_padd_tmp = lower_padd;
142                     lower_padd_tmp.raw[concat_axis] += cascade_adjustment.raw[concat_axis];
143                     auto upper_padd_tmp = upper_padd;
144                     upper_padd_tmp.raw[concat_axis] -= cascade_adjustment.raw[concat_axis];
145
146                     // set new padding for input
147                     input->set_output_padding(padding(lower_padd_tmp.sizes(), upper_padd_tmp.sizes()));
148
149                     // move lower padd further
150                     //
151                     //   |-------------- lower padd -------------|---------- upper padd -----------|
152                     //   |-- output padd ---| ----- input1 ------|----- input2 -----|-- out padd --|
153
154                     lower_padd.raw[concat_axis] += input_lenght;
155
156                     if (optimized_concat_input && !input->get_dependencies().empty())
157                         stack.push_back(std::make_pair(input->get_dependencies(), input->get_output_layout().data_padding.lower_size()));
158                 }
159             }
160
161             node.can_be_optimized(true);
162             for (auto dep : node.get_users())
163             {
164                 dep->can_share_buffer(false);
165             }
166             if (!all_dependencies_padded && !all_dependencies_unpadded)
167                 node.can_share_buffer(false);
168         });
169     }
170
171     //[2] Then try to optimize all crops
172     node_itr = p.get_processing_order().begin();
173     while (node_itr != p.get_processing_order().end())
174     {
175         auto& node = (*node_itr++);
176         if (!can_optimize(node))
177             continue;
178         // zero copy
179         program_helpers::do_for_types<crop>(*node, [&p, is_debug](crop_node& node)
180         {
181             //if the node is marked as network output, prevent optimizations which would affect a form of its output, unless debug flag is set
182             if (node.is_output() && !is_debug)
183                 return;
184
185             //do not optimize when next node is concatenation which is not output
186             if (node.get_users().size() == 1 && node.get_users().front()->is_type<concatenation>() && !node.get_users().front()->is_output())
187                 return;
188
189             if (node.get_dependencies().size() == 1 &&
190                 node.get_users().size() > 0)
191             {
192                 // optimization is available for cropping across depth(features) only
193                 // if output padding has defined padding across features already it wouldn't
194                 // work because it expect to have zeros in the padded area.
195                 const auto& crop_layout = node.get_output_layout();
196                 auto format = crop_layout.format;
197                 auto crop_prim = node.get_primitive();
198                 auto input_layout = node.get_dependency(0).get_output_layout();
199                 const auto& crop_size = crop_layout.size;
200                 const auto& out_padd = crop_layout.data_padding;
201                 if (format == format::bfyx &&
202                     crop_size.batch[0] == input_layout.size.batch[0] &&
203                     crop_size.spatial[0] == input_layout.size.spatial[0] &&
204                     crop_size.spatial[1] == input_layout.size.spatial[1] &&
205                     out_padd.lower_size().feature[0] == 0 &&
206                     out_padd.upper_size().feature[0] == 0 &&
207                     out_padd.lower_size().batch[0] == 0 &&
208                     out_padd.upper_size().batch[0] == 0 &&
209                     out_padd.lower_size().spatial[0] == 0 &&
210                     out_padd.lower_size().spatial[1] == 0 &&
211                     out_padd.upper_size().spatial[0] == 0 &&
212                     out_padd.upper_size().spatial[1] == 0)
213                 {
214                     //  Regular crop
215                     //  crop input buffer
216                     //  |___________data____________|
217                     //
218                     //  crop output buffer
219                     //  |-------->| offsets[f]  |<--|
220                     //            |_____data____|
221                     //             <------------>
222                     //           reference size
223                     //
224                     //  In-place crop
225                     //  crop output buffer
226                     //  |_low_pad_|__data_size__|___|<-upper pad
227
228                     node.set_output_padding(padding(
229                         { out_padd.lower_size().batch[0], crop_prim->offsets.feature[0], out_padd.lower_size().spatial[0], out_padd.lower_size().spatial[1] },
230                         { out_padd.upper_size().batch[0], input_layout.size.feature[0] - crop_prim->offsets.feature[0] - crop_size.feature[0],
231                             out_padd.upper_size().spatial[0], out_padd.upper_size().spatial[1] }));
232                     node.can_be_optimized(true);
233                 }
234             }
235         });
236     }
237
238     //[3] Optimize all other primitives
239     node_itr = p.get_processing_order().begin();
240     while (node_itr != p.get_processing_order().end())
241     {
242         auto& node = (*node_itr++);
243         if (!can_optimize(node))
244             continue;
245         program_helpers::do_for_types<reshape>(*node, [&p](reshape_node& node)
246         {
247             node.get_output_layout();
248             if (node.is_in_place()
249                 && node.get_fused_activation_func() == activation_none)
250                 node.can_be_optimized(true);
251         });
252         program_helpers::do_for_types<reorder>(*node, [&p](reorder_node& node)
253         {
254             auto& input = node.input();
255             auto output_layout = node.get_output_layout();
256             //This is WA for topologies that due to additional reorders added perform worse with conv1x1 optimization
257             auto remove_bf8_xy_opt = ((input.is_type<pooling>() || input.is_type<concatenation>()) &&
258                 output_layout.format == format::bf8_xy16 && input.get_users().size() == 1);
259             //Remove reorder from convolution 1x1 to bfyx in some conditions
260             auto remove_byxf_opt = (input.is_type<convolution>() &&
261                 input.get_users().size() == 1 &&
262                 input.get_output_layout().format == format::byxf);
263             //check if all inputs user have the same format
264             auto all_users_same_format = true;
265             auto input_user_layout_format = input.get_users().front()->get_output_layout().format;
266             for (auto const& user : input.get_users())
267             {
268                 if (user->get_output_layout().format != input_user_layout_format)
269                 {
270                     all_users_same_format = false;
271                     break;
272                 }
273             }
274             auto same_data_type = input.get_output_layout().data_type == output_layout.data_type;
275             //Optimization only available in case of layers that support different input and output formats.
276             //todo: new api needs to be created to read such caps
277             if (!(input.is_type<pooling>() && (output_layout.format == format::bfyx || output_layout.format == format::yxfb || output_layout.format == format::byxf) && all_users_same_format && same_data_type) &&
278                 !remove_bf8_xy_opt &&
279                 !(input.is_type<convolution>() && input.get_output_layout().format == format::bf8_xy16) &&
280                 !(input.is_type<eltwise>() && (output_layout.format == format::bfyx || output_layout.format == format::yxfb || output_layout.format == format::byxf) && all_users_same_format && same_data_type) &&
281                 !(remove_byxf_opt && (node.get_users().front()->is_type<eltwise>() || node.get_users().front()->is_type<pooling>())))
282                 return;
283
284             if (remove_bf8_xy_opt)
285             {
286                 auto users_user_layout = node.get_users().front()->get_users().front()->get_output_layout();
287                 // if users_user_layout is still bf8_yx16 (stacked convolutions) then leave the reorder
288                 if (users_user_layout.format == format::bf8_xy16)
289                     return;
290                 auto input_layout = input.get_output_layout();
291                 auto target_layout = layout(input_layout.data_type, users_user_layout.format, input_layout.size, input_layout.data_padding);
292                 input.set_output_layout(target_layout, false);
293             }
294             else if (remove_byxf_opt)
295             {
296                 auto user = node.get_users().front();
297                 auto users_users = node.get_users().front()->get_users();
298
299                 for (auto const& users_user : users_users)
300                 {
301                     if (users_user->get_output_layout().format != format::byxf && !users_user->is_type<eltwise>())
302                     {
303                         remove_byxf_opt = false;
304                         break;
305                     }
306                 }
307
308                 if (remove_byxf_opt)
309                 {
310                     auto input_layout = input.get_output_layout();
311                     user->set_output_layout(input_layout, false);
312                 }
313             }
314             else
315                 input.set_output_layout(output_layout, false);
316
317             node.can_be_optimized(true);
318             p.extract_and_remove(node); //try to remove redundant reorders
319         });
320     }
321 }