Fixed static analysis issues (transformations) (#3276)
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / common_optimizations / nop_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <functional>
18 #include <memory>
19 #include <typeindex>
20 #include <typeinfo>
21 #include <unordered_map>
22
23 #include <ngraph/opsets/opset3.hpp>
24 #include <ngraph/util.hpp>
25 #include <ngraph/log.hpp>
26 #include <transformations/common_optimizations/nop_elimination.hpp>
27
28 NGRAPH_SUPPRESS_DEPRECATED_START
29
30 using namespace std;
31 using namespace ngraph;
32
33 #define TI(x) x::type_info
34
35 static bool eliminate_nop(const std::shared_ptr<Node>& node) {
36     // skip if shapes are dynamic
37     if (node->get_input_partial_shape(0).is_dynamic() ||
38         node->get_output_partial_shape(0).is_dynamic()) {
39         return false;
40     }
41
42     if (node->get_input_shape(0) == node->get_output_shape(0)) {
43         return replace_output_update_name(node->output(0), node->input_value(0));
44     }
45     return false;
46 }
47
48 static bool eliminate_convert(const std::shared_ptr<Node>& node) {
49     bool is_out_type_agnostic = false;
50     static const std::set<NodeTypeInfo> type_agnostic{TI(opset3::NonZero)};
51     if (node->output(0).get_target_inputs().size() == 1) {
52         Input<Node> out = *node->output(0).get_target_inputs().begin();
53         is_out_type_agnostic = type_agnostic.count(out.get_node()->get_type_info()) == 1;
54     }
55     auto convert = as_type_ptr<opset3::Convert>(node);
56     auto input = convert->input_value(0);
57     if (convert->get_convert_element_type() == input.get_element_type() || is_out_type_agnostic) {
58         if (is_out_type_agnostic && is_type<opset3::Convert>(input.get_node())) {
59             input = input.get_node()->input_value(0);
60         }
61         return replace_output_update_name(node->output(0), input);
62     }
63     return false;
64 }
65
66 static bool eliminate_concat(const std::shared_ptr<Node>& node) {
67     auto node_input = node->input_value(0);
68
69     // remove concat with single input
70     if (node->get_input_size() == 1) {
71         return replace_output_update_name(node->output(0), node_input);
72     }
73     return false;
74 }
75
76 static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
77     auto input = node->input_value(0);
78     // check if reshape is not identity op
79     if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
80         NGRAPH_DEBUG << node << " has dynamic shapes.";
81         return false;
82     }
83     // remove identity op
84     if (input.get_shape() == node->get_output_shape(0)) {
85         return replace_output_update_name(node->output(0), input);
86     }
87     // eliminate redundant reshape, squeeze, or unsqueeze
88     auto input_node = input.get_node_shared_ptr();
89     if (as_type_ptr<opset3::Squeeze>(input_node) ||
90         as_type_ptr<opset3::Unsqueeze>(input_node) ||
91         as_type_ptr<opset3::Reshape>(input_node)) {
92         auto shape = node->get_output_shape(0);
93         std::vector<int64_t> vi;
94         vi.assign(shape.begin(), shape.end());
95         auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
96         auto new_reshape =
97             make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
98         new_reshape->set_friendly_name(node->get_friendly_name());
99         copy_runtime_info({input_node, node}, new_reshape);
100         replace_node(node, new_reshape);
101         return true;
102     }
103
104     return false;
105 }
106
107 static size_t count_unknown_dims(const PartialShape& ps) {
108     size_t rc = 0;
109     if (ps.is_static()) {
110         return rc;
111     }
112     for (auto i = 0; i < ps.rank().get_length(); i++) {
113         if (ps[i].is_dynamic()) {
114             rc += 1;
115         }
116     }
117     return rc;
118 }
119
120 static bool replace_squeeze_unsqueeze(const std::shared_ptr<Node>& node) {
121     auto shape_ps = node->get_output_partial_shape(0);
122     if (shape_ps.rank().get_length() == 0) {
123         return false;
124     }
125     if (count_unknown_dims(shape_ps) > 1) {
126         return false;
127     }
128     std::vector<int64_t> target_shape;
129     for (auto i = 0; i < shape_ps.rank().get_length(); i++) {
130         if (shape_ps[i].is_dynamic()) {
131             target_shape.emplace_back(-1);
132         } else {
133             target_shape.emplace_back(shape_ps[i].get_length());
134         }
135     }
136
137     shared_ptr<Node> reshape;
138     auto input = node->input_value(0).get_node_shared_ptr();
139     auto pat =
140         opset3::Constant::create<int64_t>(element::i64, Shape{target_shape.size()}, target_shape);
141
142     if (is_type<opset3::Reshape>(input) || is_type<opset3::Squeeze>(input) ||
143         is_type<opset3::Unsqueeze>(input)) {
144         reshape = make_shared<opset3::Reshape>(input->input_value(0), pat, false);
145     } else {
146         reshape = make_shared<opset3::Reshape>(node->input_value(0), pat, false);
147     }
148
149     // skip if reshape is nop
150     if (reshape->get_input_partial_shape(0).same_scheme(shape_ps)) {
151         return replace_output_update_name(node->output(0), reshape->input_value(0));
152     } else {
153         return replace_node_update_name(node, reshape);
154     }
155 }
156
157 static std::vector<int64_t> get_unsqueeze_axes(const PartialShape& data_shape,
158                                                const PartialShape& out_shape) {
159     std::vector<int64_t> axes;
160     int64_t i = 0;
161     for (auto o = 0; o < out_shape.rank().get_length(); o++) {
162         if (i < data_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[o])) {
163             i += 1;
164             continue;
165         }
166         if (out_shape[o].is_static() && out_shape[o] == 1) {
167             axes.push_back(o);
168         }
169     }
170     return axes;
171 }
172
173 static std::vector<int64_t> get_squeeze_axes(const PartialShape& data_shape,
174                                              const PartialShape& out_shape) {
175     std::vector<int64_t> axes;
176     int64_t out_i = 0;
177     for (auto i = 0; i < data_shape.rank().get_length(); i++) {
178         if (out_i < out_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[out_i])) {
179             out_i += 1;
180             continue;
181         }
182         if (data_shape[i].is_static() && data_shape[i] == 1) {
183             axes.push_back(i);
184         }
185     }
186     return axes;
187 }
188
189 static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node) {
190     auto out_shape = node->get_output_partial_shape(0);
191     // try to replace all squeeze/unsqueeze with reshape
192     if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
193         return replace_squeeze_unsqueeze(node);
194     }
195
196     auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
197     if (unsqueeze == nullptr)
198         return false;
199     auto input = unsqueeze->input_value(0).get_node_shared_ptr();
200     auto squeeze = as_type_ptr<opset3::Squeeze>(input);
201     auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
202         auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
203         auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
204         if (unsqueeze->get_output_partial_shape(0).same_scheme(
205                 new_unsq->get_output_partial_shape(0))) {
206             return replace_node_update_name(unsqueeze, new_unsq);
207         }
208         return false;
209     };
210     // eliminate redundant squeeze->unsqueeze
211     if (squeeze) {
212         const auto& data_shape = squeeze->input_value(0).get_partial_shape();
213         if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
214                                       unsqueeze->input_value(1).get_node_shared_ptr())) {
215             return replace_output_update_name(unsqueeze->output(0), squeeze->input_value(0));
216         }
217         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
218             return false;
219         }
220         if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
221             // check if single unsqueeze can handle this
222             auto axes = get_unsqueeze_axes(data_shape, out_shape);
223             if (axes.size() + data_shape.rank().get_length() == out_shape.rank().get_length()) {
224                 return replace_unsqueeze_only(axes);
225             }
226         }
227         if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
228             // check if single squeeze can handle this
229             auto axes = get_squeeze_axes(data_shape, out_shape);
230             if (data_shape.rank().get_length() - axes.size() == out_shape.rank().get_length()) {
231                 auto axes_const =
232                     opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
233                 auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
234                 if (unsqueeze->get_output_partial_shape(0).same_scheme(
235                         new_sq->get_output_partial_shape(0))) {
236                     return replace_node_update_name(unsqueeze, new_sq);
237                 }
238                 return false;
239             }
240         }
241         return false;
242     }
243     // eliminate redundant unsqueeze->unsqueeze
244     auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
245     if (unsqueeze_i) {
246         const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
247         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
248             return false;
249         }
250         auto axes = get_unsqueeze_axes(data_shape, out_shape);
251         return replace_unsqueeze_only(axes);
252     }
253
254     return false;
255 }
256
257 static bool eliminate_squeeze(const std::shared_ptr<Node>& node) {
258     auto out_shape = node->get_output_partial_shape(0);
259     // try to replace all unsqueeze/squeeze with reshape
260     if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
261         return replace_squeeze_unsqueeze(node);
262     }
263
264     auto squeeze = as_type_ptr<opset3::Squeeze>(node);
265     if (squeeze == nullptr)
266         return false;
267     auto input = squeeze->input_value(0).get_node_shared_ptr();
268     auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
269         auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
270         auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
271         if (squeeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0))) {
272             return replace_node_update_name(squeeze, new_sq);
273         }
274         return false;
275     };
276     // eliminate redundant unsqueeze->squeeze
277     if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input)) {
278         PartialShape data_shape;
279         if (op::is_parameter(input)) {
280             data_shape = unsqueeze->input(0).get_partial_shape();
281         } else {
282             data_shape = input->input(0).get_partial_shape();
283         }
284         if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
285                                       squeeze->input_value(1).get_node_shared_ptr())) {
286             return replace_output_update_name(squeeze->output(0), unsqueeze->input_value(0));
287         }
288         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
289             return false;
290         }
291         if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
292             // check if single squeeze can handle this
293             auto axes = get_squeeze_axes(data_shape, out_shape);
294             if (data_shape.rank().get_length() == out_shape.rank().get_length() + axes.size()) {
295                 return replace_squeeze_only(axes);
296             }
297         }
298         if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
299             // check if single unsqueeze can handle this
300             auto axes = get_unsqueeze_axes(data_shape, out_shape);
301             if (data_shape.rank().get_length() + axes.size() == out_shape.rank().get_length()) {
302                 auto axes_const =
303                     opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
304                 auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
305                 if (squeeze->get_output_partial_shape(0).same_scheme(
306                         new_unsq->get_output_partial_shape(0))) {
307                     replace_output_update_name(squeeze, new_unsq);
308                     return true;
309                 }
310             }
311         }
312         return false;
313     }
314     // eliminate redundant squeeze->squeeze
315     if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input)) {
316         PartialShape data_shape;
317         if (op::is_parameter(input)) {
318             data_shape = squeeze_i->input(0).get_partial_shape();
319         } else {
320             data_shape = input->input(0).get_partial_shape();
321         }
322         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
323             return false;
324         }
325         auto axes = get_squeeze_axes(data_shape, out_shape);
326         return replace_squeeze_only(axes);
327     }
328     return false;
329 }
330
331 NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
332
333 bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function) {
334     static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
335         dispatcher{{TI(opset3::Pad), &eliminate_nop},
336                    {TI(opset3::Convert), &eliminate_convert},
337                    {TI(opset3::Reshape), &eliminate_reshape_v1},
338                    {TI(opset3::Concat), &eliminate_concat},
339                    {TI(opset3::Squeeze), &eliminate_squeeze},
340                    {TI(op::v1::Broadcast), &eliminate_nop},
341                    {TI(opset3::Unsqueeze), &eliminate_unsqueeze}};
342
343     bool clobbered = false;
344
345     for (const auto& node : function->get_ops()) {
346         // Recursively apply transformation for sub-graph based operations
347         if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
348             if (auto sub_graph = sub_graph_node->get_function()) {
349                 clobbered |= run_on_function(sub_graph);
350             }
351         }
352         auto handler = dispatcher.find(node->get_type_info());
353         if (handler != dispatcher.end()) {
354             clobbered |= handler->second(node);
355         }
356     }
357
358     return clobbered;
359 }