1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
24 using namespace ngraph;
27 // Tests for binary elementwise ops.
29 void test_binary(std::string /* node_type */,
30 shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
32 // Check for bad arguments
33 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
34 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
35 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
36 auto tv0_4_2_param = make_shared<op::Parameter>(element::f32, Shape{4, 2});
38 auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x,
39 const shared_ptr<Node>& y) {
43 // Should have thrown, so fail if it didn't
44 FAIL() << "Incompatible view arguments not detected.";
46 catch (const NodeValidationFailure& error)
48 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
52 FAIL() << "Deduced type check failed for unexpected reason";
55 test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
57 auto test_binary_bad_arguments_view_element_types = [&](const shared_ptr<Node>& x,
58 const shared_ptr<Node>& y) {
62 // Should have thrown, so fail if it didn't
63 FAIL() << "Incompatible view arguments not detected.";
65 catch (const NodeValidationFailure& error)
67 EXPECT_HAS_SUBSTRING(error.what(),
68 std::string("Argument element types are inconsistent"));
72 FAIL() << "Deduced type check failed for unexpected reason";
76 test_binary_bad_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
78 auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
80 EXPECT_TRUE(node->has_same_type(node->input_values()[0].get_node_shared_ptr()));
82 test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
85 TEST(type_prop, add_bad_arguments)
88 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
89 return make_shared<op::Add>(x, y);
93 TEST(type_prop, divide_bad_arguments)
96 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
97 return make_shared<op::Divide>(x, y);
101 TEST(type_prop, multiply_bad_arguments)
103 test_binary("Multiply",
104 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
105 return make_shared<op::Multiply>(x, y);
109 TEST(type_prop, subtract_bad_arguments)
111 test_binary("Subtract",
112 [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
113 return make_shared<op::Subtract>(x, y);
118 // Tests for binary elementwise logical ops.
120 void test_binary_logical(std::string /* node_type */,
121 shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
123 // Check for bad arguments
124 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
125 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
126 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
127 auto tv0_2_4_param_3 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
128 auto tv0_4_2_param = make_shared<op::Parameter>(element::boolean, Shape{4, 2});
130 auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x,
131 const shared_ptr<Node>& y) {
135 // Should have thrown, so fail if it didn't
136 FAIL() << "Incompatible view arguments not detected.";
138 catch (const NodeValidationFailure& error)
140 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
144 FAIL() << "Deduced type check failed for unexpected reason";
147 test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
149 auto test_binary_differ_arguments_view_element_types = [&](const shared_ptr<Node>& x,
150 const shared_ptr<Node>& y) {
154 // Should have thrown, so fail if it didn't
155 FAIL() << "Incompatible view arguments not detected.";
157 catch (const NodeValidationFailure& error)
159 EXPECT_HAS_SUBSTRING(error.what(),
160 std::string("Argument element types are inconsistent"));
164 FAIL() << "Deduced type check failed for unexpected reason";
168 auto test_binary_non_bool_arguments_view_element_types = [&](const shared_ptr<Node>& x,
169 const shared_ptr<Node>& y) {
173 // Should have thrown, so fail if it didn't
174 FAIL() << "Incompatible view arguments not detected.";
176 catch (const ngraph_error& error)
178 EXPECT_HAS_SUBSTRING(error.what(), "must have boolean element type");
182 FAIL() << "Deduced type check failed for unexpected reason";
187 test_binary_differ_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
188 test_binary_differ_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_0);
189 test_binary_non_bool_arguments_view_element_types(tv0_2_4_param_2, tv0_2_4_param_3);
191 auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
193 EXPECT_TRUE(node->has_same_type(node->input_values()[0].get_node_shared_ptr()));
195 test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
198 TEST(type_prop, or_bad_arguments)
201 "Or", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
202 return make_shared<op::Or>(x, y);
206 TEST(type_prop, xor_bad_arguments)
209 "Xor", [](const shared_ptr<Node>& x, const shared_ptr<Node>& y) -> shared_ptr<Node> {
210 return make_shared<op::Xor>(x, y);
214 template <typename T>
215 void test_binary_eltwise_numpy(const element::Type& et, const op::AutoBroadcastSpec& autob)
217 auto param1 = make_shared<op::Parameter>(et, Shape{1, 3, 6});
218 auto param2 = make_shared<op::Parameter>(et, Shape{3, 1});
219 auto param3 = make_shared<op::Parameter>(et, Shape{2, 3, 6});
220 auto param4 = make_shared<op::Parameter>(et, Shape{6});
221 EXPECT_EQ(make_shared<T>(param1, param2, autob)->get_shape(), (Shape{1, 3, 6}));
222 EXPECT_EQ(make_shared<T>(param1, param3, autob)->get_shape(), (Shape{2, 3, 6}));
223 EXPECT_EQ(make_shared<T>(param4, param3, autob)->get_shape(), (Shape{2, 3, 6}));
225 auto pp1 = make_shared<op::Parameter>(et, PartialShape{1, Dimension::dynamic(), 6});
226 auto pp2 = make_shared<op::Parameter>(et, PartialShape{3, 1});
227 EXPECT_EQ(make_shared<T>(pp1, pp2, autob)->get_shape(), (Shape{1, 3, 6}));
230 TEST(type_prop, eltwise_auto_bcast)
232 test_binary_eltwise_numpy<op::v1::Add>(element::f32, op::AutoBroadcastType::NUMPY);
233 test_binary_eltwise_numpy<op::Divide>(element::f32, op::AutoBroadcastType::NUMPY);
234 test_binary_eltwise_numpy<op::Equal>(element::f32, op::AutoBroadcastType::NUMPY);
235 test_binary_eltwise_numpy<op::Greater>(element::f32, op::AutoBroadcastType::NUMPY);
236 test_binary_eltwise_numpy<op::GreaterEq>(element::f32, op::AutoBroadcastType::NUMPY);
237 test_binary_eltwise_numpy<op::Less>(element::f32, op::AutoBroadcastType::NUMPY);
238 test_binary_eltwise_numpy<op::LessEq>(element::f32, op::AutoBroadcastType::NUMPY);
239 test_binary_eltwise_numpy<op::Maximum>(element::f32, op::AutoBroadcastType::NUMPY);
240 test_binary_eltwise_numpy<op::Minimum>(element::f32, op::AutoBroadcastType::NUMPY);
241 test_binary_eltwise_numpy<op::Multiply>(element::f32, op::AutoBroadcastType::NUMPY);
242 test_binary_eltwise_numpy<op::NotEqual>(element::f32, op::AutoBroadcastType::NUMPY);
243 test_binary_eltwise_numpy<op::Or>(element::boolean, op::AutoBroadcastType::NUMPY);
244 test_binary_eltwise_numpy<op::Power>(element::f32, op::AutoBroadcastType::NUMPY);
245 test_binary_eltwise_numpy<op::Subtract>(element::f32, op::AutoBroadcastType::NUMPY);
246 test_binary_eltwise_numpy<op::Xor>(element::boolean, op::AutoBroadcastType::NUMPY);
249 TEST(type_prop, comparison_good)
251 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
252 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
253 auto eq = make_shared<op::Equal>(tv0_2_4_param_0, tv0_2_4_param_1);
254 EXPECT_EQ(eq->get_element_type(), element::boolean);
255 EXPECT_EQ(eq->get_shape(), (Shape{2, 4}));
258 TEST(type_prop, binary_arithmetic_bad_argument_element_types)
260 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
261 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
264 auto bc = make_shared<op::Add>(tv0_2_4_param_0, tv0_2_4_param_1);
265 // Should have thrown, so fail if it didn't
266 FAIL() << "Did not detect incorrect element types for arithmetic operator";
268 catch (const NodeValidationFailure& error)
270 EXPECT_HAS_SUBSTRING(error.what(),
271 std::string("Arguments cannot have boolean element type"));
275 FAIL() << "Deduced type check failed for unexpected reason";
279 TEST(type_prop, binary_elementwise_arithmetic_both_dynamic)
281 auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
282 auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
283 auto add = make_shared<op::Add>(a, b);
285 ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_dynamic());
288 TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_static)
290 auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
291 auto b = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
292 auto add = make_shared<op::Add>(a, b);
294 ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
295 ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
298 TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_dynamic)
300 auto a = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
301 auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
302 auto add = make_shared<op::Add>(a, b);
304 ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
305 ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
308 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_dynamic)
310 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
311 auto b = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
312 auto add = make_shared<op::Add>(a, b);
314 ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
315 ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
317 add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
320 TEST(type_prop, binary_elementwise_arithmetic_left_rank_dynamic_right_rank_static_dynamic)
322 auto a = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
323 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
324 auto add = make_shared<op::Add>(a, b);
326 ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
327 ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
329 add->get_output_partial_shape(0).same_scheme(PartialShape{1, Dimension::dynamic(), 3}));
333 binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_static)
335 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3});
336 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
337 auto add = make_shared<op::Add>(a, b);
339 ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
340 ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
345 binary_elementwise_arithmetic_left_rank_static_dynamic_right_rank_static_dynamic_result_rank_static_dynamic)
347 auto a = make_shared<op::Parameter>(
348 element::f32, PartialShape{1, Dimension::dynamic(), Dimension::dynamic()});
349 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
350 auto add = make_shared<op::Add>(a, b);
352 ASSERT_TRUE(add->get_output_partial_shape(0).rank().is_static());
353 ASSERT_TRUE(add->get_output_partial_shape(0).is_dynamic());
355 add->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, Dimension::dynamic()}));
358 TEST(type_prop, binary_elementwise_arithmetic_left_static_right_rank_static_dynamic)
360 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
361 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
362 auto add = make_shared<op::Add>(a, b);
364 ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
365 ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
368 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_right_static)
370 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
371 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
372 auto add = make_shared<op::Add>(a, b);
374 ASSERT_TRUE(add->get_output_partial_shape(0).is_static());
375 ASSERT_EQ(add->get_shape(), (Shape{1, 2, 3}));
378 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_inconsistent)
380 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
381 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
385 auto add = make_shared<op::Add>(a, b);
386 FAIL() << "Inconsistent partial shapes not detected";
388 catch (const NodeValidationFailure& error)
390 EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
394 FAIL() << "Deduced type check failed for unexpected reason";
398 TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_inconsistent)
400 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 3});
401 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
405 auto add = make_shared<op::Add>(a, b);
406 FAIL() << "Inconsistent partial shapes not detected";
408 catch (const NodeValidationFailure& error)
410 EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
414 FAIL() << "Deduced type check failed for unexpected reason";
418 TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_inconsistent)
420 auto a = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 3, 3});
421 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
425 auto add = make_shared<op::Add>(a, b);
426 FAIL() << "Inconsistent partial shapes not detected";
428 catch (const NodeValidationFailure& error)
430 EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
434 FAIL() << "Deduced type check failed for unexpected reason";
438 TEST(type_prop, binary_elementwise_arithmetic_left_rank_static_dynamic_different_rank)
440 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
441 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
445 auto add = make_shared<op::Add>(a, b);
446 FAIL() << "Inconsistent partial shapes not detected";
448 catch (const NodeValidationFailure& error)
450 EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
454 FAIL() << "Deduced type check failed for unexpected reason";
458 TEST(type_prop, binary_elementwise_arithmetic_right_rank_static_dynamic_different_rank)
460 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3, 4});
461 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
465 auto add = make_shared<op::Add>(a, b);
466 FAIL() << "Inconsistent partial shapes not detected";
468 catch (const NodeValidationFailure& error)
470 EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
474 FAIL() << "Deduced type check failed for unexpected reason";
478 TEST(type_prop, binary_elementwise_arithmetic_both_rank_static_dynamic_different_rank)
480 auto a = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 3, 4});
481 auto b = make_shared<op::Parameter>(element::f32, PartialShape{1, 2, Dimension::dynamic()});
485 auto add = make_shared<op::Add>(a, b);
486 FAIL() << "Inconsistent partial shapes not detected";
488 catch (const NodeValidationFailure& error)
490 EXPECT_HAS_SUBSTRING(error.what(), "Argument shapes are inconsistent");
494 FAIL() << "Deduced type check failed for unexpected reason";
498 TEST(type_prop, binary_elementwise_arithmetic_both_et_dynamic)
500 auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
501 auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
502 auto add = make_shared<op::Add>(a, b);
504 ASSERT_TRUE(add->get_output_element_type(0).is_dynamic());
507 TEST(type_prop, binary_elementwise_arithmetic_left_et_dynamic)
509 auto a = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
510 auto b = make_shared<op::Parameter>(element::u32, Shape{1, 2, 3, 4});
511 auto add = make_shared<op::Add>(a, b);
513 ASSERT_EQ(add->get_output_element_type(0), element::u32);
516 TEST(type_prop, binary_elementwise_arithmetic_right_et_dynamic)
518 auto a = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
519 auto b = make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3, 4});
520 auto add = make_shared<op::Add>(a, b);
522 ASSERT_EQ(add->get_output_element_type(0), element::i64);
525 TEST(type_prop, logic_arith_compare_partial_et)
527 auto test_arith = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
528 auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
529 auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
530 return std::make_shared<op::Add>(param0, param1);
533 auto test_compare = [](element::Type et0, element::Type et1) -> std::shared_ptr<Node> {
534 auto param0 = std::make_shared<op::Parameter>(et0, Shape{1, 2, 3});
535 auto param1 = std::make_shared<op::Parameter>(et1, Shape{1, 2, 3});
536 return std::make_shared<op::Greater>(param0, param1);
539 auto test_not = [](element::Type et) -> std::shared_ptr<Node> {
540 auto param = std::make_shared<op::Parameter>(et, Shape{1, 2, 3});
541 return std::make_shared<op::Not>(param);
555 ASSERT_EQ(test_arith(element::i32, element::i32)->get_element_type(), element::i32);
556 ASSERT_ANY_THROW({ test_arith(element::i32, element::boolean); });
557 ASSERT_EQ(test_arith(element::i32, element::dynamic)->get_element_type(), element::i32);
558 ASSERT_ANY_THROW({ test_arith(element::boolean, element::i32); });
559 ASSERT_ANY_THROW({ test_arith(element::boolean, element::boolean); });
560 ASSERT_ANY_THROW({ test_arith(element::boolean, element::dynamic); });
561 ASSERT_EQ(test_arith(element::dynamic, element::i32)->get_element_type(), element::i32);
562 ASSERT_ANY_THROW({ test_arith(element::dynamic, element::boolean); });
563 ASSERT_EQ(test_arith(element::dynamic, element::dynamic)->get_element_type(), element::dynamic);
576 ASSERT_EQ(test_compare(element::i32, element::i32)->get_element_type(), element::boolean);
577 ASSERT_ANY_THROW({ test_compare(element::i32, element::boolean); });
578 ASSERT_EQ(test_compare(element::i32, element::dynamic)->get_element_type(), element::boolean);
579 ASSERT_ANY_THROW({ test_compare(element::boolean, element::i32); });
580 ASSERT_EQ(test_compare(element::boolean, element::boolean)->get_element_type(),
582 ASSERT_EQ(test_compare(element::boolean, element::dynamic)->get_element_type(),
584 ASSERT_EQ(test_compare(element::dynamic, element::i32)->get_element_type(), element::boolean);
585 ASSERT_EQ(test_compare(element::dynamic, element::boolean)->get_element_type(),
587 ASSERT_EQ(test_compare(element::dynamic, element::dynamic)->get_element_type(),
590 // Logical negation op:
597 // TODO(amprocte): I believe the behavior should actually be:
601 ASSERT_EQ(test_not(element::i32)->get_element_type(), element::i32);
602 ASSERT_EQ(test_not(element::boolean)->get_element_type(), element::boolean);
603 ASSERT_EQ(test_not(element::dynamic)->get_element_type(), element::dynamic);