e7cad6c06a0cc2ae48244897c7703b9dfcba4eee
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / transformations / algebraic_simplification.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6
7 #include "common_test_utils/test_common.hpp"
8 #include <string>
9 #include <sstream>
10 #include <memory>
11 #include <queue>
12
13 #include <ngraph/function.hpp>
14 #include <ngraph/opsets/opset2.hpp>
15 #include <ngraph/opsets/opset3.hpp>
16 #include <ngraph/pass/manager.hpp>
17 #include <ngraph/pass/constant_folding.hpp>
18 #include <transformations/common_optimizations/algebraic_simplification.hpp>
19 #include <transformations/utils/utils.hpp>
20 #include <transformations/init_node_info.hpp>
21
22 #include "common_test_utils/ngraph_test_utils.hpp"
23
24 using namespace ngraph;
25 using namespace std;
26
27 TEST(algebraic_simplification, add_negative_tests) {
28     Shape shape{};
29     auto type = element::f32;
30     pass::Manager pass_manager;
31     pass_manager.register_pass<pass::AlgebraicSimplification>();
32
33     auto a = make_shared<op::Parameter>(type, shape);
34     auto b = make_shared<op::Parameter>(type, shape);
35     auto c = make_shared<op::Parameter>(type, shape);
36     auto abs_a = make_shared<op::Abs>(a);
37     auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
38     auto add_a_0 = a + iconst2;
39     auto add_a_0_0 = add_a_0 + iconst2;
40     auto add_b_0 = b + abs_a;
41     auto add_b_0_0 = add_b_0 + abs_a;
42
43     auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
44                                         ParameterVector{a, b, c});
45     pass_manager.run_passes(f);
46
47     auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
48     auto results = f->get_results();
49     for (size_t i = 0; i < results.size(); i++) {
50         ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
51     }
52 }
53
54 TEST(algebraic_simplification, multiply_negative_tests) {
55     Shape shape{};
56     auto type = element::f32;
57     pass::Manager pass_manager;
58     pass_manager.register_pass<pass::AlgebraicSimplification>();
59
60     auto a = make_shared<op::Parameter>(type, shape);
61     auto b = make_shared<op::Parameter>(type, shape);
62     auto c = make_shared<op::Parameter>(type, shape);
63     auto abs_a = make_shared<op::Abs>(a);
64     auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
65     auto add_a_0 = a * iconst2;
66     auto add_a_0_0 = add_a_0 * iconst2;
67     auto add_b_0 = b * abs_a;
68     auto add_b_0_0 = add_b_0 * abs_a;
69
70     auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
71                                         ParameterVector{a, b, c});
72     pass_manager.run_passes(f);
73
74     auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
75     auto results = f->get_results();
76     for (size_t i = 0; i < results.size(); i++) {
77         ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
78     }
79 }
80
81 TEST(algebraic_simplification, multiply_prod_negative) {
82     auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
83     auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
84     auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
85
86     pass::Manager pass_manager;
87     pass_manager.register_pass<pass::AlgebraicSimplification>();
88
89     auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
90     pass_manager.run_passes(f);
91     auto f_prod = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
92     ASSERT_EQ(f_prod, prod_fconst1);
93 }
94
95 TEST(algebraic_simplification, multiply_sum_negative) {
96     auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
97     auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{2, 5}, AxisSet{1});
98     auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
99
100     pass::Manager pass_manager;
101     pass_manager.register_pass<pass::AlgebraicSimplification>();
102
103     auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
104     pass_manager.run_passes(f);
105     auto f_sum = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
106     ASSERT_EQ(f_sum, sum_fconst1);
107 }
108
109 TEST(algebraic_simplification, concat_parameter_slices_reversed) {
110     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
111     auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
112     auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
113     auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
114
115     size_t concat_axis = 0;
116     auto concat = make_shared<op::Concat>(NodeVector{slice3, slice2, slice1}, concat_axis);
117
118     pass::Manager pass_manager;
119     pass_manager.register_pass<pass::AlgebraicSimplification>();
120
121     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
122     pass_manager.run_passes(f);
123     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
124 }
125
126 TEST(algebraic_simplification, concat_parameter_slices_element_count) {
127     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
128     // slicing 30 elements out of 96; should trigger a check that some elements are missing
129     auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{10, 100}, Strides{1, 1});
130     auto slice2 = make_shared<op::Slice>(a, Coordinate{10, 0}, Coordinate{20, 100}, Strides{1, 1});
131     auto slice3 = make_shared<op::Slice>(a, Coordinate{20, 0}, Coordinate{30, 100}, Strides{1, 1});
132
133     size_t concat_axis = 0;
134     auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
135
136     pass::Manager pass_manager;
137     pass_manager.register_pass<pass::AlgebraicSimplification>();
138
139     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
140     pass_manager.run_passes(f);
141     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
142 }
143
144 TEST(algebraic_simplification, concat_parameter_non_uniform_slices) {
145     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
146     auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{38, 100}, Strides{1, 1});
147     auto slice2 = make_shared<op::Slice>(a, Coordinate{38, 0}, Coordinate{64, 100}, Strides{1, 1});
148     auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
149
150     size_t concat_axis = 0;
151     auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
152
153     pass::Manager pass_manager;
154     pass_manager.register_pass<pass::AlgebraicSimplification>();
155
156     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
157     pass_manager.run_passes(f);
158     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
159 }
160
161 TEST(algebraic_simplification, concat_different_inputs) {
162     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
163     auto goe1 = -a;
164     auto goe2 = -a;
165     auto slice1 =
166         make_shared<op::Slice>(goe1, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
167     auto slice2 =
168         make_shared<op::Slice>(goe2, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
169     auto slice3 =
170         make_shared<op::Slice>(goe1, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
171
172     size_t concat_axis = 0;
173     auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
174
175     pass::Manager pass_manager;
176     pass_manager.register_pass<pass::AlgebraicSimplification>();
177
178     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
179     pass_manager.run_passes(f);
180     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
181 }
182
183 TEST(algebraic_simplification, log_no_exp) {
184     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
185     auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
186     auto abs_a = make_shared<op::Abs>(a);
187     auto div = abs_a / b;
188     auto log_div = make_shared<op::Log>(div);
189
190     auto neg_inner = make_shared<op::Negative>(log_div);
191     auto neg2 = make_shared<op::Negative>(neg_inner);
192     auto neg3 = make_shared<op::Negative>(neg2);
193     auto neg4 = make_shared<op::Negative>(neg3);
194
195     pass::Manager pass_manager;
196     pass_manager.register_pass<pass::AlgebraicSimplification>();
197
198     auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
199     pass_manager.run_passes(f);
200     ASSERT_EQ(neg_inner->input_value(0).get_node_shared_ptr(), log_div);
201 }
202
203 TEST(algebraic_simplification, log_no_divide) {
204     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
205     auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
206     auto exp_a = make_shared<op::Exp>(a);
207     auto mul = exp_a * b;
208     auto log_mul = make_shared<op::Log>(mul);
209
210     auto neg_inner = make_shared<op::Negative>(log_mul);
211     auto neg2 = make_shared<op::Negative>(neg_inner);
212     auto neg3 = make_shared<op::Negative>(neg2);
213     auto neg4 = make_shared<op::Negative>(neg3);
214
215     pass::Manager pass_manager;
216     pass_manager.register_pass<pass::AlgebraicSimplification>();
217
218     auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
219     pass_manager.run_passes(f);
220     ASSERT_EQ(neg_inner->input_value(0).get_node_shared_ptr(), log_mul);
221 }
222
223 TEST(algebraic_simplification, pass_property) {
224     auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
225
226     ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
227 }
228
229 TEST(algebraic_simplification, replace_transpose_with_reshape) {
230     auto check_usecase = [](const PartialShape& shape,
231                             const std::vector<int64_t>& perm_val,
232                             bool i32,
233                             bool multiout,
234                             size_t num) {
235         static size_t id = 0;
236         auto casename = string("usecase #") + to_string(++id);
237
238         shared_ptr<Node> perm;
239         if (i32) {
240             std::vector<int32_t> perm_val_i32(perm_val.begin(), perm_val.end());
241             perm =
242                 op::Constant::create<int32_t>(element::i32, Shape{perm_val.size()}, perm_val_i32);
243         } else {
244             perm = op::Constant::create<int64_t>(element::i64, Shape{perm_val.size()}, perm_val);
245         }
246         auto param = make_shared<op::Parameter>(element::f32, shape);
247         shared_ptr<Node> A1;
248         if (multiout) {
249             auto last_dim = shape.rank().get_length() - 1;
250             A1 = make_shared<op::v0::TopK>(param, last_dim, element::i32);
251         } else {
252             A1 = make_shared<op::v0::Abs>(param);
253         }
254         auto transpose = make_shared<op::v1::Transpose>((multiout ? A1->output(0) : A1), perm);
255         auto transpose1 = make_shared<op::v0::Abs>(transpose);
256         auto baseline_f = make_shared<Function>(transpose1, ParameterVector{param});
257         auto optimized_f = clone_function(*baseline_f);
258
259         pass::Manager pass_manager;
260         pass_manager.register_pass<pass::Validate>();
261         pass_manager.register_pass<pass::AlgebraicSimplification>();
262         pass_manager.register_pass<pass::ConstantFolding>();
263         pass_manager.run_passes(optimized_f);
264
265         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
266         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
267         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
268         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
269
270         ASSERT_EQ(count_ops_of_type<op::v1::Transpose>(baseline_f), 1);
271         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 0);
272         ASSERT_EQ(count_ops_of_type<op::v1::Transpose>(optimized_f), num);
273         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), (num ? 0 : 1));
274     };
275
276     for (auto& i32 : {true, false})
277         for (auto& multiout : {true, false}) {
278             check_usecase(Shape{1, 3}, vector<int64_t>{1, 0}, i32, multiout, 0);
279             check_usecase(Shape{2, 3, 1}, vector<int64_t>{2, 0, 1}, i32, multiout, 0);
280             check_usecase(Shape{10, 20, 1, 1}, vector<int64_t>{0, 2, 3, 1}, i32, multiout, 0);
281             check_usecase(Shape{10, 1, 1, 20}, vector<int64_t>{0, 3, 1, 2}, i32, multiout, 0);
282             check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 2, 1, 3}, i32, multiout, 0);
283             check_usecase(Shape{10, 1, 1, 1, 20}, vector<int64_t>{0, 4, 1, 2, 3}, i32, multiout, 0);
284             check_usecase(Shape{10, 20, 1, 1, 1}, vector<int64_t>{0, 2, 3, 4, 1}, i32, multiout, 0);
285             check_usecase(Shape{10, 1, 1, 1, 1}, vector<int64_t>{1, 4, 2, 3, 0}, i32, multiout, 0);
286             check_usecase(Shape{10, 1, 1, 1, 1}, vector<int64_t>{4, 2, 0, 1, 3}, i32, multiout, 0);
287             check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 2, 3, 1}, i32, multiout, 1);
288             check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 3, 1, 2}, i32, multiout, 1);
289             check_usecase(Shape{10, 20}, vector<int64_t>{1, 0}, i32, multiout, 1);
290
291             check_usecase(PartialShape{Dimension::dynamic(), 20, 1, 1},
292                           vector<int64_t>{
293                               0, 2, 3, 1,
294                           },
295                           i32,
296                           multiout,
297                           0);
298             check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 20, 1, 1},
299                           vector<int64_t>{0, 1, 3, 2, 4},
300                           i32,
301                           multiout,
302                           0);
303             check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 20, 1, 1},
304                           vector<int64_t>{0, 2, 1, 4, 3},
305                           i32,
306                           multiout,
307                           1);
308         }
309 }
310
311 // the following gather test will be used to test when
312 // gather is Nop and will be removed during `simplify_gather`
313 // algebraic_simplification pass
314
315 TEST(algebraic_simplification, gather_3d_indices_constant_axis_1) {
316     auto check_usecase = [](const PartialShape& pshape,
317                             bool i32,
318                             bool multiout,
319                             const std::vector<int64_t>& indices_val,
320                             int64_t axis_val,
321                             size_t num) {
322         static size_t id = 0;
323         auto casename = string("usecase #") + to_string(++id);
324
325         shared_ptr<Node> indices;
326         shared_ptr<Node> axis;
327         if (i32) {
328             std::vector<int32_t> indices_val_i32(indices_val.begin(), indices_val.end());
329             indices = op::Constant::create<int32_t>(
330                 element::i32, Shape{indices_val.size()}, indices_val_i32);
331             axis = op::Constant::create<int32_t>(element::i32, Shape{}, {(int32_t)axis_val});
332         } else {
333             indices =
334                 op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
335             axis = op::Constant::create<int64_t>(element::i64, Shape{}, {axis_val});
336         }
337
338         auto A = make_shared<op::Parameter>(element::f32, pshape);
339         shared_ptr<Node> A1;
340         if (multiout) {
341             auto last_dim = pshape.rank().get_length() - 1;
342             A1 = make_shared<op::v0::TopK>(A, last_dim, element::i32);
343         } else {
344             A1 = make_shared<op::v0::Abs>(A);
345         }
346         auto G = make_shared<op::v1::Gather>((multiout ? A1->output(0) : A1), indices, axis);
347
348         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(G), ParameterVector{A});
349         auto optimized_f = clone_function(*baseline_f);
350
351         pass::Manager pass_manager;
352         pass_manager.register_pass<pass::Validate>();
353         pass_manager.register_pass<pass::AlgebraicSimplification>();
354         pass_manager.run_passes(optimized_f);
355
356         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
357         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
358         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
359         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
360
361         ASSERT_EQ(count_ops_of_type<op::v1::Gather>(baseline_f), 1) << casename;
362         // the pass should short cut the Gather i/p with the gather users
363         // since we are fetching the whole tensor using gather op
364         ASSERT_EQ(count_ops_of_type<op::v1::Gather>(optimized_f), num) << casename;
365     };
366     for (auto& i32 : {true, false})
367         for (auto& multiout : {true, false}) {
368             check_usecase(PartialShape{1, 3, 2}, i32, multiout, std::vector<int64_t>{1}, 0, 0);
369             check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{0, 1}, 1, 0);
370             check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{1}, 2, 0);
371             check_usecase(PartialShape{1, 16}, i32, multiout, std::vector<int64_t>{0, 0}, 0, 1);
372         }
373 }
374
375 TEST(algebraic_simplification, gather_shapeof) {
376     auto check_usecase = [](const PartialShape& pshape,
377                             bool is_scalar_index,
378                             bool opset2,
379                             bool i32,
380                             bool multiout,
381                             bool multiout_1,
382                             const std::vector<int64_t>& indices_val,
383                             int64_t axis_val) {
384         static size_t id = 0;
385         auto casename = string("usecase #") + to_string(++id);
386
387         shared_ptr<Node> indices;
388         shared_ptr<Node> axis;
389         if (i32) {
390             std::vector<int32_t> indices_val_i32(indices_val.begin(), indices_val.end());
391             indices = is_scalar_index
392                           ? op::Constant::create<int32_t>(element::i32, Shape{}, indices_val_i32)
393                           : op::Constant::create<int32_t>(
394                                 element::i32, Shape{indices_val.size()}, indices_val_i32);
395             axis = op::Constant::create<int32_t>(element::i32, Shape{}, {(int32_t)axis_val});
396         } else {
397             indices = is_scalar_index
398                           ? op::Constant::create<int64_t>(element::i64, Shape{}, indices_val)
399                           : op::Constant::create<int64_t>(
400                                 element::i64, Shape{indices_val.size()}, indices_val);
401             axis = op::Constant::create<int64_t>(element::i64, Shape{}, {axis_val});
402         }
403
404         auto dims_1 = std::vector<Dimension>(pshape);
405         dims_1.push_back(11);
406         dims_1.push_back(13);
407         auto pshape_1 = PartialShape(dims_1);
408         auto A = make_shared<op::Parameter>(element::f32, pshape);
409         auto AA = make_shared<op::Parameter>(element::f64, pshape_1);
410         shared_ptr<Node> A1;
411         if (multiout) {
412             A1 = make_shared<TestOpMultiOut>(A, AA);
413         } else {
414             A1 = make_shared<op::v0::Abs>(A);
415         }
416         auto B = make_shared<op::v1::Gather>(
417             (multiout ? (multiout_1 ? A1->output(1) : A1->output(0)) : A1), indices, axis);
418         shared_ptr<Node> B1;
419         if (opset2) {
420             B1 = make_shared<op::v0::ShapeOf>(B);
421         } else {
422             B1 = make_shared<op::v3::ShapeOf>(B);
423         }
424         auto baseline_f = make_shared<Function>(
425             make_shared<op::v0::Abs>(B1), (multiout ? ParameterVector{A, AA} : ParameterVector{A}));
426         auto optimized_f = clone_function(*baseline_f);
427
428         pass::Manager pass_manager;
429         pass_manager.register_pass<pass::Validate>();
430         pass_manager.register_pass<pass::AlgebraicSimplification>();
431         pass_manager.run_passes(optimized_f);
432
433         ASSERT_EQ(baseline_f->get_results().at(0)->get_element_type(),
434                   optimized_f->get_results().at(0)->get_element_type());
435
436         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
437         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
438         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
439         EXPECT_TRUE(ps.same_scheme(ps_r)) << casename;
440
441         ASSERT_EQ(count_ops_of_type<op::v1::Gather>(baseline_f), 1) << casename;
442
443         auto last_node = optimized_f->get_results()[0]->input_value(0).get_node_shared_ptr();
444         if (is_scalar_index) {
445             ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(optimized_f), 1) << casename;
446             ASSERT_EQ(count_ops_of_type<op::v1::Gather>(optimized_f), 1) << casename;
447             EXPECT_TRUE(
448                 as_type_ptr<op::v1::Gather>(last_node->input_value(0).get_node_shared_ptr()))
449                 << casename;
450         } else {
451             ASSERT_EQ(count_ops_of_type<op::v0::Concat>(optimized_f), 1) << casename;
452             EXPECT_TRUE(
453                 as_type_ptr<op::v0::Concat>(last_node->input_value(0).get_node_shared_ptr()))
454                 << casename;
455         }
456     };
457
458     for (auto& opset2 : {true, false})
459         for (auto& i32 : {true, false})
460             for (auto& multiout : {true, false})
461                 for (auto& multiout_1 : {true, false}) {
462                     check_usecase(PartialShape{2, 3, 2, 1},
463                                   true,
464                                   opset2,
465                                   i32,
466                                   multiout,
467                                   multiout_1,
468                                   std::vector<int64_t>{0},
469                                   3);
470                     check_usecase(PartialShape{2, Dimension::dynamic(), 2, 1},
471                                   true,
472                                   opset2,
473                                   i32,
474                                   multiout,
475                                   multiout_1,
476                                   std::vector<int64_t>{0},
477                                   3);
478                 }
479     for (auto& opset2 : {true, false})
480         for (auto& i32 : {true, false})
481             for (auto& multiout : {true, false})
482                 for (auto& multiout_1 : {true, false}) {
483                     check_usecase(PartialShape{2, 3, 2, 1},
484                                   false,
485                                   opset2,
486                                   i32,
487                                   multiout,
488                                   multiout_1,
489                                   std::vector<int64_t>{0, 2},
490                                   1);
491                     check_usecase(PartialShape{2, Dimension::dynamic(), 2, 1},
492                                   false,
493                                   opset2,
494                                   i32,
495                                   multiout,
496                                   multiout_1,
497                                   std::vector<int64_t>{0, 2},
498                                   1);
499                 }
500 }