1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
24 using namespace ngraph;
26 TEST(type_prop, broadcast_deduce)
28 auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
29 Shape bc_shape{2, 3, 4};
30 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
31 ASSERT_EQ(bc->get_element_type(), element::f32);
32 ASSERT_EQ(bc->get_shape(), bc_shape);
35 TEST(type_prop, broadcast_axes_oob)
37 auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
38 auto bc_shape = Shape{2, 3, 4};
42 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
43 FAIL() << "Broadcast axis out of bounds not detected";
45 catch (const NodeValidationFailure& error)
47 EXPECT_HAS_SUBSTRING(error.what(),
48 "Broadcast axis index (3) exceeds specified output shape rank");
52 FAIL() << "Deduced type check failed for unexpected reason";
56 TEST(type_prop, broadcast_shape_mismatch_wrong_rank)
58 auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
59 auto bc_shape = Shape{2, 3, 4, 5};
63 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
64 FAIL() << "Output shape mismatch (wrong rank) not detected";
66 catch (const NodeValidationFailure& error)
70 "Broadcast argument shape, specified output shape, and axes are incompatible");
74 FAIL() << "Deduced type check failed for unexpected reason";
78 TEST(type_prop, broadcast_shape_mismatch_wrong_size)
80 auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
81 auto bc_shape = Shape{2, 3, 5};
85 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
86 FAIL() << "Output shape mismatch (wrong size) not detected";
88 catch (const NodeValidationFailure& error)
92 "Broadcast argument shape, specified output shape, and axes are incompatible");
96 FAIL() << "Deduced type check failed for unexpected reason";
100 TEST(type_prop, broadcast_partial_rank_dynamic_ok)
102 auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
103 Shape bc_shape{2, 3, 4};
104 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
105 ASSERT_EQ(bc->get_element_type(), element::f32);
106 ASSERT_EQ(bc->get_shape(), bc_shape);
109 TEST(type_prop, broadcast_partial_rank_dynamic_axes_oob)
111 auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
112 auto bc_shape = Shape{2, 3, 4};
116 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
117 FAIL() << "Broadcast axis out of bounds not detected";
119 catch (const NodeValidationFailure& error)
121 EXPECT_HAS_SUBSTRING(error.what(),
122 "Broadcast axis index (3) exceeds specified output shape rank");
126 FAIL() << "Deduced type check failed for unexpected reason";
130 TEST(type_prop, broadcast_partial_rank_static_dynamic_ok)
132 auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
133 Shape bc_shape{2, 3, 4};
134 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
135 ASSERT_EQ(bc->get_element_type(), element::f32);
136 ASSERT_EQ(bc->get_shape(), bc_shape);
139 TEST(type_prop, broadcast_partial_rank_static_dynamic_axes_oob)
141 auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
142 auto bc_shape = Shape{2, 3, 4};
146 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1, 3});
147 FAIL() << "Broadcast axis out of bounds not detected";
149 catch (const NodeValidationFailure& error)
151 EXPECT_HAS_SUBSTRING(error.what(),
152 "Broadcast axis index (3) exceeds specified output shape rank");
156 FAIL() << "Deduced type check failed for unexpected reason";
160 TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_rank)
162 auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
163 auto bc_shape = Shape{2, 3, 4, 5};
167 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
168 FAIL() << "Output shape mismatch (wrong rank) not detected";
170 catch (const NodeValidationFailure& error)
172 EXPECT_HAS_SUBSTRING(
174 "Broadcast argument shape, specified output shape, and axes are incompatible");
178 FAIL() << "Deduced type check failed for unexpected reason";
182 TEST(type_prop, broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size)
184 auto param = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 4});
185 auto bc_shape = Shape{2, 3, 5};
189 auto bc = make_shared<op::Broadcast>(param, bc_shape, AxisSet{1});
190 FAIL() << "Output shape mismatch (wrong size) not detected";
192 catch (const NodeValidationFailure& error)
194 EXPECT_HAS_SUBSTRING(
196 "Broadcast argument shape, specified output shape, and axes are incompatible");
200 FAIL() << "Deduced type check failed for unexpected reason";
204 // Because v3::Broadcast is backward compatible to v1::Broadcast all v1::Broadcast tests should pass
205 template <typename T>
206 class BroadcastTests : public ::testing::Test
209 TYPED_TEST_CASE_P(BroadcastTests);
211 TYPED_TEST_P(BroadcastTests, broadcast_numpy)
213 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
214 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
216 auto bc = make_shared<TypeParam>(param, target_shape);
217 ASSERT_EQ(bc->get_element_type(), element::f32);
218 ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
221 TYPED_TEST_P(BroadcastTests, broadcast_axes_mapping)
223 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
224 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
225 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
227 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
228 ASSERT_EQ(bc->get_element_type(), element::f32);
229 ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 1}));
232 TYPED_TEST_P(BroadcastTests, broadcast_target_shape_as_concat_with_constants)
234 auto param = make_shared<op::Parameter>(element::f32, Shape{16});
235 auto target_shape_constant_1 = op::Constant::create<int64_t>(element::i64, Shape{1}, {1});
236 auto target_shape_constant_2 = op::Constant::create<int64_t>(element::i64, Shape{1}, {16});
237 auto target_shape_constant_3 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
238 auto target_shape_constant_4 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
239 std::int64_t axis = 0;
240 std::vector<std::shared_ptr<Node>> args{target_shape_constant_1,
241 target_shape_constant_2,
242 target_shape_constant_3,
243 target_shape_constant_4};
244 auto target_shape = make_shared<op::Concat>(args, axis);
245 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{1}, {1});
246 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping, "NONE");
247 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
248 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
249 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
250 ASSERT_TRUE(bc->get_output_partial_shape(0).same_scheme(PartialShape{1, 16, 50, 50}));
253 TYPED_TEST_P(BroadcastTests, broadcast_target_shape_as_concat_with_node)
255 auto param = make_shared<op::Parameter>(element::f32, Shape{16});
256 auto target_shape_constant_1 = make_shared<op::Parameter>(element::i64, Shape{1});
257 auto target_shape_constant_2 = op::Constant::create<int64_t>(element::i64, Shape{1}, {16});
258 auto target_shape_constant_3 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
259 auto target_shape_constant_4 = op::Constant::create<int64_t>(element::i64, Shape{1}, {50});
260 std::int64_t axis = 0;
261 std::vector<std::shared_ptr<Node>> args{target_shape_constant_1,
262 target_shape_constant_2,
263 target_shape_constant_3,
264 target_shape_constant_4};
265 auto target_shape = make_shared<op::Concat>(args, axis);
266 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{1}, {1});
267 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping, "NONE");
268 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
269 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().same_scheme(Rank{4}));
270 ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
271 ASSERT_TRUE(bc->get_output_partial_shape(0).same_scheme(
272 PartialShape{Dimension::dynamic(), 16, 50, 50}));
275 TYPED_TEST_P(BroadcastTests, broadcast_fail_rank)
277 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
278 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
279 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 2, 3});
283 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
284 FAIL() << "Broadcast: target shape mismatch with input rank not detected";
286 catch (const NodeValidationFailure& error)
288 EXPECT_HAS_SUBSTRING(
290 "Broadcast axes_mapping shape Shape{3} doesn't match rank of input tensor 2");
294 FAIL() << "Deduced type check failed for unexpected reason";
298 TYPED_TEST_P(BroadcastTests, broadcast_fail_transpose)
300 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
301 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 1, 3});
302 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 1});
306 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
307 FAIL() << "Broadcast: transpose prohibition not detected";
309 catch (const NodeValidationFailure& error)
311 EXPECT_HAS_SUBSTRING(error.what(),
312 "Broadcast doesn't permit transposes. axes_mapping AxisVector{2, 1} "
313 "not in sorted order");
317 FAIL() << "Deduced type check failed for unexpected reason";
321 TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map)
323 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
324 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 1});
325 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 3});
329 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
330 FAIL() << "Broadcast: wrong axes_map not detected";
332 catch (const NodeValidationFailure& error)
334 EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes_mapping[1]: 3 exceeds target rank 3");
338 FAIL() << "Deduced type check failed for unexpected reason";
342 TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map_shape)
344 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
345 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 3});
346 auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});
350 auto bc = make_shared<TypeParam>(param, target_shape, axes_mapping);
351 FAIL() << "Broadcast: wrong target shape not detected";
353 catch (const NodeValidationFailure& error)
355 EXPECT_HAS_SUBSTRING(error.what(), "Broadcast target[axes_mapping[1]] Expected 1. Got 3");
359 FAIL() << "Deduced type check failed for unexpected reason";
363 TYPED_TEST_P(BroadcastTests, broadcast_axes_wrong_rank)
365 auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
366 auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
367 auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2, 2});
371 auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
372 FAIL() << "Broadcast: axes shape rank not detected";
374 catch (const NodeValidationFailure& error)
376 EXPECT_HAS_SUBSTRING(error.what(), "Broadcast axes rank must be 1");
380 FAIL() << "Deduced type check failed for unexpected reason";
384 TYPED_TEST_P(BroadcastTests, broadcast_fully_dynamic_target_shape)
386 auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
387 auto bc_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
388 auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
390 auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
391 ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
393 bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
394 bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
395 ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
398 TYPED_TEST_P(BroadcastTests, broadcast_broadcast_shape_et_wrong)
400 auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
401 // wrong element type
402 auto bc_shape = make_shared<op::Parameter>(element::boolean, Shape{1});
403 auto bc_axes = make_shared<op::Parameter>(element::i64, Shape{2});
407 auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
408 FAIL() << "Broadcast: did not detect shape element type not integral number";
410 catch (const NodeValidationFailure& error)
412 EXPECT_HAS_SUBSTRING(error.what(),
413 std::string("Broadcast shape must be an integral number"));
417 FAIL() << "Deduced type check failed for unexpected reason";
421 TYPED_TEST_P(BroadcastTests, broadcast_axes_et_wrong)
423 auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
424 auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{1});
425 // wrong element type
426 auto bc_axes = make_shared<op::Parameter>(element::f32, Shape{2});
430 auto bc = make_shared<TypeParam>(arg, bc_shape, bc_axes);
431 FAIL() << "Broadcast: did not detect axes element type not integral numbers";
433 catch (const NodeValidationFailure& error)
435 EXPECT_HAS_SUBSTRING(error.what(),
436 std::string("Broadcast axes must be integral numbers, but are:"));
440 FAIL() << "Deduced type check failed for unexpected reason";
446 TYPED_TEST_P(BroadcastTests, broadcast_explicit_all_inputs_dynamic)
448 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
449 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
450 const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
452 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
453 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
455 // const axes mapping
456 const auto axes_mapping_const =
457 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
458 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
459 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
462 TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_static_rank)
464 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
465 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
466 const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
468 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
469 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
471 // const axes mapping
472 const auto axes_mapping_const =
473 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 1, 2});
474 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
475 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
478 TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape)
480 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
481 const auto target_shape =
482 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
483 const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
485 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
487 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
488 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
489 ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
491 // const axes mapping
492 const auto axes_mapping_const =
493 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
494 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
496 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
497 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
498 ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3}));
501 TYPED_TEST_P(BroadcastTests, broadcast_explicit_input_rank_static)
503 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
504 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
505 const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
507 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
508 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
510 // const axes mapping
511 const auto axes_mapping_const =
512 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
513 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
514 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
517 TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_and_input_data_rank_static)
520 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
521 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
522 auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
524 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
525 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
527 // const axes mapping
528 const auto axes_mapping_const =
529 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{0, 2, 1});
530 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
531 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
534 TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_input)
536 const auto target_shape =
537 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 5, 10});
539 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
540 auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
542 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
543 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
544 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
545 ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
547 // const axes mapping
548 const auto axes_mapping_const =
549 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
550 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
551 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
552 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
553 ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
556 TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape)
558 const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
559 // dynamic target shape and axes mapping
560 auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
561 auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
563 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
564 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
566 // const axes mapping
567 const auto axes_mapping_const =
568 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
569 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
570 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
572 // static rank target shape
573 target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
574 bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
575 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
577 // static rank target shape and const axes mapping
578 target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
579 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
580 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
583 TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape_const_target_shape)
585 const auto data = make_shared<op::Parameter>(element::f32, PartialShape{4});
586 auto target_shape = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 4, 2, 3});
587 // dynamic axes mapping
588 const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
590 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
591 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
592 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
593 ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
595 // const axes mapping
596 const auto axes_mapping_const =
597 op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
598 bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
599 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
600 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
601 ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3}));
604 TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_target_shape)
607 auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
608 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape{4});
609 const auto axes_mapping = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
611 auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
612 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
613 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
614 ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
617 data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
618 bc = make_shared<TypeParam>(data, target_shape, axes_mapping, "EXPLICIT");
619 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
620 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
621 ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic());
626 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_shape_dynamic)
628 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
629 // dynamic output shape
630 auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
632 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
633 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
635 // static rank target shape
636 target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
637 bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
638 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
641 TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_constant)
644 auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
645 const auto target_shape =
646 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{1, 2, 3});
648 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
649 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
650 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
653 data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(2));
654 bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
655 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
656 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
659 TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_dynamic)
662 auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
663 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
665 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
666 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
669 data = make_shared<op::Parameter>(element::f32, PartialShape{3, 4, 5, 6});
670 bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
671 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
674 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_target_shape_static_rank)
676 const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
677 const auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
679 const auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
680 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
683 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_static_shape)
685 const auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
686 // static rank target_shape
687 auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
689 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
690 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic());
692 // constant target_shape
693 const auto target_shape_const =
694 op::Constant::create(element::i64, Shape{3}, vector<int64_t>{3, 2, 3});
695 bc = make_shared<TypeParam>(data, target_shape_const, "NUMPY");
696 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
697 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3);
698 ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
699 ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{3, 2, 3}));
702 TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_partially_dynamic)
704 const Shape expected_target_shape{1, 2, 3, 4};
705 const auto target_shape = op::Constant::create(
707 {expected_target_shape.size()},
708 std::vector<int64_t>(expected_target_shape.begin(), expected_target_shape.end()));
710 auto data = make_shared<op::Parameter>(element::f32, PartialShape{2, 3, Dimension::dynamic()});
711 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
712 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
713 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
714 ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
716 data = make_shared<op::Parameter>(element::f32,
717 PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()});
718 bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
719 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
720 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
721 ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
723 data = make_shared<op::Parameter>(element::f32,
724 PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
725 bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
726 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
727 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
728 ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
730 data = make_shared<op::Parameter>(
732 PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
733 bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
734 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
735 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
736 ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape);
739 TYPED_TEST_P(BroadcastTests, broadcast_numpy_static_dims_incorrect)
741 const auto target_shape = op::Constant::create(element::i64, Shape{4}, {1, 2, 3, 4});
744 make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 999, 3, 4});
747 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
749 catch (const NodeValidationFailure& error)
751 EXPECT_HAS_SUBSTRING(error.what(),
752 "Input shape dimension equal 999 cannot be broadcasted (numpy mode) "
753 "to 2. Allowed input dimension value would be 1 or 2");
757 FAIL() << "Deduced type check failed for unexpected reason";
760 data = make_shared<op::Parameter>(
762 PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 888});
765 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
767 catch (const NodeValidationFailure& error)
769 EXPECT_HAS_SUBSTRING(error.what(),
770 "Input shape dimension equal 888 cannot be broadcasted (numpy mode) "
771 "to 4. Allowed input dimension value would be 1 or 4");
775 FAIL() << "Deduced type check failed for unexpected reason";
778 data = make_shared<op::Parameter>(
780 PartialShape{5, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
783 auto bc = make_shared<TypeParam>(data, target_shape, "NUMPY");
785 catch (const NodeValidationFailure& error)
787 EXPECT_HAS_SUBSTRING(error.what(),
788 "Input shape dimension equal 5 cannot be broadcasted (numpy mode) to "
789 "1. Allowed input dimension value would be 1");
793 FAIL() << "Deduced type check failed for unexpected reason";
797 REGISTER_TYPED_TEST_CASE_P(BroadcastTests,
799 broadcast_axes_mapping,
800 broadcast_target_shape_as_concat_with_constants,
801 broadcast_target_shape_as_concat_with_node,
803 broadcast_fail_transpose,
804 broadcast_fail_axes_map,
805 broadcast_fail_axes_map_shape,
806 broadcast_axes_wrong_rank,
807 broadcast_fully_dynamic_target_shape,
808 broadcast_broadcast_shape_et_wrong,
809 broadcast_axes_et_wrong,
810 broadcast_explicit_all_inputs_dynamic,
811 broadcast_explicit_target_shape_static_rank,
812 broadcast_explicit_const_target_shape,
813 broadcast_explicit_input_rank_static,
814 broadcast_explicit_target_shape_and_input_data_rank_static,
815 broadcast_explicit_const_target_shape_static_rank_input,
816 broadcast_explicit_static_input_shape,
817 broadcast_explicit_static_input_shape_const_target_shape,
818 broadcast_explicit_static_target_shape,
819 broadcast_numpy_input_shape_dynamic,
820 broadcast_numpy_target_shape_constant,
821 broadcast_numpy_target_shape_dynamic,
822 broadcast_numpy_input_target_shape_static_rank,
823 broadcast_numpy_input_static_shape,
824 broadcast_numpy_input_partially_dynamic,
825 broadcast_numpy_static_dims_incorrect);
827 typedef ::testing::Types<op::v1::Broadcast, op::v3::Broadcast> BroadcastTypes;
828 // the last empty argument resolves compiler warning on MAC:
829 // `must specify at least one argument for '...'` (variadic macro)
830 INSTANTIATE_TYPED_TEST_CASE_P(type_prop, BroadcastTests, BroadcastTypes, );
832 // changing AutoBroadcastSpec to BroadcastModeSpec forces runing pdpd tests separately
833 TEST(type_prop, broadcast_v1_pdpd)
835 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
836 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
838 auto bc = make_shared<op::v1::Broadcast>(
839 param, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
840 ASSERT_EQ(bc->get_element_type(), element::f32);
841 ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
844 TEST(type_prop, broadcast_v3_pdpd)
846 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
847 auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 6});
849 auto bc = make_shared<op::v3::Broadcast>(
850 param, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
851 ASSERT_EQ(bc->get_element_type(), element::f32);
852 ASSERT_EQ(bc->get_shape(), (Shape{2, 3, 6}));
855 TEST(type_prop, broadcast_v3_bidirectional_mode_string)
857 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
858 const auto shape = make_shared<op::Parameter>(element::i32, Shape{2});
860 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, "BIDIRECTIONAL");
862 ASSERT_EQ(broadcast_v3->get_broadcast_spec(), op::BroadcastType::BIDIRECTIONAL);
863 ASSERT_EQ(broadcast_v3->get_version(), 3);
866 TEST(type_prop, broadcast_v3_shape_unexpected_axes_mapping_input)
868 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
869 const auto shape = make_shared<op::Parameter>(element::i16, Shape{2});
870 const auto axes_mapping = make_shared<op::Parameter>(element::f32, Shape{3});
871 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
875 const auto broadcast_v3 =
876 make_shared<op::v3::Broadcast>(arg, shape, axes_mapping, broadcast_spec);
877 FAIL() << "Unexpected axes mapping input exception not thrown";
879 catch (const NodeValidationFailure& error)
881 EXPECT_HAS_SUBSTRING(
883 std::string("axes_mapping input should not be provided for mode other than explicit"));
887 FAIL() << "Deduced type check failed for unexpected reason";
891 TEST(type_prop, broadcast_v3_not_provided_axes_input_for_explicit_mode)
893 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
894 const auto shape = make_shared<op::Parameter>(element::i16, Shape{2});
895 const auto broadcast_spec = op::BroadcastType::EXPLICIT;
899 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
900 FAIL() << "axes_mapping input should be provided if explicit mode is used";
902 catch (const NodeValidationFailure& error)
904 EXPECT_HAS_SUBSTRING(
906 std::string("axes_mapping input should be provided if explicit mode is used"));
910 FAIL() << "Deduced type check failed for unexpected reason";
914 TEST(type_prop, broadcast_v3_shape)
916 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 1});
917 const auto shape = op::Constant::create(element::i64, {2}, {1, 4});
918 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
920 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
922 ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
923 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{1, 4, 4}));
924 ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{2})));
927 TEST(type_prop, broadcast_v3_shape_2)
929 const auto arg = make_shared<op::Parameter>(element::f32, Shape{3, 1});
930 const auto shape = op::Constant::create(element::i64, {3}, {2, 1, 6});
931 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
933 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
935 ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
936 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{2, 3, 6}));
937 ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
940 TEST(type_prop, broadcast_v3_shape_3)
942 const auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 1});
943 const auto shape = op::Constant::create(element::i64, {2}, {2, 4});
944 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
946 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
948 ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
949 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{2, 4}));
950 ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{1})));
953 TEST(type_prop, broadcast_v3_shape_4)
955 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 3, 1});
956 const auto shape = op::Constant::create(element::i64, {2}, {3, 1});
957 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
959 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
961 ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
962 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{1, 3, 1}));
963 ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{})));
966 TEST(type_prop, broadcast_v3_shape_5)
968 const auto arg = make_shared<op::Parameter>(element::f32, Shape{16, 1, 1});
969 const auto shape = op::Constant::create(element::i64, {4}, {1, 1, 50, 50});
970 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
972 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
974 ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
975 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{1, 16, 50, 50}));
976 ASSERT_EQ(broadcast_v3->get_broadcast_axes(),
977 (make_pair<bool, AxisSet>(true, AxisSet{0, 2, 3})));
980 TEST(type_prop, broadcast_v3_shape_6)
982 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 3, 1});
983 const auto shape = op::Constant::create(element::i64, {3}, {3, 1, 3});
984 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
986 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
988 ASSERT_EQ(broadcast_v3->get_element_type(), element::f32);
989 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{3, 3, 3}));
990 ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
993 TEST(type_prop, broadcast_v3_shape_6_type_infer)
995 const auto arg = make_shared<op::Parameter>(element::u16, Shape{1, 3, 1});
996 const auto shape = op::Constant::create(element::i64, {3}, {3, 1, 3});
997 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
999 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1001 ASSERT_EQ(broadcast_v3->get_element_type(), element::u16);
1002 ASSERT_EQ(broadcast_v3->get_shape(), (Shape{3, 3, 3}));
1003 ASSERT_EQ(broadcast_v3->get_broadcast_axes(), (make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
1006 TEST(type_prop, broadcast_v3_incorrect_target_shape)
1008 const auto arg = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
1009 const auto shape = op::Constant::create(element::i64, {3}, {8, 6, 4});
1010 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1014 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1015 FAIL() << "Not applicable breadcast exception not thrown";
1017 catch (const NodeValidationFailure& error)
1019 EXPECT_HAS_SUBSTRING(
1021 std::string("Broadcast incorrect target shape. Expecting either 1 or 4. Got 8"));
1025 FAIL() << "Deduced type check failed for unexpected reason";
1029 TEST(type_prop, broadcast_v3_incorrect_target_shape_2)
1031 const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 1, 2});
1032 const auto shape = op::Constant::create(element::i64, {2}, {2, 3});
1033 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1037 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1038 FAIL() << "Not applicable breadcast exception not thrown";
1040 catch (const NodeValidationFailure& error)
1042 EXPECT_HAS_SUBSTRING(
1044 std::string("Broadcast incorrect target shape. Expecting either 1 or 2. Got 3"));
1048 FAIL() << "Deduced type check failed for unexpected reason";
1052 TEST(type_prop, broadcast_v3_output_rank_not_deduced)
1054 const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
1055 const auto shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1056 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1058 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1060 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
1063 TEST(type_prop, broadcast_v3_output_rank_deduced_from_arg)
1065 const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
1066 const auto shape = op::Constant::create(element::i64, {3}, {8, 6, 4});
1067 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1069 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1070 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
1071 PartialShape{Dimension::dynamic(), 8, 6, 4}));
1074 TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input)
1076 const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
1077 const auto shape = op::Constant::create(element::i64, {5}, {8, 6, 1, 5, 1});
1078 const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL;
1080 const auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, shape, broadcast_spec);
1081 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1082 ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 5);
1083 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(
1084 PartialShape{8, 6, Dimension::dynamic(), 5, Dimension::dynamic()}));
1087 TEST(type_prop, broadcast_v3_bidirectional_dynamic_input)
1089 const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
1091 // dynamic target shape
1092 auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
1093 auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1094 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1096 // static rank target shape
1097 target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1098 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1099 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1101 // constant target shape
1102 const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
1103 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1104 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1107 TEST(type_prop, broadcast_v3_bidirectional_static_rank_input)
1109 const auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
1111 // dynamic target shape
1112 auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
1113 auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1114 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1116 // static rank target shape
1117 target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1118 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1119 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1121 // constant target shape
1122 const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6});
1123 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1124 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1125 ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
1126 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_dynamic());
1129 TEST(type_prop, broadcast_v3_bidirectional_static_shape_input)
1131 const auto arg = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 1});
1133 // dynamic target shape
1134 auto target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
1135 auto broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1136 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1138 // static rank target shape
1139 target_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
1140 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape, "BIDIRECTIONAL");
1141 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic());
1143 // constant target shape
1144 auto target_shape_const = op::Constant::create(element::i64, {4}, {2, 2, 3, 2});
1145 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1146 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1147 ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
1148 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
1149 ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{2, 2, 3, 2}));
1151 target_shape_const = op::Constant::create(element::i64, {4}, {5, 2, 3, 7});
1152 broadcast_v3 = make_shared<op::v3::Broadcast>(arg, target_shape_const, "BIDIRECTIONAL");
1153 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static());
1154 ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4);
1155 ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static());
1156 ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{5, 2, 3, 7}));
1159 TEST(type_prop, broadcast_v3_bidirectional_partially_dynamic_input)
1161 const auto target_shape =
1162 op::Constant::create(element::i64, Shape{4}, vector<int64_t>{1, 1, 50, 50});
1164 auto data = make_shared<op::Parameter>(element::f32, PartialShape{16, 1, Dimension::dynamic()});
1165 auto bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1166 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1167 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1168 ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
1170 data = make_shared<op::Parameter>(element::f32,
1171 PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()});
1172 bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1173 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1174 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1175 ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));
1177 data = make_shared<op::Parameter>(element::f32,
1178 PartialShape{16, Dimension::dynamic(), Dimension::dynamic()});
1179 bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1180 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1181 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1182 ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50}));
1184 data = make_shared<op::Parameter>(
1186 PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()});
1187 bc = make_shared<op::v3::Broadcast>(data, target_shape, "BIDIRECTIONAL");
1188 ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static());
1189 ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
1190 ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50}));