1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
24 using namespace ngraph;
26 TEST(type_prop, one_hot_deduce_scalar)
28 auto param = make_shared<op::Parameter>(element::i32, Shape{});
29 auto oh = make_shared<op::OneHot>(param, Shape{9}, 0);
30 ASSERT_EQ(oh->get_element_type(), element::i32);
31 ASSERT_EQ(oh->get_shape(), (Shape{9}));
34 TEST(type_prop, one_hot_deduce_vector_0)
36 auto param = make_shared<op::Parameter>(element::i32, Shape{8});
37 auto oh = make_shared<op::OneHot>(param, Shape{9, 8}, 0);
38 ASSERT_EQ(oh->get_element_type(), element::i32);
39 ASSERT_EQ(oh->get_shape(), (Shape{9, 8}));
42 TEST(type_prop, one_hot_deduce_vector_1)
44 auto param = make_shared<op::Parameter>(element::i32, Shape{8});
45 auto oh = make_shared<op::OneHot>(param, Shape{8, 9}, 1);
46 ASSERT_EQ(oh->get_element_type(), element::i32);
47 ASSERT_EQ(oh->get_shape(), (Shape{8, 9}));
50 TEST(type_prop, one_hot_deduce_matrix_0)
52 auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
53 auto oh = make_shared<op::OneHot>(param, Shape{2, 12, 24}, 0);
54 ASSERT_EQ(oh->get_element_type(), element::i32);
55 ASSERT_EQ(oh->get_shape(), (Shape{2, 12, 24}));
58 TEST(type_prop, one_hot_deduce_matrix_1)
60 auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
61 auto oh = make_shared<op::OneHot>(param, Shape{12, 2, 24}, 1);
62 ASSERT_EQ(oh->get_element_type(), element::i32);
63 ASSERT_EQ(oh->get_shape(), (Shape{12, 2, 24}));
66 TEST(type_prop, one_hot_deduce_matrix_2)
68 auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
69 auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 2}, 2);
70 ASSERT_EQ(oh->get_element_type(), element::i32);
71 ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 2}));
74 TEST(type_prop, one_hot_deduce_et_dynamic)
76 auto param = make_shared<op::Parameter>(element::dynamic, Shape{12, 24});
77 auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 2}, 2);
78 ASSERT_EQ(oh->get_element_type(), element::dynamic);
79 ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 2}));
82 TEST(type_prop, one_hot_deduce_floating_point)
84 auto param = make_shared<op::Parameter>(element::f32, Shape{12, 24});
87 auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 3);
88 // Should have thrown, so fail if it didn't
89 FAIL() << "Invalid floating-point element type not detected.";
91 catch (const NodeValidationFailure& error)
93 EXPECT_HAS_SUBSTRING(error.what(),
94 std::string("Argument does not have integral element type."));
98 FAIL() << "Deduced type check failed for unexpected reason";
102 TEST(type_prop, one_hot_deduce_axis_oob)
104 auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
107 auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 3);
108 // Should have thrown, so fail if it didn't
109 FAIL() << "One-hot axis out of bounds not detected.";
111 catch (const NodeValidationFailure& error)
113 EXPECT_HAS_SUBSTRING(error.what(), std::string("One-hot axis (3) is out of bounds"));
117 FAIL() << "Deduced type check failed for unexpected reason";
121 TEST(type_prop, one_hot_deduce_shape_incompatible)
123 auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
126 auto oh = make_shared<op::OneHot>(param, Shape{12, 22, 8}, 2);
127 // Should have thrown, so fail if it didn't
128 FAIL() << "Incompatible one-hot output shape not detected.";
130 catch (const ngraph_error& error)
132 EXPECT_HAS_SUBSTRING(
133 error.what(), std::string("Argument shape {12,24} does not match the expected shape"));
137 FAIL() << "Deduced type check failed for unexpected reason";
141 TEST(type_prop, one_hot_partial_rank_dynamic_rank_dynamic)
143 PartialShape input_shape{PartialShape::dynamic()};
144 PartialShape requested_shape{PartialShape::dynamic()};
145 size_t one_hot_axis{3000};
147 auto param = make_shared<op::Parameter>(element::i32, input_shape);
150 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
151 // Should have thrown, so fail if it didn't
152 FAIL() << "Dynamic rank for requested result shape not detected";
154 catch (const ngraph_error& error)
156 EXPECT_HAS_SUBSTRING(error.what(), std::string("Requested result shape has dynamic rank"));
160 FAIL() << "Deduced type check failed for unexpected reason";
164 TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_ok)
166 PartialShape input_shape{PartialShape::dynamic()};
167 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
168 size_t one_hot_axis{2};
170 auto param = make_shared<op::Parameter>(element::i32, input_shape);
171 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
173 ASSERT_EQ(oh->get_output_element_type(0), element::i32);
174 ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
175 PartialShape{Dimension::dynamic(), 2, 3, Dimension::dynamic()}));
178 TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_dim_dynamic)
180 PartialShape input_shape{PartialShape::dynamic()};
181 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
182 size_t one_hot_axis{3};
184 auto param = make_shared<op::Parameter>(element::i32, input_shape);
187 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
188 // Should have thrown, so fail if it didn't
189 FAIL() << "Dynamic one-hot dimension not detected";
191 catch (const ngraph_error& error)
193 EXPECT_HAS_SUBSTRING(error.what(),
194 std::string("Requested result shape ({?,2,3,?}) has dynamic dimension "
195 "at the one-hot axis (3)"));
199 FAIL() << "Deduced type check failed for unexpected reason";
203 TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_axis_oob)
205 PartialShape input_shape{PartialShape::dynamic()};
206 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
207 size_t one_hot_axis{4};
209 auto param = make_shared<op::Parameter>(element::i32, input_shape);
212 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
213 // Should have thrown, so fail if it didn't
214 FAIL() << "One-hot axis out of bounds not detected (rank-dynamic argument, rank-static "
215 "dynamic result shape)";
217 catch (const ngraph_error& error)
219 EXPECT_HAS_SUBSTRING(
221 std::string("One-hot axis (4) is out of bounds (requested result shape: {?,2,3,?})"));
225 FAIL() << "Deduced type check failed for unexpected reason";
229 TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_ok)
231 PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
232 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
233 size_t one_hot_axis{2};
235 auto param = make_shared<op::Parameter>(element::i32, input_shape);
236 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
238 ASSERT_EQ(oh->get_output_element_type(0), element::i32);
239 ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
240 PartialShape{3, 2, 3, Dimension::dynamic(), 4}));
244 one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_rank_input_short)
246 PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic()};
247 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
248 size_t one_hot_axis{2};
250 auto param = make_shared<op::Parameter>(element::i32, input_shape);
253 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
254 // Should have thrown, so fail if it didn't
255 FAIL() << "Incompatible input/output ranks not detected (rank-static dynamic argument, "
256 "rank-static dynamic result shape)";
258 catch (const ngraph_error& error)
260 EXPECT_HAS_SUBSTRING(
262 std::string("Argument shape {3,?,?} does not match the expected shape of {?,2,?,4}"));
266 FAIL() << "Deduced type check failed for unexpected reason";
271 one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_rank_input_long)
273 PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
274 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
275 size_t one_hot_axis{2};
277 auto param = make_shared<op::Parameter>(element::i32, input_shape);
280 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
281 // Should have thrown, so fail if it didn't
282 FAIL() << "Incompatible input/output ranks not detected (rank-static dynamic argument, "
283 "rank-static dynamic result shape)";
285 catch (const ngraph_error& error)
287 EXPECT_HAS_SUBSTRING(
290 "Argument shape {3,?,?,4,5} does not match the expected shape of {?,2,?,4}"));
294 FAIL() << "Deduced type check failed for unexpected reason";
298 TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_dim)
300 PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 5};
301 PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
302 size_t one_hot_axis{2};
304 auto param = make_shared<op::Parameter>(element::i32, input_shape);
307 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
308 // Should have thrown, so fail if it didn't
309 FAIL() << "Incompatible input/output dimensions not detected (rank-static dynamic "
310 "argument, rank-static dynamic result shape)";
312 catch (const ngraph_error& error)
314 EXPECT_HAS_SUBSTRING(
316 std::string("Argument shape {3,?,?,5} does not match the expected shape of {?,2,?,4}"));
320 FAIL() << "Deduced type check failed for unexpected reason";
324 TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_dim_dynamic)
326 PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
327 PartialShape requested_shape{
328 Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
329 size_t one_hot_axis{2};
331 auto param = make_shared<op::Parameter>(element::i32, input_shape);
334 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
335 // Should have thrown, so fail if it didn't
336 FAIL() << "Dynamic one-hot dimension not detected (rank-static dynamic argument, "
337 "rank-static dynamic result shape)";
339 catch (const ngraph_error& error)
341 EXPECT_HAS_SUBSTRING(error.what(),
342 std::string("Requested result shape ({?,2,?,?,4}) has dynamic "
343 "dimension at the one-hot axis (2)"));
347 FAIL() << "Deduced type check failed for unexpected reason";
351 TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_axis_oob)
353 PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
354 PartialShape requested_shape{
355 Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
356 size_t one_hot_axis{2};
358 auto param = make_shared<op::Parameter>(element::i32, input_shape);
361 auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
362 // Should have thrown, so fail if it didn't
363 FAIL() << "One-hot axis out of bounds not detected (rank-static dynamic argument, "
364 "rank-static dynamic result shape)";
366 catch (const ngraph_error& error)
368 EXPECT_HAS_SUBSTRING(error.what(),
369 std::string("Requested result shape ({?,2,?,?,4}) has dynamic "
370 "dimension at the one-hot axis (2)"));
374 FAIL() << "Deduced type check failed for unexpected reason";
378 TEST(type_prop, one_hot_v1_output_shape)
380 auto indices = make_shared<op::Parameter>(element::i64, Shape{3});
381 auto depth = op::Constant::create(element::i64, Shape{}, {2});
382 auto on_value = op::Constant::create(element::u32, Shape{}, {5});
383 auto off_value = op::Constant::create(element::u32, Shape{}, {10});
385 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
386 ASSERT_EQ(ont_hot->get_element_type(), element::u32);
387 ASSERT_EQ(ont_hot->get_shape(), (Shape{3, 2}));
390 TEST(type_prop, one_hot_v1_output_shape_2)
392 auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3});
393 auto depth = op::Constant::create(element::i64, Shape{}, {4});
394 auto on_value = op::Constant::create(element::f32, Shape{}, {1.0f});
395 auto off_value = op::Constant::create(element::f32, Shape{}, {0.0f});
397 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
398 ASSERT_EQ(ont_hot->get_element_type(), element::f32);
399 ASSERT_EQ(ont_hot->get_shape(), (Shape{1, 3, 2, 4, 3}));
402 TEST(type_prop, one_hot_v1_indices_elem_not_integral)
404 auto indices = make_shared<op::Parameter>(element::f16, Shape{2, 2});
405 auto depth = make_shared<op::Parameter>(element::i64, Shape{});
406 auto on_value = make_shared<op::Parameter>(element::u32, Shape{});
407 auto off_value = make_shared<op::Parameter>(element::u32, Shape{});
411 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
412 // Should have thrown, so fail if it didn't
413 FAIL() << "Incorrect indices element type not detected";
415 catch (const ngraph_error& error)
417 EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices must be integral element type."));
421 FAIL() << "Deduced type check failed for unexpected reason";
425 TEST(type_prop, one_hot_v1_depth_elem_not_integral)
427 auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
428 auto depth = make_shared<op::Parameter>(element::f16, Shape{});
429 auto on_value = make_shared<op::Parameter>(element::u32, Shape{});
430 auto off_value = make_shared<op::Parameter>(element::u32, Shape{});
434 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
435 // Should have thrown, so fail if it didn't
436 FAIL() << "Incorrect depth element type not detected";
438 catch (const ngraph_error& error)
440 EXPECT_HAS_SUBSTRING(error.what(), std::string("Depth must be integral element type."));
444 FAIL() << "Deduced type check failed for unexpected reason";
448 TEST(type_prop, one_hot_v1_on_off_values_not_compatible)
450 auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
451 auto depth = make_shared<op::Parameter>(element::i64, Shape{});
452 auto on_value = make_shared<op::Parameter>(element::bf16, Shape{});
453 auto off_value = make_shared<op::Parameter>(element::f16, Shape{});
457 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
458 // Should have thrown, so fail if it didn't
459 FAIL() << "Incompatible on/off element types not detected";
461 catch (const ngraph_error& error)
463 EXPECT_HAS_SUBSTRING(
465 std::string("on_value element type must be compatible with off_value element type."));
469 FAIL() << "Deduced type check failed for unexpected reason";
473 TEST(type_prop, one_hot_v1_depth_not_scalar)
475 auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
476 auto depth = make_shared<op::Parameter>(element::i64, Shape{1});
477 auto on_value = make_shared<op::Parameter>(element::bf16, Shape{});
478 auto off_value = make_shared<op::Parameter>(element::bf16, Shape{});
482 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
483 // Should have thrown, so fail if it didn't
484 FAIL() << "Not scalar depth input not detected.";
486 catch (const ngraph_error& error)
488 EXPECT_HAS_SUBSTRING(error.what(), std::string("depth input must be scalar."));
492 FAIL() << "Deduced type check failed for unexpected reason";
496 TEST(type_prop, one_hot_v1_on_value_not_scalar)
498 auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
499 auto depth = make_shared<op::Parameter>(element::i64, Shape{});
500 auto on_value = make_shared<op::Parameter>(element::bf16, Shape{2});
501 auto off_value = make_shared<op::Parameter>(element::bf16, Shape{});
505 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
506 // Should have thrown, so fail if it didn't
507 FAIL() << "Not scalar on_value input not detected.";
509 catch (const ngraph_error& error)
511 EXPECT_HAS_SUBSTRING(error.what(), std::string("on_value input must be scalar."));
515 FAIL() << "Deduced type check failed for unexpected reason";
519 TEST(type_prop, one_hot_v1_off_value_not_scalar)
521 auto indices = make_shared<op::Parameter>(element::i64, Shape{2, 2});
522 auto depth = make_shared<op::Parameter>(element::i64, Shape{});
523 auto on_value = make_shared<op::Parameter>(element::bf16, Shape{});
524 auto off_value = make_shared<op::Parameter>(element::bf16, Shape{3});
528 auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
529 // Should have thrown, so fail if it didn't
530 FAIL() << "Not scalar off_value input not detected.";
532 catch (const ngraph_error& error)
534 EXPECT_HAS_SUBSTRING(error.what(), std::string("off_value input must be scalar."));
538 FAIL() << "Deduced type check failed for unexpected reason";