280739f5cd814dc197dcda9eee3984bb2bc0e86d
[platform/upstream/dldt.git] / inference-engine / src / transformations / src / transformations / common_optimizations / nop_elimination.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include <functional>
18 #include <memory>
19 #include <typeindex>
20 #include <typeinfo>
21 #include <unordered_map>
22
23 #include <ngraph/opsets/opset3.hpp>
24 #include <ngraph/util.hpp>
25 #include <ngraph/log.hpp>
26 #include <transformations/common_optimizations/nop_elimination.hpp>
27
28 NGRAPH_SUPPRESS_DEPRECATED_START
29
30 using namespace std;
31 using namespace ngraph;
32
33 #define TI(x) x::type_info
34
35 static bool eliminate_nop(const std::shared_ptr<Node>& node) {
36     // skip if shapes are dynamic
37     if (node->get_input_partial_shape(0).is_dynamic() ||
38         node->get_output_partial_shape(0).is_dynamic()) {
39         return false;
40     }
41
42     if (node->get_input_shape(0) == node->get_output_shape(0)) {
43         return replace_output_update_name(node->output(0), node->input_value(0));
44     }
45     return false;
46 }
47
48 static bool eliminate_sum(const std::shared_ptr<Node>& node) {
49     auto sum = as_type_ptr<op::v0::Sum>(node);
50     if (sum->get_reduction_axes().empty()) {
51         return replace_output_update_name(node->output(0), node->input_value(0));
52     }
53     return false;
54 }
55
56 static bool eliminate_convert(const std::shared_ptr<Node>& node) {
57     bool is_out_type_agnostic = false;
58     static const std::set<NodeTypeInfo> type_agnostic{TI(opset3::NonZero)};
59     if (node->output(0).get_target_inputs().size() == 1) {
60         Input<Node> out = *node->output(0).get_target_inputs().begin();
61         is_out_type_agnostic = type_agnostic.count(out.get_node()->get_type_info()) == 1;
62     }
63     auto convert = as_type_ptr<opset3::Convert>(node);
64     auto input = convert->input_value(0);
65     if (convert->get_convert_element_type() == input.get_element_type() || is_out_type_agnostic) {
66         if (is_out_type_agnostic && is_type<opset3::Convert>(input.get_node())) {
67             input = input.get_node()->input_value(0);
68         }
69         return replace_output_update_name(node->output(0), input);
70     }
71     return false;
72 }
73
74 static bool eliminate_concat(const std::shared_ptr<Node>& node) {
75     auto node_input = node->input_value(0);
76
77     // remove concat with single input
78     if (node->get_input_size() == 1) {
79         return replace_output_update_name(node->output(0), node_input);
80     }
81     return false;
82 }
83
84 static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
85     auto input = node->input_value(0);
86     // check if reshape is not identity op
87     if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
88         NGRAPH_DEBUG << node << " has dynamic shapes.";
89         return false;
90     }
91     // remove identity op
92     if (input.get_shape() == node->get_output_shape(0)) {
93         return replace_output_update_name(node->output(0), input);
94     }
95     // eliminate redundant reshape, squeeze, or unsqueeze
96     auto input_node = input.get_node_shared_ptr();
97     if (as_type_ptr<opset3::Squeeze>(input_node) ||
98         as_type_ptr<opset3::Unsqueeze>(input_node) ||
99         as_type_ptr<opset3::Reshape>(input_node)) {
100         auto shape = node->get_output_shape(0);
101         std::vector<int64_t> vi;
102         vi.assign(shape.begin(), shape.end());
103         auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
104         auto new_reshape =
105             make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
106         new_reshape->set_friendly_name(node->get_friendly_name());
107         copy_runtime_info({input_node, node}, new_reshape);
108         replace_node(node, new_reshape);
109         return true;
110     }
111
112     return false;
113 }
114
115 static size_t count_unknown_dims(const PartialShape& ps) {
116     size_t rc = 0;
117     if (ps.is_static()) {
118         return rc;
119     }
120     for (auto i = 0; i < ps.rank().get_length(); i++) {
121         if (ps[i].is_dynamic()) {
122             rc += 1;
123         }
124     }
125     return rc;
126 }
127
128 static bool replace_squeeze_unsqueeze(const std::shared_ptr<Node>& node) {
129     auto shape_ps = node->get_output_partial_shape(0);
130     if (shape_ps.rank().get_length() == 0) {
131         return false;
132     }
133     if (count_unknown_dims(shape_ps) > 1) {
134         return false;
135     }
136     std::vector<int64_t> target_shape;
137     for (auto i = 0; i < shape_ps.rank().get_length(); i++) {
138         if (shape_ps[i].is_dynamic()) {
139             target_shape.emplace_back(-1);
140         } else {
141             target_shape.emplace_back(shape_ps[i].get_length());
142         }
143     }
144
145     shared_ptr<Node> reshape;
146     auto input = node->input_value(0).get_node_shared_ptr();
147     auto pat =
148         opset3::Constant::create<int64_t>(element::i64, Shape{target_shape.size()}, target_shape);
149
150     if (is_type<opset3::Reshape>(input) || is_type<opset3::Squeeze>(input) ||
151         is_type<opset3::Unsqueeze>(input)) {
152         reshape = make_shared<opset3::Reshape>(input->input_value(0), pat, false);
153     } else {
154         reshape = make_shared<opset3::Reshape>(node->input_value(0), pat, false);
155     }
156
157     // skip if reshape is nop
158     if (reshape->get_input_partial_shape(0).same_scheme(shape_ps)) {
159         return replace_output_update_name(node->output(0), reshape->input_value(0));
160     } else {
161         return replace_node_update_name(node, reshape);
162     }
163 }
164
165 static std::vector<int64_t> get_unsqueeze_axes(const PartialShape& data_shape,
166                                                const PartialShape& out_shape) {
167     std::vector<int64_t> axes;
168     size_t i = 0;
169     for (auto o = 0; o < out_shape.rank().get_length(); o++) {
170         if (i < data_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[o])) {
171             i += 1;
172             continue;
173         }
174         if (out_shape[o].is_static() && out_shape[o] == 1) {
175             axes.push_back(o);
176         }
177     }
178     return axes;
179 }
180
181 static std::vector<int64_t> get_squeeze_axes(const PartialShape& data_shape,
182                                              const PartialShape& out_shape) {
183     std::vector<int64_t> axes;
184     size_t out_i = 0;
185     for (auto i = 0; i < data_shape.rank().get_length(); i++) {
186         if (out_i < out_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[out_i])) {
187             out_i += 1;
188             continue;
189         }
190         if (data_shape[i].is_static() && data_shape[i] == 1) {
191             axes.push_back(i);
192         }
193     }
194     return axes;
195 }
196
197 static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node) {
198     auto out_shape = node->get_output_partial_shape(0);
199     // try to replace all squeeze/unsqueeze with reshape
200     if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
201         return replace_squeeze_unsqueeze(node);
202     }
203
204     auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
205     auto input = unsqueeze->input_value(0).get_node_shared_ptr();
206     auto squeeze = as_type_ptr<opset3::Squeeze>(input);
207     auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
208         auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
209         auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
210         if (unsqueeze->get_output_partial_shape(0).same_scheme(
211                 new_unsq->get_output_partial_shape(0))) {
212             return replace_node_update_name(unsqueeze, new_unsq);
213         }
214         return false;
215     };
216     // eliminate redundant squeeze->unsqueeze
217     if (squeeze) {
218         const auto& data_shape = squeeze->input_value(0).get_partial_shape();
219         if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
220                                       unsqueeze->input_value(1).get_node_shared_ptr())) {
221             return replace_output_update_name(unsqueeze->output(0), squeeze->input_value(0));
222         }
223         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
224             return false;
225         }
226         if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
227             // check if single unsqueeze can handle this
228             auto axes = get_unsqueeze_axes(data_shape, out_shape);
229             if (axes.size() + data_shape.rank().get_length() == out_shape.rank().get_length()) {
230                 return replace_unsqueeze_only(axes);
231             }
232         }
233         if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
234             // check if single squeeze can handle this
235             auto axes = get_squeeze_axes(data_shape, out_shape);
236             if (data_shape.rank().get_length() - axes.size() == out_shape.rank().get_length()) {
237                 auto axes_const =
238                     opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
239                 auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
240                 if (unsqueeze->get_output_partial_shape(0).same_scheme(
241                         new_sq->get_output_partial_shape(0))) {
242                     return replace_node_update_name(unsqueeze, new_sq);
243                 }
244                 return false;
245             }
246         }
247         return false;
248     }
249     // eliminate redundant unsqueeze->unsqueeze
250     auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
251     if (unsqueeze_i) {
252         const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
253         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
254             return false;
255         }
256         auto axes = get_unsqueeze_axes(data_shape, out_shape);
257         return replace_unsqueeze_only(axes);
258     }
259
260     return false;
261 }
262
263 static bool eliminate_squeeze(const std::shared_ptr<Node>& node) {
264     auto out_shape = node->get_output_partial_shape(0);
265     // try to replace all unsqueeze/squeeze with reshape
266     if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
267         return replace_squeeze_unsqueeze(node);
268     }
269
270     auto squeeze = as_type_ptr<opset3::Squeeze>(node);
271     auto input = squeeze->input_value(0).get_node_shared_ptr();
272     auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
273         auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
274         auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
275         if (squeeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0))) {
276             return replace_node_update_name(squeeze, new_sq);
277         }
278         return false;
279     };
280     // eliminate redundant unsqueeze->squeeze
281     if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input)) {
282         PartialShape data_shape;
283         if (op::is_parameter(input)) {
284             data_shape = unsqueeze->input(0).get_partial_shape();
285         } else {
286             data_shape = input->input(0).get_partial_shape();
287         }
288         if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
289                                       squeeze->input_value(1).get_node_shared_ptr())) {
290             return replace_output_update_name(squeeze->output(0), unsqueeze->input_value(0));
291         }
292         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
293             return false;
294         }
295         if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
296             // check if single squeeze can handle this
297             auto axes = get_squeeze_axes(data_shape, out_shape);
298             if (data_shape.rank().get_length() == out_shape.rank().get_length() + axes.size()) {
299                 return replace_squeeze_only(axes);
300             }
301         }
302         if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
303             // check if single unsqueeze can handle this
304             auto axes = get_unsqueeze_axes(data_shape, out_shape);
305             if (data_shape.rank().get_length() + axes.size() == out_shape.rank().get_length()) {
306                 auto axes_const =
307                     opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
308                 auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
309                 if (squeeze->get_output_partial_shape(0).same_scheme(
310                         new_unsq->get_output_partial_shape(0))) {
311                     replace_output_update_name(squeeze, new_unsq);
312                     return true;
313                 }
314             }
315         }
316         return false;
317     }
318     // eliminate redundant squeeze->squeeze
319     if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input)) {
320         PartialShape data_shape;
321         if (op::is_parameter(input)) {
322             data_shape = squeeze_i->input(0).get_partial_shape();
323         } else {
324             data_shape = input->input(0).get_partial_shape();
325         }
326         if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
327             return false;
328         }
329         auto axes = get_squeeze_axes(data_shape, out_shape);
330         return replace_squeeze_only(axes);
331     }
332     return false;
333 }
334
335 static bool eliminate_stop_gradient(const std::shared_ptr<Node>& node) {
336     replace_output_update_name(node->output(0), node->input_value(0));
337     return true;
338 }
339
340 bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function) {
341     static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
342         dispatcher{{TI(opset3::Pad), &eliminate_nop},
343                    {TI(op::v0::Sum), &eliminate_sum},
344                    {TI(opset3::Convert), &eliminate_convert},
345                    {TI(op::v0::Slice), &eliminate_nop},
346                    {TI(op::v0::StopGradient), &eliminate_stop_gradient},
347                    {TI(opset3::Reshape), &eliminate_reshape_v1},
348                    {TI(opset3::Concat), &eliminate_concat},
349                    {TI(opset3::Squeeze), &eliminate_squeeze},
350                    {TI(opset3::Unsqueeze), &eliminate_unsqueeze},
351                    {TI(op::v0::Broadcast), &eliminate_nop}};
352
353     bool clobbered = false;
354
355     for (const auto& node : function->get_ops()) {
356         // Recursively apply transformation for sub-graph based operations
357         if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
358             if (auto sub_graph = sub_graph_node->get_function()) {
359                 clobbered |= run_on_function(sub_graph);
360             }
361         }
362         auto handler = dispatcher.find(node->get_type_info());
363         if (handler != dispatcher.end()) {
364             clobbered |= handler->second(node);
365         }
366     }
367
368     return clobbered;
369 }