Deprecate nGraph v0 ops and builders (#1856)
[platform/upstream/dldt.git] / ngraph / test / type_prop / select.cpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
20
21 NGRAPH_SUPPRESS_DEPRECATED_START
22
23 using namespace std;
24 using namespace ngraph;
25
26 TEST(type_prop, select_deduce)
27 {
28     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
29     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
30     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
31     auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
32     ASSERT_EQ(bc->get_element_type(), element::f32);
33     ASSERT_EQ(bc->get_shape(), (Shape{2, 4}));
34 }
35
36 TEST(type_prop, select_shape_mismatch_a)
37 {
38     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{3, 5});
39     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
40     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
41     try
42     {
43         auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
44         // Should have thrown, so fail if it didn't
45         FAIL() << "Did not detect incorrect element types for arithmetic operator";
46     }
47     catch (const NodeValidationFailure& error)
48     {
49         EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
50     }
51     catch (...)
52     {
53         FAIL() << "Deduced type check failed for unexpected reason";
54     }
55 }
56
57 TEST(type_prop, select_shape_mismatch_b)
58 {
59     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
60     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{3, 5});
61     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
62     try
63     {
64         auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
65         // Should have thrown, so fail if it didn't
66         FAIL() << "Did not detect incorrect element types for arithmetic operator";
67     }
68     catch (const NodeValidationFailure& error)
69     {
70         EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
71     }
72     catch (...)
73     {
74         FAIL() << "Deduced type check failed for unexpected reason";
75     }
76 }
77
78 TEST(type_prop, select_shape_mismatch_c)
79 {
80     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
81     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
82     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{3, 5});
83     try
84     {
85         auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
86         // Should have thrown, so fail if it didn't
87         FAIL() << "Did not detect incorrect element types for arithmetic operator";
88     }
89     catch (const NodeValidationFailure& error)
90     {
91         EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
92     }
93     catch (...)
94     {
95         FAIL() << "Deduced type check failed for unexpected reason";
96     }
97 }
98
99 TEST(type_prop, select_elem_mismatch_a)
100 {
101     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
102     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
103     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
104     try
105     {
106         auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
107         // Should have thrown, so fail if it didn't
108         FAIL() << "Did not detect incorrect element types for arithmetic operator";
109     }
110     catch (const NodeValidationFailure& error)
111     {
112         EXPECT_HAS_SUBSTRING(error.what(),
113                              std::string("Argument 0 must have boolean element type"));
114     }
115     catch (...)
116     {
117         FAIL() << "Deduced type check failed for unexpected reason";
118     }
119 }
120
121 TEST(type_prop, select_elem_mismatch_bc)
122 {
123     auto tv0_2_4_param_0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
124     auto tv0_2_4_param_1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
125     auto tv0_2_4_param_2 = make_shared<op::Parameter>(element::i32, Shape{2, 4});
126     try
127     {
128         auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
129         // Should have thrown, so fail if it didn't
130         FAIL() << "Did not detect incorrect element types for arithmetic operator";
131     }
132     catch (const NodeValidationFailure& error)
133     {
134         EXPECT_HAS_SUBSTRING(error.what(),
135                              std::string("Argument 1 and 2 element types are inconsistent"));
136     }
137     catch (...)
138     {
139         FAIL() << "Deduced type check failed for unexpected reason";
140     }
141 }
142
143 TEST(type_prop, select_partial_all_rank_dynamic)
144 {
145     auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
146     auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
147     auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
148
149     auto sel = make_shared<op::Select>(param0, param1, param2);
150
151     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
152     ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
153 }
154
155 TEST(type_prop, select_partial_all_rank_dynamic_arg0_et_dynamic_arg1_arg2_et_mismatch)
156 {
157     auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
158     auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
159     auto param2 = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
160
161     try
162     {
163         auto sel = make_shared<op::Select>(param0, param1, param2);
164         FAIL() << "Did not detect mismatched element types for args 1 and 2 (element type-dynamic "
165                   "arg0)";
166     }
167     catch (const NodeValidationFailure& error)
168     {
169         EXPECT_HAS_SUBSTRING(error.what(),
170                              std::string("Argument 1 and 2 element types are inconsistent"));
171     }
172     catch (...)
173     {
174         FAIL() << "Deduced type check failed for unexpected reason";
175     }
176 }
177
178 TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_et_dynamic)
179 {
180     auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
181     auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
182     auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
183
184     auto sel = make_shared<op::Select>(param0, param1, param2);
185
186     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
187     ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
188 }
189
190 TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg2_et_dynamic)
191 {
192     auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
193     auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
194     auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
195
196     auto sel = make_shared<op::Select>(param0, param1, param2);
197
198     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
199     ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
200 }
201
202 TEST(type_prop, select_partial_all_rank_dynamic_arg0_arg1_arg2_et_dynamic)
203 {
204     auto param0 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
205     auto param1 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
206     auto param2 = make_shared<op::Parameter>(element::dynamic, PartialShape::dynamic());
207
208     auto sel = make_shared<op::Select>(param0, param1, param2);
209
210     ASSERT_EQ(sel->get_output_element_type(0), element::dynamic);
211     ASSERT_TRUE(sel->get_output_partial_shape(0).rank().is_dynamic());
212 }
213
214 TEST(type_prop, select_partial_arg0_rank_dynamic_static_arg1_arg2_rank_dynamic_ok)
215 {
216     auto param0 =
217         make_shared<op::Parameter>(element::boolean, PartialShape{2, Dimension::dynamic(), 3});
218     auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
219     auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
220
221     auto sel = make_shared<op::Select>(param0, param1, param2);
222
223     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
224     ASSERT_TRUE(
225         sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
226 }
227
228 TEST(type_prop, select_partial_arg1_rank_dynamic_static_arg0_arg2_rank_dynamic_ok)
229 {
230     auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
231     auto param1 =
232         make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
233     auto param2 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
234
235     auto sel = make_shared<op::Select>(param0, param1, param2);
236
237     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
238     ASSERT_TRUE(
239         sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
240 }
241
242 TEST(type_prop, select_partial_arg2_rank_dynamic_static_arg0_arg1_rank_dynamic_ok)
243 {
244     auto param0 = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
245     auto param1 = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
246     auto param2 =
247         make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), 3});
248
249     auto sel = make_shared<op::Select>(param0, param1, param2);
250
251     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
252     ASSERT_TRUE(
253         sel->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic(), 3}));
254 }
255
256 TEST(type_prop, select_partial_all_rank_static_dynamic_ok)
257 {
258     auto param0 = make_shared<op::Parameter>(
259         element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
260     auto param1 = make_shared<op::Parameter>(
261         element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
262     auto param2 = make_shared<op::Parameter>(
263         element::f32, PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3});
264
265     auto sel = make_shared<op::Select>(param0, param1, param2);
266
267     ASSERT_EQ(sel->get_output_element_type(0), element::f32);
268     ASSERT_TRUE(sel->get_output_partial_shape(0).is_static());
269     ASSERT_EQ(sel->get_output_shape(0), (Shape{2, 8, 3}));
270 }
271
272 TEST(type_prop, select_partial_all_rank_static_intransitive_incompatibility)
273 {
274     auto param0 = make_shared<op::Parameter>(
275         element::boolean, PartialShape{2, Dimension::dynamic(), Dimension::dynamic()});
276     auto param1 = make_shared<op::Parameter>(
277         element::f32, PartialShape{Dimension::dynamic(), 8, Dimension::dynamic()});
278     auto param2 =
279         make_shared<op::Parameter>(element::f32, PartialShape{3, Dimension::dynamic(), 3});
280
281     try
282     {
283         auto sel = make_shared<op::Select>(param0, param1, param2);
284         FAIL() << "Did not detect intransitive partial-shape incompatibility";
285     }
286     catch (const NodeValidationFailure& error)
287     {
288         EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent"));
289     }
290     catch (...)
291     {
292         FAIL() << "Deduced type check failed for unexpected reason";
293     }
294 }
295
296 //------------------------------ v1::Select ---------------------------------//
297 //
298 //
299 struct SelectParams
300 {
301     std::vector<Shape> shapes;
302     std::vector<element::Type> ets;
303     op::AutoBroadcastSpec auto_broadcast;
304
305     SelectParams(const std::vector<Shape>& shape,
306                  const std::vector<element::Type>& et,
307                  const op::AutoBroadcastSpec& auto_broadcast)
308         : shapes(shape)
309         , ets(et)
310         , auto_broadcast(auto_broadcast)
311     {
312     }
313 };
314
315 struct DeduceV1SelectTest : ::testing::TestWithParam<SelectParams>
316 {
317 };
318
319 TEST_P(DeduceV1SelectTest, output_shape)
320 {
321     auto tp = GetParam();
322     auto cond = make_shared<op::Parameter>(tp.ets[0], tp.shapes[0]);
323     auto ptrue = make_shared<op::Parameter>(tp.ets[1], tp.shapes[1]);
324     auto pfalse = make_shared<op::Parameter>(tp.ets[2], tp.shapes[2]);
325     auto select = make_shared<op::v1::Select>(cond, ptrue, pfalse, tp.auto_broadcast);
326
327     ASSERT_EQ(select->get_shape(), tp.shapes[3]);
328     ASSERT_EQ(select->get_element_type(), tp.ets[3]);
329 }
330
331 INSTANTIATE_TEST_CASE_P(
332     type_prop,
333     DeduceV1SelectTest,
334     ::testing::Values(SelectParams({{2, 4}, {2, 4}, {2, 4}, {2, 4}},
335                                    {element::boolean, element::f32, element::f32, element::f32},
336                                    op::AutoBroadcastType::NONE),
337                       SelectParams({{2, 4}, {2, 4}, {2, 4}, {2, 4}},
338                                    {element::boolean, element::f32, element::f32, element::f32},
339                                    op::AutoBroadcastType::NUMPY),
340                       SelectParams({{}, {2, 4}, {2, 4}, {2, 4}},
341                                    {element::boolean, element::f32, element::f32, element::f32},
342                                    op::AutoBroadcastType::NUMPY),
343                       SelectParams({{}, {4}, {2, 4}, {2, 4}},
344                                    {element::boolean, element::f32, element::dynamic, element::f32},
345                                    op::AutoBroadcastType::NUMPY),
346                       SelectParams({{}, {2, 4}, {4}, {2, 4}},
347                                    {element::boolean, element::f32, element::f32, element::f32},
348                                    op::AutoBroadcastType::NUMPY),
349                       SelectParams({{4}, {2, 4}, {4}, {2, 4}},
350                                    {element::boolean, element::i8, element::dynamic, element::i8},
351                                    op::AutoBroadcastType::NUMPY),
352                       SelectParams({{4}, {4}, {2, 4}, {2, 4}},
353                                    {element::dynamic, element::dynamic, element::i8, element::i8},
354                                    op::AutoBroadcastType::NUMPY),
355                       SelectParams({{2}, {2}, {2, 4}, {2, 4}},
356                                    {element::boolean, element::f32, element::dynamic, element::f32},
357                                    {op::AutoBroadcastType::PDPD, 0}),
358                       // TODO: Whats the right behavior here?
359                       // SelectParams({{2}, {2, 4}, {2}, {2, 4}}, {element::boolean, element::f32,
360                       // element::dynamic, element::f32}, {op::AutoBroadcastType::PDPD, 0}),
361                       SelectParams({{4}, {4}, {2, 4}, {2, 4}},
362                                    {element::boolean, element::f32, element::dynamic, element::f32},
363                                    {op::AutoBroadcastType::PDPD, 1})),
364     PrintToDummyParamName());
365
366 TEST(type_prop, select_v1_partial_shape)
367 {
368     auto a = make_shared<op::Parameter>(element::boolean, PartialShape::dynamic());
369     auto b = make_shared<op::Parameter>(element::f32, Shape{2, 4});
370     auto c = make_shared<op::Parameter>(element::f32, Shape{2, 4});
371
372     auto select = make_shared<op::v1::Select>(a, b, c, op::AutoBroadcastType::NONE);
373     ASSERT_EQ(select->get_shape(), (Shape{2, 4}));
374 }
375
376 TEST(type_prop, select_v1_partial_shape_autob)
377 {
378     auto a = make_shared<op::Parameter>(element::boolean, PartialShape{Dimension::dynamic()});
379     auto b = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic()});
380     auto c = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic()});
381
382     auto select = make_shared<op::v1::Select>(a, b, c);
383     ASSERT_TRUE(
384         select->get_output_partial_shape(0).same_scheme(PartialShape{2, Dimension::dynamic()}));
385 }
386
387 TEST(type_prop, select_v1_wrong_et)
388 {
389     auto param0 = make_shared<op::Parameter>(element::i8, Shape{2, 4});
390     auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
391     auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
392
393     try
394     {
395         auto sel = make_shared<op::v1::Select>(param0, param1, param2);
396         FAIL() << "Did not detect wrong element type";
397     }
398     catch (const NodeValidationFailure& error)
399     {
400         EXPECT_HAS_SUBSTRING(error.what(),
401                              std::string("Argument 0 must have boolean element type"));
402     }
403     catch (...)
404     {
405         FAIL() << "Deduced type check failed for unexpected reason";
406     }
407 }
408
409 TEST(type_prop, select_v1_et_mismatch)
410 {
411     auto param0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
412     auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
413     auto param2 = make_shared<op::Parameter>(element::i8, Shape{2, 4});
414
415     try
416     {
417         auto sel = make_shared<op::v1::Select>(param0, param1, param2);
418         FAIL() << "Did not detect element type mismatch";
419     }
420     catch (const NodeValidationFailure& error)
421     {
422         EXPECT_HAS_SUBSTRING(error.what(),
423                              std::string("Argument 1 and 2 element types must match."));
424     }
425     catch (...)
426     {
427         FAIL() << "Deduced type check failed for unexpected reason";
428     }
429 }
430
431 TEST(type_prop, select_v1_shape_mismatch)
432 {
433     auto param0 = make_shared<op::Parameter>(element::boolean, Shape{2, 4});
434     auto param1 = make_shared<op::Parameter>(element::f32, Shape{2, 3});
435     auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
436
437     try
438     {
439         auto sel = make_shared<op::v1::Select>(param0, param1, param2);
440         FAIL() << "Did not detect shape mismatch";
441     }
442     catch (const NodeValidationFailure& error)
443     {
444         EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent."));
445     }
446     catch (...)
447     {
448         FAIL() << "Deduced type check failed for unexpected reason";
449     }
450 }
451
452 TEST(type_prop, select_v1_partial_shape_mismatch)
453 {
454     auto param0 =
455         make_shared<op::Parameter>(element::boolean, PartialShape{3, Dimension::dynamic()});
456     auto param1 = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic()});
457     auto param2 = make_shared<op::Parameter>(element::f32, Shape{2, 4});
458
459     try
460     {
461         auto sel = make_shared<op::v1::Select>(param0, param1, param2);
462         FAIL() << "Did not detect shape mismatch";
463     }
464     catch (const NodeValidationFailure& error)
465     {
466         EXPECT_HAS_SUBSTRING(error.what(), std::string("Argument shapes are inconsistent."));
467     }
468     catch (...)
469     {
470         FAIL() << "Deduced type check failed for unexpected reason";
471     }
472 }