1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
20 #include "gtest/gtest.h"
22 #include "ngraph/ngraph.hpp"
23 #include "util/ndarray.hpp"
24 #include "util/test_tools.hpp"
26 NGRAPH_SUPPRESS_DEPRECATED_START
29 using namespace ngraph;
31 template <typename OP>
35 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
36 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
38 auto node = make_shared<OP>(arg0);
39 auto new_node = node->copy_with_new_inputs(new_args);
41 return (nullptr != new_node) && (new_args == new_node->input_values());
44 template <typename OP>
48 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
49 auto arg1 = make_shared<op::Parameter>(element::f32, shape);
50 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
51 make_shared<op::Parameter>(element::f32, shape)};
53 auto node = make_shared<OP>(arg0, arg1);
54 auto new_node = node->copy_with_new_inputs(new_args);
56 return (nullptr != new_node) && (new_args == new_node->input_values());
61 ASSERT_TRUE(check_unary<op::Abs>());
66 ASSERT_TRUE(check_unary<op::Acos>());
71 ASSERT_TRUE(check_binary<op::Add>());
76 ASSERT_TRUE(check_unary<op::Asin>());
81 ASSERT_TRUE(check_unary<op::Atan>());
87 auto arg0 = make_shared<op::Parameter>(element::f32, shape1);
88 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape1)};
93 auto node = make_shared<op::Broadcast>(arg0, shape, axes);
94 auto new_node = node->copy_with_new_inputs(new_args);
95 auto node_cast = as_type_ptr<op::Broadcast>(new_node);
96 ASSERT_NE(node_cast, nullptr);
98 ASSERT_TRUE(nullptr != new_node);
99 ASSERT_TRUE(new_args == new_node->input_values());
100 ASSERT_TRUE(shape == node_cast->get_broadcast_shape());
101 ASSERT_TRUE(axes == node_cast->get_broadcast_axes());
106 ASSERT_TRUE(check_unary<op::Ceiling>());
112 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
113 auto arg1 = make_shared<op::Parameter>(element::f32, shape);
114 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
115 make_shared<op::Parameter>(element::f32, shape)};
117 auto node = make_shared<op::Concat>(NodeVector{arg0, arg1}, axis);
118 auto new_node = node->clone_with_new_inputs(new_args);
119 auto node_cast = as_type_ptr<op::Concat>(new_node);
120 ASSERT_NE(node_cast, nullptr);
122 ASSERT_TRUE(nullptr != new_node);
123 ASSERT_TRUE(new_args == new_node->input_values());
124 ASSERT_TRUE(node_cast->get_concatenation_axis() == axis);
130 vector<float> c{2.4f};
131 auto& et = element::f32;
132 auto node = op::Constant::create(et, shape, c);
133 auto new_node = node->clone_with_new_inputs(OutputVector{});
134 auto node_cast = as_type_ptr<op::Constant>(new_node);
135 ASSERT_NE(node_cast, nullptr);
136 ASSERT_TRUE(nullptr != new_node);
137 ASSERT_TRUE(OutputVector{} == new_node->input_values());
138 ASSERT_TRUE(node_cast->get_vector<float>() == c);
139 ASSERT_TRUE(node_cast->get_shape() == shape);
140 ASSERT_TRUE(node_cast->get_element_type() == et);
146 auto& et = element::f64;
147 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
148 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape)};
150 auto node = make_shared<op::Convert>(arg0, et);
151 auto new_node = node->clone_with_new_inputs(new_args);
152 auto node_cast = as_type_ptr<op::Convert>(new_node);
153 ASSERT_NE(node_cast, nullptr);
155 ASSERT_TRUE(nullptr != new_node);
156 ASSERT_TRUE(new_args == new_node->input_values());
157 ASSERT_TRUE(et == node_cast->get_convert_element_type());
162 ASSERT_TRUE(check_unary<op::Cos>());
167 ASSERT_TRUE(check_unary<op::Cosh>());
172 ASSERT_TRUE(check_binary<op::Divide>());
177 ASSERT_TRUE(check_binary<op::Dot>());
182 ASSERT_TRUE(check_binary<op::Equal>());
187 ASSERT_TRUE(check_unary<op::Exp>());
192 ASSERT_TRUE(check_unary<op::Floor>());
195 TEST(copy, greater_eq)
197 ASSERT_TRUE(check_binary<op::GreaterEq>());
202 ASSERT_TRUE(check_binary<op::Greater>());
207 ASSERT_TRUE(check_binary<op::LessEq>());
212 ASSERT_TRUE(check_binary<op::Less>());
217 ASSERT_TRUE(check_unary<op::Log>());
222 ASSERT_TRUE(check_binary<op::Maximum>());
227 ASSERT_TRUE(check_binary<op::Minimum>());
232 ASSERT_TRUE(check_binary<op::Multiply>());
237 ASSERT_TRUE(check_unary<op::Negative>());
240 TEST(copy, not_equal)
242 ASSERT_TRUE(check_binary<op::NotEqual>());
245 TEST(copy, parameter)
248 auto node = make_shared<op::Parameter>(element::f32, shape);
249 auto new_node = node->clone_with_new_inputs({});
250 auto node_cast = as_type_ptr<op::Parameter>(new_node);
251 ASSERT_NE(node_cast, nullptr);
253 ASSERT_TRUE(nullptr != new_node);
254 ASSERT_TRUE(new_node->input_values().size() == 0);
255 ASSERT_TRUE(node->has_same_type(new_node));
260 ASSERT_TRUE(check_binary<op::Power>());
265 Shape shape_in{2, 3, 4};
266 AxisVector axes{0, 1, 2};
267 Shape shape_out{6, 4};
269 auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
270 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
272 auto node = make_shared<op::Reshape>(arg0, axes, shape_out);
273 auto new_node = node->clone_with_new_inputs(new_args);
274 auto node_cast = as_type_ptr<op::Reshape>(new_node);
275 ASSERT_NE(node_cast, nullptr);
277 ASSERT_TRUE(nullptr != new_node);
278 ASSERT_TRUE(new_args == new_node->input_values());
279 ASSERT_TRUE(axes == node_cast->get_input_order());
280 ASSERT_TRUE(shape_out == node_cast->get_output_shape(0));
286 auto arg0 = make_shared<op::Parameter>(element::boolean, shape);
287 auto arg1 = make_shared<op::Parameter>(element::f32, shape);
288 auto arg2 = make_shared<op::Parameter>(element::f32, shape);
289 OutputVector new_args{make_shared<op::Parameter>(element::boolean, shape),
290 make_shared<op::Parameter>(element::f32, shape),
291 make_shared<op::Parameter>(element::f32, shape)};
293 auto node = make_shared<op::Select>(arg0, arg1, arg2);
294 auto new_node = node->clone_with_new_inputs(new_args);
295 auto node_cast = as_type_ptr<op::Select>(new_node);
296 ASSERT_NE(node_cast, nullptr);
298 ASSERT_TRUE(nullptr != new_node);
299 ASSERT_TRUE(new_args == new_node->input_values());
304 ASSERT_TRUE(check_unary<op::Sign>());
309 ASSERT_TRUE(check_unary<op::Sin>());
314 ASSERT_TRUE(check_unary<op::Sinh>());
319 Shape shape_in{2, 3, 4};
320 Coordinate lower{0, 0, 0};
321 Coordinate upper{2, 3, 4};
322 Strides strides{1, 1, 1};
324 auto arg0 = make_shared<op::Parameter>(element::f32, shape_in);
325 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape_in)};
327 auto node = make_shared<op::Slice>(arg0, lower, upper, strides);
328 auto new_node = node->clone_with_new_inputs(new_args);
329 auto node_cast = as_type_ptr<op::Slice>(new_node);
330 ASSERT_NE(node_cast, nullptr);
332 ASSERT_TRUE(nullptr != new_node);
333 ASSERT_TRUE(new_args == new_node->input_values());
334 ASSERT_TRUE(lower == node_cast->get_lower_bounds());
335 ASSERT_TRUE(upper == node_cast->get_upper_bounds());
336 ASSERT_TRUE(strides == node_cast->get_strides());
341 ASSERT_TRUE(check_binary<op::Subtract>());
348 auto arg0 = make_shared<op::Parameter>(element::f32, shape);
350 auto node = make_shared<op::Sum>(arg0, axes);
351 OutputVector new_args{make_shared<op::Parameter>(element::f32, shape),
352 node->input_value(1).get_node_shared_ptr()};
353 auto new_node = node->clone_with_new_inputs(new_args);
354 auto node_cast = as_type_ptr<op::Sum>(new_node);
355 ASSERT_NE(node_cast, nullptr);
357 ASSERT_TRUE(nullptr != new_node);
358 ASSERT_TRUE(new_args == new_node->input_values());
359 ASSERT_TRUE(axes == node_cast->get_reduction_axes());
364 ASSERT_TRUE(check_unary<op::Tan>());
369 ASSERT_TRUE(check_unary<op::Tanh>());