Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / type_prop / broadcast.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
20
21 NGRAPH_SUPPRESS_DEPRECATED_START
22
23 using namespace std;
24 using namespace ngraph;
25
26 TEST(type_prop, broadcast_deduce)
27 {
28     auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
29     Shape bc_shape{2, 3, 4};
30     auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
31     ASSERT_EQ(bc->get_element_type(), element::f32);
32     ASSERT_EQ(bc->get_shape(), bc_shape);
33 }
34
35 TEST(type_prop, broadcast_axes_oob)
36 {
37     auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
38     auto bc_shape = Shape{2, 3, 4};
39
40     try
41     {
42         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
43         FAIL() << "Broadcast axis out of bounds not detected";
44     }
45     catch (const NodeValidationFailure& error)
46     {
47         EXPECT_HAS_SUBSTRING(error.what(),
48                              "Broadcast axis index (3) exceeds specified output shape rank");
49     }
50     catch (...)
51     {
52         FAIL() << "Deduced type check failed for unexpected reason";
53     }
54 }
55
56 TEST(type_prop, broadcast_shape_mismatch_wrong_rank)
57 {
58     auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
59     auto bc_shape = Shape{2, 3, 4, 5};
60
61     try
62     {
63         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
64         FAIL() << "Output shape mismatch (wrong rank) not detected";
65     }
66     catch (const NodeValidationFailure& error)
67     {
68         EXPECT_HAS_SUBSTRING(
69             error.what(),
70             "Broadcast argument shape, specified output shape, and axes are incompatible");
71     }
72     catch (...)
73     {
74         FAIL() << "Deduced type check failed for unexpected reason";
75     }
76 }
77
78 TEST(type_prop, broadcast_shape_mismatch_wrong_size)
79 {
80     auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
81     auto bc_shape = Shape{2, 3, 5};
82
83     try
84     {
85         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
86         FAIL() << "Output shape mismatch (wrong size) not detected";
87     }
88     catch (const NodeValidationFailure& error)
89     {
90         EXPECT_HAS_SUBSTRING(
91             error.what(),
92             "Broadcast argument shape, specified output shape, and axes are incompatible");
93     }
94     catch (...)
95     {
96         FAIL() << "Deduced type check failed for unexpected reason";
97     }
98 }
99
100 TEST(type_prop, broadcast_partial_rank_dynamic_ok)
101 {
102     auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
103     Shape bc_shape{2, 3, 4};
104     auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
105     ASSERT_EQ(bc->get_element_type(), element::f32);
106     ASSERT_EQ(bc->get_shape(), bc_shape);
107 }
108
109 TEST(type_prop, broadcast_partial_rank_dynamic_axes_oob)
110 {
111     auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
112     auto bc_shape = Shape{2, 3, 4};
113
114     try
115     {
116         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
117         FAIL() << "Broadcast axis out of bounds not detected";
118     }
119     catch (const NodeValidationFailure& error)
120     {
121         EXPECT_HAS_SUBSTRING(error.what(),
122                              "Broadcast axis index (3) exceeds specified output shape rank");
123     }
124     catch (...)
125     {
126         FAIL() << "Deduced type check failed for unexpected reason";
127     }
128 }
129
130 TEST(type_prop, broadcast_partial_rank_static_dynamic_ok)
131 {
132     auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
133     Shape bc_shape{2, 3, 4};
134     auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
135     ASSERT_EQ(bc->get_element_type(), element::f32);
136     ASSERT_EQ(bc->get_shape(), bc_shape);
137 }
138
139 TEST(type_prop, broadcast_partial_rank_static_dynamic_axes_oob)
140 {
141     auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
142     auto bc_shape = Shape{2, 3, 4};
143
144     try
145     {
146         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
147         FAIL() << "Broadcast axis out of bounds not detected";
148     }
149     catch (const NodeValidationFailure& error)
150     {
151         EXPECT_HAS_SUBSTRING(error.what(),
152                              "Broadcast axis index (3) exceeds specified output shape rank");
153     }
154     catch (...)
155     {
156         FAIL() << "Deduced type check failed for unexpected reason";
157     }
158 }
159
160 TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_rank)
161 {
162     auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
163     auto bc_shape = Shape{2, 3, 4, 5};
164
165     try
166     {
167         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
168         FAIL() << "Output shape mismatch (wrong rank) not detected";
169     }
170     catch (const NodeValidationFailure& error)
171     {
172         EXPECT_HAS_SUBSTRING(
173             error.what(),
174             "Broadcast argument shape, specified output shape, and axes are incompatible");
175     }
176     catch (...)
177     {
178         FAIL() << "Deduced type check failed for unexpected reason";
179     }
180 }
181
182 TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size)
183 {
184     auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
185     auto bc_shape = Shape{2, 3, 5};
186
187     try
188     {
189         auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
190         FAIL() << "Output shape mismatch (wrong size) not detected";
191     }
192     catch (const NodeValidationFailure& error)
193     {
194         EXPECT_HAS_SUBSTRING(
195             error.what(),
196             "Broadcast argument shape, specified output shape, and axes are incompatible");
197     }
198     catch (...)
199     {
200         FAIL() << "Deduced type check failed for unexpected reason";
201     }
202 }
203
204 // Because v3::Broadcast is backward compatible to v1::Broadcast all v1::Broadcast tests should pass
205 template <typename T>
206 class BroadcastTests : public ::testing::Test
207 {
208 };
209 TYPED_TEST_CASE_P(BroadcastTests);
210
211 TYPED_TEST_P(BroadcastTests, broadcast_numpy)
212 {
213     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
214     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
215
216     auto bc = make_shared<TypeParam>(param, target_shape);
217     ASSERT_EQ(bc->get_element_type(), element::f32);
218     ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
219 }
220
221 TYPED_TEST_P(BroadcastTests, broadcast_axes_mapping)
222 {
223     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
224     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
225     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
226
227     auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
228     ASSERT_EQ(bc->get_element_type(), element::f32);
229     ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 1}));
230 }
231
232 TYPED_TEST_P(BroadcastTests, broadcast_target_shape_as_concat_with_constants)
233 {
234     auto param = make_shared<op::Parameter>(element::f32, Shape{16});
235     auto target_shape_constant_1 = op::Constant::create<int64_t>(element::i64, Shape{1}, {1});
236     auto target_shape_constant_2 = op::Constant::create<int64_t>(element::i64, Shape{1}, {16});
237     auto target_shape_constant_3 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
238     auto target_shape_constant_4 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
239     std::int64_t axis = 0;
240     std::vector<std::shared_ptr<Node>> args{target_shape_constant_1,
241                                             target_shape_constant_2,
242                                             target_shape_constant_3,
243                                             target_shape_constant_4};
244     auto target_shape = make_shared<op::Concat>(args, axis);
245     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{1}, {1});
246     auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping, "NONE");
247     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
248     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
249     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
250     ASSERT_TRUE(bc->get_output_partial_shape(0).same_scheme(PartialShape{1, 16, 50, 50}));
251 }
252
253 TYPED_TEST_P(BroadcastTests, broadcast_target_shape_as_concat_with_node)
254 {
255     auto param = make_shared<op::Parameter>(element::f32, Shape{16});
256     auto target_shape_constant_1 = make_shared<op::Parameter>(element::i64, Shape{1});
257     auto target_shape_constant_2 = op::Constant::create<int64_t>(element::i64, Shape{1}, {16});
258     auto target_shape_constant_3 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
259     auto target_shape_constant_4 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
260     std::int64_t axis = 0;
261     std::vector<std::shared_ptr<Node>> args{target_shape_constant_1,
262                                             target_shape_constant_2,
263                                             target_shape_constant_3,
264                                             target_shape_constant_4};
265     auto target_shape = make_shared<op::Concat>(args, axis);
266     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{1}, {1});
267     auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping, "NONE");
268     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
269     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
270     ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
271     ASSERT_TRUE(bc->get_output_partial_shape(0).same_scheme(
272         PartialShape{Dimension::dynamic(), 16, 50, 50}));
273 }
274
275 TYPED_TEST_P(BroadcastTests, broadcast_fail_rank)
276 {
277     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
278     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
279     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 2, 3});
280
281     try
282     {
283         auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
284         FAIL() << "Broadcast: target shape mismatch with input rank not detected";
285     }
286     catch (const NodeValidationFailure& error)
287     {
288         EXPECT_HAS_SUBSTRING(
289             error.what(),
290             "Broadcast axes_mapping shape Shape{3} doesn't match rank of input tensor 2");
291     }
292     catch (...)
293     {
294         FAIL() << "Deduced type check failed for unexpected reason";
295     }
296 }
297
298 TYPED_TEST_P(BroadcastTests, broadcast_fail_transpose)
299 {
300     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
301     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 1, 3});
302     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 1});
303
304     try
305     {
306         auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
307         FAIL() << "Broadcast: transpose prohibition not detected";
308     }
309     catch (const NodeValidationFailure& error)
310     {
311         EXPECT_HAS_SUBSTRING(error.what(),
312                              "Broadcast doesn't permit transposes. axes_mapping AxisVector{2, 1} "
313                              "not in sorted order");
314     }
315     catch (...)
316     {
317         FAIL() << "Deduced type check failed for unexpected reason";
318     }
319 }
320
321 TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map)
322 {
323     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
324     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
325     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 3});
326
327     try
328     {
329         auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
330         FAIL() << "Broadcast: wrong axes_map not detected";
331     }
332     catch (const NodeValidationFailure& error)
333     {
334         EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes_mapping[1]: 3 exceeds target rank 3");
335     }
336     catch (...)
337     {
338         FAIL() << "Deduced type check failed for unexpected reason";
339     }
340 }
341
342 TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map_shape)
343 {
344     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
345     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 3});
346     auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
347
348     try
349     {
350         auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
351         FAIL() << "Broadcast: wrong target shape not detected";
352     }
353     catch (const NodeValidationFailure& error)
354     {
355         EXPECT_HAS_SUBSTRING(error.what(), "Broadcast target[axes_mapping[1]] Expected 1. Got 3");
356     }
357     catch (...)
358     {
359         FAIL() << "Deduced type check failed for unexpected reason";
360     }
361 }
362
363 TYPED_TEST_P(BroadcastTests, broadcast_axes_wrong_rank)
364 {
365     auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
366     auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
367     auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2, 2});
368
369     try
370     {
371         auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
372         FAIL() << "Broadcast: axes shape rank not detected";
373     }
374     catch (const NodeValidationFailure& error)
375     {
376         EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes rank must be 1");
377     }
378     catch (...)
379     {
380         FAIL() << "Deduced type check failed for unexpected reason";
381     }
382 }
383
384 TYPED_TEST_P(BroadcastTests, broadcast_fully_dynamic_target_shape)
385 {
386     auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
387     auto bc_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
388     auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
389
390     auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
391     ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
392
393     bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
394     bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
395     ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
396 }
397
398 TYPED_TEST_P(BroadcastTests, broadcast_broadcast_shape_et_wrong)
399 {
400     auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
401     // wrong element type
402     auto bc_shape = make_shared<op::Parameter>(element::boolean, Shape{1});
403     auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
404
405     try
406     {
407         auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
408         FAIL() << "Broadcast: did not detect shape element type not integral number";
409     }
410     catch (const NodeValidationFailure& error)
411     {
412         EXPECT_HAS_SUBSTRING(error.what(),
413                              std::string("Broadcast shape must be an integral number"));
414     }
415     catch (...)
416     {
417         FAIL() << "Deduced type check failed for unexpected reason";
418     }
419 }
420
421 TYPED_TEST_P(BroadcastTests, broadcast_axes_et_wrong)
422 {
423     auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
424     auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
425     // wrong element type
426     auto bc_axes = make_shared<op::Parameter>(element::f32, Shape{2});
427
428     try
429     {
430         auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
431         FAIL() << "Broadcast: did not detect axes element type not integral numbers";
432     }
433     catch (const NodeValidationFailure& error)
434     {
435         EXPECT_HAS_SUBSTRING(error.what(),
436                              std::string("Broadcast axes must be integral numbers, but are:"));
437     }
438     catch (...)
439     {
440         FAIL() << "Deduced type check failed for unexpected reason";
441     }
442 }
443
444 // EXPLICIT MODE
445
446 TYPED_TEST_P(BroadcastTests, broadcast_explicit_all_inputs_dynamic)
447 {
448     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
449     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
450     const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
451
452     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
453     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
454
455     // const axes mapping
456     const auto axes_mapping_const =
457         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
458     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
459     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
460 }
461
462 TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_static_rank)
463 {
464     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
465     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
466     const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
467
468     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
469     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
470
471     // const axes mapping
472     const auto axes_mapping_const =
473         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
474     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
475     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
476 }
477
478 TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape)
479 {
480     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
481     const auto target_shape =
482         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
483     const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
484
485     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
486
487     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
488     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
489     ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
490
491     // const axes mapping
492     const auto axes_mapping_const =
493         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
494     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
495
496     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
497     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
498     ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
499 }
500
501 TYPED_TEST_P(BroadcastTests, broadcast_explicit_input_rank_static)
502 {
503     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
504     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
505     const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
506
507     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
508     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
509
510     // const axes mapping
511     const auto axes_mapping_const =
512         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
513     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
514     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
515 }
516
517 TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_and_input_data_rank_static)
518 {
519     // static rank data
520     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
521     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
522     auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
523
524     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
525     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
526
527     // const axes mapping
528     const auto axes_mapping_const =
529         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
530     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
531     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
532 }
533
534 TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_input)
535 {
536     const auto target_shape =
537         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 5, 10});
538     // static rank data
539     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
540     auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
541
542     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
543     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
544     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
545     ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
546
547     // const axes mapping
548     const auto axes_mapping_const =
549         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
550     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
551     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
552     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
553     ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
554 }
555
556 TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape)
557 {
558     const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
559     // dynamic target shape and axes mapping
560     auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
561     auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
562
563     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
564     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
565
566     // const axes mapping
567     const auto axes_mapping_const =
568         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
569     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
570     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
571
572     // static rank target shape
573     target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
574     bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
575     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
576
577     // static rank target shape and const axes mapping
578     target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
579     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
580     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
581 }
582
583 TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape_const_target_shape)
584 {
585     const auto data = make_shared<op::Parameter>(element::f32, PartialShape{4});
586     auto target_shape = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 4, 2, 3});
587     // dynamic axes mapping
588     const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
589
590     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
591     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
592     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
593     ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
594
595     // const axes mapping
596     const auto axes_mapping_const =
597         op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
598     bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
599     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
600     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
601     ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
602 }
603
604 TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_target_shape)
605 {
606     // dynamic input
607     auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
608     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape{4});
609     const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
610
611     auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
612     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
613     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
614     ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
615
616     // static rank input
617     data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
618     bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
619     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
620     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
621     ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
622 }
623
624 // NUMPY MODE
625
626 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_shape_dynamic)
627 {
628     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
629     // dynamic output shape
630     auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
631
632     auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
633     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
634
635     // static rank target shape
636     target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
637     bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
638     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
639 }
640
641 TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_constant)
642 {
643     // dynamic data
644     auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
645     const auto target_shape =
646         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
647
648     auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
649     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
650     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
651
652     // static rank data
653     data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
654     bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
655     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
656     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
657 }
658
659 TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_dynamic)
660 {
661     // static rank data
662     auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
663     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
664
665     auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
666     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
667
668     // static shape data
669     data = make_shared<op::Parameter>(element::f32, PartialShape{3, 4, 5, 6});
670     bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
671     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
672 }
673
674 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_target_shape_static_rank)
675 {
676     const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
677     const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
678
679     const auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
680     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
681 }
682
683 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_static_shape)
684 {
685     const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
686     // static rank target_shape
687     auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
688
689     auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
690     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
691
692     // constant target_shape
693     const auto target_shape_const =
694         op::Constant::create(element::i64, Shape{3}, vector<int64_t>{3, 2, 3});
695     bc = make_shared<TypeParam>(data, target_shape_const, "NUMPY");
696     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
697     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
698     ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
699     ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{3, 2, 3}));
700 }
701
702 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_partially_dynamic)
703 {
704     const Shape expected_target_shape{1, 2, 3, 4};
705     const auto target_shape = op::Constant::create(
706         element::i64,
707         {expected_target_shape.size()},
708         std::vector<int64_t>(expected_target_shape.begin(), expected_target_shape.end()));
709
710     auto data = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
711     auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
712     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
713     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
714     ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
715
716     data = make_shared<op::Parameter>(element::f32,
717                                       PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()});
718     bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
719     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
720     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
721     ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
722
723     data = make_shared<op::Parameter>(element::f32,
724                                       PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
725     bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
726     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
727     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
728     ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
729
730     data = make_shared<op::Parameter>(
731         element::f32,
732         PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
733     bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
734     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
735     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
736     ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
737 }
738
739 TYPED_TEST_P(BroadcastTests, broadcast_numpy_static_dims_incorrect)
740 {
741     const auto target_shape = op::Constant::create(element::i64, Shape{4}, {1, 2, 3, 4});
742
743     auto data =
744         make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 999, 3, 4});
745     try
746     {
747         auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
748     }
749     catch (const NodeValidationFailure& error)
750     {
751         EXPECT_HAS_SUBSTRING(error.what(),
752                              "Input shape dimension equal 999 cannot be broadcasted (numpy mode) "
753                              "to 2. Allowed input dimension value would be 1 or 2");
754     }
755     catch (...)
756     {
757         FAIL() << "Deduced type check failed for unexpected reason";
758     }
759
760     data = make_shared<op::Parameter>(
761         element::f32,
762         PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 888});
763     try
764     {
765         auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
766     }
767     catch (const NodeValidationFailure& error)
768     {
769         EXPECT_HAS_SUBSTRING(error.what(),
770                              "Input shape dimension equal 888 cannot be broadcasted (numpy mode) "
771                              "to 4. Allowed input dimension value would be 1 or 4");
772     }
773     catch (...)
774     {
775         FAIL() << "Deduced type check failed for unexpected reason";
776     }
777
778     data = make_shared<op::Parameter>(
779         element::f32,
780         PartialShape{5, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
781     try
782     {
783         auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
784     }
785     catch (const NodeValidationFailure& error)
786     {
787         EXPECT_HAS_SUBSTRING(error.what(),
788                              "Input shape dimension equal 5 cannot be broadcasted (numpy mode) to "
789                              "1. Allowed input dimension value would be 1");
790     }
791     catch (...)
792     {
793         FAIL() << "Deduced type check failed for unexpected reason";
794     }
795 }
796
797 REGISTER_TYPED_TEST_CASE_P(BroadcastTests,
798                            broadcast_numpy,
799                            broadcast_axes_mapping,
800                            broadcast_target_shape_as_concat_with_constants,
801                            broadcast_target_shape_as_concat_with_node,
802                            broadcast_fail_rank,
803                            broadcast_fail_transpose,
804                            broadcast_fail_axes_map,
805                            broadcast_fail_axes_map_shape,
806                            broadcast_axes_wrong_rank,
807                            broadcast_fully_dynamic_target_shape,
808                            broadcast_broadcast_shape_et_wrong,
809                            broadcast_axes_et_wrong,
810                            broadcast_explicit_all_inputs_dynamic,
811                            broadcast_explicit_target_shape_static_rank,
812                            broadcast_explicit_const_target_shape,
813                            broadcast_explicit_input_rank_static,
814                            broadcast_explicit_target_shape_and_input_data_rank_static,
815                            broadcast_explicit_const_target_shape_static_rank_input,
816                            broadcast_explicit_static_input_shape,
817                            broadcast_explicit_static_input_shape_const_target_shape,
818                            broadcast_explicit_static_target_shape,
819                            broadcast_numpy_input_shape_dynamic,
820                            broadcast_numpy_target_shape_constant,
821                            broadcast_numpy_target_shape_dynamic,
822                            broadcast_numpy_input_target_shape_static_rank,
823                            broadcast_numpy_input_static_shape,
824                            broadcast_numpy_input_partially_dynamic,
825                            broadcast_numpy_static_dims_incorrect);
826
827 typedef ::testing::Types<op::v1::Broadcast, op::v3::Broadcast> BroadcastTypes;
828 // the last empty argument resolves compiler warning on MAC:
829 // `must specify at least one argument for '...'` (variadic macro)
830 INSTANTIATE_TYPED_TEST_CASE_P(type_prop, BroadcastTests, BroadcastTypes, );
831
832 // changing AutoBroadcastSpec to BroadcastModeSpec forces runing pdpd tests separately
833 TEST(type_prop, broadcast_v1_pdpd)
834 {
835     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
836     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
837
838     auto bc = make_shared<op::v1::Broadcast>(
839         param, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
840     ASSERT_EQ(bc->get_element_type(), element::f32);
841     ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
842 }
843
844 TEST(type_prop, broadcast_v3_pdpd)
845 {
846     auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
847     auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
848
849     auto bc = make_shared<op::v3::Broadcast>(
850         param, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
851     ASSERT_EQ(bc->get_element_type(), element::f32);
852     ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
853 }
854
855 TEST(type_prop, broadcast_v3_bidirectional_mode_string)
856 {
857     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
858     const auto shape = make_shared<op::Parameter>(element::i32, Shape{2});
859
860     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, "BIDIRECTIONAL");
861
862     ASSERT_EQ(broadcast_v3->get_broadcast_spec(), op::BroadcastType::BIDIRECTIONAL);
863     ASSERT_EQ(broadcast_v3->get_version(), 3);
864 }
865
866 TEST(type_prop, broadcast_v3_shape_unexpected_axes_mapping_input)
867 {
868     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
869     const auto shape = make_shared<op::Parameter>(element::i16, Shape{2});
870     const auto axes_mapping = make_shared<op::Parameter>(element::f32, Shape{3});
871     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
872
873     try
874     {
875         const auto broadcast_v3 =
876             make_shared<op::v3::Broadcast>(arg, shape, axes_mapping, broadcast_spec);
877         FAIL() << "Unexpected axes mapping input exception not thrown";
878     }
879     catch (const NodeValidationFailure& error)
880     {
881         EXPECT_HAS_SUBSTRING(
882             error.what(),
883             std::string("axes_mapping input should not be provided for mode other than explicit"));
884     }
885     catch (...)
886     {
887         FAIL() << "Deduced type check failed for unexpected reason";
888     }
889 }
890
891 TEST(type_prop, broadcast_v3_not_provided_axes_input_for_explicit_mode)
892 {
893     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
894     const auto shape = make_shared<op::Parameter>(element::i16, Shape{2});
895     const auto broadcast_spec = op::BroadcastType::EXPLICIT;
896
897     try
898     {
899         const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
900         FAIL() << "axes_mapping input should be provided if explicit mode is used";
901     }
902     catch (const NodeValidationFailure& error)
903     {
904         EXPECT_HAS_SUBSTRING(
905             error.what(),
906             std::string("axes_mapping input should be provided if explicit mode is used"));
907     }
908     catch (...)
909     {
910         FAIL() << "Deduced type check failed for unexpected reason";
911     }
912 }
913
914 TEST(type_prop, broadcast_v3_shape)
915 {
916     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
917     const auto shape = op::Constant::create(element::i64, {2}, {1, 4});
918     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
919
920     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
921
922     ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
923     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{1, 4, 4}));
924     ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{2})));
925 }
926
927 TEST(type_prop, broadcast_v3_shape_2)
928 {
929     const auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 1});
930     const auto shape = op::Constant::create(element::i64, {3}, {2, 1, 6});
931     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
932
933     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
934
935     ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
936     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{2, 3, 6}));
937     ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
938 }
939
940 TEST(type_prop, broadcast_v3_shape_3)
941 {
942     const auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 1});
943     const auto shape = op::Constant::create(element::i64, {2}, {2, 4});
944     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
945
946     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
947
948     ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
949     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{2, 4}));
950     ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{1})));
951 }
952
953 TEST(type_prop, broadcast_v3_shape_4)
954 {
955     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 3, 1});
956     const auto shape = op::Constant::create(element::i64, {2}, {3, 1});
957     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
958
959     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
960
961     ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
962     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{1, 3, 1}));
963     ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{})));
964 }
965
966 TEST(type_prop, broadcast_v3_shape_5)
967 {
968     const auto arg = make_shared<op::Parameter>(element::f32, Shape{16, 1, 1});
969     const auto shape = op::Constant::create(element::i64, {4}, {1, 1, 50, 50});
970     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
971
972     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
973
974     ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
975     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{1, 16, 50, 50}));
976     ASSERT_EQ(broadcast_v3->get_broadcast_axes(),
977               (make_pair<bool, AxisSet>(true, AxisSet{0, 2, 3})));
978 }
979
980 TEST(type_prop, broadcast_v3_shape_6)
981 {
982     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 3, 1});
983     const auto shape = op::Constant::create(element::i64, {3}, {3, 1, 3});
984     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
985
986     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
987
988     ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
989     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{3, 3, 3}));
990     ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
991 }
992
993 TEST(type_prop, broadcast_v3_shape_6_type_infer)
994 {
995     const auto arg = make_shared<op::Parameter>(element::u16, Shape{1, 3, 1});
996     const auto shape = op::Constant::create(element::i64, {3}, {3, 1, 3});
997     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
998
999     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1000
1001     ASSERT_EQ(broadcast_v3->get_element_type(), element::u16);
1002     ASSERT_EQ(broadcast_v3->get_shape(), (Shape{3, 3, 3}));
1003     ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
1004 }
1005
1006 TEST(type_prop, broadcast_v3_incorrect_target_shape)
1007 {
1008     const auto arg = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
1009     const auto shape = op::Constant::create(element::i64, {3}, {8, 6, 4});
1010     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1011
1012     try
1013     {
1014         const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1015         FAIL() << "Not applicable breadcast exception not thrown";
1016     }
1017     catch (const NodeValidationFailure& error)
1018     {
1019         EXPECT_HAS_SUBSTRING(
1020             error.what(),
1021             std::string("Broadcast incorrect target shape. Expecting either 1 or 4. Got 8"));
1022     }
1023     catch (...)
1024     {
1025         FAIL() << "Deduced type check failed for unexpected reason";
1026     }
1027 }
1028
1029 TEST(type_prop, broadcast_v3_incorrect_target_shape_2)
1030 {
1031     const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 1, 2});
1032     const auto shape = op::Constant::create(element::i64, {2}, {2, 3});
1033     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1034
1035     try
1036     {
1037         const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1038         FAIL() << "Not applicable breadcast exception not thrown";
1039     }
1040     catch (const NodeValidationFailure& error)
1041     {
1042         EXPECT_HAS_SUBSTRING(
1043             error.what(),
1044             std::string("Broadcast incorrect target shape. Expecting either 1 or 2. Got 3"));
1045     }
1046     catch (...)
1047     {
1048         FAIL() << "Deduced type check failed for unexpected reason";
1049     }
1050 }
1051
1052 TEST(type_prop, broadcast_v3_output_rank_not_deduced)
1053 {
1054     const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
1055     const auto shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1056     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1057
1058     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1059
1060     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
1061 }
1062
1063 TEST(type_prop, broadcast_v3_output_rank_deduced_from_arg)
1064 {
1065     const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
1066     const auto shape = op::Constant::create(element::i64, {3}, {8, 6, 4});
1067     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1068
1069     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1070     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
1071         PartialShape{Dimension::dynamic(), 8, 6, 4}));
1072 }
1073
1074 TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input)
1075 {
1076     const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
1077     const auto shape = op::Constant::create(element::i64, {5}, {8, 6, 1, 5, 1});
1078     const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1079
1080     const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1081     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1082     ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 5);
1083     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
1084         PartialShape{8, 6, Dimension::dynamic(), 5, Dimension::dynamic()}));
1085 }
1086
1087 TEST(type_prop, broadcast_v3_bidirectional_dynamic_input)
1088 {
1089     const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
1090
1091     // dynamic target shape
1092     auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
1093     auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1094     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1095
1096     // static rank target shape
1097     target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1098     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1099     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1100
1101     // constant target shape
1102     const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
1103     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1104     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1105 }
1106
1107 TEST(type_prop, broadcast_v3_bidirectional_static_rank_input)
1108 {
1109     const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
1110
1111     // dynamic target shape
1112     auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
1113     auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1114     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1115
1116     // static rank target shape
1117     target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1118     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1119     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1120
1121     // constant target shape
1122     const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
1123     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1124     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1125     ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
1126     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_dynamic());
1127 }
1128
1129 TEST(type_prop, broadcast_v3_bidirectional_static_shape_input)
1130 {
1131     const auto arg = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 1});
1132
1133     // dynamic target shape
1134     auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
1135     auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1136     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1137
1138     // static rank target shape
1139     target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1140     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1141     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1142
1143     // constant target shape
1144     auto target_shape_const = op::Constant::create(element::i64, {4}, {2, 2, 3, 2});
1145     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1146     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1147     ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
1148     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
1149     ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{2, 2, 3, 2}));
1150
1151     target_shape_const = op::Constant::create(element::i64, {4}, {5, 2, 3, 7});
1152     broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1153     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1154     ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
1155     ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
1156     ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{5, 2, 3, 7}));
1157 }
1158
1159 TEST(type_prop, broadcast_v3_bidirectional_partially_dynamic_input)
1160 {
1161     const auto target_shape =
1162         op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 50, 50});
1163
1164     auto data = make_shared<op::Parameter>(element::f32, PartialShape{16, 1, Dimension::dynamic()});
1165     auto bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1166     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1167     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1168     ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
1169
1170     data = make_shared<op::Parameter>(element::f32,
1171                                       PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()});
1172     bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1173     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1174     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1175     ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
1176
1177     data = make_shared<op::Parameter>(element::f32,
1178                                       PartialShape{16, Dimension::dynamic(), Dimension::dynamic()});
1179     bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1180     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1181     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1182     ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
1183
1184     data = make_shared<op::Parameter>(
1185         element::f32,
1186         PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
1187     bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1188     ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1189     ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1190     ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
1191 }