3b2befc66d2cff1cc8a93bc309450e0090493f76
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / transformations / nop_elimination.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6
7 #include "common_test_utils/test_common.hpp"
8 #include <string>
9 #include <sstream>
10 #include <memory>
11 #include <queue>
12
13 #include <ngraph/function.hpp>
14 #include <ngraph/opsets/opset1.hpp>
15 #include <ngraph/pass/manager.hpp>
16 #include <ngraph/pass/constant_folding.hpp>
17 #include <transformations/common_optimizations/nop_elimination.hpp>
18 #include <transformations/utils/utils.hpp>
19 #include <transformations/init_node_info.hpp>
20 #include <transformations/rt_info/fused_names_attribute.hpp>
21
22 #include "common_test_utils/ngraph_test_utils.hpp"
23
24 NGRAPH_SUPPRESS_DEPRECATED_START
25
26 using namespace ngraph;
27 using namespace std;
28
29 TEST(nop_elimination, eliminate_sum) {
30     Shape shape{2, 2};
31     auto A = make_shared<op::Parameter>(element::f32, shape);
32     auto s = make_shared<op::v0::Sum>(A, AxisSet{});
33     auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
34
35     pass::Manager pass_manager;
36     pass_manager.register_pass<pass::NopElimination>();
37     pass_manager.run_passes(f);
38
39     ASSERT_EQ(count_ops_of_type<op::v0::Sum>(f), 0);
40 }
41
42 TEST(nop_elimination, eliminate_convert) {
43     Shape shape{};
44     auto type = element::f32;
45     auto A = make_shared<op::Parameter>(type, shape);
46     auto c = make_shared<op::v0::Convert>(A, element::f32);
47     auto f = make_shared<Function>(make_shared<op::v0::Abs>(c), ParameterVector{A});
48
49     pass::Manager pass_manager;
50     pass_manager.register_pass<pass::NopElimination>();
51     pass_manager.run_passes(f);
52
53     ASSERT_EQ(count_ops_of_type<op::v0::Convert>(f), 0);
54 }
55
56 TEST(nop_elimination, convert_type_agnostic) {
57     Shape shape{};
58     auto type = element::from<char>();
59     auto A = make_shared<op::Parameter>(type, shape);
60     auto c1 = make_shared<op::v0::Convert>(A, element::from<uint8_t>());
61     auto c = make_shared<op::v0::Convert>(c1, element::f32);
62     auto z = make_shared<op::v3::NonZero>(c);
63     auto f = make_shared<Function>(make_shared<op::v0::Abs>(z), ParameterVector{A});
64
65     pass::Manager pass_manager;
66     pass_manager.register_pass<pass::Validate>();
67     pass_manager.register_pass<pass::NopElimination>();
68     pass_manager.run_passes(f);
69
70     ASSERT_EQ(count_ops_of_type<op::v0::Convert>(f), 0);
71 }
72
73 TEST(nop_elimination, eliminate_slice) {
74     Shape shape{2, 2};
75     auto A = make_shared<op::Parameter>(element::f32, shape);
76     auto s = make_shared<op::v0::Slice>(A, Coordinate{0, 0}, Coordinate{2, 2});
77     auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
78
79     pass::Manager pass_manager;
80     pass_manager.register_pass<pass::NopElimination>();
81     pass_manager.run_passes(f);
82
83     ASSERT_EQ(count_ops_of_type<op::v0::Slice>(f), 0);
84 }
85
86 TEST(nop_elimination, eliminate_broadcast) {
87     Shape shape{};
88     auto A = make_shared<op::Parameter>(element::f32, shape);
89     auto b = make_shared<op::v0::Broadcast>(A, shape, AxisSet{});
90     auto f = make_shared<Function>(make_shared<op::v0::Abs>(b), ParameterVector{A});
91
92     pass::Manager pass_manager;
93     pass_manager.register_pass<pass::NopElimination>();
94     pass_manager.run_passes(f);
95
96     ASSERT_EQ(count_ops_of_type<op::v0::Broadcast>(f), 0);
97 }
98
99 TEST(nop_elimination, eliminate_stop_gradient) {
100     Shape shape{};
101     auto A = make_shared<op::Parameter>(element::f32, shape);
102     auto s = make_shared<op::v0::StopGradient>(A);
103     auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
104
105     pass::Manager pass_manager;
106     pass_manager.register_pass<pass::NopElimination>();
107     pass_manager.run_passes(f);
108
109     ASSERT_EQ(count_ops_of_type<op::v0::StopGradient>(f), 0);
110 }
111
112 TEST(nop_elimination, pass_property) {
113     auto pass = std::make_shared<ngraph::pass::NopElimination>();
114     ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
115 }
116
117 TEST(nop_elimination, reshape_elimination_v1) {
118     auto generate_func = [](bool zero) {
119         auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape{8, 16, 2, 3});
120         auto pattern_org = op::Constant::create(element::i64, Shape{3}, vector<int64_t>{8, 16, 6});
121         auto pattern = op::Constant::create(element::i64, Shape{3}, vector<int64_t>{8, 16, 6});
122         auto reshape_v1_org = std::make_shared<op::v1::Reshape>(arg, pattern_org, zero);
123         auto reshape_v1 = std::make_shared<op::v1::Reshape>(reshape_v1_org, pattern, zero);
124         auto abs = std::make_shared<op::v0::Abs>(reshape_v1);
125         return std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
126     };
127
128     auto func = generate_func(false);
129     auto nopass_func = generate_func(false);
130     auto func_zero = generate_func(true);
131     auto nopass_func_zero = generate_func(true);
132
133     pass::Manager pass_manager;
134     pass_manager.register_pass<pass::NopElimination>();
135     pass_manager.run_passes(func);
136     pass_manager.run_passes(func_zero);
137     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(nopass_func) == 2);
138     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func) == 1);
139     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(nopass_func_zero) == 2);
140     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
141 }
142
143 TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
144     std::shared_ptr<Function> f;
145     {
146         auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});
147
148         auto relu = std::make_shared<opset4::Relu>(arg);
149         relu->set_friendly_name("relu");
150
151         auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
152         auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
153         squeeze->set_friendly_name("squeeze");
154
155         auto reshape_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
156         auto reshape = std::make_shared<opset4::Reshape>(squeeze, reshape_shape, false);
157         reshape->set_friendly_name("reshape");
158
159         auto abs = std::make_shared<opset4::Abs>(reshape);
160
161         f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
162     }
163
164     pass::Manager pass_manager;
165     pass_manager.register_pass<pass::InitNodeInfo>();
166     pass_manager.register_pass<pass::NopElimination>();
167     pass_manager.run_passes(f);
168
169     bool reshape_is_missing = true;
170     for (auto node : f->get_ops()) {
171         if (node->get_friendly_name() == "reshape") {
172             reshape_is_missing = false;
173             ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
174             auto original_names = getFusedNamesVector(node);
175             sort(original_names.begin(), original_names.end());
176             ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
177         }
178     }
179     ASSERT_FALSE(reshape_is_missing);
180 }
181
182 TEST(nop_elimination, reshape_elimination_v1_dynamic) {
183     auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
184     auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
185     auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern, false);
186     auto abs = std::make_shared<op::v0::Abs>(reshape_v1);
187     auto f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg, pattern});
188     pass::Manager pass_manager;
189     pass_manager.register_pass<pass::NopElimination>();
190     pass_manager.run_passes(f);
191     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
192 }
193
194 TEST(nop_elimination, concat_elimination_single_node) {
195     int64_t a = 0;
196     auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
197     auto f =
198         make_shared<Function>(make_shared<op::v0::Concat>(NodeVector{A}, a), ParameterVector{A});
199
200     pass::Manager pass_manager;
201     pass_manager.register_pass<pass::Validate>();
202     pass_manager.register_pass<pass::NopElimination>();
203     pass_manager.run_passes(f);
204
205     ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 1);
206 }
207
208 TEST(nop_elimination, concat_elimination_single_input) {
209     int64_t a = 0;
210     auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
211     auto B = make_shared<op::v0::Concat>(NodeVector{A}, a);
212     auto f = make_shared<Function>(make_shared<op::v0::Abs>(B), ParameterVector{A});
213
214     pass::Manager pass_manager;
215     pass_manager.register_pass<pass::Validate>();
216     pass_manager.register_pass<pass::NopElimination>();
217     pass_manager.run_passes(f);
218
219     ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
220 }
221
222 TEST(nop_elimination, concat_elimination_single_input_dynamic) {
223     int64_t a = 0;
224     auto A = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3});
225     auto B = make_shared<op::v0::Concat>(NodeVector{A}, a);
226     auto f = make_shared<Function>(make_shared<op::v0::Abs>(B), ParameterVector{A});
227
228     pass::Manager pass_manager;
229     pass_manager.register_pass<pass::Validate>();
230     pass_manager.register_pass<pass::NopElimination>();
231     pass_manager.run_passes(f);
232
233     ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
234 }
235
236 TEST(nop_elimination, unsqueeze_elimination) {
237     const auto axis = op::Constant::create<int64_t>(element::i64, {}, {0});
238     const auto A = make_shared<op::Parameter>(
239         element::f32, PartialShape{3, Dimension::dynamic(), Dimension::dynamic()});
240     const auto unsqueeze = make_shared<op::v0::Unsqueeze>(A, axis);
241     auto f = make_shared<Function>(unsqueeze, ParameterVector{A});
242
243     pass::Manager pass_manager;
244     pass_manager.register_pass<pass::Validate>();
245     pass_manager.register_pass<pass::NopElimination>();
246     pass_manager.run_passes(f);
247
248     ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 1);
249 }
250
251 TEST(nop_elimination, squeeze_unsqueeze_overlap_elimination) {
252     auto check_usecase = [](const PartialShape& shape,
253                             const std::vector<int64_t>& sq_axes_val,
254                             const std::vector<int64_t>& unsq_axes_val,
255                             bool sq_to_unsq,
256                             bool i32,
257                             bool multiout,
258                             size_t sc,
259                             size_t usc,
260                             size_t rc) {
261         static size_t id = 0;
262         auto casename = string("usecase #") + to_string(++id);
263
264         shared_ptr<Node> sq_axes;
265         shared_ptr<Node> unsq_axes;
266         if (i32) {
267             std::vector<int32_t> sq_axes_val_i32(sq_axes_val.begin(), sq_axes_val.end());
268             std::vector<int32_t> unsq_axes_val_i32(unsq_axes_val.begin(), unsq_axes_val.end());
269             sq_axes = op::Constant::create<int32_t>(
270                 element::i32, Shape{sq_axes_val.size()}, sq_axes_val_i32);
271             unsq_axes = op::Constant::create<int32_t>(
272                 element::i32, Shape{unsq_axes_val.size()}, unsq_axes_val_i32);
273         } else {
274             sq_axes =
275                 op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val.size()}, sq_axes_val);
276             unsq_axes = op::Constant::create<int64_t>(
277                 element::i64, Shape{unsq_axes_val.size()}, unsq_axes_val);
278         }
279
280         auto A = make_shared<op::Parameter>(element::f32, shape);
281         shared_ptr<Node> A1;
282         if (multiout) {
283             auto last_dim = shape.rank().get_length() - 1;
284             A1 = make_shared<op::v0::TopK>(A, last_dim, element::i32);
285         } else {
286             A1 = make_shared<op::v0::Abs>(A);
287         }
288
289         shared_ptr<Node> B1;
290         if (sq_to_unsq) {
291             auto B = make_shared<op::v0::Squeeze>((multiout ? A1->output(0) : A1), sq_axes);
292             B1 = make_shared<op::v0::Unsqueeze>(B, unsq_axes);
293         } else {
294             auto B = make_shared<op::v0::Unsqueeze>((multiout ? A1->output(0) : A1), unsq_axes);
295             B1 = make_shared<op::v0::Squeeze>(B, sq_axes);
296         }
297
298         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
299         auto optimized_f = clone_function(*baseline_f);
300
301         pass::Manager pass_manager;
302         pass_manager.register_pass<pass::Validate>();
303         pass_manager.register_pass<pass::NopElimination>();
304         pass_manager.run_passes(optimized_f);
305
306         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
307         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
308         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
309         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
310
311         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1) << casename;
312         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1) << casename;
313         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), sc) << casename;
314         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), usc) << casename;
315         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), rc) << casename;
316     };
317
318     // static shapes, all squeeze/unsqueeze replaced by reshape
319     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, false, false, 0, 0, 1);
320     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, true, false, 0, 0, 1);
321     // multioutout ops
322     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, false, true, 0, 0, 1);
323     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, true, true, 0, 0, 1);
324     check_usecase(PartialShape{1}, {0}, {0, 1, 2, 3}, true, true, true, 0, 0, 1);
325
326     // axes match - expect all squeeze/unsqueeze/reshape cancel out
327     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {1, 2}, true, true, true, 0, 0, 0);
328
329     // dynamic shapes - axes match, expect all cancel
330     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1},
331                   {0, 2, 4},
332                   {0, 2, 4},
333                   true,
334                   true,
335                   true,
336                   0,
337                   0,
338                   0);
339     check_usecase(PartialShape{1, Dimension::dynamic(), 1, 2, 1},
340                   {0, 2, 4},
341                   {0, 2, 4},
342                   true,
343                   false,
344                   true,
345                   0,
346                   0,
347                   0);
348
349     // squeeze axes overlap fully
350     check_usecase(
351         PartialShape{Dimension::dynamic(), 1, 1, 3}, {1, 2}, {1, 2, 3}, true, true, true, 0, 0, 1);
352     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
353                   {1, 2},
354                   {1, 2, 3},
355                   true,
356                   true,
357                   true,
358                   0,
359                   1,
360                   0);
361     check_usecase(PartialShape{2, 1, 1, 4}, {1, 2}, {1, 2, 3}, true, true, true, 0, 0, 1);
362     check_usecase(PartialShape{2, 1, 1, Dimension::dynamic(), Dimension::dynamic()},
363                   {1, 2},
364                   {1, 2, 3},
365                   true,
366                   true,
367                   true,
368                   0,
369                   1,
370                   0);
371     check_usecase(PartialShape{1, Dimension::dynamic(), 1, 1, Dimension::dynamic()},
372                   {2, 3},
373                   {2, 3, 5},
374                   true,
375                   true,
376                   true,
377                   0,
378                   1,
379                   0);
380
381     // unsqueeze axes overlap fully
382     check_usecase(PartialShape{1, Dimension::dynamic(), 1, 1, 1, Dimension::dynamic(), 3},
383                   {2, 3},
384                   {2},
385                   true,
386                   true,
387                   true,
388                   1,
389                   0,
390                   0);
391     check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1},
392                   {2, 3},
393                   {2},
394                   true,
395                   true,
396                   true,
397                   1,
398                   0,
399                   0);
400     check_usecase(
401         PartialShape{Dimension::dynamic(), 3, 1, 1}, {2, 3}, {2}, true, true, true, 0, 0, 1);
402     check_usecase(PartialShape{3, 1, 1}, {1, 2}, {1}, true, true, true, 0, 0, 1);
403
404     // squeeze->unsqueeze axes overlap
405     check_usecase(
406         PartialShape{Dimension::dynamic(), 1, 1, 4}, {1, 2}, {0}, true, true, true, 0, 0, 1);
407     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
408                   {1, 2},
409                   {0},
410                   true,
411                   true,
412                   true,
413                   1,
414                   1,
415                   0);
416     check_usecase(PartialShape{3, 1, 1, 4}, {1, 2}, {0}, true, true, true, 0, 0, 1);
417     check_usecase(PartialShape{2, 1, 1, Dimension::dynamic(), Dimension::dynamic()},
418                   {1, 2},
419                   {2},
420                   true,
421                   true,
422                   true,
423                   1,
424                   1,
425                   0);
426     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, 3, Dimension::dynamic(), 4},
427                   {1, 2},
428                   {2},
429                   true,
430                   true,
431                   true,
432                   1,
433                   1,
434                   0);
435     check_usecase(PartialShape{2, 1, Dimension::dynamic(), 1, Dimension::dynamic()},
436                   {1, 3},
437                   {3},
438                   true,
439                   true,
440                   true,
441                   1,
442                   1,
443                   0);
444     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1, 1, 4},
445                   {4, 5},
446                   {1, 5},
447                   true,
448                   true,
449                   true,
450                   1,
451                   1,
452                   0);
453
454     //
455     // Unsqueeze->Squeeze cases, testcase 23 - ..
456     //
457     // static shapes, all unsqueeze/squeeze replaced by reshape
458     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, false, false, 0, 0, 1);
459     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, true, false, 0, 0, 1);
460     // multioutout ops
461     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, false, true, 0, 0, 1);
462     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, true, true, 0, 0, 1);
463     check_usecase(PartialShape{1}, {0}, {0, 1, 2, 3}, false, true, true, 0, 0, 1);
464     check_usecase(PartialShape{3, 1, 1, 4}, {2, 3}, {0}, false, true, true, 0, 0, 1);
465     // dynamic shapes
466     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1},
467                   {0, 2, 4},
468                   {0, 2, 4},
469                   false,
470                   true,
471                   true,
472                   0,
473                   0,
474                   0);
475     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
476                   {2},
477                   {0},
478                   true,
479                   true,
480                   true,
481                   1,
482                   1,
483                   0);
484     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, 4}, {2}, {0}, true, true, true, 0, 0, 1);
485     check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1},
486                   {2, 3},
487                   {2},
488                   true,
489                   true,
490                   true,
491                   1,
492                   0,
493                   0);
494 }
495
496 TEST(nop_elimination, squeeze_squeeze_overlap_elimination) {
497     auto check_usecase = [](const PartialShape& shape,
498                             const std::vector<int64_t>& sq_axes_val_1,
499                             const std::vector<int64_t>& sq_axes_val_2,
500                             size_t sq) {
501         static size_t id = 0;
502         auto casename = string("usecase #") + to_string(++id);
503         auto sq_axes_1 =
504             op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val_1.size()}, sq_axes_val_1);
505         auto sq_axes_2 =
506             op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val_2.size()}, sq_axes_val_2);
507         auto A = make_shared<op::Parameter>(element::f32, shape);
508         auto A1 = make_shared<op::v0::Abs>(A);
509         auto B = make_shared<op::v0::Squeeze>(A1, sq_axes_1);
510         auto B1 = make_shared<op::v0::Squeeze>(B, sq_axes_2);
511         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
512         auto optimized_f = clone_function(*baseline_f);
513
514         pass::Manager pass_manager;
515         pass_manager.register_pass<pass::Validate>();
516         pass_manager.register_pass<pass::NopElimination>();
517         pass_manager.run_passes(optimized_f);
518         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
519         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
520         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
521         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
522         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 2) << casename;
523         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), sq) << casename;
524     };
525
526     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic()}, {0}, {1}, 1);
527     check_usecase(
528         PartialShape{1, 1, 1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {2, 1}, {2, 4}, 1);
529     check_usecase(
530         PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), 1, 1}, {-1, -5}, {2}, 1);
531     check_usecase(
532         PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {0}, {1, 3}, 1);
533 }
534
535 TEST(nop_elimination, unsqueeze_unsqueeze_overlap_elimination) {
536     auto check_usecase = [](const PartialShape& shape,
537                             const std::vector<int64_t>& unsq_axes_val_1,
538                             const std::vector<int64_t>& unsq_axes_val_2,
539                             size_t unsq) {
540         static size_t id = 0;
541         auto casename = string("usecase #") + to_string(++id);
542         auto unsq_axes_1 = op::Constant::create<int64_t>(
543             element::i64, Shape{unsq_axes_val_1.size()}, unsq_axes_val_1);
544         auto unsq_axes_2 = op::Constant::create<int64_t>(
545             element::i64, Shape{unsq_axes_val_2.size()}, unsq_axes_val_2);
546         auto A = make_shared<op::Parameter>(element::f32, shape);
547         auto A1 = make_shared<op::v0::Abs>(A);
548         auto B = make_shared<op::v0::Unsqueeze>(A1, unsq_axes_1);
549         auto B1 = make_shared<op::v0::Unsqueeze>(B, unsq_axes_2);
550         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
551         auto optimized_f = clone_function(*baseline_f);
552
553         pass::Manager pass_manager;
554         pass_manager.register_pass<pass::Validate>();
555         pass_manager.register_pass<pass::NopElimination>();
556         pass_manager.run_passes(optimized_f);
557         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
558         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
559         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
560         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
561         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 2) << casename;
562         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), unsq) << casename;
563     };
564
565     check_usecase(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()}, {0}, {2}, 1);
566     check_usecase(
567         PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {2, 1}, {2, 4}, 1);
568     check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1}, {-1, -3}, {2}, 1);
569     check_usecase(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {0}, {1, 3}, 1);
570 }
571
572 TEST(nop_elimination, unsqueeze_squeeze_elimination) {
573     auto generate_func = [](const Shape& shape, const std::vector<int64_t>& axes_val) {
574         auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
575         auto A = make_shared<op::Parameter>(element::f32, shape);
576         auto A1 = make_shared<op::v0::Abs>(A);
577         auto B = make_shared<op::v0::Unsqueeze>(A1, axes);
578         auto B1 = make_shared<op::v0::Squeeze>(B, axes);
579         return make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
580     };
581
582     auto check_usecase = [&](const Shape& shape, const std::vector<int64_t>& axes_val) {
583         auto baseline_f = generate_func(shape, axes_val);
584         auto optimized_f = generate_func(shape, axes_val);
585         pass::NopElimination().run_on_function(optimized_f);
586
587         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
588         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
589         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
590         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
591     };
592
593     check_usecase(Shape{6}, std::vector<int64_t>{0});
594     check_usecase(Shape{3, 2}, std::vector<int64_t>{0, 3});
595     check_usecase(Shape{3, 2}, std::vector<int64_t>{0, 2, 4});
596     check_usecase(Shape{3, 2}, std::vector<int64_t>{-1, -4});
597 }
598
599 TEST(nop_elimination, reshape_unsqueeze_elimination) {
600     auto check_usecase = [](const Shape& shape,
601                             const std::vector<int64_t>& pat_val,
602                             bool zero,
603                             const std::vector<int64_t>& axes_val) {
604         auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
605         auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
606         auto A = make_shared<op::Parameter>(element::f32, shape);
607         auto A1 = make_shared<op::v0::Abs>(A);
608
609         auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
610         auto pat2 =
611             op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
612         auto B1 = make_shared<op::v0::Unsqueeze>(B, axes);
613         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
614         auto optimized_f = clone_function(*baseline_f);
615         pass::NopElimination().run_on_function(optimized_f);
616
617         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
618         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
619         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
620     };
621
622     check_usecase(Shape{1, 2, 3, 2, 1}, {2, 3, 2}, false, {2, 4});
623     check_usecase(Shape{12}, {2, 3, 2}, false, {3});
624     check_usecase(Shape{3, 2, 1, 2}, {0, 2, 2}, true, {1, 4});
625     check_usecase(Shape{2, 3, 2}, {2, -1, 2}, false, {2});
626     check_usecase(Shape{2, 3, 2, 1}, {2, 3, 2}, false, {0});
627 }
628 TEST(nop_elimination, reshape_squeeze_elimination) {
629     auto check_usecase = [](const Shape& shape,
630                             const std::vector<int64_t>& pat_val,
631                             bool zero,
632                             const std::vector<int64_t>& axes_val) {
633         auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
634         auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
635         auto A = make_shared<op::Parameter>(element::f32, shape);
636         auto A1 = make_shared<op::v0::Abs>(A);
637
638         auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
639         auto pat2 =
640             op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
641         auto B1 = make_shared<op::v0::Squeeze>(B, axes);
642         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
643         auto optimized_f = clone_function(*baseline_f);
644         pass::NopElimination().run_on_function(optimized_f);
645
646         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
647         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
648         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
649     };
650
651     check_usecase(Shape{1, 2, 3, 2, 1}, {2, 3, 1, 2, 1}, false, {2, 4});
652     check_usecase(Shape{12}, {2, 3, 2, 1}, false, {3});
653     check_usecase(Shape{3, 2, 1, 2}, {0, 1, 2, 2, 1}, true, {1, 4});
654     check_usecase(Shape{2, 3, 2}, {2, -1, 1, 2}, false, {2});
655     check_usecase(Shape{2, 3, 2, 1}, {1, 2, 3, 2}, false, {0});
656 }
657
658 TEST(nop_elimination, reshape_reshape_elimination) {
659     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& pat_val, bool zero) {
660         auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
661         auto A = make_shared<op::Parameter>(element::f32, shape);
662         auto A1 = make_shared<op::v0::Abs>(A);
663
664         auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
665         auto pat2 =
666             op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
667         auto B1 = make_shared<op::v1::Reshape>(B, pat2, true);
668         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
669         auto optimized_f = clone_function(*baseline_f);
670         pass::NopElimination().run_on_function(optimized_f);
671
672         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 2);
673         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
674     };
675
676     check_usecase(Shape{1, 2, 3, 2, 1}, std::vector<int64_t>{2, 3, 2}, false);
677     check_usecase(Shape{12}, std::vector<int64_t>{2, 3, 2}, false);
678     check_usecase(Shape{3, 2, 1, 2}, std::vector<int64_t>{0, 2, 2}, true);
679     check_usecase(Shape{2, 3, 2}, ::vector<int64_t>{2, -1, 2}, false);
680     check_usecase(Shape{2, 3, 2, 1}, ::vector<int64_t>{2, 3, 2}, false);
681 }
682
683 TEST(nop_elimination, squeeze_reshape_elimination) {
684     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
685         auto indices =
686             op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
687         auto A = make_shared<op::Parameter>(element::f32, shape);
688         auto A1 = make_shared<op::v0::Abs>(A);
689
690         auto B = make_shared<op::v0::Squeeze>(A1, indices);
691         auto pat2 = op::Constant::create<int64_t>(element::i64, Shape{1}, std::vector<int64_t>{-1});
692         auto B1 = make_shared<op::v1::Reshape>(B, pat2, false);
693         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
694         auto optimized_f = clone_function(*baseline_f);
695         pass::NopElimination().run_on_function(optimized_f);
696
697         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
698         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
699         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
700         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
701     };
702
703     check_usecase(Shape{1, 2, 3, 2, 1}, std::vector<int64_t>{0, 4});
704     check_usecase(Shape{1, 1}, std::vector<int64_t>{0, 1});
705     check_usecase(Shape{2, 3, 1, 2}, std::vector<int64_t>{2});
706     check_usecase(Shape{1, 6, 2, 1}, std::vector<int64_t>{3});
707 }
708
709 TEST(nop_elimination, unsqueeze_reshape_elimination) {
710     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
711         auto indices =
712             op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
713         auto A = make_shared<op::Parameter>(element::f32, shape);
714         auto A1 = make_shared<op::v0::Abs>(A);
715
716         auto B = make_shared<op::v0::Unsqueeze>(A1, indices);
717         auto pat2 = op::Constant::create<int64_t>(element::i64, Shape{1}, std::vector<int64_t>{-1});
718         auto B1 = make_shared<op::v1::Reshape>(B, pat2, false);
719         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
720         auto optimized_f = clone_function(*baseline_f);
721         pass::NopElimination().run_on_function(optimized_f);
722
723         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
724         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
725         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
726         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
727     };
728
729     check_usecase(Shape{2, 3, 2}, std::vector<int64_t>{0, 4});
730     check_usecase(Shape{}, std::vector<int64_t>{0, 1});
731     check_usecase(Shape{2, 3, 2}, std::vector<int64_t>{2});
732     check_usecase(Shape{1, 6, 2}, std::vector<int64_t>{3});
733 }
734
735 TEST(nop_elimination, squeeze_unsqueeze_elimination_negative) {
736     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
737         auto indices = op::Constant::create(element::i64, Shape{indices_val.size()}, indices_val);
738         auto input = make_shared<op::Parameter>(element::f32, shape);
739         auto squeeze = make_shared<ngraph::opset1::Squeeze>(input, indices);
740         auto baseline_f = make_shared<Function>(squeeze, ParameterVector{input});
741         auto optimized_f = clone_function(*baseline_f);
742         pass::NopElimination().run_on_function(optimized_f);
743
744         ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(baseline_f), 1);
745         ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(optimized_f), 1);
746     };
747
748     check_usecase(Shape{1, 1, 1}, std::vector<int64_t>{0, 1, 2});
749 }
750
751 TEST(nop_elimination, topk_convert_elimination) {
752     auto check_usecase = []() {
753         auto A = make_shared<op::Parameter>(element::f32, Shape{20, 3, 4});
754         auto A1 = make_shared<op::v0::Abs>(A);
755         auto B = make_shared<op::TopK>(A1, 0, element::i64, 10);
756         auto C = make_shared<op::Convert>(B->output(0), B->output(0).get_element_type());
757         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(C), ParameterVector{A});
758         auto optimized_f = clone_function(*baseline_f);
759         pass::NopElimination().run_on_function(optimized_f);
760
761         ASSERT_EQ(count_ops_of_type<op::Convert>(baseline_f), 1);
762         ASSERT_EQ(count_ops_of_type<op::Convert>(optimized_f), 0);
763     };
764
765     check_usecase();
766 }