1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
21 #include <unordered_map>
23 #include <ngraph/opsets/opset3.hpp>
24 #include <ngraph/util.hpp>
25 #include <ngraph/log.hpp>
26 #include <transformations/common_optimizations/nop_elimination.hpp>
28 NGRAPH_SUPPRESS_DEPRECATED_START
31 using namespace ngraph;
33 #define TI(x) x::type_info
35 static bool eliminate_nop(const std::shared_ptr<Node>& node) {
36 // skip if shapes are dynamic
37 if (node->get_input_partial_shape(0).is_dynamic() ||
38 node->get_output_partial_shape(0).is_dynamic()) {
42 if (node->get_input_shape(0) == node->get_output_shape(0)) {
43 return replace_output_update_name(node->output(0), node->input_value(0));
48 static bool eliminate_sum(const std::shared_ptr<Node>& node) {
49 auto sum = as_type_ptr<op::v0::Sum>(node);
50 if (sum->get_reduction_axes().empty()) {
51 return replace_output_update_name(node->output(0), node->input_value(0));
56 static bool eliminate_convert(const std::shared_ptr<Node>& node) {
57 bool is_out_type_agnostic = false;
58 static const std::set<NodeTypeInfo> type_agnostic{TI(opset3::NonZero)};
59 if (node->output(0).get_target_inputs().size() == 1) {
60 Input<Node> out = *node->output(0).get_target_inputs().begin();
61 is_out_type_agnostic = type_agnostic.count(out.get_node()->get_type_info()) == 1;
63 auto convert = as_type_ptr<opset3::Convert>(node);
64 auto input = convert->input_value(0);
65 if (convert->get_convert_element_type() == input.get_element_type() || is_out_type_agnostic) {
66 if (is_out_type_agnostic && is_type<opset3::Convert>(input.get_node())) {
67 input = input.get_node()->input_value(0);
69 return replace_output_update_name(node->output(0), input);
74 static bool eliminate_concat(const std::shared_ptr<Node>& node) {
75 auto node_input = node->input_value(0);
77 // remove concat with single input
78 if (node->get_input_size() == 1) {
79 return replace_output_update_name(node->output(0), node_input);
84 static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
85 auto input = node->input_value(0);
86 // check if reshape is not identity op
87 if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
88 NGRAPH_DEBUG << node << " has dynamic shapes.";
92 if (input.get_shape() == node->get_output_shape(0)) {
93 return replace_output_update_name(node->output(0), input);
95 // eliminate redundant reshape, squeeze, or unsqueeze
96 auto input_node = input.get_node_shared_ptr();
97 if (as_type_ptr<opset3::Squeeze>(input_node) ||
98 as_type_ptr<opset3::Unsqueeze>(input_node) ||
99 as_type_ptr<opset3::Reshape>(input_node)) {
100 auto shape = node->get_output_shape(0);
101 std::vector<int64_t> vi;
102 vi.assign(shape.begin(), shape.end());
103 auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
105 make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
106 new_reshape->set_friendly_name(node->get_friendly_name());
107 copy_runtime_info({input_node, node}, new_reshape);
108 replace_node(node, new_reshape);
115 static size_t count_unknown_dims(const PartialShape& ps) {
117 if (ps.is_static()) {
120 for (auto i = 0; i < ps.rank().get_length(); i++) {
121 if (ps[i].is_dynamic()) {
128 static bool replace_squeeze_unsqueeze(const std::shared_ptr<Node>& node) {
129 auto shape_ps = node->get_output_partial_shape(0);
130 if (shape_ps.rank().get_length() == 0) {
133 if (count_unknown_dims(shape_ps) > 1) {
136 std::vector<int64_t> target_shape;
137 for (auto i = 0; i < shape_ps.rank().get_length(); i++) {
138 if (shape_ps[i].is_dynamic()) {
139 target_shape.emplace_back(-1);
141 target_shape.emplace_back(shape_ps[i].get_length());
145 shared_ptr<Node> reshape;
146 auto input = node->input_value(0).get_node_shared_ptr();
148 opset3::Constant::create<int64_t>(element::i64, Shape{target_shape.size()}, target_shape);
150 if (is_type<opset3::Reshape>(input) || is_type<opset3::Squeeze>(input) ||
151 is_type<opset3::Unsqueeze>(input)) {
152 reshape = make_shared<opset3::Reshape>(input->input_value(0), pat, false);
154 reshape = make_shared<opset3::Reshape>(node->input_value(0), pat, false);
157 // skip if reshape is nop
158 if (reshape->get_input_partial_shape(0).same_scheme(shape_ps)) {
159 return replace_output_update_name(node->output(0), reshape->input_value(0));
161 return replace_node_update_name(node, reshape);
165 static std::vector<int64_t> get_unsqueeze_axes(const PartialShape& data_shape,
166 const PartialShape& out_shape) {
167 std::vector<int64_t> axes;
169 for (auto o = 0; o < out_shape.rank().get_length(); o++) {
170 if (i < data_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[o])) {
174 if (out_shape[o].is_static() && out_shape[o] == 1) {
181 static std::vector<int64_t> get_squeeze_axes(const PartialShape& data_shape,
182 const PartialShape& out_shape) {
183 std::vector<int64_t> axes;
185 for (auto i = 0; i < data_shape.rank().get_length(); i++) {
186 if (out_i < out_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[out_i])) {
190 if (data_shape[i].is_static() && data_shape[i] == 1) {
197 static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node) {
198 auto out_shape = node->get_output_partial_shape(0);
199 // try to replace all squeeze/unsqueeze with reshape
200 if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
201 return replace_squeeze_unsqueeze(node);
204 auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
205 auto input = unsqueeze->input_value(0).get_node_shared_ptr();
206 auto squeeze = as_type_ptr<opset3::Squeeze>(input);
207 auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
208 auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
209 auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
210 if (unsqueeze->get_output_partial_shape(0).same_scheme(
211 new_unsq->get_output_partial_shape(0))) {
212 return replace_node_update_name(unsqueeze, new_unsq);
216 // eliminate redundant squeeze->unsqueeze
218 const auto& data_shape = squeeze->input_value(0).get_partial_shape();
219 if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
220 unsqueeze->input_value(1).get_node_shared_ptr())) {
221 return replace_output_update_name(unsqueeze->output(0), squeeze->input_value(0));
223 if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
226 if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
227 // check if single unsqueeze can handle this
228 auto axes = get_unsqueeze_axes(data_shape, out_shape);
229 if (axes.size() + data_shape.rank().get_length() == out_shape.rank().get_length()) {
230 return replace_unsqueeze_only(axes);
233 if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
234 // check if single squeeze can handle this
235 auto axes = get_squeeze_axes(data_shape, out_shape);
236 if (data_shape.rank().get_length() - axes.size() == out_shape.rank().get_length()) {
238 opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
239 auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
240 if (unsqueeze->get_output_partial_shape(0).same_scheme(
241 new_sq->get_output_partial_shape(0))) {
242 return replace_node_update_name(unsqueeze, new_sq);
249 // eliminate redundant unsqueeze->unsqueeze
250 auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
252 const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
253 if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
256 auto axes = get_unsqueeze_axes(data_shape, out_shape);
257 return replace_unsqueeze_only(axes);
263 static bool eliminate_squeeze(const std::shared_ptr<Node>& node) {
264 auto out_shape = node->get_output_partial_shape(0);
265 // try to replace all unsqueeze/squeeze with reshape
266 if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
267 return replace_squeeze_unsqueeze(node);
270 auto squeeze = as_type_ptr<opset3::Squeeze>(node);
271 auto input = squeeze->input_value(0).get_node_shared_ptr();
272 auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
273 auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
274 auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
275 if (squeeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0))) {
276 return replace_node_update_name(squeeze, new_sq);
280 // eliminate redundant unsqueeze->squeeze
281 if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input)) {
282 PartialShape data_shape;
283 if (op::is_parameter(input)) {
284 data_shape = unsqueeze->input(0).get_partial_shape();
286 data_shape = input->input(0).get_partial_shape();
288 if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
289 squeeze->input_value(1).get_node_shared_ptr())) {
290 return replace_output_update_name(squeeze->output(0), unsqueeze->input_value(0));
292 if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
295 if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
296 // check if single squeeze can handle this
297 auto axes = get_squeeze_axes(data_shape, out_shape);
298 if (data_shape.rank().get_length() == out_shape.rank().get_length() + axes.size()) {
299 return replace_squeeze_only(axes);
302 if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
303 // check if single unsqueeze can handle this
304 auto axes = get_unsqueeze_axes(data_shape, out_shape);
305 if (data_shape.rank().get_length() + axes.size() == out_shape.rank().get_length()) {
307 opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
308 auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
309 if (squeeze->get_output_partial_shape(0).same_scheme(
310 new_unsq->get_output_partial_shape(0))) {
311 replace_output_update_name(squeeze, new_unsq);
318 // eliminate redundant squeeze->squeeze
319 if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input)) {
320 PartialShape data_shape;
321 if (op::is_parameter(input)) {
322 data_shape = squeeze_i->input(0).get_partial_shape();
324 data_shape = input->input(0).get_partial_shape();
326 if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
329 auto axes = get_squeeze_axes(data_shape, out_shape);
330 return replace_squeeze_only(axes);
335 static bool eliminate_stop_gradient(const std::shared_ptr<Node>& node) {
336 replace_output_update_name(node->output(0), node->input_value(0));
340 bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function) {
341 static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
342 dispatcher{{TI(opset3::Pad), &eliminate_nop},
343 {TI(op::v0::Sum), &eliminate_sum},
344 {TI(opset3::Convert), &eliminate_convert},
345 {TI(op::v0::Slice), &eliminate_nop},
346 {TI(op::v0::StopGradient), &eliminate_stop_gradient},
347 {TI(opset3::Reshape), &eliminate_reshape_v1},
348 {TI(opset3::Concat), &eliminate_concat},
349 {TI(opset3::Squeeze), &eliminate_squeeze},
350 {TI(op::v1::Broadcast), &eliminate_nop},
351 {TI(opset3::Unsqueeze), &eliminate_unsqueeze}};
353 bool clobbered = false;
355 for (const auto& node : function->get_ops()) {
356 // Recursively apply transformation for sub-graph based operations
357 if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
358 if (auto sub_graph = sub_graph_node->get_function()) {
359 clobbered |= run_on_function(sub_graph);
362 auto handler = dispatcher.find(node->get_type_info());
363 if (handler != dispatcher.end()) {
364 clobbered |= handler->second(node);