1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
20 #include "gtest/gtest.h"
22 #include "ngraph/ngraph.hpp"
23 #include "util/ndarray.hpp"
24 #include "util/test_tools.hpp"
27 using namespace ngraph;
29 template <typename OP>
33 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
34 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
36 auto node = make_shared<OP>(arg0);
37 auto new_node = node->copy_with_new_inputs(new_args);
39 return (nullptr != new_node) && (new_args == new_node->input_values());
42 template <typename OP>
46 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
47 auto arg1 = make_shared<op::Parameter>(element::f32, shape);
48 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
49 make_shared<op::Parameter>(element::f32, shape)};
51 auto node = make_shared<OP>(arg0, arg1);
52 auto new_node = node->copy_with_new_inputs(new_args);
54 return (nullptr != new_node) && (new_args == new_node->input_values());
59 ASSERT_TRUE(check_unary<op::Abs>());
64 ASSERT_TRUE(check_unary<op::Acos>());
69 ASSERT_TRUE(check_binary<op::Add>());
74 ASSERT_TRUE(check_unary<op::Asin>());
79 ASSERT_TRUE(check_unary<op::Atan>());
84 ASSERT_TRUE(check_binary<op::Atan2>());
90 auto arg0 = make_shared<op::Parameter>(element::f32, shape1);
91 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape1)};
96 auto node = make_shared<op::Broadcast>(arg0, shape, axes);
97 auto new_node = node->copy_with_new_inputs(new_args);
98 auto node_cast = as_type_ptr<op::Broadcast>(new_node);
99 ASSERT_NE(node_cast, nullptr);
101 ASSERT_TRUE(nullptr != new_node);
102 ASSERT_TRUE(new_args == new_node->input_values());
103 ASSERT_TRUE(shape == node_cast->get_broadcast_shape());
104 ASSERT_TRUE(axes == node_cast->get_broadcast_axes());
109 ASSERT_TRUE(check_unary<op::Ceiling>());
115 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
116 auto arg1 = make_shared<op::Parameter>(element::f32, shape);
117 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
118 make_shared<op::Parameter>(element::f32, shape)};
120 auto node = make_shared<op::Concat>(NodeVector{arg0, arg1}, axis);
121 auto new_node = node->clone_with_new_inputs(new_args);
122 auto node_cast = as_type_ptr<op::Concat>(new_node);
123 ASSERT_NE(node_cast, nullptr);
125 ASSERT_TRUE(nullptr != new_node);
126 ASSERT_TRUE(new_args == new_node->input_values());
127 ASSERT_TRUE(node_cast->get_concatenation_axis() == axis);
133 vector<float> c{2.4f};
134 auto& et = element::f32;
135 auto node = op::Constant::create(et, shape, c);
136 auto new_node = node->clone_with_new_inputs(OutputVector{});
137 auto node_cast = as_type_ptr<op::Constant>(new_node);
138 ASSERT_NE(node_cast, nullptr);
139 ASSERT_TRUE(nullptr != new_node);
140 ASSERT_TRUE(NodeVector{} == new_node->get_arguments());
141 ASSERT_TRUE(node_cast->get_vector<float>() == c);
142 ASSERT_TRUE(node_cast->get_shape() == shape);
143 ASSERT_TRUE(node_cast->get_element_type() == et);
149 auto& et = element::f64;
150 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
151 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
153 auto node = make_shared<op::Convert>(arg0, et);
154 auto new_node = node->clone_with_new_inputs(new_args);
155 auto node_cast = as_type_ptr<op::Convert>(new_node);
156 ASSERT_NE(node_cast, nullptr);
158 ASSERT_TRUE(nullptr != new_node);
159 ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
160 ASSERT_TRUE(et == node_cast->get_convert_element_type());
165 ASSERT_TRUE(check_unary<op::Cos>());
170 ASSERT_TRUE(check_unary<op::Cosh>());
175 ASSERT_TRUE(check_binary<op::Divide>());
180 ASSERT_TRUE(check_binary<op::Dot>());
185 ASSERT_TRUE(check_binary<op::Equal>());
190 ASSERT_TRUE(check_unary<op::Exp>());
195 ASSERT_TRUE(check_unary<op::Floor>());
198 TEST(copy, greater_eq)
200 ASSERT_TRUE(check_binary<op::GreaterEq>());
205 ASSERT_TRUE(check_binary<op::Greater>());
210 ASSERT_TRUE(check_binary<op::LessEq>());
215 ASSERT_TRUE(check_binary<op::Less>());
220 ASSERT_TRUE(check_unary<op::Log>());
225 ASSERT_TRUE(check_binary<op::Maximum>());
230 ASSERT_TRUE(check_binary<op::Minimum>());
235 ASSERT_TRUE(check_binary<op::Multiply>());
240 ASSERT_TRUE(check_unary<op::Negative>());
243 TEST(copy, not_equal)
245 ASSERT_TRUE(check_binary<op::NotEqual>());
248 TEST(copy, parameter)
251 auto node = make_shared<op::Parameter>(element::f32, shape);
252 auto new_node = node->clone_with_new_inputs({});
253 auto node_cast = as_type_ptr<op::Parameter>(new_node);
254 ASSERT_NE(node_cast, nullptr);
256 ASSERT_TRUE(nullptr != new_node);
257 ASSERT_TRUE(new_node->get_arguments().size() == 0);
258 ASSERT_TRUE(node->has_same_type(new_node));
263 ASSERT_TRUE(check_binary<op::Power>());
268 Shape shape_in{2, 3, 4};
269 AxisVector axes{0, 1, 2};
270 Shape shape_out{6, 4};
272 auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
273 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
275 auto node = make_shared<op::Reshape>(arg0, axes, shape_out);
276 auto new_node = node->clone_with_new_inputs(new_args);
277 auto node_cast = as_type_ptr<op::Reshape>(new_node);
278 ASSERT_NE(node_cast, nullptr);
280 ASSERT_TRUE(nullptr != new_node);
281 ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
282 ASSERT_TRUE(axes == node_cast->get_input_order());
283 ASSERT_TRUE(shape_out == node_cast->get_output_shape(0));
289 auto arg0 = make_shared<op::Parameter>(element::boolean, shape);
290 auto arg1 = make_shared<op::Parameter>(element::f32, shape);
291 auto arg2 = make_shared<op::Parameter>(element::f32, shape);
292 OutputVector new_args{make_shared<op::Parameter>(element::boolean, shape),
293 make_shared<op::Parameter>(element::f32, shape),
294 make_shared<op::Parameter>(element::f32, shape)};
296 auto node = make_shared<op::Select>(arg0, arg1, arg2);
297 auto new_node = node->clone_with_new_inputs(new_args);
298 auto node_cast = as_type_ptr<op::Select>(new_node);
299 ASSERT_NE(node_cast, nullptr);
301 ASSERT_TRUE(nullptr != new_node);
302 ASSERT_TRUE(new_args == new_node->input_values());
307 ASSERT_TRUE(check_unary<op::Sign>());
312 ASSERT_TRUE(check_unary<op::Sin>());
317 ASSERT_TRUE(check_unary<op::Sinh>());
322 Shape shape_in{2, 3, 4};
323 Coordinate lower{0, 0, 0};
324 Coordinate upper{2, 3, 4};
325 Strides strides{1, 1, 1};
327 auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
328 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
330 auto node = make_shared<op::Slice>(arg0, lower, upper, strides);
331 auto new_node = node->clone_with_new_inputs(new_args);
332 auto node_cast = as_type_ptr<op::Slice>(new_node);
333 ASSERT_NE(node_cast, nullptr);
335 ASSERT_TRUE(nullptr != new_node);
336 ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
337 ASSERT_TRUE(lower == node_cast->get_lower_bounds());
338 ASSERT_TRUE(upper == node_cast->get_upper_bounds());
339 ASSERT_TRUE(strides == node_cast->get_strides());
344 ASSERT_TRUE(check_binary<op::Subtract>());
351 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
353 auto node = make_shared<op::Sum>(arg0, axes);
354 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape), node->get_argument(1)};
355 auto new_node = node->clone_with_new_inputs(new_args);
356 auto node_cast = as_type_ptr<op::Sum>(new_node);
357 ASSERT_NE(node_cast, nullptr);
359 ASSERT_TRUE(nullptr != new_node);
360 ASSERT_TRUE(new_args == as_output_vector(new_node->get_arguments()));
361 ASSERT_TRUE(axes == node_cast->get_reduction_axes());
366 ASSERT_TRUE(check_unary<op::Tan>());
371 ASSERT_TRUE(check_unary<op::Tanh>());