Remove obsoleted v0::Broadcast and BroadcastLike operators (#2779)
[platform/upstream/dldt.git] / inference-engine / tests / functional / inference_engine / transformations / algebraic_simplification.cpp
1 // Copyright (C) 2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6
7 #include "common_test_utils/test_common.hpp"
8 #include <string>
9 #include <sstream>
10 #include <memory>
11 #include <queue>
12
13 #include <ngraph/function.hpp>
14 #include <ngraph/opsets/opset2.hpp>
15 #include <ngraph/opsets/opset3.hpp>
16 #include <ngraph/pass/manager.hpp>
17 #include <ngraph/pass/constant_folding.hpp>
18 #include <ngraph/builder/autobroadcast.hpp>
19 #include <transformations/common_optimizations/algebraic_simplification.hpp>
20 #include <transformations/utils/utils.hpp>
21 #include <transformations/init_node_info.hpp>
22
23 #include "common_test_utils/ngraph_test_utils.hpp"
24
25 using namespace ngraph;
26 using namespace std;
27
28 TEST(algebraic_simplification, add_negative_tests) {
29     Shape shape{};
30     auto type = element::f32;
31     pass::Manager pass_manager;
32     pass_manager.register_pass<pass::AlgebraicSimplification>();
33
34     auto a = make_shared<op::Parameter>(type, shape);
35     auto b = make_shared<op::Parameter>(type, shape);
36     auto c = make_shared<op::Parameter>(type, shape);
37     auto abs_a = make_shared<op::Abs>(a);
38     auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
39     auto add_a_0 = a + iconst2;
40     auto add_a_0_0 = add_a_0 + iconst2;
41     auto add_b_0 = b + abs_a;
42     auto add_b_0_0 = add_b_0 + abs_a;
43
44     auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
45                                         ParameterVector{a, b, c});
46     pass_manager.run_passes(f);
47
48     auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
49     auto results = f->get_results();
50     for (size_t i = 0; i < results.size(); i++) {
51         ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
52     }
53 }
54
55 TEST(algebraic_simplification, multiply_negative_tests) {
56     Shape shape{};
57     auto type = element::f32;
58     pass::Manager pass_manager;
59     pass_manager.register_pass<pass::AlgebraicSimplification>();
60
61     auto a = make_shared<op::Parameter>(type, shape);
62     auto b = make_shared<op::Parameter>(type, shape);
63     auto c = make_shared<op::Parameter>(type, shape);
64     auto abs_a = make_shared<op::Abs>(a);
65     auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
66     auto add_a_0 = a * iconst2;
67     auto add_a_0_0 = add_a_0 * iconst2;
68     auto add_b_0 = b * abs_a;
69     auto add_b_0_0 = add_b_0 * abs_a;
70
71     auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
72                                         ParameterVector{a, b, c});
73     pass_manager.run_passes(f);
74
75     auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
76     auto results = f->get_results();
77     for (size_t i = 0; i < results.size(); i++) {
78         ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
79     }
80 }
81
82 TEST(algebraic_simplification, multiply_prod_negative) {
83     auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
84     auto broadcast = builder::opset1::make_broadcast(fconst1, Shape{2, 5}, AxisSet{1});
85     auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
86
87     pass::Manager pass_manager;
88     pass_manager.register_pass<pass::AlgebraicSimplification>();
89
90     auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
91     pass_manager.run_passes(f);
92     auto f_prod = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
93     ASSERT_EQ(f_prod, prod_fconst1);
94 }
95
96 TEST(algebraic_simplification, multiply_sum_negative) {
97     auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
98     auto broadcast = builder::opset1::make_broadcast(fconst1, Shape{2, 5}, AxisSet{1});
99     auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
100
101     pass::Manager pass_manager;
102     pass_manager.register_pass<pass::AlgebraicSimplification>();
103
104     auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
105     pass_manager.run_passes(f);
106     auto f_sum = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
107     ASSERT_EQ(f_sum, sum_fconst1);
108 }
109
110 TEST(algebraic_simplification, concat_parameter_slices_reversed) {
111     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
112     auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
113     auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
114     auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
115
116     size_t concat_axis = 0;
117     auto concat = make_shared<op::Concat>(NodeVector{slice3, slice2, slice1}, concat_axis);
118
119     pass::Manager pass_manager;
120     pass_manager.register_pass<pass::AlgebraicSimplification>();
121
122     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
123     pass_manager.run_passes(f);
124     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
125 }
126
127 TEST(algebraic_simplification, concat_parameter_slices_element_count) {
128     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
129     // slicing 30 elements out of 96; should trigger a check that some elements are missing
130     auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{10, 100}, Strides{1, 1});
131     auto slice2 = make_shared<op::Slice>(a, Coordinate{10, 0}, Coordinate{20, 100}, Strides{1, 1});
132     auto slice3 = make_shared<op::Slice>(a, Coordinate{20, 0}, Coordinate{30, 100}, Strides{1, 1});
133
134     size_t concat_axis = 0;
135     auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
136
137     pass::Manager pass_manager;
138     pass_manager.register_pass<pass::AlgebraicSimplification>();
139
140     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
141     pass_manager.run_passes(f);
142     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
143 }
144
145 TEST(algebraic_simplification, concat_parameter_non_uniform_slices) {
146     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
147     auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{38, 100}, Strides{1, 1});
148     auto slice2 = make_shared<op::Slice>(a, Coordinate{38, 0}, Coordinate{64, 100}, Strides{1, 1});
149     auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
150
151     size_t concat_axis = 0;
152     auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
153
154     pass::Manager pass_manager;
155     pass_manager.register_pass<pass::AlgebraicSimplification>();
156
157     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
158     pass_manager.run_passes(f);
159     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
160 }
161
162 TEST(algebraic_simplification, concat_different_inputs) {
163     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
164     auto goe1 = -a;
165     auto goe2 = -a;
166     auto slice1 =
167         make_shared<op::Slice>(goe1, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
168     auto slice2 =
169         make_shared<op::Slice>(goe2, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
170     auto slice3 =
171         make_shared<op::Slice>(goe1, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
172
173     size_t concat_axis = 0;
174     auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
175
176     pass::Manager pass_manager;
177     pass_manager.register_pass<pass::AlgebraicSimplification>();
178
179     auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
180     pass_manager.run_passes(f);
181     ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
182 }
183
184 TEST(algebraic_simplification, log_no_exp) {
185     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
186     auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
187     auto abs_a = make_shared<op::Abs>(a);
188     auto div = abs_a / b;
189     auto log_div = make_shared<op::Log>(div);
190
191     auto neg_inner = make_shared<op::Negative>(log_div);
192     auto neg2 = make_shared<op::Negative>(neg_inner);
193     auto neg3 = make_shared<op::Negative>(neg2);
194     auto neg4 = make_shared<op::Negative>(neg3);
195
196     pass::Manager pass_manager;
197     pass_manager.register_pass<pass::AlgebraicSimplification>();
198
199     auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
200     pass_manager.run_passes(f);
201     ASSERT_EQ(neg_inner->input_value(0).get_node_shared_ptr(), log_div);
202 }
203
204 TEST(algebraic_simplification, log_no_divide) {
205     auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
206     auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
207     auto exp_a = make_shared<op::Exp>(a);
208     auto mul = exp_a * b;
209     auto log_mul = make_shared<op::Log>(mul);
210
211     auto neg_inner = make_shared<op::Negative>(log_mul);
212     auto neg2 = make_shared<op::Negative>(neg_inner);
213     auto neg3 = make_shared<op::Negative>(neg2);
214     auto neg4 = make_shared<op::Negative>(neg3);
215
216     pass::Manager pass_manager;
217     pass_manager.register_pass<pass::AlgebraicSimplification>();
218
219     auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
220     pass_manager.run_passes(f);
221     ASSERT_EQ(neg_inner->input_value(0).get_node_shared_ptr(), log_mul);
222 }
223
224 TEST(algebraic_simplification, pass_property) {
225     auto pass = std::make_shared<ngraph::pass::AlgebraicSimplification>();
226
227     ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
228 }
229
230 TEST(algebraic_simplification, replace_transpose_with_reshape) {
231     auto check_usecase = [](const PartialShape& shape,
232                             const std::vector<int64_t>& perm_val,
233                             bool i32,
234                             bool multiout,
235                             size_t num) {
236         static size_t id = 0;
237         auto casename = string("usecase #") + to_string(++id);
238
239         shared_ptr<Node> perm;
240         if (i32) {
241             std::vector<int32_t> perm_val_i32(perm_val.begin(), perm_val.end());
242             perm =
243                 op::Constant::create<int32_t>(element::i32, Shape{perm_val.size()}, perm_val_i32);
244         } else {
245             perm = op::Constant::create<int64_t>(element::i64, Shape{perm_val.size()}, perm_val);
246         }
247         auto param = make_shared<op::Parameter>(element::f32, shape);
248         shared_ptr<Node> A1;
249         if (multiout) {
250             auto last_dim = shape.rank().get_length() - 1;
251             A1 = make_shared<op::v0::TopK>(param, last_dim, element::i32);
252         } else {
253             A1 = make_shared<op::v0::Abs>(param);
254         }
255         auto transpose = make_shared<op::v1::Transpose>((multiout ? A1->output(0) : A1), perm);
256         auto transpose1 = make_shared<op::v0::Abs>(transpose);
257         auto baseline_f = make_shared<Function>(transpose1, ParameterVector{param});
258         auto optimized_f = clone_function(*baseline_f);
259
260         pass::Manager pass_manager;
261         pass_manager.register_pass<pass::Validate>();
262         pass_manager.register_pass<pass::AlgebraicSimplification>();
263         pass_manager.register_pass<pass::ConstantFolding>();
264         pass_manager.run_passes(optimized_f);
265
266         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
267         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
268         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
269         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
270
271         ASSERT_EQ(count_ops_of_type<op::v1::Transpose>(baseline_f), 1);
272         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(baseline_f), 0);
273         ASSERT_EQ(count_ops_of_type<op::v1::Transpose>(optimized_f), num);
274         ASSERT_EQ(count_ops_of_type<op::v1::Reshape>(optimized_f), (num ? 0 : 1));
275     };
276
277     for (auto& i32 : {true, false})
278         for (auto& multiout : {true, false}) {
279             check_usecase(Shape{1, 3}, vector<int64_t>{1, 0}, i32, multiout, 0);
280             check_usecase(Shape{2, 3, 1}, vector<int64_t>{2, 0, 1}, i32, multiout, 0);
281             check_usecase(Shape{10, 20, 1, 1}, vector<int64_t>{0, 2, 3, 1}, i32, multiout, 0);
282             check_usecase(Shape{10, 1, 1, 20}, vector<int64_t>{0, 3, 1, 2}, i32, multiout, 0);
283             check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 2, 1, 3}, i32, multiout, 0);
284             check_usecase(Shape{10, 1, 1, 1, 20}, vector<int64_t>{0, 4, 1, 2, 3}, i32, multiout, 0);
285             check_usecase(Shape{10, 20, 1, 1, 1}, vector<int64_t>{0, 2, 3, 4, 1}, i32, multiout, 0);
286             check_usecase(Shape{10, 1, 1, 1, 1}, vector<int64_t>{1, 4, 2, 3, 0}, i32, multiout, 0);
287             check_usecase(Shape{10, 1, 1, 1, 1}, vector<int64_t>{4, 2, 0, 1, 3}, i32, multiout, 0);
288             check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 2, 3, 1}, i32, multiout, 1);
289             check_usecase(Shape{10, 20, 1, 2}, vector<int64_t>{0, 3, 1, 2}, i32, multiout, 1);
290             check_usecase(Shape{10, 20}, vector<int64_t>{1, 0}, i32, multiout, 1);
291
292             check_usecase(PartialShape{Dimension::dynamic(), 20, 1, 1},
293                           vector<int64_t>{
294                               0, 2, 3, 1,
295                           },
296                           i32,
297                           multiout,
298                           0);
299             check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 20, 1, 1},
300                           vector<int64_t>{0, 1, 3, 2, 4},
301                           i32,
302                           multiout,
303                           0);
304             check_usecase(PartialShape{Dimension::dynamic(), Dimension::dynamic(), 20, 1, 1},
305                           vector<int64_t>{0, 2, 1, 4, 3},
306                           i32,
307                           multiout,
308                           1);
309         }
310 }
311
312 // the following gather test will be used to test when
313 // gather is Nop and will be removed during `simplify_gather`
314 // algebraic_simplification pass
315
316 TEST(algebraic_simplification, gather_3d_indices_constant_axis_1) {
317     auto check_usecase = [](const PartialShape& pshape,
318                             bool i32,
319                             bool multiout,
320                             const std::vector<int64_t>& indices_val,
321                             int64_t axis_val,
322                             size_t num) {
323         static size_t id = 0;
324         auto casename = string("usecase #") + to_string(++id);
325
326         shared_ptr<Node> indices;
327         shared_ptr<Node> axis;
328         if (i32) {
329             std::vector<int32_t> indices_val_i32(indices_val.begin(), indices_val.end());
330             indices = op::Constant::create<int32_t>(
331                 element::i32, Shape{indices_val.size()}, indices_val_i32);
332             axis = op::Constant::create<int32_t>(element::i32, Shape{}, {(int32_t)axis_val});
333         } else {
334             indices =
335                 op::Constant::create<int64_t>(element::i64, Shape{indices_val.size()}, indices_val);
336             axis = op::Constant::create<int64_t>(element::i64, Shape{}, {axis_val});
337         }
338
339         auto A = make_shared<op::Parameter>(element::f32, pshape);
340         shared_ptr<Node> A1;
341         if (multiout) {
342             auto last_dim = pshape.rank().get_length() - 1;
343             A1 = make_shared<op::v0::TopK>(A, last_dim, element::i32);
344         } else {
345             A1 = make_shared<op::v0::Abs>(A);
346         }
347         auto G = make_shared<op::v1::Gather>((multiout ? A1->output(0) : A1), indices, axis);
348
349         auto baseline_f = make_shared<Function>(make_shared<op::v0::Abs>(G), ParameterVector{A});
350         auto optimized_f = clone_function(*baseline_f);
351
352         pass::Manager pass_manager;
353         pass_manager.register_pass<pass::Validate>();
354         pass_manager.register_pass<pass::AlgebraicSimplification>();
355         pass_manager.run_passes(optimized_f);
356
357         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
358         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
359         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
360         ASSERT_EQ(ps.rank().get_length(), ps_r.rank().get_length()) << casename;
361
362         ASSERT_EQ(count_ops_of_type<op::v1::Gather>(baseline_f), 1) << casename;
363         // the pass should short cut the Gather i/p with the gather users
364         // since we are fetching the whole tensor using gather op
365         ASSERT_EQ(count_ops_of_type<op::v1::Gather>(optimized_f), num) << casename;
366     };
367     for (auto& i32 : {true, false})
368         for (auto& multiout : {true, false}) {
369             check_usecase(PartialShape{1, 3, 2}, i32, multiout, std::vector<int64_t>{1}, 0, 0);
370             check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{0, 1}, 1, 0);
371             check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{1}, 2, 0);
372             check_usecase(PartialShape{1, 16}, i32, multiout, std::vector<int64_t>{0, 0}, 0, 1);
373         }
374 }
375
376 TEST(algebraic_simplification, gather_shapeof) {
377     auto check_usecase = [](const PartialShape& pshape,
378                             bool is_scalar_index,
379                             bool opset2,
380                             bool i32,
381                             bool multiout,
382                             bool multiout_1,
383                             const std::vector<int64_t>& indices_val,
384                             int64_t axis_val) {
385         static size_t id = 0;
386         auto casename = string("usecase #") + to_string(++id);
387
388         shared_ptr<Node> indices;
389         shared_ptr<Node> axis;
390         if (i32) {
391             std::vector<int32_t> indices_val_i32(indices_val.begin(), indices_val.end());
392             indices = is_scalar_index
393                           ? op::Constant::create<int32_t>(element::i32, Shape{}, indices_val_i32)
394                           : op::Constant::create<int32_t>(
395                                 element::i32, Shape{indices_val.size()}, indices_val_i32);
396             axis = op::Constant::create<int32_t>(element::i32, Shape{}, {(int32_t)axis_val});
397         } else {
398             indices = is_scalar_index
399                           ? op::Constant::create<int64_t>(element::i64, Shape{}, indices_val)
400                           : op::Constant::create<int64_t>(
401                                 element::i64, Shape{indices_val.size()}, indices_val);
402             axis = op::Constant::create<int64_t>(element::i64, Shape{}, {axis_val});
403         }
404
405         auto dims_1 = std::vector<Dimension>(pshape);
406         dims_1.push_back(11);
407         dims_1.push_back(13);
408         auto pshape_1 = PartialShape(dims_1);
409         auto A = make_shared<op::Parameter>(element::f32, pshape);
410         auto AA = make_shared<op::Parameter>(element::f64, pshape_1);
411         shared_ptr<Node> A1;
412         if (multiout) {
413             A1 = make_shared<TestOpMultiOut>(A, AA);
414         } else {
415             A1 = make_shared<op::v0::Abs>(A);
416         }
417         auto B = make_shared<op::v1::Gather>(
418             (multiout ? (multiout_1 ? A1->output(1) : A1->output(0)) : A1), indices, axis);
419         shared_ptr<Node> B1;
420         if (opset2) {
421             B1 = make_shared<op::v0::ShapeOf>(B);
422         } else {
423             B1 = make_shared<op::v3::ShapeOf>(B);
424         }
425         auto baseline_f = make_shared<Function>(
426             make_shared<op::v0::Abs>(B1), (multiout ? ParameterVector{A, AA} : ParameterVector{A}));
427         auto optimized_f = clone_function(*baseline_f);
428
429         pass::Manager pass_manager;
430         pass_manager.register_pass<pass::Validate>();
431         pass_manager.register_pass<pass::AlgebraicSimplification>();
432         pass_manager.run_passes(optimized_f);
433
434         ASSERT_EQ(baseline_f->get_results().at(0)->get_element_type(),
435                   optimized_f->get_results().at(0)->get_element_type());
436
437         auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
438         auto ps_r = optimized_f->get_results()[0]->get_output_partial_shape(0);
439         EXPECT_TRUE(ps.rank().is_static() && ps_r.rank().is_static()) << casename;
440         EXPECT_TRUE(ps.same_scheme(ps_r)) << casename;
441
442         ASSERT_EQ(count_ops_of_type<op::v1::Gather>(baseline_f), 1) << casename;
443
444         auto last_node = optimized_f->get_results()[0]->input_value(0).get_node_shared_ptr();
445         if (is_scalar_index) {
446             ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(optimized_f), 1) << casename;
447             ASSERT_EQ(count_ops_of_type<op::v1::Gather>(optimized_f), 1) << casename;
448             EXPECT_TRUE(
449                 as_type_ptr<op::v1::Gather>(last_node->input_value(0).get_node_shared_ptr()))
450                 << casename;
451         } else {
452             ASSERT_EQ(count_ops_of_type<op::v0::Concat>(optimized_f), 1) << casename;
453             EXPECT_TRUE(
454                 as_type_ptr<op::v0::Concat>(last_node->input_value(0).get_node_shared_ptr()))
455                 << casename;
456         }
457     };
458
459     for (auto& opset2 : {true, false})
460         for (auto& i32 : {true, false})
461             for (auto& multiout : {true, false})
462                 for (auto& multiout_1 : {true, false}) {
463                     check_usecase(PartialShape{2, 3, 2, 1},
464                                   true,
465                                   opset2,
466                                   i32,
467                                   multiout,
468                                   multiout_1,
469                                   std::vector<int64_t>{0},
470                                   3);
471                     check_usecase(PartialShape{2, Dimension::dynamic(), 2, 1},
472                                   true,
473                                   opset2,
474                                   i32,
475                                   multiout,
476                                   multiout_1,
477                                   std::vector<int64_t>{0},
478                                   3);
479                 }
480     for (auto& opset2 : {true, false})
481         for (auto& i32 : {true, false})
482             for (auto& multiout : {true, false})
483                 for (auto& multiout_1 : {true, false}) {
484                     check_usecase(PartialShape{2, 3, 2, 1},
485                                   false,
486                                   opset2,
487                                   i32,
488                                   multiout,
489                                   multiout_1,
490                                   std::vector<int64_t>{0, 2},
491                                   1);
492                     check_usecase(PartialShape{2, Dimension::dynamic(), 2, 1},
493                                   false,
494                                   opset2,
495                                   i32,
496                                   multiout,
497                                   multiout_1,
498                                   std::vector<int64_t>{0, 2},
499                                   1);
500                 }
501 }