Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / transformations / nop_elimination.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6
7 #include "common_test_utils/test_common.hpp"
8 #include <string>
9 #include <sstream>
10 #include <memory>
11 #include <queue>
12
13 #include <ngraph/function.hpp>
14 #include <ngraph/opsets/opset1.hpp>
15 #include <ngraph/pass/manager.hpp>
16 #include <ngraph/pass/constant_folding.hpp>
17 #include <transformations/common_optimizations/nop_elimination.hpp>
18 #include <transformations/utils/utils.hpp>
19 #include <transformations/init_node_info.hpp>
20
21 #include "common_test_utils/ngraph_test_utils.hpp"
22
23 NGRAPH_SUPPRESS_DEPRECATED_START
24
25 using namespace ngraph;
26 using namespace std;
27
28 TEST(nop_elimination, eliminate_pad) {
29     Shape shape_a{2};
30     auto A = make_shared<op::Parameter>(element::f32, shape_a);
31     Shape shape_b{};
32     auto B = make_shared<op::Parameter>(element::f32, shape_b);
33     CoordinateDiff padding_below{0};
34     CoordinateDiff padding_above{0};
35     auto p = make_shared<op::v0::Pad>(A, B, padding_below, padding_above);
36     auto f = make_shared<Function>(make_shared<op::v0::Abs>(p), ParameterVector{A, B});
37
38     pass::Manager pass_manager;
39     pass_manager.register_pass<pass::NopElimination>();
40     pass_manager.run_passes(f);
41
42     ASSERT_EQ(count_ops_of_type<op::v0::Pad>(f), 0);
43 }
44
45 TEST(nop_elimination, eliminate_sum) {
46     Shape shape{2, 2};
47     auto A = make_shared<op::Parameter>(element::f32, shape);
48     auto s = make_shared<op::v0::Sum>(A, AxisSet{});
49     auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
50
51     pass::Manager pass_manager;
52     pass_manager.register_pass<pass::NopElimination>();
53     pass_manager.run_passes(f);
54
55     ASSERT_EQ(count_ops_of_type<op::v0::Sum>(f), 0);
56 }
57
58 TEST(nop_elimination, eliminate_convert) {
59     Shape shape{};
60     auto type = element::f32;
61     auto A = make_shared<op::Parameter>(type, shape);
62     auto c = make_shared<op::v0::Convert>(A, element::f32);
63     auto f = make_shared<Function>(make_shared<op::v0::Abs>(c), ParameterVector{A});
64
65     pass::Manager pass_manager;
66     pass_manager.register_pass<pass::NopElimination>();
67     pass_manager.run_passes(f);
68
69     ASSERT_EQ(count_ops_of_type<op::v0::Convert>(f), 0);
70 }
71
72 TEST(nop_elimination, convert_type_agnostic) {
73     Shape shape{};
74     auto type = element::from<char>();
75     auto A = make_shared<op::Parameter>(type, shape);
76     auto c1 = make_shared<op::v0::Convert>(A, element::from<uint8_t>());
77     auto c = make_shared<op::v0::Convert>(c1, element::f32);
78     auto z = make_shared<op::v3::NonZero>(c);
79     auto f = make_shared<Function>(make_shared<op::v0::Abs>(z), ParameterVector{A});
80
81     pass::Manager pass_manager;
82     pass_manager.register_pass<pass::Validate>();
83     pass_manager.register_pass<pass::NopElimination>();
84     pass_manager.run_passes(f);
85
86     ASSERT_EQ(count_ops_of_type<op::v0::Convert>(f), 0);
87 }
88
89 TEST(nop_elimination, eliminate_slice) {
90     Shape shape{2, 2};
91     auto A = make_shared<op::Parameter>(element::f32, shape);
92     auto s = make_shared<op::v0::Slice>(A, Coordinate{0, 0}, Coordinate{2, 2});
93     auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
94
95     pass::Manager pass_manager;
96     pass_manager.register_pass<pass::NopElimination>();
97     pass_manager.run_passes(f);
98
99     ASSERT_EQ(count_ops_of_type<op::v0::Slice>(f), 0);
100 }
101
102 TEST(nop_elimination, eliminate_broadcast) {
103     Shape shape{};
104     auto A = make_shared<op::Parameter>(element::f32, shape);
105     auto b = make_shared<op::v0::Broadcast>(A, shape, AxisSet{});
106     auto f = make_shared<Function>(make_shared<op::v0::Abs>(b), ParameterVector{A});
107
108     pass::Manager pass_manager;
109     pass_manager.register_pass<pass::NopElimination>();
110     pass_manager.run_passes(f);
111
112     ASSERT_EQ(count_ops_of_type<op::v0::Broadcast>(f), 0);
113 }
114
115 TEST(nop_elimination, eliminate_stop_gradient) {
116     Shape shape{};
117     auto A = make_shared<op::Parameter>(element::f32, shape);
118     auto s = make_shared<op::v0::StopGradient>(A);
119     auto f = make_shared<Function>(make_shared<op::v0::Abs>(s), ParameterVector{A});
120
121     pass::Manager pass_manager;
122     pass_manager.register_pass<pass::NopElimination>();
123     pass_manager.run_passes(f);
124
125     ASSERT_EQ(count_ops_of_type<op::v0::StopGradient>(f), 0);
126 }
127
128 TEST(nop_elimination, pass_property) {
129     auto pass = std::make_shared<ngraph::pass::NopElimination>();
130     ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
131 }
132
133 TEST(nop_elimination, reshape_elimination_v1) {
134     auto generate_func = [](bool zero) {
135         auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape{8, 16, 2, 3});
136         auto pattern_org = op::Constant::create(element::i64, Shape{3}, vector<int64_t>{8, 16, 6});
137         auto pattern = op::Constant::create(element::i64, Shape{3}, vector<int64_t>{8, 16, 6});
138         auto reshape_v1_org = std::make_shared<op::v1::Reshape>(arg, pattern_org, zero);
139         auto reshape_v1 = std::make_shared<op::v1::Reshape>(reshape_v1_org, pattern, zero);
140         auto abs = std::make_shared<op::v0::Abs>(reshape_v1);
141         return std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
142     };
143
144     auto func = generate_func(false);
145     auto nopass_func = generate_func(false);
146     auto func_zero = generate_func(true);
147     auto nopass_func_zero = generate_func(true);
148
149     pass::Manager pass_manager;
150     pass_manager.register_pass<pass::NopElimination>();
151     pass_manager.run_passes(func);
152     pass_manager.run_passes(func_zero);
153     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(nopass_func) == 2);
154     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func) == 1);
155     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(nopass_func_zero) == 2);
156     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(func_zero) == 1);
157 }
158
159 TEST(nop_elimination, reshape_elimination_v1_dynamic) {
160     auto arg = std::make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
161     auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
162     auto reshape_v1 = std::make_shared<op::v1::Reshape>(arg, pattern, false);
163     auto abs = std::make_shared<op::v0::Abs>(reshape_v1);
164     auto f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg, pattern});
165     pass::Manager pass_manager;
166     pass_manager.register_pass<pass::NopElimination>();
167     pass_manager.run_passes(f);
168     ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
169 }
170
171 TEST(nop_elimination, concat_elimination_single_node) {
172     int64_t a = 0;
173     auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
174     auto f =
175         make_shared<Function>(make_shared<op::v0::Concat>(NodeVector{A}, a), ParameterVector{A});
176
177     pass::Manager pass_manager;
178     pass_manager.register_pass<pass::Validate>();
179     pass_manager.register_pass<pass::NopElimination>();
180     pass_manager.run_passes(f);
181
182     ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 1);
183 }
184
185 TEST(nop_elimination, concat_elimination_single_input) {
186     int64_t a = 0;
187     auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
188     auto B = make_shared<op::v0::Concat>(NodeVector{A}, a);
189     auto f = make_shared<Function>(make_shared<op::v0::Abs>(B), ParameterVector{A});
190
191     pass::Manager pass_manager;
192     pass_manager.register_pass<pass::Validate>();
193     pass_manager.register_pass<pass::NopElimination>();
194     pass_manager.run_passes(f);
195
196     ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
197 }
198
199 TEST(nop_elimination, concat_elimination_single_input_dynamic) {
200     int64_t a = 0;
201     auto A = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3});
202     auto B = make_shared<op::v0::Concat>(NodeVector{A}, a);
203     auto f = make_shared<Function>(make_shared<op::v0::Abs>(B), ParameterVector{A});
204
205     pass::Manager pass_manager;
206     pass_manager.register_pass<pass::Validate>();
207     pass_manager.register_pass<pass::NopElimination>();
208     pass_manager.run_passes(f);
209
210     ASSERT_EQ(count_ops_of_type<op::v0::Concat>(f), 0);
211 }
212
213 TEST(nop_elimination, unsqueeze_elimination) {
214     const auto axis = op::Constant::create<int64_t>(element::i64, {}, {0});
215     const auto A = make_shared<op::Parameter>(
216         element::f32, PartialShape{3, Dimension::dynamic(), Dimension::dynamic()});
217     const auto unsqueeze = make_shared<op::v0::Unsqueeze>(A, axis);
218     auto f = make_shared<Function>(unsqueeze, ParameterVector{A});
219
220     pass::Manager pass_manager;
221     pass_manager.register_pass<pass::Validate>();
222     pass_manager.register_pass<pass::NopElimination>();
223     pass_manager.run_passes(f);
224
225     ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 1);
226 }
227
228 TEST(nop_elimination, squeeze_unsqueeze_overlap_elimination) {
229     auto check_usecase = [](const PartialShape& shape,
230                             const std::vector<int64_t>& sq_axes_val,
231                             const std::vector<int64_t>& unsq_axes_val,
232                             bool sq_to_unsq,
233                             bool i32,
234                             bool multiout,
235                             size_t sc,
236                             size_t usc,
237                             size_t rc) {
238         static size_t id = 0;
239         auto casename = string("usecase #") + to_string(++id);
240
241         shared_ptr<Node> sq_axes;
242         shared_ptr<Node> unsq_axes;
243         if (i32) {
244             std::vector<int32_t> sq_axes_val_i32(sq_axes_val.begin(), sq_axes_val.end());
245             std::vector<int32_t> unsq_axes_val_i32(unsq_axes_val.begin(), unsq_axes_val.end());
246             sq_axes = op::Constant::create<int32_t>(
247                 element::i32, Shape{sq_axes_val.size()}, sq_axes_val_i32);
248             unsq_axes = op::Constant::create<int32_t>(
249                 element::i32, Shape{unsq_axes_val.size()}, unsq_axes_val_i32);
250         } else {
251             sq_axes =
252                 op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val.size()}, sq_axes_val);
253             unsq_axes = op::Constant::create<int64_t>(
254                 element::i64, Shape{unsq_axes_val.size()}, unsq_axes_val);
255         }
256
257         auto A = make_shared<op::Parameter>(element::f32, shape);
258         shared_ptr<Node> A1;
259         if (multiout) {
260             auto last_dim = shape.rank().get_length() - 1;
261             A1 = make_shared<op::v0::TopK>(A, last_dim, element::i32);
262         } else {
263             A1 = make_shared<op::v0::Abs>(A);
264         }
265
266         shared_ptr<Node> B1;
267         if (sq_to_unsq) {
268             auto B = make_shared<op::v0::Squeeze>((multiout ? A1->output(0) : A1), sq_axes);
269             B1 = make_shared<op::v0::Unsqueeze>(B, unsq_axes);
270         } else {
271             auto B = make_shared<op::v0::Unsqueeze>((multiout ? A1->output(0) : A1), unsq_axes);
272             B1 = make_shared<op::v0::Squeeze>(B, sq_axes);
273         }
274
275         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
276         auto optimized_f = clone_function(*baseline_f);
277
278         pass::Manager pass_manager;
279         pass_manager.register_pass<pass::Validate>();
280         pass_manager.register_pass<pass::NopElimination>();
281         pass_manager.run_passes(optimized_f);
282
283         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
284         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
285         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
286         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
287
288         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1) << casename;
289         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1) << casename;
290         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), sc) << casename;
291         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), usc) << casename;
292         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), rc) << casename;
293     };
294
295     // static shapes, all squeeze/unsqueeze replaced by reshape
296     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, false, false, 0, 0, 1);
297     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, true, false, 0, 0, 1);
298     // multioutout ops
299     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, false, true, 0, 0, 1);
300     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {0}, true, true, true, 0, 0, 1);
301     check_usecase(PartialShape{1}, {0}, {0, 1, 2, 3}, true, true, true, 0, 0, 1);
302
303     // axes match - expect all squeeze/unsqueeze/reshape cancel out
304     check_usecase(PartialShape{2, 1, 1, 6}, {1, 2}, {1, 2}, true, true, true, 0, 0, 0);
305
306     // dynamic shapes - axes match, expect all cancel
307     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1},
308                   {0, 2, 4},
309                   {0, 2, 4},
310                   true,
311                   true,
312                   true,
313                   0,
314                   0,
315                   0);
316     check_usecase(PartialShape{1, Dimension::dynamic(), 1, 2, 1},
317                   {0, 2, 4},
318                   {0, 2, 4},
319                   true,
320                   false,
321                   true,
322                   0,
323                   0,
324                   0);
325
326     // squeeze axes overlap fully
327     check_usecase(
328         PartialShape{Dimension::dynamic(), 1, 1, 3}, {1, 2}, {1, 2, 3}, true, true, true, 0, 0, 1);
329     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
330                   {1, 2},
331                   {1, 2, 3},
332                   true,
333                   true,
334                   true,
335                   0,
336                   1,
337                   0);
338     check_usecase(PartialShape{2, 1, 1, 4}, {1, 2}, {1, 2, 3}, true, true, true, 0, 0, 1);
339     check_usecase(PartialShape{2, 1, 1, Dimension::dynamic(), Dimension::dynamic()},
340                   {1, 2},
341                   {1, 2, 3},
342                   true,
343                   true,
344                   true,
345                   0,
346                   1,
347                   0);
348     check_usecase(PartialShape{1, Dimension::dynamic(), 1, 1, Dimension::dynamic()},
349                   {2, 3},
350                   {2, 3, 5},
351                   true,
352                   true,
353                   true,
354                   0,
355                   1,
356                   0);
357
358     // unsqueeze axes overlap fully
359     check_usecase(PartialShape{1, Dimension::dynamic(), 1, 1, 1, Dimension::dynamic(), 3},
360                   {2, 3},
361                   {2},
362                   true,
363                   true,
364                   true,
365                   1,
366                   0,
367                   0);
368     check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1},
369                   {2, 3},
370                   {2},
371                   true,
372                   true,
373                   true,
374                   1,
375                   0,
376                   0);
377     check_usecase(
378         PartialShape{Dimension::dynamic(), 3, 1, 1}, {2, 3}, {2}, true, true, true, 0, 0, 1);
379     check_usecase(PartialShape{3, 1, 1}, {1, 2}, {1}, true, true, true, 0, 0, 1);
380
381     // squeeze->unsqueeze axes overlap
382     check_usecase(
383         PartialShape{Dimension::dynamic(), 1, 1, 4}, {1, 2}, {0}, true, true, true, 0, 0, 1);
384     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
385                   {1, 2},
386                   {0},
387                   true,
388                   true,
389                   true,
390                   1,
391                   1,
392                   0);
393     check_usecase(PartialShape{3, 1, 1, 4}, {1, 2}, {0}, true, true, true, 0, 0, 1);
394     check_usecase(PartialShape{2, 1, 1, Dimension::dynamic(), Dimension::dynamic()},
395                   {1, 2},
396                   {2},
397                   true,
398                   true,
399                   true,
400                   1,
401                   1,
402                   0);
403     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, 3, Dimension::dynamic(), 4},
404                   {1, 2},
405                   {2},
406                   true,
407                   true,
408                   true,
409                   1,
410                   1,
411                   0);
412     check_usecase(PartialShape{2, 1, Dimension::dynamic(), 1, Dimension::dynamic()},
413                   {1, 3},
414                   {3},
415                   true,
416                   true,
417                   true,
418                   1,
419                   1,
420                   0);
421     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1, 1, 4},
422                   {4, 5},
423                   {1, 5},
424                   true,
425                   true,
426                   true,
427                   1,
428                   1,
429                   0);
430
431     //
432     // Unsqueeze->Squeeze cases, testcase 23 - ..
433     //
434     // static shapes, all unsqueeze/squeeze replaced by reshape
435     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, false, false, 0, 0, 1);
436     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, true, false, 0, 0, 1);
437     // multioutout ops
438     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, false, true, 0, 0, 1);
439     check_usecase(PartialShape{2, 6, 1}, {4}, {1, 2}, false, true, true, 0, 0, 1);
440     check_usecase(PartialShape{1}, {0}, {0, 1, 2, 3}, false, true, true, 0, 0, 1);
441     check_usecase(PartialShape{3, 1, 1, 4}, {2, 3}, {0}, false, true, true, 0, 0, 1);
442     // dynamic shapes
443     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1},
444                   {0, 2, 4},
445                   {0, 2, 4},
446                   false,
447                   true,
448                   true,
449                   0,
450                   0,
451                   0);
452     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, Dimension::dynamic()},
453                   {2},
454                   {0},
455                   true,
456                   true,
457                   true,
458                   1,
459                   1,
460                   0);
461     check_usecase(PartialShape{Dimension::dynamic(), 1, 1, 4}, {2}, {0}, true, true, true, 0, 0, 1);
462     check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1},
463                   {2, 3},
464                   {2},
465                   true,
466                   true,
467                   true,
468                   1,
469                   0,
470                   0);
471 }
472
473 TEST(nop_elimination, squeeze_squeeze_overlap_elimination) {
474     auto check_usecase = [](const PartialShape& shape,
475                             const std::vector<int64_t>& sq_axes_val_1,
476                             const std::vector<int64_t>& sq_axes_val_2,
477                             size_t sq) {
478         static size_t id = 0;
479         auto casename = string("usecase #") + to_string(++id);
480         auto sq_axes_1 =
481             op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val_1.size()}, sq_axes_val_1);
482         auto sq_axes_2 =
483             op::Constant::create<int64_t>(element::i64, Shape{sq_axes_val_2.size()}, sq_axes_val_2);
484         auto A = make_shared<op::Parameter>(element::f32, shape);
485         auto A1 = make_shared<op::v0::Abs>(A);
486         auto B = make_shared<op::v0::Squeeze>(A1, sq_axes_1);
487         auto B1 = make_shared<op::v0::Squeeze>(B, sq_axes_2);
488         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
489         auto optimized_f = clone_function(*baseline_f);
490
491         pass::Manager pass_manager;
492         pass_manager.register_pass<pass::Validate>();
493         pass_manager.register_pass<pass::NopElimination>();
494         pass_manager.run_passes(optimized_f);
495         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
496         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
497         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
498         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
499         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 2) << casename;
500         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), sq) << casename;
501     };
502
503     check_usecase(PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic()}, {0}, {1}, 1);
504     check_usecase(
505         PartialShape{1, 1, 1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {2, 1}, {2, 4}, 1);
506     check_usecase(
507         PartialShape{1, Dimension::dynamic(), Dimension::dynamic(), 1, 1}, {-1, -5}, {2}, 1);
508     check_usecase(
509         PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {0}, {1, 3}, 1);
510 }
511
512 TEST(nop_elimination, unsqueeze_unsqueeze_overlap_elimination) {
513     auto check_usecase = [](const PartialShape& shape,
514                             const std::vector<int64_t>& unsq_axes_val_1,
515                             const std::vector<int64_t>& unsq_axes_val_2,
516                             size_t unsq) {
517         static size_t id = 0;
518         auto casename = string("usecase #") + to_string(++id);
519         auto unsq_axes_1 = op::Constant::create<int64_t>(
520             element::i64, Shape{unsq_axes_val_1.size()}, unsq_axes_val_1);
521         auto unsq_axes_2 = op::Constant::create<int64_t>(
522             element::i64, Shape{unsq_axes_val_2.size()}, unsq_axes_val_2);
523         auto A = make_shared<op::Parameter>(element::f32, shape);
524         auto A1 = make_shared<op::v0::Abs>(A);
525         auto B = make_shared<op::v0::Unsqueeze>(A1, unsq_axes_1);
526         auto B1 = make_shared<op::v0::Unsqueeze>(B, unsq_axes_2);
527         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
528         auto optimized_f = clone_function(*baseline_f);
529
530         pass::Manager pass_manager;
531         pass_manager.register_pass<pass::Validate>();
532         pass_manager.register_pass<pass::NopElimination>();
533         pass_manager.run_passes(optimized_f);
534         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
535         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
536         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
537         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
538         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 2) << casename;
539         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), unsq) << casename;
540     };
541
542     check_usecase(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()}, {0}, {2}, 1);
543     check_usecase(
544         PartialShape{1, Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {2, 1}, {2, 4}, 1);
545     check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1}, {-1, -3}, {2}, 1);
546     check_usecase(PartialShape{Dimension::dynamic(), 1, Dimension::dynamic(), 1}, {0}, {1, 3}, 1);
547 }
548
549 TEST(nop_elimination, unsqueeze_squeeze_elimination) {
550     auto generate_func = [](const Shape& shape, const std::vector<int64_t>& axes_val) {
551         auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
552         auto A = make_shared<op::Parameter>(element::f32, shape);
553         auto A1 = make_shared<op::v0::Abs>(A);
554         auto B = make_shared<op::v0::Unsqueeze>(A1, axes);
555         auto B1 = make_shared<op::v0::Squeeze>(B, axes);
556         return make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
557     };
558
559     auto check_usecase = [&](const Shape& shape, const std::vector<int64_t>& axes_val) {
560         auto baseline_f = generate_func(shape, axes_val);
561         auto optimized_f = generate_func(shape, axes_val);
562         pass::NopElimination().run_on_function(optimized_f);
563
564         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
565         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
566         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
567         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
568     };
569
570     check_usecase(Shape{6}, std::vector<int64_t>{0});
571     check_usecase(Shape{3, 2}, std::vector<int64_t>{0, 3});
572     check_usecase(Shape{3, 2}, std::vector<int64_t>{0, 2, 4});
573     check_usecase(Shape{3, 2}, std::vector<int64_t>{-1, -4});
574 }
575
576 TEST(nop_elimination, reshape_unsqueeze_elimination) {
577     auto check_usecase = [](const Shape& shape,
578                             const std::vector<int64_t>& pat_val,
579                             bool zero,
580                             const std::vector<int64_t>& axes_val) {
581         auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
582         auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
583         auto A = make_shared<op::Parameter>(element::f32, shape);
584         auto A1 = make_shared<op::v0::Abs>(A);
585
586         auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
587         auto pat2 =
588             op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
589         auto B1 = make_shared<op::v0::Unsqueeze>(B, axes);
590         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
591         auto optimized_f = clone_function(*baseline_f);
592         pass::NopElimination().run_on_function(optimized_f);
593
594         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
595         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
596         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
597     };
598
599     check_usecase(Shape{1, 2, 3, 2, 1}, {2, 3, 2}, false, {2, 4});
600     check_usecase(Shape{12}, {2, 3, 2}, false, {3});
601     check_usecase(Shape{3, 2, 1, 2}, {0, 2, 2}, true, {1, 4});
602     check_usecase(Shape{2, 3, 2}, {2, -1, 2}, false, {2});
603     check_usecase(Shape{2, 3, 2, 1}, {2, 3, 2}, false, {0});
604 }
605 TEST(nop_elimination, reshape_squeeze_elimination) {
606     auto check_usecase = [](const Shape& shape,
607                             const std::vector<int64_t>& pat_val,
608                             bool zero,
609                             const std::vector<int64_t>& axes_val) {
610         auto axes = op::Constant::create<int64_t>(element::i64, Shape{axes_val.size()}, axes_val);
611         auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
612         auto A = make_shared<op::Parameter>(element::f32, shape);
613         auto A1 = make_shared<op::v0::Abs>(A);
614
615         auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
616         auto pat2 =
617             op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
618         auto B1 = make_shared<op::v0::Squeeze>(B, axes);
619         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
620         auto optimized_f = clone_function(*baseline_f);
621         pass::NopElimination().run_on_function(optimized_f);
622
623         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
624         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
625         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
626     };
627
628     check_usecase(Shape{1, 2, 3, 2, 1}, {2, 3, 1, 2, 1}, false, {2, 4});
629     check_usecase(Shape{12}, {2, 3, 2, 1}, false, {3});
630     check_usecase(Shape{3, 2, 1, 2}, {0, 1, 2, 2, 1}, true, {1, 4});
631     check_usecase(Shape{2, 3, 2}, {2, -1, 1, 2}, false, {2});
632     check_usecase(Shape{2, 3, 2, 1}, {1, 2, 3, 2}, false, {0});
633 }
634
635 TEST(nop_elimination, reshape_reshape_elimination) {
636     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& pat_val, bool zero) {
637         auto pat = op::Constant::create<int64_t>(element::i64, Shape{pat_val.size()}, pat_val);
638         auto A = make_shared<op::Parameter>(element::f32, shape);
639         auto A1 = make_shared<op::v0::Abs>(A);
640
641         auto B = make_shared<op::v1::Reshape>(A1, pat, zero);
642         auto pat2 =
643             op::Constant::create<int64_t>(element::i64, Shape{2}, std::vector<int64_t>{0, -1});
644         auto B1 = make_shared<op::v1::Reshape>(B, pat2, true);
645         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
646         auto optimized_f = clone_function(*baseline_f);
647         pass::NopElimination().run_on_function(optimized_f);
648
649         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 2);
650         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
651     };
652
653     check_usecase(Shape{1, 2, 3, 2, 1}, std::vector<int64_t>{2, 3, 2}, false);
654     check_usecase(Shape{12}, std::vector<int64_t>{2, 3, 2}, false);
655     check_usecase(Shape{3, 2, 1, 2}, std::vector<int64_t>{0, 2, 2}, true);
656     check_usecase(Shape{2, 3, 2}, ::vector<int64_t>{2, -1, 2}, false);
657     check_usecase(Shape{2, 3, 2, 1}, ::vector<int64_t>{2, 3, 2}, false);
658 }
659
660 TEST(nop_elimination, squeeze_reshape_elimination) {
661     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
662         auto indices =
663             op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
664         auto A = make_shared<op::Parameter>(element::f32, shape);
665         auto A1 = make_shared<op::v0::Abs>(A);
666
667         auto B = make_shared<op::v0::Squeeze>(A1, indices);
668         auto pat2 = op::Constant::create<int64_t>(element::i64, Shape{1}, std::vector<int64_t>{-1});
669         auto B1 = make_shared<op::v1::Reshape>(B, pat2, false);
670         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
671         auto optimized_f = clone_function(*baseline_f);
672         pass::NopElimination().run_on_function(optimized_f);
673
674         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
675         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(baseline_f), 1);
676         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
677         ASSERT_EQ(count_ops_of_type<op::v0::Squeeze>(optimized_f), 0);
678     };
679
680     check_usecase(Shape{1, 2, 3, 2, 1}, std::vector<int64_t>{0, 4});
681     check_usecase(Shape{1, 1}, std::vector<int64_t>{0, 1});
682     check_usecase(Shape{2, 3, 1, 2}, std::vector<int64_t>{2});
683     check_usecase(Shape{1, 6, 2, 1}, std::vector<int64_t>{3});
684 }
685
686 TEST(nop_elimination, unsqueeze_reshape_elimination) {
687     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
688         auto indices =
689             op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
690         auto A = make_shared<op::Parameter>(element::f32, shape);
691         auto A1 = make_shared<op::v0::Abs>(A);
692
693         auto B = make_shared<op::v0::Unsqueeze>(A1, indices);
694         auto pat2 = op::Constant::create<int64_t>(element::i64, Shape{1}, std::vector<int64_t>{-1});
695         auto B1 = make_shared<op::v1::Reshape>(B, pat2, false);
696         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(B1), ParameterVector{A});
697         auto optimized_f = clone_function(*baseline_f);
698         pass::NopElimination().run_on_function(optimized_f);
699
700         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 1);
701         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(baseline_f), 1);
702         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), 1);
703         ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(optimized_f), 0);
704     };
705
706     check_usecase(Shape{2, 3, 2}, std::vector<int64_t>{0, 4});
707     check_usecase(Shape{}, std::vector<int64_t>{0, 1});
708     check_usecase(Shape{2, 3, 2}, std::vector<int64_t>{2});
709     check_usecase(Shape{1, 6, 2}, std::vector<int64_t>{3});
710 }
711
712 TEST(nop_elimination, squeeze_unsqueeze_elimination_negative) {
713     auto check_usecase = [](const Shape& shape, const std::vector<int64_t>& indices_val) {
714         auto indices = op::Constant::create(element::i64, Shape{indices_val.size()}, indices_val);
715         auto input = make_shared<op::Parameter>(element::f32, shape);
716         auto squeeze = make_shared<ngraph::opset1::Squeeze>(input, indices);
717         auto baseline_f = make_shared<Function>(squeeze, ParameterVector{input});
718         auto optimized_f = clone_function(*baseline_f);
719         pass::NopElimination().run_on_function(optimized_f);
720
721         ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(baseline_f), 1);
722         ASSERT_EQ(count_ops_of_type<ngraph::opset1::Squeeze>(optimized_f), 1);
723     };
724
725     check_usecase(Shape{1, 1, 1}, std::vector<int64_t>{0, 1, 2});
726 }
727
728 TEST(nop_elimination, topk_convert_elimination) {
729     auto check_usecase = []() {
730         auto A = make_shared<op::Parameter>(element::f32, Shape{20, 3, 4});
731         auto A1 = make_shared<op::v0::Abs>(A);
732         auto B = make_shared<op::TopK>(A1, 0, element::i64, 10);
733         auto C = make_shared<op::Convert>(B->output(0), B->output(0).get_element_type());
734         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(C), ParameterVector{A});
735         auto optimized_f = clone_function(*baseline_f);
736         pass::NopElimination().run_on_function(optimized_f);
737
738         ASSERT_EQ(count_ops_of_type<op::Convert>(baseline_f), 1);
739         ASSERT_EQ(count_ops_of_type<op::Convert>(optimized_f), 0);
740     };
741
742     check_usecase();
743 }