1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
17 #include "gtest/gtest.h"
18 #include "ngraph/ngraph.hpp"
19 #include "util/type_prop.hpp"
21 NGRAPH_SUPPRESS_DEPRECATED_START
24 using namespace ngraph;
26 TEST(type_prop, reshape_deduce_s2v)
28 auto param = make_shared<op::Parameter>(element::f32, Shape{});
29 auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1});
30 ASSERT_EQ(r->get_element_type(), element::f32);
31 ASSERT_EQ(r->get_shape(), (Shape{1}));
34 TEST(type_prop, reshape_deduce_s2m)
36 auto param = make_shared<op::Parameter>(element::f32, Shape{});
37 auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1, 1});
38 ASSERT_EQ(r->get_element_type(), element::f32);
39 ASSERT_EQ(r->get_shape(), (Shape{1, 1}));
42 TEST(type_prop, reshape_deduce_s2t)
44 auto param = make_shared<op::Parameter>(element::f32, Shape{});
45 auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1, 1, 1});
46 ASSERT_EQ(r->get_element_type(), element::f32);
47 ASSERT_EQ(r->get_shape(), (Shape{1, 1, 1}));
50 TEST(type_prop, reshape_deduce_v2s)
52 auto param = make_shared<op::Parameter>(element::f32, Shape{1});
53 auto r = make_shared<op::Reshape>(param, AxisVector{0}, Shape{});
54 ASSERT_EQ(r->get_element_type(), element::f32);
55 ASSERT_EQ(r->get_shape(), (Shape{}));
58 TEST(type_prop, reshape_deduce_m2s)
60 auto param = make_shared<op::Parameter>(element::f32, Shape{1, 1});
61 auto r = make_shared<op::Reshape>(param, AxisVector{0, 1}, Shape{});
62 ASSERT_EQ(r->get_element_type(), element::f32);
63 ASSERT_EQ(r->get_shape(), (Shape{}));
66 TEST(type_prop, reshape_deduce_t2s)
68 auto param = make_shared<op::Parameter>(element::f32, Shape{1, 1, 1});
69 auto r = make_shared<op::Reshape>(param, AxisVector{0, 1, 2}, Shape{});
70 ASSERT_EQ(r->get_element_type(), element::f32);
71 ASSERT_EQ(r->get_shape(), (Shape{}));
74 TEST(type_prop, reshape_deduce_m2v_01)
76 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4});
77 auto r = make_shared<op::Reshape>(param, AxisVector{0, 1}, Shape{12});
78 ASSERT_EQ(r->get_element_type(), element::f32);
79 ASSERT_EQ(r->get_shape(), (Shape{12}));
82 TEST(type_prop, reshape_deduce_m2v_10)
84 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4});
85 auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{12});
86 ASSERT_EQ(r->get_element_type(), element::f32);
87 ASSERT_EQ(r->get_shape(), (Shape{12}));
90 TEST(type_prop, reshape_deduce_t2v_012)
92 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
93 auto r = make_shared<op::Reshape>(param, AxisVector{0, 1, 2}, Shape{60});
94 ASSERT_EQ(r->get_element_type(), element::f32);
95 ASSERT_EQ(r->get_shape(), (Shape{60}));
98 TEST(type_prop, reshape_deduce_t2v_120)
100 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
101 auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
102 ASSERT_EQ(r->get_element_type(), element::f32);
103 ASSERT_EQ(r->get_shape(), (Shape{60}));
106 TEST(type_prop, reshape_deduce_not_enough_axes)
108 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
111 auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{60});
112 // Should have thrown, so fail if it didn't
113 FAIL() << "Not enough axes not detected";
115 catch (const NodeValidationFailure& error)
117 EXPECT_HAS_SUBSTRING(
119 std::string("Input axis order is not a permutation of argument's axis indices"));
123 FAIL() << "Deduced type check failed for unexpected reason";
127 TEST(type_prop, reshape_deduce_too_many_axes)
129 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
132 auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0, 3}, Shape{60});
133 // Should have thrown, so fail if it didn't
134 FAIL() << "Too many axes not detected";
136 catch (const NodeValidationFailure& error)
138 EXPECT_HAS_SUBSTRING(
140 std::string("Input axis order is not a permutation of argument's axis indices"));
144 FAIL() << "Deduced type check failed for unexpected reason";
148 TEST(type_prop, reshape_deduce_duplicate_axes)
150 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
153 auto r = make_shared<op::Reshape>(param, AxisVector{1, 1, 0}, Shape{60});
154 // Should have thrown, so fail if it didn't
155 FAIL() << "Too many axes not detected";
157 catch (const NodeValidationFailure& error)
159 EXPECT_HAS_SUBSTRING(
161 std::string("Input axis order is not a permutation of argument's axis indices"));
165 FAIL() << "Deduced type check failed for unexpected reason";
169 TEST(type_prop, reshape_deduce_wrong_output_shape)
171 auto param = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
174 auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{3, 3, 3});
175 // Should have thrown, so fail if it didn't
176 FAIL() << "Too many axes not detected";
178 catch (const NodeValidationFailure& error)
180 EXPECT_HAS_SUBSTRING(error.what(),
181 std::string("Product of output shape dimensions does not match "
182 "product of argument shape dimensions"));
186 FAIL() << "Deduced type check failed for unexpected reason";
191 // Input shape rank dynamic, so we should set the desired output shape if the axis vector is not
192 // known invalid (invalid means it's not a permutation of {0,...,n-1} for any n).
194 TEST(type_prop, reshape_partial_rank_dynamic_axisvector_ok)
196 auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
197 auto r = make_shared<op::Reshape>(param, AxisVector{2, 1, 0, 3}, Shape{3, 1, 8, 2});
198 ASSERT_EQ(r->get_element_type(), element::f32);
199 ASSERT_TRUE(r->get_output_partial_shape(0).is_static());
200 ASSERT_EQ(r->get_shape(), (Shape{3, 1, 8, 2}));
203 TEST(type_prop, reshape_partial_rank_dynamic_axisvector_not_ok)
205 auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
208 auto r = make_shared<op::Reshape>(param, AxisVector{2, 1, 0, 4}, Shape{3, 1, 8, 2});
209 // Should have thrown, so fail if it didn't
210 FAIL() << "Did not detect malformed AxisVector (input shape rank dynamic)";
212 catch (const NodeValidationFailure& error)
214 EXPECT_HAS_SUBSTRING(
216 std::string("Input axis order is not a permutation of argument's axis indices"));
220 FAIL() << "Deduced type check failed for unexpected reason";
225 // Input shape rank static but input shape is dynamic, so should set desired output shape if the
226 // axis vector is consistent with the static rank.
228 TEST(type_prop, reshape_partial_rank_static_dynamic_axisvector_ok)
231 PartialShape{Dimension::dynamic(), 6, Dimension::dynamic(), Dimension::dynamic()};
232 auto param = make_shared<op::Parameter>(element::f32, param_shape);
233 auto r = make_shared<op::Reshape>(param, AxisVector{2, 1, 0, 3}, Shape{3, 1, 8, 2});
234 ASSERT_EQ(r->get_element_type(), element::f32);
235 ASSERT_TRUE(r->get_output_partial_shape(0).is_static());
236 ASSERT_EQ(r->get_shape(), (Shape{3, 1, 8, 2}));
239 TEST(type_prop, reshape_partial_rank_static_dynamic_axisvector_not_ok)
242 PartialShape{Dimension::dynamic(), 6, Dimension::dynamic(), Dimension::dynamic()};
243 auto param = make_shared<op::Parameter>(element::f32, param_shape);
246 auto r = make_shared<op::Reshape>(param, AxisVector{2, 1, 0}, Shape{3, 1, 8, 2});
247 // Should have thrown, so fail if it didn't
248 FAIL() << "Did not detect AxisVector inconsistent with rank (rank-static dynamic shape)";
250 catch (const NodeValidationFailure& error)
252 EXPECT_HAS_SUBSTRING(
254 std::string("Input axis order is not a permutation of argument's axis indices"));
258 FAIL() << "Deduced type check failed for unexpected reason";
263 // Input shape rank static but input shape is dynamic, _but_ one of its static dimensions is zero,
264 // so should set desired output shape only if it also has zero elements.
266 TEST(type_prop, reshape_partial_rank_static_dynamic_but_zero_ok)
269 PartialShape{Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()};
270 auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
271 auto r = make_shared<op::Reshape>(param, AxisVector{2, 1, 0, 3}, Shape{3, 1, 0, 2});
272 ASSERT_EQ(r->get_element_type(), element::f32);
273 ASSERT_TRUE(r->get_output_partial_shape(0).is_static());
274 ASSERT_EQ(r->get_shape(), (Shape{3, 1, 0, 2}));
277 TEST(type_prop, reshape_partial_rank_static_dynamic_but_zero_not_ok)
280 PartialShape{Dimension::dynamic(), 0, Dimension::dynamic(), Dimension::dynamic()};
281 auto param = make_shared<op::Parameter>(element::f32, param_shape);
284 auto r = make_shared<op::Reshape>(param, AxisVector{2, 1, 0}, Shape{3, 1, 8, 2});
285 // Should have thrown, so fail if it didn't
286 FAIL() << "Did not detect inconsistent output shape with static-zero-element rank-dynamic"
287 " static input shape";
289 catch (const NodeValidationFailure& error)
291 EXPECT_HAS_SUBSTRING(
293 std::string("Input axis order is not a permutation of argument's axis indices"));
297 FAIL() << "Deduced type check failed for unexpected reason";