1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
24 using namespace ngraph;
26 TEST(type_prop, gather_nd_scalar_from_2d)
28 Shape params_shape{2, 2};
29 Shape indices_shape{2, 2};
31 auto P = make_shared<op::Parameter>(element::f32, params_shape);
32 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
33 auto G = make_shared<op::GatherND>(P, I);
34 ASSERT_EQ(G->get_element_type(), element::f32);
35 ASSERT_EQ(G->get_shape(), out_shape);
38 TEST(type_prop, gather_nd_1d_from_2d)
40 Shape params_shape{2, 2};
41 Shape indices_shape{2, 1};
42 Shape out_shape{2, 2};
43 auto P = make_shared<op::Parameter>(element::f32, params_shape);
44 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
45 auto G = make_shared<op::GatherND>(P, I);
46 ASSERT_EQ(G->get_element_type(), element::f32);
47 ASSERT_EQ(G->get_shape(), out_shape);
50 TEST(type_prop, gather_nd_scalar_from_3d)
52 Shape params_shape{2, 2, 2};
53 Shape indices_shape{2, 3};
55 auto P = make_shared<op::Parameter>(element::f32, params_shape);
56 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
57 auto G = make_shared<op::GatherND>(P, I);
58 ASSERT_EQ(G->get_element_type(), element::f32);
59 ASSERT_EQ(G->get_shape(), out_shape);
62 TEST(type_prop, gather_nd_1d_from_3d)
64 Shape params_shape{2, 2, 2};
65 Shape indices_shape{2, 2};
66 Shape out_shape{2, 2};
67 auto P = make_shared<op::Parameter>(element::f32, params_shape);
68 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
69 auto G = make_shared<op::GatherND>(P, I);
70 ASSERT_EQ(G->get_element_type(), element::f32);
71 ASSERT_EQ(G->get_shape(), out_shape);
74 TEST(type_prop, gather_nd_2d_from_3d)
76 Shape params_shape{2, 2, 2};
77 Shape indices_shape{1, 1};
78 Shape out_shape{1, 2, 2};
79 auto P = make_shared<op::Parameter>(element::f32, params_shape);
80 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
81 auto G = make_shared<op::GatherND>(P, I);
82 ASSERT_EQ(G->get_element_type(), element::f32);
83 ASSERT_EQ(G->get_shape(), out_shape);
86 TEST(type_prop, gather_nd_batch_scalar_from_2d)
88 Shape params_shape{2, 2};
89 Shape indices_shape{2, 1, 2};
90 Shape out_shape{2, 1};
91 auto P = make_shared<op::Parameter>(element::f32, params_shape);
92 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
93 auto G = make_shared<op::GatherND>(P, I);
94 ASSERT_EQ(G->get_element_type(), element::f32);
95 ASSERT_EQ(G->get_shape(), out_shape);
98 TEST(type_prop, gather_nd_batch_1d_from_2d)
100 Shape params_shape{2, 2};
101 Shape indices_shape{2, 1, 1};
102 Shape out_shape{2, 1, 2};
103 auto P = make_shared<op::Parameter>(element::f32, params_shape);
104 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
105 auto G = make_shared<op::GatherND>(P, I);
106 ASSERT_EQ(G->get_element_type(), element::f32);
107 ASSERT_EQ(G->get_shape(), out_shape);
110 TEST(type_prop, gather_nd_batch_scalar_from_3d)
112 Shape params_shape{2, 2, 2};
113 Shape indices_shape{2, 2, 3};
114 Shape out_shape{2, 2};
115 auto P = make_shared<op::Parameter>(element::f32, params_shape);
116 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
117 auto G = make_shared<op::GatherND>(P, I);
118 ASSERT_EQ(G->get_element_type(), element::f32);
119 ASSERT_EQ(G->get_shape(), out_shape);
122 TEST(type_prop, gather_nd_batch_1d_from_3d)
124 Shape params_shape{2, 2, 2};
125 Shape indices_shape{2, 2, 2};
126 Shape out_shape{2, 2, 2};
127 auto P = make_shared<op::Parameter>(element::f32, params_shape);
128 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
129 auto G = make_shared<op::GatherND>(P, I);
130 ASSERT_EQ(G->get_element_type(), element::f32);
131 ASSERT_EQ(G->get_shape(), out_shape);
134 TEST(type_prop, gather_nd_batch_2d_from_3d)
136 Shape params_shape{2, 2, 2};
137 Shape indices_shape{2, 1, 1};
138 Shape out_shape{2, 1, 2, 2};
139 auto P = make_shared<op::Parameter>(element::f32, params_shape);
140 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
141 auto G = make_shared<op::GatherND>(P, I);
142 ASSERT_EQ(G->get_element_type(), element::f32);
143 ASSERT_EQ(G->get_shape(), out_shape);
146 TEST(type_prop, gather_nd_fail_params_rank)
148 Shape params_shape{};
149 Shape indices_shape{2, 1, 1};
150 Shape out_shape{2, 1, 2, 2};
151 auto P = make_shared<op::Parameter>(element::f32, params_shape);
152 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
155 auto G = make_shared<op::GatherND>(P, I);
156 // Should have thrown, so fail if it didn't
157 FAIL() << "Incorrect params rank";
159 catch (const NodeValidationFailure& error)
161 EXPECT_HAS_SUBSTRING(error.what(), std::string("params rank is expected to be at least 1"));
165 FAIL() << "Deduced type check failed for unexpected reason";
169 TEST(type_prop, gather_nd_fail_indices_rank)
171 Shape params_shape{2, 2, 2};
172 Shape indices_shape{};
173 Shape out_shape{2, 1, 2, 2};
174 auto P = make_shared<op::Parameter>(element::f32, params_shape);
175 auto I = make_shared<op::Parameter>(element::i32, indices_shape);
178 auto G = make_shared<op::GatherND>(P, I);
179 // Should have thrown, so fail if it didn't
180 FAIL() << "Incorrect indices rank";
182 catch (const NodeValidationFailure& error)
184 EXPECT_HAS_SUBSTRING(error.what(),
185 std::string("indices rank is expected to be at least 1"));
189 FAIL() << "Deduced type check failed for unexpected reason";
193 TEST(type_prop, gather_nd_fail_indices_element_type)
195 Shape params_shape{2, 2, 2};
196 Shape indices_shape{2, 1, 1};
197 Shape out_shape{2, 1, 2, 2};
198 auto P = make_shared<op::Parameter>(element::f32, params_shape);
199 auto I = make_shared<op::Parameter>(element::i16, indices_shape);
202 auto G = make_shared<op::GatherND>(P, I);
203 // Should have thrown, so fail if it didn't
204 FAIL() << "Incorrect indices element type";
206 catch (const NodeValidationFailure& error)
208 EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices element type must be i64 or i32"));
212 FAIL() << "Deduced type check failed for unexpected reason";