1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
24 using namespace ngraph;
26 TEST(type_prop, select_deduce)
28 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
29 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
30 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
31 auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
32 ASSERT_EQ(bc->get_element_type(), element::f32);
33 ASSERT_EQ(bc->get_shape(), (Shape{2, 4}));
36 TEST(type_prop, select_shape_mismatch_a)
38 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{3, 5});
39 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
40 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
43 auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
44 // Should have thrown, so fail if it didn't
45 FAIL() << "Did not detect incorrect element types for arithmetic operator";
47 catch (const NodeValidationFailure& error)
49 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
53 FAIL() << "Deduced type check failed for unexpected reason";
57 TEST(type_prop, select_shape_mismatch_b)
59 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
60 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{3, 5});
61 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
64 auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
65 // Should have thrown, so fail if it didn't
66 FAIL() << "Did not detect incorrect element types for arithmetic operator";
68 catch (const NodeValidationFailure& error)
70 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
74 FAIL() << "Deduced type check failed for unexpected reason";
78 TEST(type_prop, select_shape_mismatch_c)
80 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
81 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
82 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{3, 5});
85 auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
86 // Should have thrown, so fail if it didn't
87 FAIL() << "Did not detect incorrect element types for arithmetic operator";
89 catch (const NodeValidationFailure& error)
91 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
95 FAIL() << "Deduced type check failed for unexpected reason";
99 TEST(type_prop, select_elem_mismatch_a)
101 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
102 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
103 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
106 auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
107 // Should have thrown, so fail if it didn't
108 FAIL() << "Did not detect incorrect element types for arithmetic operator";
110 catch (const NodeValidationFailure& error)
112 EXPECT_HAS_SUBSTRING(error.what(),
113 std::string("Argument 0 must have boolean element type"));
117 FAIL() << "Deduced type check failed for unexpected reason";
121 TEST(type_prop, select_elem_mismatch_bc)
123 auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
124 auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
125 auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
128 auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
129 // Should have thrown, so fail if it didn't
130 FAIL() << "Did not detect incorrect element types for arithmetic operator";
132 catch (const NodeValidationFailure& error)
134 EXPECT_HAS_SUBSTRING(error.what(),
135 std::string("Argument 1 and 2 element types are inconsistent"));
139 FAIL() << "Deduced type check failed for unexpected reason";
143 TEST(type_prop, select_partial_all_rank_dynamic)
145 auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
146 auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
147 auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
149 auto sel = make_shared<op::Select>(param0, param1, param2);
151 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
152 ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
155 TEST(type_prop, select_partial_all_rank_dynamic_arg0_et_dynamic_arg1_arg2_et_mismatch)
157 auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
158 auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
159 auto param2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
163 auto sel = make_shared<op::Select>(param0, param1, param2);
164 FAIL() << "Did not detect mismatched element types for args 1 and 2 (element type-dynamic "
167 catch (const NodeValidationFailure& error)
169 EXPECT_HAS_SUBSTRING(error.what(),
170 std::string("Argument 1 and 2 element types are inconsistent"));
174 FAIL() << "Deduced type check failed for unexpected reason";
178 TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_et_dynamic)
180 auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
181 auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
182 auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
184 auto sel = make_shared<op::Select>(param0, param1, param2);
186 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
187 ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
190 TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg2_et_dynamic)
192 auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
193 auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
194 auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
196 auto sel = make_shared<op::Select>(param0, param1, param2);
198 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
199 ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
202 TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_arg2_et_dynamic)
204 auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
205 auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
206 auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
208 auto sel = make_shared<op::Select>(param0, param1, param2);
210 ASSERT_EQ(sel->get_output_element_type(0), element::dynamic);
211 ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
214 TEST(type_prop, select_partial_arg0_rank_dynamic_static_arg1_arg2_rank_dynamic_ok)
217 make_shared<op::Parameter>(element::boolean, PartialShape{2, Dimension::dynamic(), 3});
218 auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
219 auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
221 auto sel = make_shared<op::Select>(param0, param1, param2);
223 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
225 sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
228 TEST(type_prop, select_partial_arg1_rank_dynamic_static_arg0_arg2_rank_dynamic_ok)
230 auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
232 make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
233 auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
235 auto sel = make_shared<op::Select>(param0, param1, param2);
237 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
239 sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
242 TEST(type_prop, select_partial_arg2_rank_dynamic_static_arg0_arg1_rank_dynamic_ok)
244 auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
245 auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
247 make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
249 auto sel = make_shared<op::Select>(param0, param1, param2);
251 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
253 sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
256 TEST(type_prop, select_partial_all_rank_static_dynamic_ok)
258 auto param0 = make_shared<op::Parameter>(
259 element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
260 auto param1 = make_shared<op::Parameter>(
261 element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
262 auto param2 = make_shared<op::Parameter>(
263 element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3});
265 auto sel = make_shared<op::Select>(param0, param1, param2);
267 ASSERT_EQ(sel->get_output_element_type(0), element::f32);
268 ASSERT_TRUE(sel->get_output_partial_shape(0).is_static());
269 ASSERT_EQ(sel->get_output_shape(0), (Shape{2, 8, 3}));
272 TEST(type_prop, select_partial_all_rank_static_intransitive_incompatibility)
274 auto param0 = make_shared<op::Parameter>(
275 element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
276 auto param1 = make_shared<op::Parameter>(
277 element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
279 make_shared<op::Parameter>(element::f32, PartialShape{3, Dimension::dynamic(), 3});
283 auto sel = make_shared<op::Select>(param0, param1, param2);
284 FAIL() << "Did not detect intransitive partial-shape incompatibility";
286 catch (const NodeValidationFailure& error)
288 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
292 FAIL() << "Deduced type check failed for unexpected reason";
296 //------------------------------ v1::Select ---------------------------------//
301 std::vector<Shape> shapes;
302 std::vector<element::Type> ets;
303 op::AutoBroadcastSpec auto_broadcast;
305 SelectParams(const std::vector<Shape>& shape,
306 const std::vector<element::Type>& et,
307 const op::AutoBroadcastSpec& auto_broadcast)
310 , auto_broadcast(auto_broadcast)
315 struct DeduceV1SelectTest : ::testing::TestWithParam<SelectParams>
319 TEST_P(DeduceV1SelectTest, output_shape)
321 auto tp = GetParam();
322 auto cond = make_shared<op::Parameter>(tp.ets[0], tp.shapes[0]);
323 auto ptrue = make_shared<op::Parameter>(tp.ets[1], tp.shapes[1]);
324 auto pfalse = make_shared<op::Parameter>(tp.ets[2], tp.shapes[2]);
325 auto select = make_shared<op::v1::Select>(cond, ptrue, pfalse, tp.auto_broadcast);
327 ASSERT_EQ(select->get_shape(), tp.shapes[3]);
328 ASSERT_EQ(select->get_element_type(), tp.ets[3]);
331 INSTANTIATE_TEST_CASE_P(
334 ::testing::Values(SelectParams({{2, 4}, {2, 4}, {2, 4}, {2, 4}},
335 {element::boolean, element::f32, element::f32, element::f32},
336 op::AutoBroadcastType::NONE),
337 SelectParams({{2, 4}, {2, 4}, {2, 4}, {2, 4}},
338 {element::boolean, element::f32, element::f32, element::f32},
339 op::AutoBroadcastType::NUMPY),
340 SelectParams({{}, {2, 4}, {2, 4}, {2, 4}},
341 {element::boolean, element::f32, element::f32, element::f32},
342 op::AutoBroadcastType::NUMPY),
343 SelectParams({{}, {4}, {2, 4}, {2, 4}},
344 {element::boolean, element::f32, element::dynamic, element::f32},
345 op::AutoBroadcastType::NUMPY),
346 SelectParams({{}, {2, 4}, {4}, {2, 4}},
347 {element::boolean, element::f32, element::f32, element::f32},
348 op::AutoBroadcastType::NUMPY),
349 SelectParams({{4}, {2, 4}, {4}, {2, 4}},
350 {element::boolean, element::i8, element::dynamic, element::i8},
351 op::AutoBroadcastType::NUMPY),
352 SelectParams({{4}, {4}, {2, 4}, {2, 4}},
353 {element::dynamic, element::dynamic, element::i8, element::i8},
354 op::AutoBroadcastType::NUMPY),
355 SelectParams({{2}, {2}, {2, 4}, {2, 4}},
356 {element::boolean, element::f32, element::dynamic, element::f32},
357 {op::AutoBroadcastType::PDPD, 0}),
358 // TODO: Whats the right behavior here?
359 // SelectParams({{2}, {2, 4}, {2}, {2, 4}}, {element::boolean, element::f32,
360 // element::dynamic, element::f32}, {op::AutoBroadcastType::PDPD, 0}),
361 SelectParams({{4}, {4}, {2, 4}, {2, 4}},
362 {element::boolean, element::f32, element::dynamic, element::f32},
363 {op::AutoBroadcastType::PDPD, 1})),
364 PrintToDummyParamName());
366 TEST(type_prop, select_v1_partial_shape)
368 auto a = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
369 auto b = make_shared<op::Parameter>(element::f32, Shape{2, 4});
370 auto c = make_shared<op::Parameter>(element::f32, Shape{2, 4});
372 auto select = make_shared<op::v1::Select>(a, b, c, op::AutoBroadcastType::NONE);
373 ASSERT_EQ(select->get_shape(), (Shape{2, 4}));
376 TEST(type_prop, select_v1_partial_shape_autob)
378 auto a = make_shared<op::Parameter>(element::boolean, PartialShape{Dimension::dynamic()});
379 auto b = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic()});
380 auto c = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic()});
382 auto select = make_shared<op::v1::Select>(a, b, c);
384 select->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic()}));
387 TEST(type_prop, select_v1_wrong_et)
389 auto param0 = make_shared<op::Parameter>(element::i8, Shape{2, 4});
390 auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
391 auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
395 auto sel = make_shared<op::v1::Select>(param0, param1, param2);
396 FAIL() << "Did not detect wrong element type";
398 catch (const NodeValidationFailure& error)
400 EXPECT_HAS_SUBSTRING(error.what(),
401 std::string("Argument 0 must have boolean element type"));
405 FAIL() << "Deduced type check failed for unexpected reason";
409 TEST(type_prop, select_v1_et_mismatch)
411 auto param0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
412 auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
413 auto param2 = make_shared<op::Parameter>(element::i8, Shape{2, 4});
417 auto sel = make_shared<op::v1::Select>(param0, param1, param2);
418 FAIL() << "Did not detect element type mismatch";
420 catch (const NodeValidationFailure& error)
422 EXPECT_HAS_SUBSTRING(error.what(),
423 std::string("Argument 1 and 2 element types must match."));
427 FAIL() << "Deduced type check failed for unexpected reason";
431 TEST(type_prop, select_v1_shape_mismatch)
433 auto param0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
434 auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 3});
435 auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
439 auto sel = make_shared<op::v1::Select>(param0, param1, param2);
440 FAIL() << "Did not detect shape mismatch";
442 catch (const NodeValidationFailure& error)
444 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent."));
448 FAIL() << "Deduced type check failed for unexpected reason";
452 TEST(type_prop, select_v1_partial_shape_mismatch)
455 make_shared<op::Parameter>(element::boolean, PartialShape{3, Dimension::dynamic()});
456 auto param1 = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic()});
457 auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
461 auto sel = make_shared<op::v1::Select>(param0, param1, param2);
462 FAIL() << "Did not detect shape mismatch";
464 catch (const NodeValidationFailure& error)
466 EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent."));
470 FAIL() << "Deduced type check failed for unexpected reason";