1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include "common_test_utils/test_common.hpp"
13 #include <ngraph/function.hpp>
14 #include <ngraph/opsets/opset1.hpp>
15 #include <ngraph/pass/manager.hpp>
16 #include <ngraph/pass/constant_folding.hpp>
17 #include <transformations/common_optimizations/nop_elimination.hpp>
18 #include <transformations/utils/utils.hpp>
19 #include <transformations/init_node_info.hpp>
20 #include <transformations/rt_info/fused_names_attribute.hpp>
22 #include "common_test_utils/ngraph_test_utils.hpp"
24 NGRAPH_SUPPRESS_DEPRECATED_START
26 using namespace ngraph;
29 TEST(nop_elimination, eliminate_sum) {
31 auto A = make_shared<op::Parameter>(element::f32, shape);
32 auto s = make_shared<op::v0::Sum>(A, AxisSet{});
33 auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
35 pass::Manager pass_manager;
36 pass_manager.register_pass<pass::NopElimination>();
37 pass_manager.run_passes(f);
39 ASSERT_EQ(count_ops_of_type<op::v0::Sum>(f), 0);
42 TEST(nop_elimination, eliminate_convert) {
44 auto type = element::f32;
45 auto A = make_shared<op::Parameter>(type, shape);
46 auto c = make_shared<op::v0::Convert>(A, element::f32);
47 auto f = make_shared<Function>(make_shared<op::v0::Abs>(c), ParameterVector{A});
49 pass::Manager pass_manager;
50 pass_manager.register_pass<pass::NopElimination>();
51 pass_manager.run_passes(f);
53 ASSERT_EQ(count_ops_of_type<op::v0::Convert>(f), 0);
56 TEST(nop_elimination, convert_type_agnostic) {
58 auto type = element::from<char>();
59 auto A = make_shared<op::Parameter>(type, shape);
60 auto c1 = make_shared<op::v0::Convert>(A, element::from<uint8_t>());
61 auto c = make_shared<op::v0::Convert>(c1, element::f32);
62 auto z = make_shared<op::v3::NonZero>(c);
63 auto f = make_shared<Function>(make_shared<op::v0::Abs>(z), ParameterVector{A});
65 pass::Manager pass_manager;
66 pass_manager.register_pass<pass::Validate>();
67 pass_manager.register_pass<pass::NopElimination>();
68 pass_manager.run_passes(f);
70 ASSERT_EQ(count_ops_of_type<op::v0::Convert>(f), 0);
73 TEST(nop_elimination, eliminate_slice) {
75 auto A = make_shared<op::Parameter>(element::f32, shape);
76 auto s = make_shared<op::v0::Slice>(A, Coordinate{0, 0}, Coordinate{2, 2});
77 auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
79 pass::Manager pass_manager;
80 pass_manager.register_pass<pass::NopElimination>();
81 pass_manager.run_passes(f);
83 ASSERT_EQ(count_ops_of_type<op::v0::Slice>(f), 0);
86 TEST(nop_elimination, eliminate_broadcast) {
88 auto A = make_shared<op::Parameter>(element::f32, shape);
89 auto b = make_shared<op::v0::Broadcast>(A, shape, AxisSet{});
90 auto f = make_shared<Function>(make_shared<op::v0::Abs>(b), ParameterVector{A});
92 pass::Manager pass_manager;
93 pass_manager.register_pass<pass::NopElimination>();
94 pass_manager.run_passes(f);
96 ASSERT_EQ(count_ops_of_type<op::v0::Broadcast>(f), 0);
99 TEST(nop_elimination, eliminate_stop_gradient) {
101 auto A = make_shared<op::Parameter>(element::f32, shape);
102 auto s = make_shared<op::v0::StopGradient>(A);
103 auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
105 pass::Manager pass_manager;
106 pass_manager.register_pass<pass::NopElimination>();
107 pass_manager.run_passes(f);
109 ASSERT_EQ(count_ops_of_type<op::v0::StopGradient>(f), 0);
112 TEST(nop_elimination, pass_property) {
113 auto pass = std::make_shared<ngraph::pass::NopElimination>();
114 ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
117 TEST(nop_elimination, reshape_elimination_v1) {
118 auto generate_func = [](bool zero) {
119 auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape{8, 16, 2, 3});
120 auto pattern_org = op::Constant::create(element::i64, Shape{3}, vector<int64_t>{8, 16, 6});
121 auto pattern = op::Constant::create(element::i64, Shape{3}, vector<int64_t>{8, 16, 6});
122 auto reshape_v1_org = std::make_shared<op::v1::Reshape>(arg, pattern_org, zero);
123 auto reshape_v1 = std::make_shared<op::v1::Reshape>(reshape_v1_org, pattern, zero);
124 auto abs = std::make_shared<op::v0::Abs>(reshape_v1);
125 return std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
128 auto func = generate_func(false);
129 auto nopass_func = generate_func(false);
130 auto func_zero = generate_func(true);
131 auto nopass_func_zero = generate_func(true);
133 pass::Manager pass_manager;
134 pass_manager.register_pass<pass::NopElimination>();
135 pass_manager.run_passes(func);
136 pass_manager.run_passes(func_zero);
137 ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(nopass_func) == 2);
138 ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func) == 1);
139 ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(nopass_func_zero) == 2);
140 ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
143 TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
144 std::shared_ptr<Function> f;
146 auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
148 auto relu = std::make_shared<opset4::Relu>(arg);
149 relu->set_friendly_name("relu");
151 auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
152 auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
153 squeeze->set_friendly_name("squeeze");
155 auto reshape_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
156 auto reshape = std::make_shared<opset4::Reshape>(squeeze, reshape_shape, false);
157 reshape->set_friendly_name("reshape");
159 auto abs = std::make_shared<opset4::Abs>(reshape);
161 f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
164 pass::Manager pass_manager;
165 pass_manager.register_pass<pass::InitNodeInfo>();
166 pass_manager.register_pass<pass::NopElimination>();
167 pass_manager.run_passes(f);
169 bool reshape_is_missing = true;
170 for (auto node : f->get_ops()) {
171 if (node->get_friendly_name() == "reshape") {
172 reshape_is_missing = false;
173 ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
174 auto original_names = getFusedNamesVector(node);
175 sort(original_names.begin(), original_names.end());
176 ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
179 ASSERT_FALSE(reshape_is_missing);
182 TEST(nop_elimination, reshape_elimination_v1_dynamic) {
183 auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
184 auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
185 auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern, false);
186 auto abs = std::make_shared<op::v0::Abs>(reshape_v1);
187 auto f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg, pattern});
188 pass::Manager pass_manager;
189 pass_manager.register_pass<pass::NopElimination>();
190 pass_manager.run_passes(f);
191 ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
194 TEST(nop_elimination, concat_elimination_single_node) {
196 auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
198 make_shared<Function>(make_shared<op::v0::Concat>(NodeVector{A}, a), ParameterVector{A});
200 pass::Manager pass_manager;
201 pass_manager.register_pass<pass::Validate>();
202 pass_manager.register_pass<pass::NopElimination>();
203 pass_manager.run_passes(f);
205 ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 1);
208 TEST(nop_elimination, concat_elimination_single_input) {
210 auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
211 auto B = make_shared<op::v0::Concat>(NodeVector{A}, a);
212 auto f = make_shared<Function>(make_shared<op::v0::Abs>(B), ParameterVector{A});
214 pass::Manager pass_manager;
215 pass_manager.register_pass<pass::Validate>();
216 pass_manager.register_pass<pass::NopElimination>();
217 pass_manager.run_passes(f);
219 ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
222 TEST(nop_elimination, concat_elimination_single_input_dynamic) {
224 auto A = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3});
225 auto B = make_shared<op::v0::Concat>(NodeVector{A}, a);
226 auto f = make_shared<Function>(make_shared<op::v0::Abs>(B), ParameterVector{A});
228 pass::Manager pass_manager;
229 pass_manager.register_pass<pass::Validate>();
230 pass_manager.register_pass<pass::NopElimination>();
231 pass_manager.run_passes(f);
233 ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
236 TEST(nop_elimination, unsqueeze_elimination) {
237 const auto axis = op::Constant::create<int64_t>(element::i64, {}, {0});
238 const auto A = make_shared<op::Parameter>(
239 element::f32, PartialShape{3, Dimension::dynamic(), Dimension::dynamic()});
240 const auto unsqueeze = make_shared<op::v0::Unsqueeze>(A, axis);
241 auto f = make_shared<Function>(unsqueeze, ParameterVector{A});
243 pass::Manager pass_manager;
244 pass_manager.register_pass<pass::Validate>();
245 pass_manager.register_pass<pass::NopElimination>();
246 pass_manager.run_passes(f);
248 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 1);
251 TEST(nop_elimination, squeeze_unsqueeze_overlap_elimination) {
252 auto check_usecase = [](const PartialShape& shape,
253 const std::vector<int64_t>& sq_axes_val,
254 const std::vector<int64_t>& unsq_axes_val,
261 static size_t id = 0;
262 auto casename = string("usecase #") + to_string(++id);
264 shared_ptr<Node> sq_axes;
265 shared_ptr<Node> unsq_axes;
267 std::vector<int32_t> sq_axes_val_i32(sq_axes_val.begin(), sq_axes_val.end());
268 std::vector<int32_t> unsq_axes_val_i32(unsq_axes_val.begin(), unsq_axes_val.end());
269 sq_axes = op::Constant::create<int32_t>(
270 element::i32, Shape{sq_axes_val.size()}, sq_axes_val_i32);
271 unsq_axes = op::Constant::create<int32_t>(
272 element::i32, Shape{unsq_axes_val.size()}, unsq_axes_val_i32);
275 op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val.size()}, sq_axes_val);
276 unsq_axes = op::Constant::create<int64_t>(
277 element::i64, Shape{unsq_axes_val.size()}, unsq_axes_val);
280 auto A = make_shared<op::Parameter>(element::f32, shape);
283 auto last_dim = shape.rank().get_length() - 1;
284 A1 = make_shared<op::v0::TopK>(A, last_dim, element::i32);
286 A1 = make_shared<op::v0::Abs>(A);
291 auto B = make_shared<op::v0::Squeeze>((multiout ? A1->output(0) : A1), sq_axes);
292 B1 = make_shared<op::v0::Unsqueeze>(B, unsq_axes);
294 auto B = make_shared<op::v0::Unsqueeze>((multiout ? A1->output(0) : A1), unsq_axes);
295 B1 = make_shared<op::v0::Squeeze>(B, sq_axes);
298 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
299 auto optimized_f = clone_function(*baseline_f);
301 pass::Manager pass_manager;
302 pass_manager.register_pass<pass::Validate>();
303 pass_manager.register_pass<pass::NopElimination>();
304 pass_manager.run_passes(optimized_f);
306 auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
307 auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
308 EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
309 ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
311 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1) << casename;
312 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1) << casename;
313 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), sc) << casename;
314 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), usc) << casename;
315 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), rc) << casename;
318 // static shapes, all squeeze/unsqueeze replaced by reshape
319 check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, false, false, 0, 0, 1);
320 check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, true, false, 0, 0, 1);
322 check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, false, true, 0, 0, 1);
323 check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, true, true, 0, 0, 1);
324 check_usecase(PartialShape{1}, {0}, {0, 1, 2, 3}, true, true, true, 0, 0, 1);
326 // axes match - expect all squeeze/unsqueeze/reshape cancel out
327 check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {1, 2}, true, true, true, 0, 0, 0);
329 // dynamic shapes - axes match, expect all cancel
330 check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1},
339 check_usecase(PartialShape{1, Dimension::dynamic(), 1, 2, 1},
349 // squeeze axes overlap fully
351 PartialShape{Dimension::dynamic(), 1, 1, 3}, {1, 2}, {1, 2, 3}, true, true, true, 0, 0, 1);
352 check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
361 check_usecase(PartialShape{2, 1, 1, 4}, {1, 2}, {1, 2, 3}, true, true, true, 0, 0, 1);
362 check_usecase(PartialShape{2, 1, 1, Dimension::dynamic(), Dimension::dynamic()},
371 check_usecase(PartialShape{1, Dimension::dynamic(), 1, 1, Dimension::dynamic()},
381 // unsqueeze axes overlap fully
382 check_usecase(PartialShape{1, Dimension::dynamic(), 1, 1, 1, Dimension::dynamic(), 3},
391 check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1},
401 PartialShape{Dimension::dynamic(), 3, 1, 1}, {2, 3}, {2}, true, true, true, 0, 0, 1);
402 check_usecase(PartialShape{3, 1, 1}, {1, 2}, {1}, true, true, true, 0, 0, 1);
404 // squeeze->unsqueeze axes overlap
406 PartialShape{Dimension::dynamic(), 1, 1, 4}, {1, 2}, {0}, true, true, true, 0, 0, 1);
407 check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
416 check_usecase(PartialShape{3, 1, 1, 4}, {1, 2}, {0}, true, true, true, 0, 0, 1);
417 check_usecase(PartialShape{2, 1, 1, Dimension::dynamic(), Dimension::dynamic()},
426 check_usecase(PartialShape{Dimension::dynamic(), 1, 1, 3, Dimension::dynamic(), 4},
435 check_usecase(PartialShape{2, 1, Dimension::dynamic(), 1, Dimension::dynamic()},
444 check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1, 1, 4},
455 // Unsqueeze->Squeeze cases, testcase 23 - ..
457 // static shapes, all unsqueeze/squeeze replaced by reshape
458 check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, false, false, 0, 0, 1);
459 check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, true, false, 0, 0, 1);
461 check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, false, true, 0, 0, 1);
462 check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, true, true, 0, 0, 1);
463 check_usecase(PartialShape{1}, {0}, {0, 1, 2, 3}, false, true, true, 0, 0, 1);
464 check_usecase(PartialShape{3, 1, 1, 4}, {2, 3}, {0}, false, true, true, 0, 0, 1);
466 check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1},
475 check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
484 check_usecase(PartialShape{Dimension::dynamic(), 1, 1, 4}, {2}, {0}, true, true, true, 0, 0, 1);
485 check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1},
496 TEST(nop_elimination, squeeze_squeeze_overlap_elimination) {
497 auto check_usecase = [](const PartialShape& shape,
498 const std::vector<int64_t>& sq_axes_val_1,
499 const std::vector<int64_t>& sq_axes_val_2,
501 static size_t id = 0;
502 auto casename = string("usecase #") + to_string(++id);
504 op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val_1.size()}, sq_axes_val_1);
506 op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val_2.size()}, sq_axes_val_2);
507 auto A = make_shared<op::Parameter>(element::f32, shape);
508 auto A1 = make_shared<op::v0::Abs>(A);
509 auto B = make_shared<op::v0::Squeeze>(A1, sq_axes_1);
510 auto B1 = make_shared<op::v0::Squeeze>(B, sq_axes_2);
511 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
512 auto optimized_f = clone_function(*baseline_f);
514 pass::Manager pass_manager;
515 pass_manager.register_pass<pass::Validate>();
516 pass_manager.register_pass<pass::NopElimination>();
517 pass_manager.run_passes(optimized_f);
518 auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
519 auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
520 EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
521 ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
522 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 2) << casename;
523 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), sq) << casename;
526 check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic()}, {0}, {1}, 1);
528 PartialShape{1, 1, 1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {2, 1}, {2, 4}, 1);
530 PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), 1, 1}, {-1, -5}, {2}, 1);
532 PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {0}, {1, 3}, 1);
535 TEST(nop_elimination, unsqueeze_unsqueeze_overlap_elimination) {
536 auto check_usecase = [](const PartialShape& shape,
537 const std::vector<int64_t>& unsq_axes_val_1,
538 const std::vector<int64_t>& unsq_axes_val_2,
540 static size_t id = 0;
541 auto casename = string("usecase #") + to_string(++id);
542 auto unsq_axes_1 = op::Constant::create<int64_t>(
543 element::i64, Shape{unsq_axes_val_1.size()}, unsq_axes_val_1);
544 auto unsq_axes_2 = op::Constant::create<int64_t>(
545 element::i64, Shape{unsq_axes_val_2.size()}, unsq_axes_val_2);
546 auto A = make_shared<op::Parameter>(element::f32, shape);
547 auto A1 = make_shared<op::v0::Abs>(A);
548 auto B = make_shared<op::v0::Unsqueeze>(A1, unsq_axes_1);
549 auto B1 = make_shared<op::v0::Unsqueeze>(B, unsq_axes_2);
550 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
551 auto optimized_f = clone_function(*baseline_f);
553 pass::Manager pass_manager;
554 pass_manager.register_pass<pass::Validate>();
555 pass_manager.register_pass<pass::NopElimination>();
556 pass_manager.run_passes(optimized_f);
557 auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
558 auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
559 EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
560 ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
561 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 2) << casename;
562 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), unsq) << casename;
565 check_usecase(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()}, {0}, {2}, 1);
567 PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {2, 1}, {2, 4}, 1);
568 check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1}, {-1, -3}, {2}, 1);
569 check_usecase(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {0}, {1, 3}, 1);
572 TEST(nop_elimination, unsqueeze_squeeze_elimination) {
573 auto generate_func = [](const Shape& shape, const std::vector<int64_t>& axes_val) {
574 auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
575 auto A = make_shared<op::Parameter>(element::f32, shape);
576 auto A1 = make_shared<op::v0::Abs>(A);
577 auto B = make_shared<op::v0::Unsqueeze>(A1, axes);
578 auto B1 = make_shared<op::v0::Squeeze>(B, axes);
579 return make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
582 auto check_usecase = [&](const Shape& shape, const std::vector<int64_t>& axes_val) {
583 auto baseline_f = generate_func(shape, axes_val);
584 auto optimized_f = generate_func(shape, axes_val);
585 pass::NopElimination().run_on_function(optimized_f);
587 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
588 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
589 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
590 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
593 check_usecase(Shape{6}, std::vector<int64_t>{0});
594 check_usecase(Shape{3, 2}, std::vector<int64_t>{0, 3});
595 check_usecase(Shape{3, 2}, std::vector<int64_t>{0, 2, 4});
596 check_usecase(Shape{3, 2}, std::vector<int64_t>{-1, -4});
599 TEST(nop_elimination, reshape_unsqueeze_elimination) {
600 auto check_usecase = [](const Shape& shape,
601 const std::vector<int64_t>& pat_val,
603 const std::vector<int64_t>& axes_val) {
604 auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
605 auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
606 auto A = make_shared<op::Parameter>(element::f32, shape);
607 auto A1 = make_shared<op::v0::Abs>(A);
609 auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
611 op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
612 auto B1 = make_shared<op::v0::Unsqueeze>(B, axes);
613 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
614 auto optimized_f = clone_function(*baseline_f);
615 pass::NopElimination().run_on_function(optimized_f);
617 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
618 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
619 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
622 check_usecase(Shape{1, 2, 3, 2, 1}, {2, 3, 2}, false, {2, 4});
623 check_usecase(Shape{12}, {2, 3, 2}, false, {3});
624 check_usecase(Shape{3, 2, 1, 2}, {0, 2, 2}, true, {1, 4});
625 check_usecase(Shape{2, 3, 2}, {2, -1, 2}, false, {2});
626 check_usecase(Shape{2, 3, 2, 1}, {2, 3, 2}, false, {0});
628 TEST(nop_elimination, reshape_squeeze_elimination) {
629 auto check_usecase = [](const Shape& shape,
630 const std::vector<int64_t>& pat_val,
632 const std::vector<int64_t>& axes_val) {
633 auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
634 auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
635 auto A = make_shared<op::Parameter>(element::f32, shape);
636 auto A1 = make_shared<op::v0::Abs>(A);
638 auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
640 op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
641 auto B1 = make_shared<op::v0::Squeeze>(B, axes);
642 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
643 auto optimized_f = clone_function(*baseline_f);
644 pass::NopElimination().run_on_function(optimized_f);
646 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
647 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
648 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
651 check_usecase(Shape{1, 2, 3, 2, 1}, {2, 3, 1, 2, 1}, false, {2, 4});
652 check_usecase(Shape{12}, {2, 3, 2, 1}, false, {3});
653 check_usecase(Shape{3, 2, 1, 2}, {0, 1, 2, 2, 1}, true, {1, 4});
654 check_usecase(Shape{2, 3, 2}, {2, -1, 1, 2}, false, {2});
655 check_usecase(Shape{2, 3, 2, 1}, {1, 2, 3, 2}, false, {0});
658 TEST(nop_elimination, reshape_reshape_elimination) {
659 auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& pat_val, bool zero) {
660 auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
661 auto A = make_shared<op::Parameter>(element::f32, shape);
662 auto A1 = make_shared<op::v0::Abs>(A);
664 auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
666 op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
667 auto B1 = make_shared<op::v1::Reshape>(B, pat2, true);
668 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
669 auto optimized_f = clone_function(*baseline_f);
670 pass::NopElimination().run_on_function(optimized_f);
672 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 2);
673 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
676 check_usecase(Shape{1, 2, 3, 2, 1}, std::vector<int64_t>{2, 3, 2}, false);
677 check_usecase(Shape{12}, std::vector<int64_t>{2, 3, 2}, false);
678 check_usecase(Shape{3, 2, 1, 2}, std::vector<int64_t>{0, 2, 2}, true);
679 check_usecase(Shape{2, 3, 2}, ::vector<int64_t>{2, -1, 2}, false);
680 check_usecase(Shape{2, 3, 2, 1}, ::vector<int64_t>{2, 3, 2}, false);
683 TEST(nop_elimination, squeeze_reshape_elimination) {
684 auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
686 op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
687 auto A = make_shared<op::Parameter>(element::f32, shape);
688 auto A1 = make_shared<op::v0::Abs>(A);
690 auto B = make_shared<op::v0::Squeeze>(A1, indices);
691 auto pat2 = op::Constant::create<int64_t>(element::i64, Shape{1}, std::vector<int64_t>{-1});
692 auto B1 = make_shared<op::v1::Reshape>(B, pat2, false);
693 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
694 auto optimized_f = clone_function(*baseline_f);
695 pass::NopElimination().run_on_function(optimized_f);
697 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
698 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
699 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
700 ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
703 check_usecase(Shape{1, 2, 3, 2, 1}, std::vector<int64_t>{0, 4});
704 check_usecase(Shape{1, 1}, std::vector<int64_t>{0, 1});
705 check_usecase(Shape{2, 3, 1, 2}, std::vector<int64_t>{2});
706 check_usecase(Shape{1, 6, 2, 1}, std::vector<int64_t>{3});
709 TEST(nop_elimination, unsqueeze_reshape_elimination) {
710 auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
712 op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
713 auto A = make_shared<op::Parameter>(element::f32, shape);
714 auto A1 = make_shared<op::v0::Abs>(A);
716 auto B = make_shared<op::v0::Unsqueeze>(A1, indices);
717 auto pat2 = op::Constant::create<int64_t>(element::i64, Shape{1}, std::vector<int64_t>{-1});
718 auto B1 = make_shared<op::v1::Reshape>(B, pat2, false);
719 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
720 auto optimized_f = clone_function(*baseline_f);
721 pass::NopElimination().run_on_function(optimized_f);
723 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
724 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
725 ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
726 ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
729 check_usecase(Shape{2, 3, 2}, std::vector<int64_t>{0, 4});
730 check_usecase(Shape{}, std::vector<int64_t>{0, 1});
731 check_usecase(Shape{2, 3, 2}, std::vector<int64_t>{2});
732 check_usecase(Shape{1, 6, 2}, std::vector<int64_t>{3});
735 TEST(nop_elimination, squeeze_unsqueeze_elimination_negative) {
736 auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
737 auto indices = op::Constant::create(element::i64, Shape{indices_val.size()}, indices_val);
738 auto input = make_shared<op::Parameter>(element::f32, shape);
739 auto squeeze = make_shared<ngraph::opset1::Squeeze>(input, indices);
740 auto baseline_f = make_shared<Function>(squeeze, ParameterVector{input});
741 auto optimized_f = clone_function(*baseline_f);
742 pass::NopElimination().run_on_function(optimized_f);
744 ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(baseline_f), 1);
745 ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(optimized_f), 1);
748 check_usecase(Shape{1, 1, 1}, std::vector<int64_t>{0, 1, 2});
751 TEST(nop_elimination, topk_convert_elimination) {
752 auto check_usecase = []() {
753 auto A = make_shared<op::Parameter>(element::f32, Shape{20, 3, 4});
754 auto A1 = make_shared<op::v0::Abs>(A);
755 auto B = make_shared<op::TopK>(A1, 0, element::i64, 10);
756 auto C = make_shared<op::Convert>(B->output(0), B->output(0).get_element_type());
757 auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(C), ParameterVector{A});
758 auto optimized_f = clone_function(*baseline_f);
759 pass::NopElimination().run_on_function(optimized_f);
761 ASSERT_EQ(count_ops_of_type<op::Convert>(baseline_f), 1);
762 ASSERT_EQ(count_ops_of_type<op::Convert>(optimized_f), 0);